package org.nd4j.linalg.dataset.api.preprocessor;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import lombok.NonNull;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
import org.nd4j.linalg.dataset.api.preprocessor.stats.NormalizerStats;

/* loaded from: input_file:org/nd4j/linalg/dataset/api/preprocessor/AbstractMultiDataSetNormalizer.class */
abstract class AbstractMultiDataSetNormalizer<S extends NormalizerStats> extends AbstractNormalizer<S> implements MultiDataNormalization, Serializable {
    private List<S> featureStats;
    private List<S> labelStats;
    private boolean fitLabels;

    protected AbstractMultiDataSetNormalizer() {
        this.fitLabels = false;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public AbstractMultiDataSetNormalizer(NormalizerStrategy<S> normalizerStrategy) {
        super(normalizerStrategy);
        this.fitLabels = false;
    }

    @Override // org.nd4j.linalg.dataset.api.preprocessor.MultiDataNormalization
    public void fitLabel(boolean z) {
        this.fitLabels = z;
    }

    @Override // org.nd4j.linalg.dataset.api.preprocessor.MultiDataNormalization
    public boolean isFitLabel() {
        return this.fitLabels;
    }

    @Override // org.nd4j.linalg.dataset.api.preprocessor.AbstractNormalizer
    protected boolean isFit() {
        return this.featureStats != null;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public S getFeatureStats(int i) {
        return getFeatureStats().get(i);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public List<S> getFeatureStats() {
        assertIsFit();
        return this.featureStats;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public S getLabelStats(int i) {
        return getLabelStats().get(i);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public List<S> getLabelStats() {
        assertIsFit();
        return this.labelStats;
    }

    @Override // org.nd4j.linalg.dataset.api.preprocessor.MultiDataNormalization
    public void fit(@NonNull MultiDataSet multiDataSet) {
        if (multiDataSet == null) {
            throw new NullPointerException("dataSet");
        }
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        fitPartial(multiDataSet, arrayList, arrayList2);
        this.featureStats = buildList(arrayList);
        if (isFitLabel()) {
            this.labelStats = buildList(arrayList2);
        }
    }

    @Override // org.nd4j.linalg.dataset.api.preprocessor.MultiDataNormalization
    public void fit(@NonNull MultiDataSetIterator multiDataSetIterator) {
        if (multiDataSetIterator == null) {
            throw new NullPointerException("iterator");
        }
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        multiDataSetIterator.reset();
        while (multiDataSetIterator.hasNext()) {
            fitPartial(multiDataSetIterator.next(), arrayList, arrayList2);
        }
        this.featureStats = buildList(arrayList);
        if (isFitLabel()) {
            this.labelStats = buildList(arrayList2);
        }
    }

    private List<S> buildList(@NonNull List<NormalizerStats.Builder> list) {
        if (list == null) {
            throw new NullPointerException("builders");
        }
        ArrayList arrayList = new ArrayList(list.size());
        Iterator<NormalizerStats.Builder> it = list.iterator();
        while (it.hasNext()) {
            arrayList.add(it.next().build());
        }
        return arrayList;
    }

    private void fitPartial(MultiDataSet multiDataSet, List<NormalizerStats.Builder> list, List<NormalizerStats.Builder> list2) {
        int numFeatureArrays = multiDataSet.numFeatureArrays();
        int numLabelsArrays = multiDataSet.numLabelsArrays();
        ensureStatsBuilders(list, numFeatureArrays);
        ensureStatsBuilders(list2, numLabelsArrays);
        for (int i = 0; i < numFeatureArrays; i++) {
            list.get(i).add(multiDataSet.getFeatures(i), multiDataSet.getFeaturesMaskArray(i));
        }
        if (isFitLabel()) {
            for (int i2 = 0; i2 < numLabelsArrays; i2++) {
                list2.get(i2).add(multiDataSet.getLabels(i2), multiDataSet.getLabelsMaskArray(i2));
            }
        }
    }

    private void ensureStatsBuilders(List<NormalizerStats.Builder> list, int i) {
        if (list.isEmpty()) {
            for (int i2 = 0; i2 < i; i2++) {
                list.add(newBuilder());
            }
        }
    }

    protected abstract NormalizerStats.Builder newBuilder();

    @Override // org.nd4j.linalg.dataset.api.preprocessor.MultiDataNormalization
    public void transform(@NonNull MultiDataSet multiDataSet) {
        if (multiDataSet == null) {
            throw new NullPointerException("toPreProcess");
        }
        preProcess(multiDataSet);
    }

    @Override // org.nd4j.linalg.dataset.api.preprocessor.MultiDataNormalization, org.nd4j.linalg.dataset.api.MultiDataSetPreProcessor
    public void preProcess(@NonNull MultiDataSet multiDataSet) {
        if (multiDataSet == null) {
            throw new NullPointerException("toPreProcess");
        }
        int numFeatureArrays = multiDataSet.numFeatureArrays();
        int numLabelsArrays = multiDataSet.numLabelsArrays();
        for (int i = 0; i < numFeatureArrays; i++) {
            this.strategy.preProcess(multiDataSet.getFeatures(i), multiDataSet.getFeaturesMaskArray(i), getFeatureStats(i));
        }
        if (isFitLabel()) {
            for (int i2 = 0; i2 < numLabelsArrays; i2++) {
                this.strategy.preProcess(multiDataSet.getLabels(i2), multiDataSet.getLabelsMaskArray(i2), getLabelStats(i2));
            }
        }
    }

    @Override // org.nd4j.linalg.dataset.api.preprocessor.MultiDataNormalization
    public void revert(@NonNull MultiDataSet multiDataSet) {
        if (multiDataSet == null) {
            throw new NullPointerException("data");
        }
        revertFeatures(multiDataSet.getFeatures(), multiDataSet.getFeaturesMaskArrays());
        revertLabels(multiDataSet.getLabels(), multiDataSet.getLabelsMaskArrays());
    }

    @Override // org.nd4j.linalg.dataset.api.preprocessor.MultiDataNormalization
    public void revertFeatures(@NonNull INDArray[] iNDArrayArr) {
        if (iNDArrayArr == null) {
            throw new NullPointerException("features");
        }
        revertFeatures(iNDArrayArr, null);
    }

    @Override // org.nd4j.linalg.dataset.api.preprocessor.MultiDataNormalization
    public void revertFeatures(@NonNull INDArray[] iNDArrayArr, INDArray[] iNDArrayArr2) {
        if (iNDArrayArr == null) {
            throw new NullPointerException("features");
        }
        for (int i = 0; i < iNDArrayArr.length; i++) {
            revertFeatures(iNDArrayArr[i], iNDArrayArr2 == null ? null : iNDArrayArr2[i], i);
        }
    }

    public void revertFeatures(@NonNull INDArray iNDArray, INDArray iNDArray2, int i) {
        if (iNDArray == null) {
            throw new NullPointerException("features");
        }
        this.strategy.revert(iNDArray, iNDArray2, getFeatureStats(i));
    }

    @Override // org.nd4j.linalg.dataset.api.preprocessor.MultiDataNormalization
    public void revertLabels(INDArray[] iNDArrayArr) {
        revertLabels(iNDArrayArr, null);
    }

    @Override // org.nd4j.linalg.dataset.api.preprocessor.MultiDataNormalization
    public void revertLabels(@NonNull INDArray[] iNDArrayArr, INDArray[] iNDArrayArr2) {
        if (iNDArrayArr == null) {
            throw new NullPointerException("labels");
        }
        for (int i = 0; i < iNDArrayArr.length; i++) {
            revertLabels(iNDArrayArr[i], iNDArrayArr2 == null ? null : iNDArrayArr2[i], i);
        }
    }

    public void revertLabels(@NonNull INDArray iNDArray, INDArray iNDArray2, int i) {
        if (iNDArray == null) {
            throw new NullPointerException("labels");
        }
        if (isFitLabel()) {
            this.strategy.revert(iNDArray, iNDArray2, getLabelStats(i));
        }
    }

    public int numInputs() {
        return getFeatureStats().size();
    }

    public int numOutputs() {
        return getLabelStats().size();
    }

    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof AbstractMultiDataSetNormalizer)) {
            return false;
        }
        AbstractMultiDataSetNormalizer abstractMultiDataSetNormalizer = (AbstractMultiDataSetNormalizer) obj;
        if (!abstractMultiDataSetNormalizer.canEqual(this)) {
            return false;
        }
        List<S> featureStats = getFeatureStats();
        List<S> featureStats2 = abstractMultiDataSetNormalizer.getFeatureStats();
        if (featureStats == null) {
            if (featureStats2 != null) {
                return false;
            }
        } else if (!featureStats.equals(featureStats2)) {
            return false;
        }
        List<S> labelStats = getLabelStats();
        List<S> labelStats2 = abstractMultiDataSetNormalizer.getLabelStats();
        if (labelStats == null) {
            if (labelStats2 != null) {
                return false;
            }
        } else if (!labelStats.equals(labelStats2)) {
            return false;
        }
        return this.fitLabels == abstractMultiDataSetNormalizer.fitLabels;
    }

    protected boolean canEqual(Object obj) {
        return obj instanceof AbstractMultiDataSetNormalizer;
    }

    public int hashCode() {
        List<S> featureStats = getFeatureStats();
        int hashCode = (1 * 59) + (featureStats == null ? 43 : featureStats.hashCode());
        List<S> labelStats = getLabelStats();
        return (((hashCode * 59) + (labelStats == null ? 43 : labelStats.hashCode())) * 59) + (this.fitLabels ? 79 : 97);
    }

    public void setFeatureStats(List<S> list) {
        this.featureStats = list;
    }

    public void setLabelStats(List<S> list) {
        this.labelStats = list;
    }
}
