package org.nd4j.linalg.api.ops;

import java.util.Collections;
import java.util.List;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/nd4j/linalg/api/ops/BaseScalarBoolOp.class */
public abstract class BaseScalarBoolOp extends BaseOp implements ScalarOp {
    private static final Logger log = LoggerFactory.getLogger(BaseScalarBoolOp.class);

    public BaseScalarBoolOp() {
    }

    public BaseScalarBoolOp(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, Number number) {
        super(iNDArray, iNDArray2, iNDArray3);
        this.scalarValue = Nd4j.scalar(iNDArray.dataType(), number);
    }

    public BaseScalarBoolOp(INDArray iNDArray, Number number) {
        super(iNDArray);
        this.scalarValue = Nd4j.scalar(iNDArray.dataType(), number);
    }

    public BaseScalarBoolOp(INDArray iNDArray, INDArray iNDArray2, Number number) {
        super(iNDArray, (INDArray) null, iNDArray2);
        this.scalarValue = Nd4j.scalar(iNDArray.dataType(), number);
    }

    public BaseScalarBoolOp(SameDiff sameDiff, SDVariable sDVariable, Number number) {
        this(sameDiff, sDVariable, number, false, null);
    }

    public BaseScalarBoolOp(SameDiff sameDiff, SDVariable sDVariable, Number number, boolean z) {
        this(sameDiff, sDVariable, number, z, null);
    }

    public BaseScalarBoolOp(SameDiff sameDiff, SDVariable sDVariable, Number number, boolean z, Object[] objArr) {
        super(sameDiff, z, objArr);
        this.scalarValue = Nd4j.scalar(sDVariable.dataType(), number);
        if (sDVariable == null) {
            throw new IllegalArgumentException("Input not null variable.");
        }
        this.xVertexId = sDVariable.getVarName();
        sameDiff.addArgsFor(new String[]{this.xVertexId}, this);
        if (Shape.isPlaceholderShape(sDVariable.getShape())) {
            sameDiff.addPropertyToResolve(this, sDVariable.getVarName());
        }
        f().validateDifferentialFunctionsameDiff(sDVariable);
    }

    public BaseScalarBoolOp(SameDiff sameDiff, SDVariable sDVariable, Number number, Object[] objArr) {
        this(sameDiff, sDVariable, number, false, objArr);
    }

    @Override // org.nd4j.linalg.api.ops.BaseOp, org.nd4j.linalg.api.ops.Op
    public INDArray z() {
        return this.z;
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public List<LongShapeDescriptor> calculateOutputShape() {
        return this.x == null ? Collections.emptyList() : Collections.singletonList(LongShapeDescriptor.fromShape(this.x.shape(), DataType.BOOL));
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public Op.Type opType() {
        return Op.Type.SCALAR_BOOL;
    }

    @Override // org.nd4j.linalg.api.ops.ScalarOp
    public void setScalar(Number number) {
        this.scalarValue = Nd4j.scalar(number);
    }

    @Override // org.nd4j.linalg.api.ops.ScalarOp
    public void setScalar(INDArray iNDArray) {
        this.scalarValue = iNDArray;
    }

    @Override // org.nd4j.linalg.api.ops.ScalarOp
    public INDArray scalar() {
        return (this.scalarValue == null && y() != null && y().isScalar()) ? y() : this.scalarValue;
    }

    @Override // org.nd4j.linalg.api.ops.ScalarOp
    public int[] getDimension() {
        return this.dimensions;
    }

    @Override // org.nd4j.linalg.api.ops.ScalarOp
    public void setDimension(int... iArr) {
        defineDimensions(iArr);
    }

    @Override // org.nd4j.linalg.api.ops.ScalarOp
    public boolean validateDataTypes(boolean z) {
        Preconditions.checkArgument(z().isB(), "Op.Z must have floating point type, since one of operands is floating point. op.z.datatype=" + z().dataType());
        return true;
    }

    @Override // org.nd4j.linalg.api.ops.ScalarOp
    public Op.Type getOpType() {
        return Op.Type.SCALAR_BOOL;
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public List<DataType> calculateOutputDataTypes(List<DataType> list) {
        Preconditions.checkState(list != null && list.size() == 1, "Expected exactly 1 input datatype for %s, got input %s", getClass(), list);
        return Collections.singletonList(DataType.BOOL);
    }
}
