package org.nd4j.linalg.api.ops;

import java.util.ArrayList;
import java.util.List;
import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.util.SameDiffUtils;
import org.nd4j.common.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/nd4j/linalg/api/ops/BaseScalarOp.class */
public abstract class BaseScalarOp extends BaseOp implements ScalarOp {
    private static final Logger log = LoggerFactory.getLogger(BaseScalarOp.class);

    public BaseScalarOp() {
        this.scalarValue = Nd4j.scalar(0.0f);
    }

    public BaseScalarOp(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, Number number) {
        super(iNDArray, iNDArray2, iNDArray3);
        if (iNDArray.isCompressed()) {
            Nd4j.getCompressor().decompressi(iNDArray);
        }
        MemoryWorkspace scopeOutOfWorkspaces = Nd4j.getMemoryManager().scopeOutOfWorkspaces();
        Throwable th = null;
        try {
            try {
                this.scalarValue = Nd4j.scalar(iNDArray.dataType(), number);
                if (scopeOutOfWorkspaces != null) {
                    if (0 == 0) {
                        scopeOutOfWorkspaces.close();
                        return;
                    }
                    try {
                        scopeOutOfWorkspaces.close();
                    } catch (Throwable th2) {
                        th.addSuppressed(th2);
                    }
                }
            } catch (Throwable th3) {
                th = th3;
                throw th3;
            }
        } catch (Throwable th4) {
            if (scopeOutOfWorkspaces != null) {
                if (th != null) {
                    try {
                        scopeOutOfWorkspaces.close();
                    } catch (Throwable th5) {
                        th.addSuppressed(th5);
                    }
                } else {
                    scopeOutOfWorkspaces.close();
                }
            }
            throw th4;
        }
    }

    public BaseScalarOp(INDArray iNDArray, Number number) {
        super(iNDArray);
        if (iNDArray.isCompressed()) {
            Nd4j.getCompressor().decompressi(iNDArray);
        }
        MemoryWorkspace scopeOutOfWorkspaces = Nd4j.getMemoryManager().scopeOutOfWorkspaces();
        Throwable th = null;
        try {
            try {
                this.scalarValue = Nd4j.scalar(iNDArray.dataType(), number);
                if (scopeOutOfWorkspaces != null) {
                    if (0 == 0) {
                        scopeOutOfWorkspaces.close();
                        return;
                    }
                    try {
                        scopeOutOfWorkspaces.close();
                    } catch (Throwable th2) {
                        th.addSuppressed(th2);
                    }
                }
            } catch (Throwable th3) {
                th = th3;
                throw th3;
            }
        } catch (Throwable th4) {
            if (scopeOutOfWorkspaces != null) {
                if (th != null) {
                    try {
                        scopeOutOfWorkspaces.close();
                    } catch (Throwable th5) {
                        th.addSuppressed(th5);
                    }
                } else {
                    scopeOutOfWorkspaces.close();
                }
            }
            throw th4;
        }
    }

    public BaseScalarOp(INDArray iNDArray, INDArray iNDArray2, Number number) {
        super(iNDArray, (INDArray) null, iNDArray2);
        if (iNDArray.isCompressed()) {
            Nd4j.getCompressor().decompressi(iNDArray);
        }
        MemoryWorkspace scopeOutOfWorkspaces = Nd4j.getMemoryManager().scopeOutOfWorkspaces();
        Throwable th = null;
        try {
            try {
                this.scalarValue = Nd4j.scalar(iNDArray.dataType(), number);
                if (scopeOutOfWorkspaces != null) {
                    if (0 == 0) {
                        scopeOutOfWorkspaces.close();
                        return;
                    }
                    try {
                        scopeOutOfWorkspaces.close();
                    } catch (Throwable th2) {
                        th.addSuppressed(th2);
                    }
                }
            } catch (Throwable th3) {
                th = th3;
                throw th3;
            }
        } catch (Throwable th4) {
            if (scopeOutOfWorkspaces != null) {
                if (th != null) {
                    try {
                        scopeOutOfWorkspaces.close();
                    } catch (Throwable th5) {
                        th.addSuppressed(th5);
                    }
                } else {
                    scopeOutOfWorkspaces.close();
                }
            }
            throw th4;
        }
    }

    public BaseScalarOp(SameDiff sameDiff, SDVariable sDVariable, Number number) {
        this(sameDiff, sDVariable, number, false, null);
    }

    public BaseScalarOp(SameDiff sameDiff, SDVariable sDVariable, Number number, boolean z) {
        this(sameDiff, sDVariable, number, z, null);
    }

    public BaseScalarOp(SameDiff sameDiff, @NonNull SDVariable sDVariable, Number number, boolean z, Object[] objArr) {
        super(sameDiff, z, objArr);
        if (sDVariable == null) {
            throw new NullPointerException("i_v is marked non-null but is null");
        }
        this.scalarValue = Nd4j.scalar(sDVariable.dataType(), number);
        this.xVertexId = sDVariable.name();
        sameDiff.addArgsFor(new String[]{this.xVertexId}, this);
        SameDiffUtils.validateDifferentialFunctionSameDiff(sameDiff, sDVariable, this);
    }

    public BaseScalarOp(SameDiff sameDiff, SDVariable sDVariable, Number number, Object[] objArr) {
        this(sameDiff, sDVariable, number, false, objArr);
    }

    @Override // org.nd4j.linalg.api.ops.BaseOp, org.nd4j.linalg.api.ops.Op
    public INDArray z() {
        return this.z;
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public List<LongShapeDescriptor> calculateOutputShape() {
        return calculateOutputShape(null);
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public List<LongShapeDescriptor> calculateOutputShape(OpContext opContext) {
        INDArray inputArray = opContext != null ? opContext.getInputArray(0) : x();
        ArrayList arrayList = new ArrayList(1);
        arrayList.add(LongShapeDescriptor.fromShape(inputArray != null ? inputArray.shape() : arg().getShape(), Shape.pickPairwiseDataType(arg().dataType(), this.scalarValue.dataType())));
        return arrayList;
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public Op.Type opType() {
        return Op.Type.SCALAR;
    }

    @Override // org.nd4j.linalg.api.ops.ScalarOp
    public void setScalar(Number number) {
        this.scalarValue = Nd4j.scalar(this.x.dataType(), number);
    }

    @Override // org.nd4j.linalg.api.ops.ScalarOp
    public void setScalar(INDArray iNDArray) {
        this.scalarValue = iNDArray;
    }

    @Override // org.nd4j.linalg.api.ops.ScalarOp
    public INDArray scalar() {
        return (y() == null || !y().isScalar()) ? this.scalarValue : y();
    }

    @Override // org.nd4j.linalg.api.ops.ScalarOp
    public int[] getDimension() {
        return this.dimensions;
    }

    @Override // org.nd4j.linalg.api.ops.ScalarOp
    public void setDimension(int... iArr) {
        defineDimensions(iArr);
    }

    @Override // org.nd4j.linalg.api.ops.ScalarOp
    public boolean validateDataTypes(boolean z) {
        if (y() == null) {
            if (!x().isR()) {
                return true;
            }
            Preconditions.checkArgument(z().isR(), "Op.Z must have floating point type, since one of operands is floating point: x.dataType=%s, z.dataType=%s, op=%s", this.x.dataType(), this.z.dataType(), getClass().getName());
            return true;
        }
        if (y().isR() || x().isR()) {
            Preconditions.checkArgument(z().isR(), "Op.Z must have floating point type, since one of operands is floating point: x.dataType=%s, y.dataType=%s, z.dataType=%s, op=%s", this.x.dataType(), this.y.dataType(), this.z.dataType(), getClass().getName());
        }
        if (z) {
            return true;
        }
        Preconditions.checkArgument(this.x.dataType() == this.y.dataType() || this.y.dataType() == DataType.BOOL, "Op.X must have same data type as Op.Y");
        return true;
    }

    @Override // org.nd4j.linalg.api.ops.ScalarOp
    public Op.Type getOpType() {
        return Op.Type.SCALAR;
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public List<DataType> calculateOutputDataTypes(List<DataType> list) {
        Preconditions.checkState(list != null && list.size() == 1, "Expected exactly 1 input datatype %s, got input %s", getClass(), list);
        return list;
    }
}
