package org.nd4j.linalg.dataset.api.preprocessor;

import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import lombok.NonNull;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DistributionStats;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.dataset.api.MultiDataSetPreProcessor;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;

/* loaded from: input_file:org/nd4j/linalg/dataset/api/preprocessor/MultiNormalizerStandardize.class */
public class MultiNormalizerStandardize extends AbstractNormalizerStandardize implements MultiDataSetPreProcessor {
    private List<DistributionStats> featureStats;
    private List<DistributionStats> labelStats;
    private boolean fitLabels = false;

    public void fitLabel(boolean z) {
        this.fitLabels = z;
    }

    public boolean isFitLabel() {
        return this.fitLabels;
    }

    public void fit(@NonNull MultiDataSet multiDataSet) {
        if (multiDataSet == null) {
            throw new NullPointerException("dataSet");
        }
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        fitPartial(multiDataSet, arrayList, arrayList2);
        this.featureStats = DistributionStats.Builder.buildList(arrayList);
        if (this.fitLabels) {
            this.labelStats = DistributionStats.Builder.buildList(arrayList2);
        }
    }

    public void fit(@NonNull MultiDataSetIterator multiDataSetIterator) {
        if (multiDataSetIterator == null) {
            throw new NullPointerException("iterator");
        }
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        multiDataSetIterator.reset();
        while (multiDataSetIterator.hasNext()) {
            fitPartial(multiDataSetIterator.next(), arrayList, arrayList2);
        }
        this.featureStats = DistributionStats.Builder.buildList(arrayList);
        if (this.fitLabels) {
            this.labelStats = DistributionStats.Builder.buildList(arrayList2);
        }
    }

    private void fitPartial(MultiDataSet multiDataSet, List<DistributionStats.Builder> list, List<DistributionStats.Builder> list2) {
        int length = multiDataSet.getFeatures().length;
        int length2 = multiDataSet.getLabels().length;
        ensureStatsBuilders(list, length);
        ensureStatsBuilders(list2, length2);
        for (int i = 0; i < length; i++) {
            list.get(i).add(multiDataSet.getFeatures(i), multiDataSet.getFeaturesMaskArray(i));
        }
        if (this.fitLabels) {
            for (int i2 = 0; i2 < length2; i2++) {
                list2.get(i2).add(multiDataSet.getLabels(i2), multiDataSet.getLabelsMaskArray(i2));
            }
        }
    }

    private void ensureStatsBuilders(List<DistributionStats.Builder> list, int i) {
        if (list.isEmpty()) {
            for (int i2 = 0; i2 < i; i2++) {
                list.add(new DistributionStats.Builder());
            }
        }
    }

    @Override // org.nd4j.linalg.dataset.api.MultiDataSetPreProcessor
    public void preProcess(@NonNull MultiDataSet multiDataSet) {
        if (multiDataSet == null) {
            throw new NullPointerException("toPreProcess");
        }
        assertIsFit();
        int length = multiDataSet.getFeatures().length;
        int length2 = multiDataSet.getLabels().length;
        for (int i = 0; i < length; i++) {
            preProcess(multiDataSet.getFeatures(i), this.featureStats.get(i));
        }
        if (this.fitLabels) {
            for (int i2 = 0; i2 < length2; i2++) {
                preProcess(multiDataSet.getLabels(i2), this.labelStats.get(i2));
            }
        }
    }

    public void revert(@NonNull MultiDataSet multiDataSet) {
        if (multiDataSet == null) {
            throw new NullPointerException("data");
        }
        assertIsFit();
        INDArray[] features = multiDataSet.getFeatures();
        for (int i = 0; i < features.length; i++) {
            revert(features[i], this.featureStats.get(i));
        }
        if (this.fitLabels) {
            INDArray[] labels = multiDataSet.getLabels();
            for (int i2 = 0; i2 < labels.length; i2++) {
                revert(labels[i2], this.labelStats.get(i2));
            }
        }
    }

    public INDArray getFeatureMean(int i) {
        assertIsFit();
        return this.featureStats.get(i).getMean();
    }

    public INDArray getLabelMean(int i) {
        assertIsFit();
        return this.labelStats.get(i).getMean();
    }

    public INDArray getFeatureStd(int i) {
        assertIsFit();
        return this.featureStats.get(i).getStd();
    }

    public INDArray getLabelStd(int i) {
        assertIsFit();
        return this.labelStats.get(i).getStd();
    }

    @Override // org.nd4j.linalg.dataset.api.preprocessor.AbstractNormalizerStandardize
    protected boolean isFit() {
        return this.featureStats != null;
    }

    public void load(@NonNull List<File> list, @NonNull List<File> list2) throws IOException {
        if (list == null) {
            throw new NullPointerException("featureFiles");
        }
        if (list2 == null) {
            throw new NullPointerException("labelFiles");
        }
        this.featureStats = load(list);
        if (this.fitLabels) {
            this.labelStats = load(list2);
        }
    }

    private List<DistributionStats> load(List<File> list) throws IOException {
        ArrayList arrayList = new ArrayList(list.size() / 2);
        for (int i = 0; i < list.size() / 2; i++) {
            arrayList.add(DistributionStats.load(list.get(i * 2), list.get((i * 2) + 1)));
        }
        return arrayList;
    }

    public void save(@NonNull List<File> list, @NonNull List<File> list2) throws IOException {
        if (list == null) {
            throw new NullPointerException("featureFiles");
        }
        if (list2 == null) {
            throw new NullPointerException("labelFiles");
        }
        saveStats(this.featureStats, list);
        if (this.fitLabels) {
            saveStats(this.labelStats, list2);
        }
    }

    private void saveStats(List<DistributionStats> list, List<File> list2) throws IOException {
        int size = list.size() * 2;
        if (size != list2.size()) {
            throw new RuntimeException(String.format("Need twice as many files as inputs / outputs (%d), got %d", Integer.valueOf(size), Integer.valueOf(list2.size())));
        }
        for (int i = 0; i < list.size(); i++) {
            list.get(i).save(list2.get(i * 2), list2.get((i * 2) + 1));
        }
    }
}
