package org.nd4j.linalg.compression;

import java.util.Iterator;
import java.util.Map;
import java.util.ServiceLoader;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import lombok.NonNull;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/nd4j/linalg/compression/BasicNDArrayCompressor.class */
public class BasicNDArrayCompressor {
    private static final Logger log = LoggerFactory.getLogger(BasicNDArrayCompressor.class);
    private static final BasicNDArrayCompressor INSTANCE = new BasicNDArrayCompressor();
    protected Map<String, NDArrayCompressor> codecs;
    protected String defaultCompression = "FLOAT16";

    private BasicNDArrayCompressor() {
        loadCompressors();
    }

    protected void loadCompressors() {
        this.codecs = new ConcurrentHashMap();
        Iterator it = ServiceLoader.load(NDArrayCompressor.class).iterator();
        while (it.hasNext()) {
            NDArrayCompressor nDArrayCompressor = (NDArrayCompressor) it.next();
            this.codecs.put(nDArrayCompressor.getDescriptor().toUpperCase(), nDArrayCompressor);
        }
        if (this.codecs.isEmpty()) {
            log.error("Error loading ND4J Compressors via service loader: No compressors were found. This usually occurs when running ND4J UI from an uber-jar, which was built incorrectly (without services resource files being included)");
            throw new RuntimeException("Error loading ND4J Compressors via service loader: No compressors were found. This usually occurs when running ND4J UI from an uber-jar, which was built incorrectly (without services resource files being included)");
        }
    }

    public Set<String> getAvailableCompressors() {
        return this.codecs.keySet();
    }

    public void printAvailableCompressors() {
        StringBuilder sb = new StringBuilder();
        sb.append("Available compressors: ");
        Iterator<String> it = this.codecs.keySet().iterator();
        while (it.hasNext()) {
            sb.append("[").append(it.next()).append("] ");
        }
        System.out.println(sb.toString());
    }

    public static BasicNDArrayCompressor getInstance() {
        return INSTANCE;
    }

    public BasicNDArrayCompressor setDefaultCompression(@NonNull String str) {
        if (str == null) {
            throw new NullPointerException("algorithm is marked @NonNull but is null");
        }
        String upperCase = str.toUpperCase();
        synchronized (this) {
            this.defaultCompression = upperCase;
        }
        return this;
    }

    public String getDefaultCompression() {
        String str;
        synchronized (this) {
            str = this.defaultCompression;
        }
        return str;
    }

    public DataBuffer compress(DataBuffer dataBuffer) {
        return compress(dataBuffer, getDefaultCompression());
    }

    public DataBuffer compress(DataBuffer dataBuffer, String str) {
        String upperCase = str.toUpperCase();
        if (this.codecs.containsKey(upperCase)) {
            return this.codecs.get(upperCase).compress(dataBuffer);
        }
        throw new RuntimeException("Non-existent compression algorithm requested: [" + upperCase + "]");
    }

    public INDArray compress(INDArray iNDArray) {
        Nd4j.getExecutioner().commit();
        return compress(iNDArray, getDefaultCompression());
    }

    public void compressi(INDArray iNDArray) {
        compressi(iNDArray, getDefaultCompression());
    }

    public INDArray compress(INDArray iNDArray, String str) {
        String upperCase = str.toUpperCase();
        if (this.codecs.containsKey(upperCase)) {
            return this.codecs.get(upperCase).compress(iNDArray);
        }
        throw new RuntimeException("Non-existent compression algorithm requested: [" + upperCase + "]");
    }

    public void compressi(INDArray iNDArray, String str) {
        String upperCase = str.toUpperCase();
        if (!this.codecs.containsKey(upperCase)) {
            throw new RuntimeException("Non-existent compression algorithm requested: [" + upperCase + "]");
        }
        this.codecs.get(upperCase).compressi(iNDArray);
    }

    public DataBuffer decompress(DataBuffer dataBuffer, DataType dataType) {
        if (dataBuffer.dataType() != DataType.COMPRESSED) {
            throw new IllegalStateException("You can't decompress DataBuffer with dataType of: " + dataBuffer.dataType());
        }
        CompressionDescriptor compressionDescriptor = ((CompressedDataBuffer) dataBuffer).getCompressionDescriptor();
        if (this.codecs.containsKey(compressionDescriptor.getCompressionAlgorithm())) {
            return this.codecs.get(compressionDescriptor.getCompressionAlgorithm()).decompress(dataBuffer, dataType);
        }
        throw new RuntimeException("Non-existent compression algorithm requested: [" + compressionDescriptor.getCompressionAlgorithm() + "]");
    }

    public NDArrayCompressor getCompressor(@NonNull String str) {
        if (str == null) {
            throw new NullPointerException("name is marked @NonNull but is null");
        }
        return this.codecs.get(str);
    }

    public INDArray decompress(INDArray iNDArray) {
        if (iNDArray.data().dataType() != DataType.COMPRESSED) {
            return iNDArray;
        }
        CompressionDescriptor compressionDescriptor = iNDArray.data().getCompressionDescriptor();
        if (this.codecs.containsKey(compressionDescriptor.getCompressionAlgorithm())) {
            return this.codecs.get(compressionDescriptor.getCompressionAlgorithm()).decompress(iNDArray);
        }
        throw new RuntimeException("Non-existent compression algorithm requested: [" + compressionDescriptor.getCompressionAlgorithm() + "]");
    }

    public void decompressi(INDArray iNDArray) {
        if (iNDArray.data().dataType() != DataType.COMPRESSED) {
            return;
        }
        CompressionDescriptor compressionDescriptor = iNDArray.data().getCompressionDescriptor();
        if (!this.codecs.containsKey(compressionDescriptor.getCompressionAlgorithm())) {
            throw new RuntimeException("Non-existent compression algorithm requested: [" + compressionDescriptor.getCompressionAlgorithm() + "]");
        }
        this.codecs.get(compressionDescriptor.getCompressionAlgorithm()).decompressi(iNDArray);
    }

    public void autoDecompress(INDArray... iNDArrayArr) {
        for (INDArray iNDArray : iNDArrayArr) {
            autoDecompress(iNDArray);
        }
    }

    public void autoDecompress(INDArray iNDArray) {
        if (iNDArray.isCompressed()) {
            decompressi(iNDArray);
        }
    }

    public INDArray compress(float[] fArr) {
        return this.codecs.get(this.defaultCompression).compress(fArr);
    }

    public INDArray compress(double[] dArr) {
        return this.codecs.get(this.defaultCompression).compress(dArr);
    }
}
