package org.nd4j.linalg.api.rng.distribution.impl;

import org.apache.commons.math3.exception.NotPositiveException;
import org.apache.commons.math3.exception.NumberIsTooLargeException;
import org.apache.commons.math3.exception.OutOfRangeException;
import org.apache.commons.math3.exception.util.LocalizedFormats;
import org.apache.commons.math3.special.Beta;
import org.apache.commons.math3.util.FastMath;
import org.nd4j.linalg.api.iter.NdIndexIterator;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.random.impl.BinomialDistributionEx;
import org.nd4j.linalg.api.rng.Random;
import org.nd4j.linalg.api.rng.distribution.BaseDistribution;
import org.nd4j.linalg.factory.Nd4j;

/* loaded from: input_file:org/nd4j/linalg/api/rng/distribution/impl/BinomialDistribution.class */
public class BinomialDistribution extends BaseDistribution {
    private final int numberOfTrials;
    private double probabilityOfSuccess;
    private INDArray p;

    public BinomialDistribution(int i, double d) {
        this(Nd4j.getRandom(), i, d);
    }

    public BinomialDistribution(Random random, int i, double d) {
        super(random);
        if (i < 0) {
            throw new NotPositiveException(LocalizedFormats.NUMBER_OF_TRIALS, Integer.valueOf(i));
        }
        if (d < 0.0d || d > 1.0d) {
            throw new OutOfRangeException(Double.valueOf(d), 0, 1);
        }
        this.probabilityOfSuccess = d;
        this.numberOfTrials = i;
    }

    public BinomialDistribution(int i, INDArray iNDArray) {
        this.random = Nd4j.getRandom();
        this.numberOfTrials = i;
        this.p = iNDArray;
    }

    public int getNumberOfTrials() {
        return this.numberOfTrials;
    }

    public double getProbabilityOfSuccess() {
        return this.probabilityOfSuccess;
    }

    public double probability(int i) {
        return (i < 0 || i > this.numberOfTrials) ? 0.0d : FastMath.exp(SaddlePointExpansion.logBinomialProbability(i, this.numberOfTrials, this.probabilityOfSuccess, 1.0d - this.probabilityOfSuccess));
    }

    public double cumulativeProbability(int i) {
        return i < 0 ? 0.0d : i >= this.numberOfTrials ? 1.0d : 1.0d - Beta.regularizedBeta(this.probabilityOfSuccess, i + 1.0d, this.numberOfTrials - i);
    }

    @Override // org.nd4j.linalg.api.rng.distribution.Distribution
    public double density(double d) {
        return 0.0d;
    }

    @Override // org.nd4j.linalg.api.rng.distribution.Distribution
    public double cumulativeProbability(double d) {
        return d < 0.0d ? 0.0d : d >= ((double) this.numberOfTrials) ? 1.0d : 1.0d - Beta.regularizedBeta(this.probabilityOfSuccess, d + 1.0d, this.numberOfTrials - d);
    }

    @Override // org.nd4j.linalg.api.rng.distribution.Distribution
    public double cumulativeProbability(double d, double d2) throws NumberIsTooLargeException {
        return 0.0d;
    }

    @Override // org.nd4j.linalg.api.rng.distribution.Distribution
    public double getNumericalMean() {
        return this.numberOfTrials * this.probabilityOfSuccess;
    }

    @Override // org.nd4j.linalg.api.rng.distribution.Distribution
    public double getNumericalVariance() {
        double d = this.probabilityOfSuccess;
        return this.numberOfTrials * d * (1.0d - d);
    }

    @Override // org.nd4j.linalg.api.rng.distribution.Distribution
    public double getSupportLowerBound() {
        if (this.probabilityOfSuccess < 1.0d) {
            return 0.0d;
        }
        return this.numberOfTrials;
    }

    @Override // org.nd4j.linalg.api.rng.distribution.Distribution
    public double getSupportUpperBound() {
        if (this.probabilityOfSuccess > 0.0d) {
            return this.numberOfTrials;
        }
        return 0.0d;
    }

    @Override // org.nd4j.linalg.api.rng.distribution.Distribution
    public boolean isSupportLowerBoundInclusive() {
        return false;
    }

    @Override // org.nd4j.linalg.api.rng.distribution.Distribution
    public boolean isSupportUpperBoundInclusive() {
        return false;
    }

    @Override // org.nd4j.linalg.api.rng.distribution.Distribution
    public boolean isSupportConnected() {
        return true;
    }

    private void ensureConsistent(int i) {
        this.probabilityOfSuccess = this.p.reshape(-1).getDouble(i);
    }

    @Override // org.nd4j.linalg.api.rng.distribution.BaseDistribution, org.nd4j.linalg.api.rng.distribution.Distribution
    public INDArray sample(int[] iArr) {
        return sample(Nd4j.createUninitialized(iArr, Nd4j.order().charValue()));
    }

    @Override // org.nd4j.linalg.api.rng.distribution.BaseDistribution, org.nd4j.linalg.api.rng.distribution.Distribution
    public INDArray sample(INDArray iNDArray) {
        if (this.random.getStatePointer() != null) {
            return this.p != null ? Nd4j.getExecutioner().exec(new BinomialDistributionEx(iNDArray, this.numberOfTrials, this.p), this.random) : Nd4j.getExecutioner().exec(new BinomialDistributionEx(iNDArray, this.numberOfTrials, this.probabilityOfSuccess), this.random);
        }
        NdIndexIterator ndIndexIterator = new NdIndexIterator(iNDArray.shape());
        long length = iNDArray.length();
        if (this.p != null) {
            for (int i = 0; i < length; i++) {
                long[] next = ndIndexIterator.next();
                iNDArray.putScalar(next, new org.apache.commons.math3.distribution.BinomialDistribution(Nd4j.getRandom(), this.numberOfTrials, this.p.getDouble(next)).sample());
            }
        } else {
            org.apache.commons.math3.distribution.BinomialDistribution binomialDistribution = new org.apache.commons.math3.distribution.BinomialDistribution(Nd4j.getRandom(), this.numberOfTrials, this.probabilityOfSuccess);
            for (int i2 = 0; i2 < length; i2++) {
                iNDArray.putScalar(ndIndexIterator.next(), binomialDistribution.sample());
            }
        }
        return iNDArray;
    }
}
