package org.nd4j.linalg.dataset;

import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import lombok.NonNull;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.DataSetUtil;
import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/nd4j/linalg/dataset/DistributionStats.class */
public class DistributionStats {
    private static final Logger logger = LoggerFactory.getLogger(NormalizerStandardize.class);
    private final INDArray mean;
    private final INDArray std;

    /* loaded from: input_file:org/nd4j/linalg/dataset/DistributionStats$Builder.class */
    public static class Builder {
        private int runningCount = 0;
        private INDArray runningMean;
        private INDArray runningVariance;

        public Builder addFeatures(@NonNull org.nd4j.linalg.dataset.api.DataSet dataSet) {
            if (dataSet == null) {
                throw new NullPointerException("dataSet");
            }
            return add(dataSet.getFeatures(), dataSet.getFeaturesMaskArray());
        }

        public Builder addLabels(@NonNull org.nd4j.linalg.dataset.api.DataSet dataSet) {
            if (dataSet == null) {
                throw new NullPointerException("dataSet");
            }
            return add(dataSet.getLabels(), dataSet.getLabelsMaskArray());
        }

        public Builder add(@NonNull INDArray iNDArray, INDArray iNDArray2) {
            if (iNDArray == null) {
                throw new NullPointerException("data");
            }
            INDArray tailor2d = DataSetUtil.tailor2d(iNDArray, iNDArray2);
            INDArray mean = tailor2d.mean(0);
            INDArray var = tailor2d.var(false, 0);
            int size = tailor2d.size(0);
            if (this.runningMean == null) {
                this.runningMean = mean;
                this.runningVariance = var;
                this.runningCount = size;
            } else {
                this.runningVariance.muli(Integer.valueOf(this.runningCount)).addiRowVector(var.muli(Integer.valueOf(size))).addiRowVector(Transforms.pow(mean.subRowVector(this.runningMean), 2).muli(Float.valueOf((this.runningCount * size) / (this.runningCount + size)))).divi(Integer.valueOf(this.runningCount + size));
                this.runningCount += size;
                this.runningMean.addi(tailor2d.subRowVector(this.runningMean).sum(0).divi(Integer.valueOf(this.runningCount)));
            }
            return this;
        }

        public DistributionStats build() {
            return new DistributionStats(this.runningMean.dup(), Transforms.sqrt(this.runningVariance, true));
        }

        public static List<DistributionStats> buildList(@NonNull List<Builder> list) {
            if (list == null) {
                throw new NullPointerException("builders");
            }
            ArrayList arrayList = new ArrayList(list.size());
            Iterator<Builder> it = list.iterator();
            while (it.hasNext()) {
                arrayList.add(it.next().build());
            }
            return arrayList;
        }
    }

    public DistributionStats(@NonNull INDArray iNDArray, @NonNull INDArray iNDArray2) {
        if (iNDArray == null) {
            throw new NullPointerException("mean");
        }
        if (iNDArray2 == null) {
            throw new NullPointerException("std");
        }
        Transforms.max(iNDArray2, Nd4j.EPS_THRESHOLD, false);
        if (iNDArray2.min(1) == Nd4j.scalar(Nd4j.EPS_THRESHOLD)) {
            logger.info("API_INFO: Std deviation found to be zero. Transform will round up to epsilon to avoid nans.");
        }
        this.mean = iNDArray;
        this.std = iNDArray2;
    }

    public static DistributionStats load(@NonNull File file, @NonNull File file2) throws IOException {
        if (file == null) {
            throw new NullPointerException("meanFile");
        }
        if (file2 == null) {
            throw new NullPointerException("stdFile");
        }
        return new DistributionStats(Nd4j.readBinary(file), Nd4j.readBinary(file2));
    }

    public void save(@NonNull File file, @NonNull File file2) throws IOException {
        if (file == null) {
            throw new NullPointerException("meanFile");
        }
        if (file2 == null) {
            throw new NullPointerException("stdFile");
        }
        Nd4j.saveBinary(getMean(), file);
        Nd4j.saveBinary(getStd(), file2);
    }

    public INDArray getMean() {
        return this.mean;
    }

    public INDArray getStd() {
        return this.std;
    }
}
