package org.nd4j.evaluation.classification;

import java.io.Serializable;
import java.util.Arrays;
import java.util.List;
import org.nd4j.base.Preconditions;
import org.nd4j.evaluation.BaseEvaluation;
import org.nd4j.evaluation.curves.Histogram;
import org.nd4j.evaluation.curves.ReliabilityDiagram;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.ReduceOp;
import org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition;
import org.nd4j.linalg.api.ops.impl.transforms.any.IsMax;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.conditions.Conditions;
import org.nd4j.linalg.lossfunctions.LossUtil;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.nd4j.linalg.primitives.Triple;
import org.nd4j.serde.jackson.shaded.NDArrayDeSerializer;
import org.nd4j.serde.jackson.shaded.NDArraySerializer;
import org.nd4j.serde.jackson.shaded.NDArrayTextDeSerializer;
import org.nd4j.serde.jackson.shaded.NDArrayTextSerializer;
import org.nd4j.shade.jackson.annotation.JsonProperty;
import org.nd4j.shade.jackson.databind.annotation.JsonDeserialize;
import org.nd4j.shade.jackson.databind.annotation.JsonSerialize;

/* loaded from: input_file:org/nd4j/evaluation/classification/EvaluationCalibration.class */
public class EvaluationCalibration extends BaseEvaluation<EvaluationCalibration> {
    public static final int DEFAULT_RELIABILITY_DIAG_NUM_BINS = 10;
    public static final int DEFAULT_HISTOGRAM_NUM_BINS = 50;
    private final int reliabilityDiagNumBins;
    private final int histogramNumBins;
    private final boolean excludeEmptyBins;
    protected int axis;

    @JsonDeserialize(using = NDArrayDeSerializer.class)
    @JsonSerialize(using = NDArraySerializer.class)
    private INDArray rDiagBinPosCount;

    @JsonDeserialize(using = NDArrayDeSerializer.class)
    @JsonSerialize(using = NDArraySerializer.class)
    private INDArray rDiagBinTotalCount;

    @JsonDeserialize(using = NDArrayDeSerializer.class)
    @JsonSerialize(using = NDArraySerializer.class)
    private INDArray rDiagBinSumPredictions;

    @JsonDeserialize(using = NDArrayTextDeSerializer.class)
    @JsonSerialize(using = NDArrayTextSerializer.class)
    private INDArray labelCountsEachClass;

    @JsonDeserialize(using = NDArrayTextDeSerializer.class)
    @JsonSerialize(using = NDArrayTextSerializer.class)
    private INDArray predictionCountsEachClass;

    @JsonDeserialize(using = NDArrayTextDeSerializer.class)
    @JsonSerialize(using = NDArrayTextSerializer.class)
    private INDArray residualPlotOverall;

    @JsonDeserialize(using = NDArrayDeSerializer.class)
    @JsonSerialize(using = NDArraySerializer.class)
    private INDArray residualPlotByLabelClass;

    @JsonDeserialize(using = NDArrayTextDeSerializer.class)
    @JsonSerialize(using = NDArrayTextSerializer.class)
    private INDArray probHistogramOverall;

    @JsonDeserialize(using = NDArrayDeSerializer.class)
    @JsonSerialize(using = NDArraySerializer.class)
    private INDArray probHistogramByLabelClass;

    public EvaluationCalibration() {
        this(10, 50, true);
    }

    public EvaluationCalibration(int i, int i2) {
        this(i, i2, true);
    }

    public EvaluationCalibration(@JsonProperty("reliabilityDiagNumBins") int i, @JsonProperty("histogramNumBins") int i2, @JsonProperty("excludeEmptyBins") boolean z) {
        this.axis = 1;
        this.reliabilityDiagNumBins = i;
        this.histogramNumBins = i2;
        this.excludeEmptyBins = z;
    }

    public void setAxis(int i) {
        this.axis = i;
    }

    public int getAxis() {
        return this.axis;
    }

    @Override // org.nd4j.evaluation.BaseEvaluation, org.nd4j.evaluation.IEvaluation
    public void eval(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3) {
        INDArray castTo;
        INDArray castTo2;
        Triple<INDArray, INDArray, INDArray> reshapeAndExtractNotMasked = BaseEvaluation.reshapeAndExtractNotMasked(iNDArray, iNDArray2, iNDArray3, this.axis);
        if (reshapeAndExtractNotMasked == null) {
            return;
        }
        INDArray iNDArray4 = (INDArray) reshapeAndExtractNotMasked.getFirst();
        INDArray iNDArray5 = (INDArray) reshapeAndExtractNotMasked.getSecond();
        INDArray iNDArray6 = (INDArray) reshapeAndExtractNotMasked.getThird();
        Preconditions.checkState(iNDArray6 == null, "Per-output masking for EvaluationCalibration is not supported");
        long size = iNDArray4.size(1);
        if (this.rDiagBinPosCount == null) {
            DataType dataType = DataType.DOUBLE;
            this.rDiagBinPosCount = Nd4j.create(DataType.LONG, this.reliabilityDiagNumBins, size);
            this.rDiagBinTotalCount = Nd4j.create(DataType.LONG, this.reliabilityDiagNumBins, size);
            this.rDiagBinSumPredictions = Nd4j.create(dataType, this.reliabilityDiagNumBins, size);
            this.labelCountsEachClass = Nd4j.create(DataType.LONG, 1, size);
            this.predictionCountsEachClass = Nd4j.create(DataType.LONG, 1, size);
            this.residualPlotOverall = Nd4j.create(dataType, 1, this.histogramNumBins);
            this.residualPlotByLabelClass = Nd4j.create(dataType, this.histogramNumBins, size);
            this.probHistogramOverall = Nd4j.create(dataType, 1, this.histogramNumBins);
            this.probHistogramByLabelClass = Nd4j.create(dataType, this.histogramNumBins, size);
        }
        double d = 1.0d / this.histogramNumBins;
        double d2 = 1.0d / this.reliabilityDiagNumBins;
        INDArray iNDArray7 = iNDArray4;
        if (iNDArray6 != null) {
            iNDArray7 = iNDArray6.isColumnVectorOrScalar() ? iNDArray7.mulColumnVector(iNDArray6) : iNDArray7.mul(iNDArray6);
        }
        int i = 0;
        while (i < this.reliabilityDiagNumBins) {
            INDArray muli = iNDArray5.gte(Double.valueOf(i * d2)).castTo(iNDArray5.dataType()).muli(i == this.reliabilityDiagNumBins - 1 ? iNDArray5.lte(Double.valueOf(1.0d)).castTo(iNDArray5.dataType()) : iNDArray5.lt(Double.valueOf((i + 1) * d2)).castTo(iNDArray5.dataType()));
            if (iNDArray6 != null) {
                if (iNDArray6.isColumnVectorOrScalar()) {
                    muli.muliColumnVector(iNDArray6);
                } else {
                    muli.muli(iNDArray6);
                }
            }
            INDArray mul = iNDArray7.mul(muli);
            INDArray mul2 = iNDArray5.mul(muli);
            INDArray sum = muli.sum(0);
            this.rDiagBinSumPredictions.getRow(i).addi(mul2.sum(0).castTo(this.rDiagBinSumPredictions.dataType()));
            this.rDiagBinPosCount.getRow(i).addi(mul.sum(0).castTo(this.rDiagBinPosCount.dataType()));
            this.rDiagBinTotalCount.getRow(i).addi(sum.castTo(this.rDiagBinTotalCount.dataType()));
            i++;
        }
        this.labelCountsEachClass.addi(iNDArray4.sum(0).castTo(this.labelCountsEachClass.dataType()));
        INDArray exec = Nd4j.getExecutioner().exec(new IsMax(iNDArray5.dup(), 1));
        if (iNDArray6 != null) {
            LossUtil.applyMask(exec, iNDArray6);
        }
        this.predictionCountsEachClass.addi(exec.sum(0).castTo(this.predictionCountsEachClass.dataType()));
        INDArray sub = iNDArray4.sub(iNDArray5);
        INDArray dup = iNDArray5.dup();
        Transforms.abs(sub, false);
        if (iNDArray6 != null) {
            INDArray mul3 = iNDArray6.mul((Number) (-10));
            sub.addiColumnVector(mul3);
            dup.addiColumnVector(mul3);
        }
        for (int i2 = 0; i2 < this.histogramNumBins; i2++) {
            INDArray castTo3 = sub.gte(Double.valueOf(i2 * d)).castTo(iNDArray5.dataType());
            INDArray castTo4 = dup.gte(Double.valueOf(i2 * d)).castTo(iNDArray5.dataType());
            if (i2 == this.histogramNumBins - 1) {
                castTo = sub.lte(Double.valueOf(1.0d)).castTo(iNDArray5.dataType());
                castTo2 = dup.lte(Double.valueOf(1.0d)).castTo(iNDArray5.dataType());
            } else {
                castTo = sub.lt(Double.valueOf((i2 + 1) * d)).castTo(iNDArray5.dataType());
                castTo2 = dup.lt(Double.valueOf((i2 + 1) * d)).castTo(iNDArray5.dataType());
            }
            INDArray iNDArray8 = castTo2;
            INDArray muli2 = castTo3.muli(castTo);
            INDArray muli3 = castTo4.muli(iNDArray8);
            this.residualPlotOverall.putScalar(0L, i2, this.residualPlotOverall.getInt(0, i2) + muli2.sumNumber().intValue());
            this.residualPlotByLabelClass.getRow(i2).addi(iNDArray7.mul(muli2).sum(0).castTo(this.residualPlotByLabelClass.dataType()));
            this.probHistogramOverall.putScalar(0L, i2, this.probHistogramOverall.getInt(0, i2) + muli3.sumNumber().intValue());
            this.probHistogramByLabelClass.getRow(i2).addi(iNDArray7.mul(muli3).sum(0).castTo(this.probHistogramByLabelClass.dataType()));
        }
    }

    @Override // org.nd4j.evaluation.BaseEvaluation, org.nd4j.evaluation.IEvaluation
    public void eval(INDArray iNDArray, INDArray iNDArray2) {
        eval(iNDArray, iNDArray2, (INDArray) null);
    }

    @Override // org.nd4j.evaluation.IEvaluation
    public void eval(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, List<? extends Serializable> list) {
        throw new UnsupportedOperationException("Not yet implemented");
    }

    @Override // org.nd4j.evaluation.IEvaluation
    public void merge(EvaluationCalibration evaluationCalibration) {
        if (this.reliabilityDiagNumBins != evaluationCalibration.reliabilityDiagNumBins) {
            throw new UnsupportedOperationException("Cannot merge EvaluationCalibration instances with different numbers of bins");
        }
        if (evaluationCalibration.rDiagBinPosCount == null) {
            return;
        }
        if (this.rDiagBinPosCount == null) {
            this.rDiagBinPosCount = evaluationCalibration.rDiagBinPosCount;
            this.rDiagBinTotalCount = evaluationCalibration.rDiagBinTotalCount;
            this.rDiagBinSumPredictions = evaluationCalibration.rDiagBinSumPredictions;
        }
        this.rDiagBinPosCount.addi(evaluationCalibration.rDiagBinPosCount);
        this.rDiagBinTotalCount.addi(evaluationCalibration.rDiagBinTotalCount);
        this.rDiagBinSumPredictions.addi(evaluationCalibration.rDiagBinSumPredictions);
    }

    @Override // org.nd4j.evaluation.IEvaluation
    public void reset() {
        this.rDiagBinPosCount = null;
        this.rDiagBinTotalCount = null;
        this.rDiagBinSumPredictions = null;
    }

    @Override // org.nd4j.evaluation.IEvaluation
    public String stats() {
        return "EvaluationCalibration(nBins=" + this.reliabilityDiagNumBins + ")";
    }

    public int numClasses() {
        if (this.rDiagBinTotalCount == null) {
            return -1;
        }
        return (int) this.rDiagBinTotalCount.size(1);
    }

    public ReliabilityDiagram getReliabilityDiagram(int i) {
        int i2;
        INDArray column = this.rDiagBinTotalCount.getColumn(i);
        INDArray column2 = this.rDiagBinPosCount.getColumn(i);
        double[] asDouble = this.rDiagBinSumPredictions.getColumn(i).castTo(DataType.DOUBLE).div(column.castTo(DataType.DOUBLE)).data().asDouble();
        double[] asDouble2 = column2.castTo(DataType.DOUBLE).div(column.castTo(DataType.DOUBLE)).data().asDouble();
        if (this.excludeEmptyBins && (i2 = Nd4j.getExecutioner().exec((ReduceOp) new MatchCondition(column, Conditions.equals((Number) 0), new int[0])).getInt(0)) != 0) {
            asDouble = new double[(int) (column.length() - i2)];
            asDouble2 = new double[asDouble.length];
            int i3 = 0;
            for (int i4 = 0; i4 < asDouble.length; i4++) {
                if (column.getDouble(i4) != 0.0d) {
                    asDouble[i3] = asDouble[i4];
                    asDouble2[i3] = asDouble2[i4];
                    i3++;
                }
            }
        }
        return new ReliabilityDiagram("Reliability Diagram: Class " + i, asDouble, asDouble2);
    }

    public int[] getLabelCountsEachClass() {
        if (this.labelCountsEachClass == null) {
            return null;
        }
        return this.labelCountsEachClass.data().asInt();
    }

    public int[] getPredictionCountsEachClass() {
        if (this.predictionCountsEachClass == null) {
            return null;
        }
        return this.predictionCountsEachClass.data().asInt();
    }

    public Histogram getResidualPlotAllClasses() {
        return new Histogram("Residual Plot - All Predictions and Classes", 0.0d, 1.0d, this.residualPlotOverall.data().asInt());
    }

    public Histogram getResidualPlot(int i) {
        return new Histogram("Residual Plot - Predictions for Label Class " + i, 0.0d, 1.0d, this.residualPlotByLabelClass.getColumn(i).dup().data().asInt());
    }

    public Histogram getProbabilityHistogramAllClasses() {
        return new Histogram("Network Probabilities Histogram - All Predictions and Classes", 0.0d, 1.0d, this.probHistogramOverall.data().asInt());
    }

    public Histogram getProbabilityHistogram(int i) {
        return new Histogram("Network Probabilities Histogram - P(class " + i + ") - Data Labelled Class " + i + " Only", 0.0d, 1.0d, this.probHistogramByLabelClass.getColumn(i).dup().data().asInt());
    }

    public static EvaluationCalibration fromJson(String str) {
        return (EvaluationCalibration) fromJson(str, EvaluationCalibration.class);
    }

    public int getReliabilityDiagNumBins() {
        return this.reliabilityDiagNumBins;
    }

    public int getHistogramNumBins() {
        return this.histogramNumBins;
    }

    public boolean isExcludeEmptyBins() {
        return this.excludeEmptyBins;
    }

    public INDArray getRDiagBinPosCount() {
        return this.rDiagBinPosCount;
    }

    public INDArray getRDiagBinTotalCount() {
        return this.rDiagBinTotalCount;
    }

    public INDArray getRDiagBinSumPredictions() {
        return this.rDiagBinSumPredictions;
    }

    public INDArray getResidualPlotOverall() {
        return this.residualPlotOverall;
    }

    public INDArray getResidualPlotByLabelClass() {
        return this.residualPlotByLabelClass;
    }

    public INDArray getProbHistogramOverall() {
        return this.probHistogramOverall;
    }

    public INDArray getProbHistogramByLabelClass() {
        return this.probHistogramByLabelClass;
    }

    @Override // org.nd4j.evaluation.BaseEvaluation
    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof EvaluationCalibration)) {
            return false;
        }
        EvaluationCalibration evaluationCalibration = (EvaluationCalibration) obj;
        if (!evaluationCalibration.canEqual(this) || getReliabilityDiagNumBins() != evaluationCalibration.getReliabilityDiagNumBins() || getHistogramNumBins() != evaluationCalibration.getHistogramNumBins() || isExcludeEmptyBins() != evaluationCalibration.isExcludeEmptyBins()) {
            return false;
        }
        INDArray rDiagBinPosCount = getRDiagBinPosCount();
        INDArray rDiagBinPosCount2 = evaluationCalibration.getRDiagBinPosCount();
        if (rDiagBinPosCount == null) {
            if (rDiagBinPosCount2 != null) {
                return false;
            }
        } else if (!rDiagBinPosCount.equals(rDiagBinPosCount2)) {
            return false;
        }
        INDArray rDiagBinTotalCount = getRDiagBinTotalCount();
        INDArray rDiagBinTotalCount2 = evaluationCalibration.getRDiagBinTotalCount();
        if (rDiagBinTotalCount == null) {
            if (rDiagBinTotalCount2 != null) {
                return false;
            }
        } else if (!rDiagBinTotalCount.equals(rDiagBinTotalCount2)) {
            return false;
        }
        INDArray rDiagBinSumPredictions = getRDiagBinSumPredictions();
        INDArray rDiagBinSumPredictions2 = evaluationCalibration.getRDiagBinSumPredictions();
        if (rDiagBinSumPredictions == null) {
            if (rDiagBinSumPredictions2 != null) {
                return false;
            }
        } else if (!rDiagBinSumPredictions.equals(rDiagBinSumPredictions2)) {
            return false;
        }
        if (!Arrays.equals(getLabelCountsEachClass(), evaluationCalibration.getLabelCountsEachClass()) || !Arrays.equals(getPredictionCountsEachClass(), evaluationCalibration.getPredictionCountsEachClass())) {
            return false;
        }
        INDArray residualPlotOverall = getResidualPlotOverall();
        INDArray residualPlotOverall2 = evaluationCalibration.getResidualPlotOverall();
        if (residualPlotOverall == null) {
            if (residualPlotOverall2 != null) {
                return false;
            }
        } else if (!residualPlotOverall.equals(residualPlotOverall2)) {
            return false;
        }
        INDArray residualPlotByLabelClass = getResidualPlotByLabelClass();
        INDArray residualPlotByLabelClass2 = evaluationCalibration.getResidualPlotByLabelClass();
        if (residualPlotByLabelClass == null) {
            if (residualPlotByLabelClass2 != null) {
                return false;
            }
        } else if (!residualPlotByLabelClass.equals(residualPlotByLabelClass2)) {
            return false;
        }
        INDArray probHistogramOverall = getProbHistogramOverall();
        INDArray probHistogramOverall2 = evaluationCalibration.getProbHistogramOverall();
        if (probHistogramOverall == null) {
            if (probHistogramOverall2 != null) {
                return false;
            }
        } else if (!probHistogramOverall.equals(probHistogramOverall2)) {
            return false;
        }
        INDArray probHistogramByLabelClass = getProbHistogramByLabelClass();
        INDArray probHistogramByLabelClass2 = evaluationCalibration.getProbHistogramByLabelClass();
        return probHistogramByLabelClass == null ? probHistogramByLabelClass2 == null : probHistogramByLabelClass.equals(probHistogramByLabelClass2);
    }

    @Override // org.nd4j.evaluation.BaseEvaluation
    protected boolean canEqual(Object obj) {
        return obj instanceof EvaluationCalibration;
    }

    @Override // org.nd4j.evaluation.BaseEvaluation
    public int hashCode() {
        int reliabilityDiagNumBins = (((((1 * 59) + getReliabilityDiagNumBins()) * 59) + getHistogramNumBins()) * 59) + (isExcludeEmptyBins() ? 79 : 97);
        INDArray rDiagBinPosCount = getRDiagBinPosCount();
        int hashCode = (reliabilityDiagNumBins * 59) + (rDiagBinPosCount == null ? 43 : rDiagBinPosCount.hashCode());
        INDArray rDiagBinTotalCount = getRDiagBinTotalCount();
        int hashCode2 = (hashCode * 59) + (rDiagBinTotalCount == null ? 43 : rDiagBinTotalCount.hashCode());
        INDArray rDiagBinSumPredictions = getRDiagBinSumPredictions();
        int hashCode3 = (((((hashCode2 * 59) + (rDiagBinSumPredictions == null ? 43 : rDiagBinSumPredictions.hashCode())) * 59) + Arrays.hashCode(getLabelCountsEachClass())) * 59) + Arrays.hashCode(getPredictionCountsEachClass());
        INDArray residualPlotOverall = getResidualPlotOverall();
        int hashCode4 = (hashCode3 * 59) + (residualPlotOverall == null ? 43 : residualPlotOverall.hashCode());
        INDArray residualPlotByLabelClass = getResidualPlotByLabelClass();
        int hashCode5 = (hashCode4 * 59) + (residualPlotByLabelClass == null ? 43 : residualPlotByLabelClass.hashCode());
        INDArray probHistogramOverall = getProbHistogramOverall();
        int hashCode6 = (hashCode5 * 59) + (probHistogramOverall == null ? 43 : probHistogramOverall.hashCode());
        INDArray probHistogramByLabelClass = getProbHistogramByLabelClass();
        return (hashCode6 * 59) + (probHistogramByLabelClass == null ? 43 : probHistogramByLabelClass.hashCode());
    }
}
