package org.nd4j.linalg.dataset.api.preprocessor;

import java.io.File;
import java.io.IOException;
import lombok.NonNull;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DistributionStats;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;

/* loaded from: input_file:org/nd4j/linalg/dataset/api/preprocessor/NormalizerStandardize.class */
public class NormalizerStandardize extends AbstractNormalizerStandardize implements DataNormalization {
    private DistributionStats featureStats;
    private DistributionStats labelStats;
    private boolean fitLabels;

    public NormalizerStandardize() {
        this.fitLabels = false;
    }

    public NormalizerStandardize(INDArray iNDArray, INDArray iNDArray2) {
        this.fitLabels = false;
        this.featureStats = new DistributionStats(iNDArray, iNDArray2);
        this.fitLabels = false;
    }

    public NormalizerStandardize(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, INDArray iNDArray4) {
        this.fitLabels = false;
        this.featureStats = new DistributionStats(iNDArray, iNDArray2);
        this.labelStats = new DistributionStats(iNDArray3, iNDArray4);
        this.fitLabels = true;
    }

    @Override // org.nd4j.linalg.dataset.api.preprocessor.DataNormalization
    public void fitLabel(boolean z) {
        this.fitLabels = z;
    }

    @Override // org.nd4j.linalg.dataset.api.preprocessor.DataNormalization
    public boolean isFitLabel() {
        return this.fitLabels;
    }

    @Override // org.nd4j.linalg.dataset.api.preprocessor.DataNormalization
    public void fit(@NonNull DataSet dataSet) {
        if (dataSet == null) {
            throw new NullPointerException("dataSet");
        }
        this.featureStats = new DistributionStats.Builder().addFeatures(dataSet).build();
        if (this.fitLabels) {
            this.labelStats = new DistributionStats.Builder().addLabels(dataSet).build();
        }
    }

    @Override // org.nd4j.linalg.dataset.api.preprocessor.DataNormalization
    public void fit(@NonNull DataSetIterator dataSetIterator) {
        if (dataSetIterator == null) {
            throw new NullPointerException("iterator");
        }
        DistributionStats.Builder builder = new DistributionStats.Builder();
        DistributionStats.Builder builder2 = new DistributionStats.Builder();
        dataSetIterator.reset();
        while (dataSetIterator.hasNext()) {
            org.nd4j.linalg.dataset.DataSet next = dataSetIterator.next();
            builder.addFeatures(next);
            if (this.fitLabels) {
                builder2.addLabels(next);
            }
        }
        this.featureStats = builder.build();
        if (this.fitLabels) {
            this.labelStats = builder2.build();
        }
        dataSetIterator.reset();
    }

    @Override // org.nd4j.linalg.dataset.api.preprocessor.DataNormalization, org.nd4j.linalg.dataset.api.DataSetPreProcessor
    public void preProcess(@NonNull DataSet dataSet) {
        if (dataSet == null) {
            throw new NullPointerException("toPreProcess");
        }
        assertIsFit();
        preProcess(dataSet.getFeatures(), this.featureStats);
        if (this.fitLabels) {
            preProcess(dataSet.getLabels(), this.labelStats);
        }
    }

    @Override // org.nd4j.linalg.dataset.api.preprocessor.DataNormalization
    public void transform(DataSet dataSet) {
        preProcess(dataSet);
    }

    @Override // org.nd4j.linalg.dataset.api.preprocessor.DataNormalization
    public void transform(INDArray iNDArray) {
        transform(iNDArray, true);
    }

    @Override // org.nd4j.linalg.dataset.api.preprocessor.DataNormalization
    public void transformLabel(INDArray iNDArray) {
        transform(iNDArray, false);
    }

    private void transform(INDArray iNDArray, boolean z) {
        preProcess(iNDArray, z ? this.featureStats : this.labelStats);
    }

    @Override // org.nd4j.linalg.dataset.api.preprocessor.DataNormalization
    public void revert(DataSet dataSet) {
        assertIsFit();
        revert(dataSet.getFeatures(), this.featureStats);
        if (this.fitLabels) {
            revert(dataSet.getLabels(), this.labelStats);
        }
    }

    @Override // org.nd4j.linalg.dataset.api.preprocessor.DataNormalization
    public void revertFeatures(INDArray iNDArray) {
        revert(iNDArray, this.featureStats);
    }

    @Override // org.nd4j.linalg.dataset.api.preprocessor.DataNormalization
    public void revertLabels(INDArray iNDArray) {
        if (this.fitLabels) {
            revert(iNDArray, this.labelStats);
        }
    }

    public INDArray getMean() {
        assertIsFit();
        return this.featureStats.getMean();
    }

    public INDArray getLabelMean() {
        assertIsFit();
        return this.labelStats.getMean();
    }

    public INDArray getStd() {
        assertIsFit();
        return this.featureStats.getStd();
    }

    public INDArray getLabelStd() {
        assertIsFit();
        return this.labelStats.getStd();
    }

    @Override // org.nd4j.linalg.dataset.api.preprocessor.AbstractNormalizerStandardize
    protected boolean isFit() {
        return this.featureStats != null;
    }

    @Override // org.nd4j.linalg.dataset.api.preprocessor.DataNormalization
    public void load(File... fileArr) throws IOException {
        this.featureStats = DistributionStats.load(fileArr[0], fileArr[1]);
        if (this.fitLabels) {
            this.labelStats = DistributionStats.load(fileArr[2], fileArr[3]);
        }
    }

    @Override // org.nd4j.linalg.dataset.api.preprocessor.DataNormalization
    public void save(File... fileArr) throws IOException {
        this.featureStats.save(fileArr[0], fileArr[1]);
        if (this.fitLabels) {
            this.labelStats.save(fileArr[2], fileArr[3]);
        }
    }
}
