package org.tensorflow;

import java.util.function.Consumer;
import org.bytedeco.javacpp.PointerScope;
import org.tensorflow.internal.buffer.TensorBuffers;
import org.tensorflow.internal.c_api.TF_Tensor;
import org.tensorflow.internal.c_api.global.tensorflow;
import org.tensorflow.ndarray.Shape;
import org.tensorflow.ndarray.buffer.ByteDataBuffer;
import org.tensorflow.types.family.TType;

/* loaded from: input_file:org/tensorflow/Tensor.class */
public final class Tensor<T extends TType> implements AutoCloseable {
    private PointerScope tensorScope;
    private TF_Tensor tensorHandle;
    private final DataType<T> dtype;
    private final Shape shape;
    private T data = null;
    private Long numBytes = null;

    public static <T extends TType> Tensor<T> of(DataType<T> dataType, Shape shape) {
        return of(dataType, shape, shape.size() * dataType.byteSize());
    }

    public static <T extends TType> Tensor<T> of(DataType<T> dataType, Shape shape, long j) {
        if (!dataType.isVariableLength() && shape.size() * dataType.byteSize() > j) {
            throw new IllegalArgumentException("Tensor size is not large enough to contain all scalar values");
        }
        Tensor<T> tensor = new Tensor<>(dataType, shape);
        TF_Tensor allocate = allocate(((Tensor) tensor).dtype.nativeCode(), shape.asArray(), j);
        PointerScope pointerScope = new PointerScope(new Class[0]);
        Throwable th = null;
        try {
            try {
                pointerScope.attach(allocate);
                ((Tensor) tensor).tensorHandle = allocate;
                ((Tensor) tensor).tensorScope = pointerScope.extend();
                if (pointerScope != null) {
                    if (0 != 0) {
                        try {
                            pointerScope.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    } else {
                        pointerScope.close();
                    }
                }
                return tensor;
            } finally {
            }
        } catch (Throwable th3) {
            if (pointerScope != null) {
                if (th != null) {
                    try {
                        pointerScope.close();
                    } catch (Throwable th4) {
                        th.addSuppressed(th4);
                    }
                } else {
                    pointerScope.close();
                }
            }
            throw th3;
        }
    }

    public static <T extends TType> Tensor<T> of(DataType<T> dataType, Shape shape, Consumer<T> consumer) {
        return of(dataType, shape, shape.size() * dataType.byteSize(), consumer);
    }

    public static <T extends TType> Tensor<T> of(DataType<T> dataType, Shape shape, long j, Consumer<T> consumer) {
        Tensor<T> of = of(dataType, shape, j);
        try {
            consumer.accept(of.data());
            return of;
        } catch (Throwable th) {
            of.close();
            throw th;
        }
    }

    public static <T extends TType> Tensor<T> of(DataType<T> dataType, Shape shape, ByteDataBuffer byteDataBuffer) {
        Tensor<T> of = of(dataType, shape, byteDataBuffer.size());
        byteDataBuffer.copyTo(TensorBuffers.toBytes(of.nativeHandle()), byteDataBuffer.size());
        return of;
    }

    /* JADX WARN: Multi-variable type inference failed */
    public <U extends TType> Tensor<U> expect(DataType<U> dataType) {
        if (dataType.equals(this.dtype)) {
            return this;
        }
        throw new IllegalArgumentException("Cannot cast from tensor of " + this.dtype + " to tensor of " + dataType);
    }

    @Override // java.lang.AutoCloseable
    public void close() {
        this.tensorScope.close();
    }

    public DataType<T> dataType() {
        return this.dtype;
    }

    public long numBytes() {
        if (this.numBytes == null) {
            this.numBytes = Long.valueOf(tensorflow.TF_TensorByteSize(this.tensorHandle));
        }
        return this.numBytes.longValue();
    }

    public Shape shape() {
        return this.shape;
    }

    public T data() {
        if (this.data == null) {
            this.data = this.dtype.map(this);
        } else {
            nativeHandle();
        }
        return this.data;
    }

    public ByteDataBuffer rawData() {
        return TensorBuffers.toBytes(nativeHandle(), true);
    }

    public String toString() {
        return String.format("%s tensor with shape %s", this.dtype.toString(), this.shape);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static Tensor<?> fromHandle(TF_Tensor tF_Tensor) {
        Tensor<?> tensor = new Tensor<>(DataTypes.fromNativeCode(dtype(tF_Tensor)), Shape.of(shape(tF_Tensor)));
        PointerScope pointerScope = new PointerScope(new Class[0]);
        Throwable th = null;
        try {
            pointerScope.attach(tF_Tensor);
            ((Tensor) tensor).tensorHandle = tF_Tensor;
            ((Tensor) tensor).tensorScope = pointerScope.extend();
            if (pointerScope != null) {
                if (0 != 0) {
                    try {
                        pointerScope.close();
                    } catch (Throwable th2) {
                        th.addSuppressed(th2);
                    }
                } else {
                    pointerScope.close();
                }
            }
            return tensor;
        } catch (Throwable th3) {
            if (pointerScope != null) {
                if (0 != 0) {
                    try {
                        pointerScope.close();
                    } catch (Throwable th4) {
                        th.addSuppressed(th4);
                    }
                } else {
                    pointerScope.close();
                }
            }
            throw th3;
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static Tensor<?> fromHandle(TF_Tensor tF_Tensor, EagerSession eagerSession) {
        Tensor<?> fromHandle = fromHandle(tF_Tensor);
        eagerSession.attach(tF_Tensor);
        ((Tensor) fromHandle).tensorScope.detach(tF_Tensor);
        return fromHandle;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public TF_Tensor nativeHandle() {
        return requireHandle(this.tensorHandle);
    }

    private static TF_Tensor requireHandle(TF_Tensor tF_Tensor) {
        if (tF_Tensor == null || tF_Tensor.isNull()) {
            throw new IllegalStateException("close() was called on the Tensor");
        }
        return tF_Tensor;
    }

    private static TF_Tensor allocate(int i, long[] jArr, long j) {
        TF_Tensor allocateTensor = TF_Tensor.allocateTensor(i, jArr, j);
        if (allocateTensor == null || allocateTensor.isNull()) {
            throw new IllegalStateException("unable to allocate memory for the Tensor");
        }
        return allocateTensor;
    }

    private static int dtype(TF_Tensor tF_Tensor) {
        requireHandle(tF_Tensor);
        return tensorflow.TF_TensorType(tF_Tensor);
    }

    private static long[] shape(TF_Tensor tF_Tensor) {
        requireHandle(tF_Tensor);
        int TF_NumDims = tensorflow.TF_NumDims(tF_Tensor);
        long[] jArr = new long[TF_NumDims];
        for (int i = 0; i < TF_NumDims; i++) {
            jArr[i] = tensorflow.TF_Dim(tF_Tensor, i);
        }
        return jArr;
    }

    private Tensor(DataType<T> dataType, Shape shape) {
        this.dtype = dataType;
        this.shape = shape;
    }

    static {
        TensorFlow.init();
    }
}
