package org.nd4j.linalg.lossfunctions.impl;

import java.util.List;
import java.util.Map;
import onnx.OnnxProto3;
import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.activations.IActivation;
import org.nd4j.linalg.activations.impl.ActivationSoftmax;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.indexing.BooleanIndexing;
import org.nd4j.linalg.indexing.conditions.Conditions;
import org.nd4j.linalg.lossfunctions.ILossFunction;
import org.nd4j.linalg.lossfunctions.LossUtil;
import org.nd4j.linalg.lossfunctions.serde.RowVectorDeserializer;
import org.nd4j.linalg.lossfunctions.serde.RowVectorSerializer;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.nd4j.linalg.primitives.Pair;
import org.nd4j.shade.jackson.annotation.JsonInclude;
import org.nd4j.shade.jackson.annotation.JsonProperty;
import org.nd4j.shade.jackson.databind.annotation.JsonDeserialize;
import org.nd4j.shade.jackson.databind.annotation.JsonSerialize;
import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef;
import org.tensorflow.framework.NodeDef;

@JsonInclude(JsonInclude.Include.NON_NULL)
/* loaded from: input_file:org/nd4j/linalg/lossfunctions/impl/LossMCXENT.class */
public class LossMCXENT extends DifferentialFunction implements ILossFunction {
    private static final double DEFAULT_SOFTMAX_CLIPPING_EPSILON = 1.0E-10d;

    @JsonDeserialize(using = RowVectorDeserializer.class)
    @JsonSerialize(using = RowVectorSerializer.class)
    private INDArray weights;
    private double softmaxClipEps;

    public LossMCXENT() {
        this(null);
    }

    public LossMCXENT(INDArray iNDArray) {
        this(DEFAULT_SOFTMAX_CLIPPING_EPSILON, iNDArray);
    }

    public LossMCXENT(@JsonProperty("softmaxClipEps") double d, @JsonProperty("weights") INDArray iNDArray) {
        if (iNDArray != null && !iNDArray.isRowVector()) {
            throw new IllegalArgumentException("Weights array must be a row vector");
        }
        if (d < 0.0d || d > 0.5d) {
            throw new IllegalArgumentException("Invalid clipping epsilon: epsilon should be >= 0 (but near zero). Got: " + d);
        }
        this.weights = iNDArray;
        this.softmaxClipEps = d;
    }

    private INDArray scoreArray(INDArray iNDArray, INDArray iNDArray2, IActivation iActivation, INDArray iNDArray3) {
        if (!iNDArray.equalShapes(iNDArray2)) {
            Preconditions.throwEx("Labels and preOutput must have equal shapes: got shapes %s vs %s", new Object[]{iNDArray.shape(), iNDArray2.shape()});
        }
        INDArray activation = iActivation.getActivation(iNDArray2.dup(), true);
        if ((iActivation instanceof ActivationSoftmax) && this.softmaxClipEps > 0.0d) {
            BooleanIndexing.replaceWhere(activation, Double.valueOf(this.softmaxClipEps), Conditions.lessThan(Double.valueOf(this.softmaxClipEps)));
            BooleanIndexing.replaceWhere(activation, Double.valueOf(1.0d - this.softmaxClipEps), Conditions.greaterThan(Double.valueOf(1.0d - this.softmaxClipEps)));
        }
        INDArray muli = Transforms.log(activation, false).muli(iNDArray);
        if (this.weights != null) {
            if (this.weights.length() != muli.size(1)) {
                throw new IllegalStateException("Weights vector (length " + this.weights.length() + ") does not match output.size(1)=" + iNDArray2.size(1));
            }
            muli.muliRowVector(this.weights);
        }
        if (iNDArray3 != null) {
            LossUtil.applyMask(muli, iNDArray3);
        }
        return muli;
    }

    @Override // org.nd4j.linalg.lossfunctions.ILossFunction
    public double computeScore(INDArray iNDArray, INDArray iNDArray2, IActivation iActivation, INDArray iNDArray3, boolean z) {
        double d = -scoreArray(iNDArray, iNDArray2, iActivation, iNDArray3).sumNumber().doubleValue();
        if (z) {
            d /= r0.size(0);
        }
        return d;
    }

    @Override // org.nd4j.linalg.lossfunctions.ILossFunction
    public INDArray computeScoreArray(INDArray iNDArray, INDArray iNDArray2, IActivation iActivation, INDArray iNDArray3) {
        return scoreArray(iNDArray, iNDArray2, iActivation, iNDArray3).sum(1).muli((Number) (-1));
    }

    @Override // org.nd4j.linalg.lossfunctions.ILossFunction
    public INDArray computeGradient(INDArray iNDArray, INDArray iNDArray2, IActivation iActivation, INDArray iNDArray3) {
        INDArray iNDArray4;
        if (!iNDArray.equalShapes(iNDArray2)) {
            Preconditions.throwEx("Labels and preOutput must have equal shapes: got shapes %s vs %s", new Object[]{iNDArray.shape(), iNDArray2.shape()});
        }
        INDArray activation = iActivation.getActivation(iNDArray2.dup(), true);
        if (!(iActivation instanceof ActivationSoftmax)) {
            iNDArray4 = (INDArray) iActivation.backprop(iNDArray2, activation.rdivi(iNDArray).negi()).getFirst();
            if (this.weights != null) {
                if (this.weights.length() != activation.size(1)) {
                    throw new IllegalStateException("Weights vector (length " + this.weights.length() + ") does not match output.size(1)=" + activation.size(1));
                }
                iNDArray4.muliRowVector(this.weights);
            }
        } else {
            if (iNDArray3 != null && LossUtil.isPerOutputMasking(activation, iNDArray3)) {
                throw new UnsupportedOperationException("Per output masking for MCXENT + softmax: not supported");
            }
            if (this.weights == null) {
                iNDArray4 = activation.subi(iNDArray);
            } else {
                if (this.weights.length() != activation.size(1)) {
                    throw new IllegalStateException("Weights vector (length " + this.weights.length() + ") does not match output.size(1)=" + activation.size(1));
                }
                INDArray mulRowVector = iNDArray.mulRowVector(this.weights);
                iNDArray4 = activation.mulColumnVector(mulRowVector.sum(1)).sub(mulRowVector);
            }
        }
        if (iNDArray3 != null) {
            LossUtil.applyMask(iNDArray4, iNDArray3);
        }
        return iNDArray4;
    }

    @Override // org.nd4j.linalg.lossfunctions.ILossFunction
    public Pair<Double, INDArray> computeGradientAndScore(INDArray iNDArray, INDArray iNDArray2, IActivation iActivation, INDArray iNDArray3, boolean z) {
        return new Pair<>(Double.valueOf(computeScore(iNDArray, iNDArray2, iActivation, iNDArray3, z)), computeGradient(iNDArray, iNDArray2, iActivation, iNDArray3));
    }

    @Override // org.nd4j.linalg.lossfunctions.ILossFunction
    public String name() {
        return toString();
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public String toString() {
        return this.weights == null ? "LossMCXENT()" : "LossMCXENT(weights=" + this.weights + ")";
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public SDVariable[] outputVariables() {
        return new SDVariable[0];
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public SDVariable[] outputVariables(String str) {
        return new SDVariable[0];
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public List<SDVariable> doDiff(List<SDVariable> list) {
        return null;
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public String opName() {
        return "lossmcxent";
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public Op.Type opType() {
        return Op.Type.CUSTOM;
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public void initFromTensorFlow(NodeDef nodeDef, SameDiff sameDiff, Map<String, AttrValue> map, GraphDef graphDef) {
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public void initFromOnnx(OnnxProto3.NodeProto nodeProto, SameDiff sameDiff, Map<String, OnnxProto3.AttributeProto> map, OnnxProto3.GraphProto graphProto) {
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public String onnxName() {
        return "SoftmaxCrossEntropyWithLogits";
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public String tensorflowName() {
        return "SoftmaxCrossEntropyWithLogits";
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof LossMCXENT)) {
            return false;
        }
        LossMCXENT lossMCXENT = (LossMCXENT) obj;
        if (!lossMCXENT.canEqual(this)) {
            return false;
        }
        INDArray weights = getWeights();
        INDArray weights2 = lossMCXENT.getWeights();
        if (weights == null) {
            if (weights2 != null) {
                return false;
            }
        } else if (!weights.equals(weights2)) {
            return false;
        }
        return Double.compare(getSoftmaxClipEps(), lossMCXENT.getSoftmaxClipEps()) == 0;
    }

    protected boolean canEqual(Object obj) {
        return obj instanceof LossMCXENT;
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public int hashCode() {
        INDArray weights = getWeights();
        int hashCode = (1 * 59) + (weights == null ? 43 : weights.hashCode());
        long doubleToLongBits = Double.doubleToLongBits(getSoftmaxClipEps());
        return (hashCode * 59) + ((int) ((doubleToLongBits >>> 32) ^ doubleToLongBits));
    }

    public INDArray getWeights() {
        return this.weights;
    }

    public double getSoftmaxClipEps() {
        return this.softmaxClipEps;
    }

    public void setWeights(INDArray iNDArray) {
        this.weights = iNDArray;
    }

    public void setSoftmaxClipEps(double d) {
        this.softmaxClipEps = d;
    }
}
