package org.nd4j.linalg.dataset.api.preprocessor;

import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.BroadcastOp;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastAddOp;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastDivOp;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastMulOp;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastSubOp;
import org.nd4j.linalg.dataset.DistributionStats;
import org.nd4j.linalg.factory.Nd4j;

/* JADX INFO: Access modifiers changed from: package-private */
/* loaded from: input_file:org/nd4j/linalg/dataset/api/preprocessor/AbstractNormalizerStandardize.class */
public abstract class AbstractNormalizerStandardize {
    /* JADX INFO: Access modifiers changed from: package-private */
    public void assertIsFit() {
        if (!isFit()) {
            throw new RuntimeException("API_USE_ERROR: Preprocessors have to be explicitly fit before use. Usage: .fit(dataset) or .fit(datasetiterator)");
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void preProcess(INDArray iNDArray, DistributionStats distributionStats) {
        if (iNDArray.rank() == 2) {
            iNDArray.subiRowVector(distributionStats.getMean());
            iNDArray.diviRowVector(distributionStats.getStd());
        } else {
            Nd4j.getExecutioner().execAndReturn((BroadcastOp) new BroadcastSubOp(iNDArray, distributionStats.getMean(), iNDArray, 1));
            Nd4j.getExecutioner().execAndReturn((BroadcastOp) new BroadcastDivOp(iNDArray, distributionStats.getStd(), iNDArray, 1));
        }
    }

    protected abstract boolean isFit();

    /* JADX INFO: Access modifiers changed from: package-private */
    public void revert(INDArray iNDArray, DistributionStats distributionStats) {
        if (iNDArray.rank() == 2) {
            iNDArray.muliRowVector(distributionStats.getStd());
            iNDArray.addiRowVector(distributionStats.getMean());
        } else {
            Nd4j.getExecutioner().execAndReturn((BroadcastOp) new BroadcastMulOp(iNDArray, distributionStats.getStd(), iNDArray, 1));
            Nd4j.getExecutioner().execAndReturn((BroadcastOp) new BroadcastAddOp(iNDArray, distributionStats.getMean(), iNDArray, 1));
        }
    }
}
