package org.nd4j.linalg.lossfunctions.impl;

import org.nd4j.base.Preconditions;
import org.nd4j.linalg.activations.IActivation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.lossfunctions.ILossFunction;
import org.nd4j.linalg.lossfunctions.LossUtil;
import org.nd4j.linalg.primitives.Pair;

/* loaded from: input_file:org/nd4j/linalg/lossfunctions/impl/LossWasserstein.class */
public class LossWasserstein implements ILossFunction {
    private INDArray scoreArray(INDArray iNDArray, INDArray iNDArray2, IActivation iActivation, INDArray iNDArray3) {
        if (!iNDArray.equalShapes(iNDArray2)) {
            Preconditions.throwEx("Labels and preOutput must have equal shapes: got shapes %s vs %s", new Object[]{iNDArray.shape(), iNDArray2.shape()});
        }
        INDArray mul = iNDArray.castTo(iNDArray2.dataType()).mul(iActivation.getActivation(iNDArray2.dup(), true));
        if (iNDArray3 != null) {
            LossUtil.applyMask(mul, iNDArray3);
        }
        return mul;
    }

    @Override // org.nd4j.linalg.lossfunctions.ILossFunction
    public double computeScore(INDArray iNDArray, INDArray iNDArray2, IActivation iActivation, INDArray iNDArray3, boolean z) {
        double doubleValue = scoreArray(iNDArray, iNDArray2, iActivation, iNDArray3).mean(1).sumNumber().doubleValue();
        if (z) {
            doubleValue /= r0.size(0);
        }
        return doubleValue;
    }

    @Override // org.nd4j.linalg.lossfunctions.ILossFunction
    public INDArray computeScoreArray(INDArray iNDArray, INDArray iNDArray2, IActivation iActivation, INDArray iNDArray3) {
        return Nd4j.expandDims(scoreArray(iNDArray, iNDArray2, iActivation, iNDArray3).mean(1), 1);
    }

    @Override // org.nd4j.linalg.lossfunctions.ILossFunction
    public INDArray computeGradient(INDArray iNDArray, INDArray iNDArray2, IActivation iActivation, INDArray iNDArray3) {
        if (!iNDArray.equalShapes(iNDArray2)) {
            Preconditions.throwEx("Labels and preOutput must have equal shapes: got shapes %s vs %s", new Object[]{iNDArray.shape(), iNDArray2.shape()});
        }
        INDArray castTo = iNDArray.castTo(iNDArray2.dataType());
        INDArray div = castTo.div(Long.valueOf(castTo.size(1)));
        if (iNDArray3 != null && LossUtil.isPerOutputMasking(div, iNDArray3)) {
            LossUtil.applyMask(castTo, iNDArray3);
        }
        INDArray iNDArray4 = (INDArray) iActivation.backprop(iNDArray2, div).getFirst();
        if (iNDArray3 != null) {
            LossUtil.applyMask(iNDArray4, iNDArray3);
        }
        return iNDArray4;
    }

    @Override // org.nd4j.linalg.lossfunctions.ILossFunction
    public Pair<Double, INDArray> computeGradientAndScore(INDArray iNDArray, INDArray iNDArray2, IActivation iActivation, INDArray iNDArray3, boolean z) {
        return new Pair<>(Double.valueOf(computeScore(iNDArray, iNDArray2, iActivation, iNDArray3, z)), computeGradient(iNDArray, iNDArray2, iActivation, iNDArray3));
    }

    @Override // org.nd4j.linalg.lossfunctions.ILossFunction
    public String name() {
        return toString();
    }

    public String toString() {
        return "LossWasserstein()";
    }

    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        return (obj instanceof LossWasserstein) && ((LossWasserstein) obj).canEqual(this);
    }

    protected boolean canEqual(Object obj) {
        return obj instanceof LossWasserstein;
    }

    public int hashCode() {
        return 1;
    }
}
