package org.nd4j.nativeblas;

import java.io.File;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.charset.Charset;
import java.util.HashMap;
import java.util.Map;
import org.bytedeco.javacpp.BytePointer;
import org.bytedeco.javacpp.DoublePointer;
import org.bytedeco.javacpp.FloatPointer;
import org.bytedeco.javacpp.LongPointer;
import org.bytedeco.javacpp.Pointer;
import org.bytedeco.javacpp.indexer.ByteIndexer;
import org.bytedeco.javacpp.indexer.DoubleIndexer;
import org.bytedeco.javacpp.indexer.FloatIndexer;
import org.bytedeco.javacpp.indexer.LongIndexer;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.concurrency.AffinityManager;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.performance.PerformanceTracker;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.factory.BaseNDArrayFactory;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.memory.MemcpyDirection;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/nd4j/nativeblas/BaseNativeNDArrayFactory.class */
public abstract class BaseNativeNDArrayFactory extends BaseNDArrayFactory {
    private static final Logger log = LoggerFactory.getLogger(BaseNativeNDArrayFactory.class);
    protected NativeOps nativeOps;

    public BaseNativeNDArrayFactory(DataType dataType, Character ch) {
        super(dataType, ch);
        this.nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps();
    }

    public BaseNativeNDArrayFactory(DataType dataType, char c) {
        super(dataType, c);
        this.nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps();
    }

    public BaseNativeNDArrayFactory() {
        this.nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps();
    }

    public Pointer convertToNumpy(INDArray iNDArray) {
        LongPointer longPointer = new LongPointer(1L);
        Pointer numpyHeaderForNd4j = NativeOpsHolder.getInstance().getDeviceNativeOps().numpyHeaderForNd4j(iNDArray.data().pointer(), iNDArray.shapeInfoDataBuffer().pointer(), iNDArray.data().getElementSize(), longPointer);
        numpyHeaderForNd4j.capacity(longPointer.get() - 1);
        numpyHeaderForNd4j.position(0L);
        BytePointer bytePointer = new BytePointer((int) (r0 + (iNDArray.data().getElementSize() * iNDArray.data().length())));
        BytePointer bytePointer2 = new BytePointer(numpyHeaderForNd4j);
        ByteIndexer.create(bytePointer2);
        bytePointer.position(0);
        Pointer.memcpy(bytePointer, bytePointer2, bytePointer2.capacity());
        bytePointer.position((int) (0 + bytePointer2.capacity()));
        Nd4j.getAffinityManager().ensureLocation(iNDArray, AffinityManager.Location.HOST);
        Pointer.memcpy(bytePointer, iNDArray.data().pointer(), iNDArray.data().getElementSize() * iNDArray.data().length());
        bytePointer.position(0L);
        return bytePointer;
    }

    public INDArray createFromNpyPointer(Pointer pointer) {
        Pointer dataPointForNumpy = this.nativeOps.dataPointForNumpy(pointer);
        int elementSizeForNpyArray = this.nativeOps.elementSizeForNpyArray(pointer);
        DataBuffer dataBuffer = null;
        Pointer shapeBufferForNumpy = this.nativeOps.shapeBufferForNumpy(pointer);
        int lengthForShapeBufferPointer = this.nativeOps.lengthForShapeBufferPointer(shapeBufferForNumpy);
        shapeBufferForNumpy.capacity(8 * lengthForShapeBufferPointer);
        shapeBufferForNumpy.limit(8 * lengthForShapeBufferPointer);
        shapeBufferForNumpy.position(0L);
        LongPointer longPointer = new LongPointer(shapeBufferForNumpy);
        LongPointer longPointer2 = new LongPointer(lengthForShapeBufferPointer);
        long helperStartTransaction = PerformanceTracker.getInstance().helperStartTransaction();
        Pointer.memcpy(longPointer2, longPointer, shapeBufferForNumpy.limit());
        PerformanceTracker.getInstance().helperRegisterTransaction(0, helperStartTransaction, shapeBufferForNumpy.limit(), MemcpyDirection.HOST_TO_HOST);
        DataBuffer createBuffer = Nd4j.createBuffer(longPointer2, DataType.LONG, lengthForShapeBufferPointer, LongIndexer.create(longPointer2));
        dataPointForNumpy.position(0L);
        dataPointForNumpy.limit(elementSizeForNpyArray * Shape.length(createBuffer));
        dataPointForNumpy.capacity(elementSizeForNpyArray * Shape.length(createBuffer));
        if (elementSizeForNpyArray == 4) {
            FloatPointer floatPointer = new FloatPointer(dataPointForNumpy.limit() / elementSizeForNpyArray);
            long helperStartTransaction2 = PerformanceTracker.getInstance().helperStartTransaction();
            Pointer.memcpy(floatPointer, dataPointForNumpy, dataPointForNumpy.limit());
            PerformanceTracker.getInstance().helperRegisterTransaction(0, helperStartTransaction2, dataPointForNumpy.limit(), MemcpyDirection.HOST_TO_HOST);
            dataBuffer = Nd4j.createBuffer(floatPointer, DataType.FLOAT, Shape.length(createBuffer), FloatIndexer.create(floatPointer));
        } else if (elementSizeForNpyArray == 8) {
            DoublePointer doublePointer = new DoublePointer(dataPointForNumpy.limit() / elementSizeForNpyArray);
            long helperStartTransaction3 = PerformanceTracker.getInstance().helperStartTransaction();
            Pointer.memcpy(doublePointer, dataPointForNumpy, dataPointForNumpy.limit());
            PerformanceTracker.getInstance().helperRegisterTransaction(0, helperStartTransaction3, dataPointForNumpy.limit(), MemcpyDirection.HOST_TO_HOST);
            dataBuffer = Nd4j.createBuffer(doublePointer, DataType.DOUBLE, Shape.length(createBuffer), DoubleIndexer.create(doublePointer));
        }
        INDArray create = Nd4j.create(dataBuffer, Shape.shape(createBuffer), Shape.strideArr(createBuffer), 0L, Shape.order(createBuffer));
        Nd4j.getAffinityManager().tagLocation(create, AffinityManager.Location.DEVICE);
        return create;
    }

    public INDArray createFromNpyHeaderPointer(Pointer pointer) {
        Pointer dataPointForNumpyHeader = this.nativeOps.dataPointForNumpyHeader(pointer);
        int elementSizeForNpyArrayHeader = this.nativeOps.elementSizeForNpyArrayHeader(pointer);
        DataBuffer dataBuffer = null;
        Pointer shapeBufferForNumpyHeader = this.nativeOps.shapeBufferForNumpyHeader(pointer);
        int lengthForShapeBufferPointer = this.nativeOps.lengthForShapeBufferPointer(shapeBufferForNumpyHeader);
        shapeBufferForNumpyHeader.capacity(8 * lengthForShapeBufferPointer);
        shapeBufferForNumpyHeader.limit(8 * lengthForShapeBufferPointer);
        shapeBufferForNumpyHeader.position(0L);
        LongPointer longPointer = new LongPointer(shapeBufferForNumpyHeader);
        LongPointer longPointer2 = new LongPointer(lengthForShapeBufferPointer);
        long helperStartTransaction = PerformanceTracker.getInstance().helperStartTransaction();
        Pointer.memcpy(longPointer2, longPointer, shapeBufferForNumpyHeader.limit());
        PerformanceTracker.getInstance().helperRegisterTransaction(0, helperStartTransaction, shapeBufferForNumpyHeader.limit(), MemcpyDirection.HOST_TO_HOST);
        DataBuffer createBuffer = Nd4j.createBuffer(longPointer2, DataType.LONG, lengthForShapeBufferPointer, LongIndexer.create(longPointer2));
        dataPointForNumpyHeader.position(0L);
        dataPointForNumpyHeader.limit(elementSizeForNpyArrayHeader * Shape.length(createBuffer));
        dataPointForNumpyHeader.capacity(elementSizeForNpyArrayHeader * Shape.length(createBuffer));
        if (elementSizeForNpyArrayHeader == 4) {
            FloatPointer floatPointer = new FloatPointer(dataPointForNumpyHeader.limit() / elementSizeForNpyArrayHeader);
            long helperStartTransaction2 = PerformanceTracker.getInstance().helperStartTransaction();
            Pointer.memcpy(floatPointer, dataPointForNumpyHeader, dataPointForNumpyHeader.limit());
            PerformanceTracker.getInstance().helperRegisterTransaction(0, helperStartTransaction2, dataPointForNumpyHeader.limit(), MemcpyDirection.HOST_TO_HOST);
            dataBuffer = Nd4j.createBuffer(floatPointer, DataType.FLOAT, Shape.length(createBuffer), FloatIndexer.create(floatPointer));
        } else if (elementSizeForNpyArrayHeader == 8) {
            DoublePointer doublePointer = new DoublePointer(dataPointForNumpyHeader.limit() / elementSizeForNpyArrayHeader);
            long helperStartTransaction3 = PerformanceTracker.getInstance().helperStartTransaction();
            Pointer.memcpy(doublePointer, dataPointForNumpyHeader, dataPointForNumpyHeader.limit());
            PerformanceTracker.getInstance().helperRegisterTransaction(0, helperStartTransaction3, dataPointForNumpyHeader.limit(), MemcpyDirection.HOST_TO_HOST);
            dataBuffer = Nd4j.createBuffer(doublePointer, DataType.DOUBLE, Shape.length(createBuffer), DoubleIndexer.create(doublePointer));
        }
        return Nd4j.create(dataBuffer, Shape.shape(createBuffer), Shape.strideArr(createBuffer), 0L, Shape.order(createBuffer));
    }

    public INDArray createFromNpyFile(File file) {
        byte[] bytes = file.getAbsolutePath().getBytes(Charset.forName("UTF-8"));
        ByteBuffer order = ByteBuffer.allocateDirect(bytes.length).order(ByteOrder.nativeOrder());
        order.put(bytes);
        order.rewind();
        order.position(0);
        Pointer numpyFromFile = this.nativeOps.numpyFromFile(new BytePointer(order));
        INDArray createFromNpyPointer = createFromNpyPointer(numpyFromFile);
        this.nativeOps.releaseNumpy(numpyFromFile);
        return createFromNpyPointer;
    }

    /* JADX WARN: Code restructure failed: missing block: B:59:0x022a, code lost:
    
        return r0;
     */
    /* JADX WARN: Multi-variable type inference failed */
    /*
        Code decompiled incorrectly, please refer to instructions dump.
        To view partially-correct add '--show-bad-code' argument
    */
    public java.util.Map<java.lang.String, org.nd4j.linalg.api.ndarray.INDArray> createFromNpzFile(java.io.File r7) throws java.lang.Exception {
        /*
            Method dump skipped, instructions count: 555
            To view this dump add '--comments-level debug' option
        */
        throw new UnsupportedOperationException("Method not decompiled: org.nd4j.nativeblas.BaseNativeNDArrayFactory.createFromNpzFile(java.io.File):java.util.Map");
    }

    public Map<String, INDArray> _createFromNpzFile(File file) throws Exception {
        INDArray create;
        byte[] bytes = file.getAbsolutePath().getBytes(Charset.forName("UTF-8"));
        ByteBuffer order = ByteBuffer.allocateDirect(bytes.length).order(ByteOrder.nativeOrder());
        order.put(bytes);
        order.rewind();
        order.position(0);
        Pointer mapFromNpzFile = this.nativeOps.mapFromNpzFile(new BytePointer(order));
        int numNpyArraysInMap = this.nativeOps.getNumNpyArraysInMap(mapFromNpzFile);
        HashMap hashMap = new HashMap();
        for (int i = 0; i < numNpyArraysInMap; i++) {
            String npyArrayNameFromMap = this.nativeOps.getNpyArrayNameFromMap(mapFromNpzFile, i);
            Pointer npyArrayFromMap = this.nativeOps.getNpyArrayFromMap(mapFromNpzFile, i);
            int npyArrayRank = this.nativeOps.getNpyArrayRank(npyArrayFromMap);
            long[] jArr = new long[npyArrayRank];
            LongPointer npyArrayShape = this.nativeOps.getNpyArrayShape(npyArrayFromMap);
            long j = 1;
            for (int i2 = 0; i2 < npyArrayRank; i2++) {
                jArr[i2] = npyArrayShape.get(i2);
                j *= jArr[i2];
            }
            int npyArrayElemSize = this.nativeOps.getNpyArrayElemSize(npyArrayFromMap) * 8;
            char npyArrayOrder = this.nativeOps.getNpyArrayOrder(npyArrayFromMap);
            Pointer dataPointForNumpyStruct = this.nativeOps.dataPointForNumpyStruct(npyArrayFromMap);
            dataPointForNumpyStruct.position(0L);
            long j2 = npyArrayElemSize * j;
            dataPointForNumpyStruct.limit(j2);
            dataPointForNumpyStruct.capacity(j2);
            if (npyArrayElemSize == 32) {
                FloatPointer floatPointer = new FloatPointer(dataPointForNumpyStruct.limit() / npyArrayElemSize);
                create = Nd4j.create(Nd4j.createBuffer(floatPointer, DataType.FLOAT, j, FloatIndexer.create(floatPointer)), jArr, Nd4j.getStrides(jArr, npyArrayOrder), 0L, npyArrayOrder, DataType.FLOAT);
            } else {
                if (npyArrayElemSize != 64) {
                    throw new Exception("Unsupported data type: " + String.valueOf(npyArrayElemSize));
                }
                DoublePointer doublePointer = new DoublePointer(dataPointForNumpyStruct.limit() / npyArrayElemSize);
                create = Nd4j.create(Nd4j.createBuffer(doublePointer, DataType.DOUBLE, j, DoubleIndexer.create(doublePointer)), jArr, Nd4j.getStrides(jArr, npyArrayOrder), 0L, npyArrayOrder, DataType.DOUBLE);
            }
            hashMap.put(npyArrayNameFromMap, create);
        }
        return hashMap;
    }
}
