package org.nd4j.linalg.dataset.api.preprocessor.classimbalance;

import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.BooleanIndexing;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.indexing.conditions.Conditions;
import org.nd4j.linalg.ops.transforms.Transforms;

/* loaded from: input_file:org/nd4j/linalg/dataset/api/preprocessor/classimbalance/BaseUnderSamplingPreProcessor.class */
public abstract class BaseUnderSamplingPreProcessor {
    protected int tbpttWindowSize;
    private boolean maskAllMajorityWindows = true;
    private boolean donotMaskMinorityWindows = false;

    public void donotMaskAllMajorityWindows() {
        this.maskAllMajorityWindows = false;
    }

    public void donotMaskMinorityWindows() {
        this.donotMaskMinorityWindows = true;
    }

    public INDArray adjustMasks(INDArray iNDArray, INDArray iNDArray2, int i, double d) {
        INDArray iNDArray3;
        if (iNDArray2 == null) {
            iNDArray2 = Nd4j.ones(iNDArray.size(0), iNDArray.size(2));
        }
        validateData(iNDArray, iNDArray2);
        INDArray zeros = Nd4j.zeros(iNDArray2.shape());
        long size = iNDArray.size(2);
        while (true) {
            long j = size;
            if (j <= 0) {
                return Nd4j.getExecutioner().exec(new BernoulliDistribution(Nd4j.createUninitialized(zeros.shape()), zeros), Nd4j.getRandom());
            }
            long max = Math.max(j - this.tbpttWindowSize, 0L);
            INDArray iNDArray4 = zeros.get(NDArrayIndex.all(), NDArrayIndex.interval(max, j));
            INDArray iNDArray5 = iNDArray2.get(NDArrayIndex.all(), NDArrayIndex.interval(max, j));
            if (iNDArray.size(1) == 2) {
                iNDArray3 = iNDArray.get(NDArrayIndex.all(), NDArrayIndex.point(i), NDArrayIndex.interval(max, j));
            } else {
                iNDArray3 = iNDArray.get(NDArrayIndex.all(), NDArrayIndex.point(0L), NDArrayIndex.interval(max, j));
                if (i == 0) {
                    iNDArray3 = Transforms.not(iNDArray3);
                }
            }
            iNDArray4.assign(calculateBernoulli(iNDArray3, iNDArray5, d));
            size = max;
        }
    }

    private INDArray calculateBernoulli(INDArray iNDArray, INDArray iNDArray2, double d) {
        INDArray muli = iNDArray.dup().muli(iNDArray2);
        INDArray muli2 = Transforms.not(iNDArray).muli(iNDArray2);
        if (muli2.sumNumber().intValue() == 0 || (muli.sumNumber().intValue() > 0 && this.donotMaskMinorityWindows)) {
            return iNDArray2;
        }
        if (muli.sumNumber().intValue() == 0 && !this.maskAllMajorityWindows) {
            return iNDArray2.muli(Double.valueOf(1.0d - d));
        }
        INDArray divi = muli.sum(1).div(muli2.sum(1)).muli(Double.valueOf(1.0d - d)).divi(Double.valueOf(d));
        BooleanIndexing.replaceWhere(divi, Double.valueOf(1.0d), Conditions.greaterThan(Double.valueOf(1.0d)));
        return muli2.muliColumnVector(divi).addi(muli);
    }

    private void validateData(INDArray iNDArray, INDArray iNDArray2) {
        if (iNDArray.rank() != 3) {
            throw new IllegalArgumentException("UnderSamplingByMaskingPreProcessor can only be applied to a time series dataset");
        }
        if (iNDArray.size(1) > 2) {
            throw new IllegalArgumentException("UnderSamplingByMaskingPreProcessor can only be applied to labels that represent binary classes. Label size was found to be " + iNDArray.size(1) + ".Expecting size=1 or size=2.");
        }
        if (iNDArray.size(1) == 2 && !iNDArray.sum(1).mul(iNDArray2).equals(iNDArray2)) {
            throw new IllegalArgumentException("Labels of size minibatchx2xtimesteps are expected to be one hot." + iNDArray.toString() + "\n is not one-hot");
        }
    }
}
