package org.nd4j.linalg.lossfunctions.impl;

import org.apache.commons.math3.util.Pair;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.lossfunctions.ILossFunction;
import org.nd4j.linalg.ops.transforms.Transforms;

/* loaded from: input_file:org/nd4j/linalg/lossfunctions/impl/LossPoisson.class */
public class LossPoisson implements ILossFunction {
    public INDArray scoreArray(INDArray iNDArray, INDArray iNDArray2, String str, INDArray iNDArray3) {
        INDArray execAndReturn = Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform(str, iNDArray2.dup()));
        INDArray log = Transforms.log(execAndReturn);
        log.muli(iNDArray);
        INDArray sub = execAndReturn.sub(log);
        if (iNDArray3 != null) {
            sub.muliColumnVector(iNDArray3);
        }
        return sub;
    }

    @Override // org.nd4j.linalg.lossfunctions.ILossFunction
    public double computeScore(INDArray iNDArray, INDArray iNDArray2, String str, INDArray iNDArray3, boolean z) {
        double doubleValue = scoreArray(iNDArray, iNDArray2, str, iNDArray3).sumNumber().doubleValue();
        if (z) {
            doubleValue /= r0.size(0);
        }
        return doubleValue;
    }

    @Override // org.nd4j.linalg.lossfunctions.ILossFunction
    public INDArray computeScoreArray(INDArray iNDArray, INDArray iNDArray2, String str, INDArray iNDArray3) {
        return scoreArray(iNDArray, iNDArray2, str, iNDArray3).sum(1);
    }

    @Override // org.nd4j.linalg.lossfunctions.ILossFunction
    public INDArray computeGradient(INDArray iNDArray, INDArray iNDArray2, String str, INDArray iNDArray3) {
        INDArray execAndReturn = Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform(str, iNDArray2.dup()).derivative());
        INDArray muli = iNDArray.div(Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform(str, iNDArray2.dup()))).muli((Number) (-1));
        muli.addi((Number) 1);
        muli.muli(execAndReturn);
        if (iNDArray3 != null) {
            muli.muliColumnVector(iNDArray3);
        }
        return muli;
    }

    @Override // org.nd4j.linalg.lossfunctions.ILossFunction
    public Pair<Double, INDArray> computeGradientAndScore(INDArray iNDArray, INDArray iNDArray2, String str, INDArray iNDArray3, boolean z) {
        return new Pair<>(Double.valueOf(computeScore(iNDArray, iNDArray2, str, iNDArray3, z)), computeGradient(iNDArray, iNDArray2, str, iNDArray3));
    }

    public String toString() {
        return "LossPoisson()";
    }

    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        return (obj instanceof LossPoisson) && ((LossPoisson) obj).canEqual(this);
    }

    protected boolean canEqual(Object obj) {
        return obj instanceof LossPoisson;
    }

    public int hashCode() {
        return 1;
    }
}
