package org.nd4j.linalg.dataset.api.preprocessor;

import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.BroadcastOp;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastAddOp;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastDivOp;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastMulOp;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastSubOp;
import org.nd4j.linalg.dataset.api.DataSetUtil;
import org.nd4j.linalg.dataset.api.preprocessor.stats.DistributionStats;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.BooleanIndexing;
import org.nd4j.linalg.indexing.conditions.Conditions;

/* loaded from: input_file:org/nd4j/linalg/dataset/api/preprocessor/StandardizeStrategy.class */
public class StandardizeStrategy implements NormalizerStrategy<DistributionStats> {
    @Override // org.nd4j.linalg.dataset.api.preprocessor.NormalizerStrategy
    public void preProcess(INDArray iNDArray, INDArray iNDArray2, DistributionStats distributionStats) {
        if (iNDArray.rank() <= 2) {
            iNDArray.subiRowVector(distributionStats.getMean());
            iNDArray.diviRowVector(filteredStd(distributionStats));
        } else {
            Nd4j.getExecutioner().execAndReturn((BroadcastOp) new BroadcastSubOp(iNDArray, distributionStats.getMean(), iNDArray, 1));
            Nd4j.getExecutioner().execAndReturn((BroadcastOp) new BroadcastDivOp(iNDArray, filteredStd(distributionStats), iNDArray, 1));
        }
        if (iNDArray2 != null) {
            DataSetUtil.setMaskedValuesToZero(iNDArray, iNDArray2);
        }
    }

    @Override // org.nd4j.linalg.dataset.api.preprocessor.NormalizerStrategy
    public void revert(INDArray iNDArray, INDArray iNDArray2, DistributionStats distributionStats) {
        if (iNDArray.rank() <= 2) {
            iNDArray.muliRowVector(filteredStd(distributionStats));
            iNDArray.addiRowVector(distributionStats.getMean());
        } else {
            Nd4j.getExecutioner().execAndReturn((BroadcastOp) new BroadcastMulOp(iNDArray, filteredStd(distributionStats), iNDArray, 1));
            Nd4j.getExecutioner().execAndReturn((BroadcastOp) new BroadcastAddOp(iNDArray, distributionStats.getMean(), iNDArray, 1));
        }
        if (iNDArray2 != null) {
            DataSetUtil.setMaskedValuesToZero(iNDArray, iNDArray2);
        }
    }

    private static INDArray filteredStd(DistributionStats distributionStats) {
        INDArray std = distributionStats.getStd();
        BooleanIndexing.replaceWhere(std, Double.valueOf(1.0d), Conditions.equals((Number) 0));
        return std;
    }
}
