package org.nd4j.linalg.api.ops.impl.transforms;

import java.util.HashSet;
import java.util.Set;
import org.nd4j.linalg.api.complex.IComplexNumber;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.BaseTransformOp;
import org.nd4j.linalg.api.ops.Op;

/* loaded from: input_file:org/nd4j/linalg/api/ops/impl/transforms/LinearIndex.class */
public class LinearIndex extends BaseTransformOp {
    private int internalCount;
    private int[] indices;
    private boolean wholeArray;
    private Set<Integer> encountered;

    public LinearIndex() {
        this.internalCount = 0;
        this.wholeArray = false;
        this.encountered = new HashSet();
    }

    public LinearIndex(INDArray iNDArray) {
        this(iNDArray, true);
    }

    public LinearIndex(INDArray iNDArray, INDArray iNDArray2, boolean z) {
        super(iNDArray, iNDArray2);
        this.internalCount = 0;
        this.wholeArray = false;
        this.encountered = new HashSet();
        this.wholeArray = z;
        initIndexesIfNecessary();
    }

    public LinearIndex(INDArray iNDArray, INDArray iNDArray2, int i, boolean z) {
        super(iNDArray, iNDArray2, i);
        this.internalCount = 0;
        this.wholeArray = false;
        this.encountered = new HashSet();
        this.wholeArray = z;
        initIndexesIfNecessary();
    }

    public LinearIndex(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, int i, boolean z) {
        super(iNDArray, iNDArray2, iNDArray3, i);
        this.internalCount = 0;
        this.wholeArray = false;
        this.encountered = new HashSet();
        this.wholeArray = z;
        initIndexesIfNecessary();
    }

    public LinearIndex(INDArray iNDArray, boolean z) {
        super(iNDArray);
        this.internalCount = 0;
        this.wholeArray = false;
        this.encountered = new HashSet();
        this.wholeArray = z;
        initIndexesIfNecessary();
    }

    @Override // org.nd4j.linalg.api.ops.Op
    public String name() {
        return "linearindex";
    }

    @Override // org.nd4j.linalg.api.ops.Op
    public IComplexNumber op(IComplexNumber iComplexNumber, double d) {
        addToIndex();
        return iComplexNumber;
    }

    @Override // org.nd4j.linalg.api.ops.Op
    public IComplexNumber op(IComplexNumber iComplexNumber, float f) {
        addToIndex();
        return iComplexNumber;
    }

    @Override // org.nd4j.linalg.api.ops.Op
    public IComplexNumber op(IComplexNumber iComplexNumber, IComplexNumber iComplexNumber2) {
        addToIndex();
        return iComplexNumber;
    }

    @Override // org.nd4j.linalg.api.ops.Op
    public float op(float f, float f2) {
        addToIndex();
        return f;
    }

    @Override // org.nd4j.linalg.api.ops.Op
    public double op(double d, double d2) {
        addToIndex();
        return d;
    }

    @Override // org.nd4j.linalg.api.ops.Op
    public double op(double d) {
        addToIndex();
        return d;
    }

    @Override // org.nd4j.linalg.api.ops.Op
    public float op(float f) {
        addToIndex();
        return f;
    }

    @Override // org.nd4j.linalg.api.ops.Op
    public IComplexNumber op(IComplexNumber iComplexNumber) {
        addToIndex();
        return iComplexNumber;
    }

    private void addToIndex() {
        if (this.wholeArray) {
            int linearIndex = getLinearIndex();
            if (this.encountered.contains(Integer.valueOf(linearIndex))) {
                throw new IllegalStateException("Please checking striding. Index: " + linearIndex + " already encountered ");
            }
            this.encountered.add(Integer.valueOf(linearIndex));
            this.indices[this.internalCount] = linearIndex;
            this.internalCount++;
            this.numProcessed++;
        }
    }

    private int getLinearIndex() {
        return this.x.linearIndex(this.numProcessed);
    }

    private void initIndexesIfNecessary() {
        if (this.wholeArray) {
            this.indices = new int[this.x.length()];
        }
    }

    public int[] getIndices() {
        return this.indices;
    }

    @Override // org.nd4j.linalg.api.ops.BaseOp, org.nd4j.linalg.api.ops.Op
    public void exec() {
        for (int i = 0; i < this.x.length(); i++) {
            addToIndex();
        }
    }

    @Override // org.nd4j.linalg.api.ops.BaseOp, org.nd4j.linalg.api.ops.Op
    public boolean isPassThrough() {
        return true;
    }

    @Override // org.nd4j.linalg.api.ops.Op
    public Op opForDimension(int i, int i2) {
        INDArray vectorAlongDimension = this.x.vectorAlongDimension(i, i2);
        return y() != null ? new LinearIndex(vectorAlongDimension, this.y.vectorAlongDimension(i, i2), this.z.vectorAlongDimension(i, i2), vectorAlongDimension.length(), false) : new LinearIndex(vectorAlongDimension, this.z.vectorAlongDimension(i, i2), this.x.length(), false);
    }
}
