package org.nd4j.linalg.lossfunctions;

import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.TransformOp;
import org.nd4j.linalg.api.ops.impl.transforms.SoftMax;
import org.nd4j.linalg.factory.Nd4j;

/* loaded from: input_file:org/nd4j/linalg/lossfunctions/LossUtil.class */
public class LossUtil {
    public static INDArray dLdZsoftmaxPreOut(INDArray iNDArray, INDArray iNDArray2) {
        return dLdZsoftmax(iNDArray, Nd4j.getExecutioner().execAndReturn((TransformOp) new SoftMax(iNDArray2.dup())));
    }

    public static INDArray dLdZsoftmax(INDArray iNDArray, INDArray iNDArray2) {
        return iNDArray2.mul(iNDArray.subColumnVector(iNDArray2.mul(iNDArray).sum(1)));
    }

    public static INDArray dLdZsoftmaxi(INDArray iNDArray, INDArray iNDArray2) {
        return iNDArray2.muli(iNDArray.subiColumnVector(iNDArray2.mul(iNDArray).sum(1)));
    }
}
