package ai.djl.mxnet.engine;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.index.NDArrayIndexer;
import ai.djl.ndarray.index.full.NDIndexFullPick;
import ai.djl.ndarray.index.full.NDIndexFullSlice;
import ai.djl.ndarray.types.Shape;
import java.util.Iterator;
import java.util.Stack;

/* loaded from: input_file:ai/djl/mxnet/engine/MxNDArrayIndexer.class */
public class MxNDArrayIndexer extends NDArrayIndexer {
    private MxNDManager manager;

    /* JADX INFO: Access modifiers changed from: package-private */
    public MxNDArrayIndexer(MxNDManager mxNDManager) {
        this.manager = mxNDManager;
    }

    public NDArray get(NDArray nDArray, NDIndexFullPick nDIndexFullPick) {
        NDArray mo10from = this.manager.mo10from(nDArray);
        MxOpParams mxOpParams = new MxOpParams();
        mxOpParams.addParam("axis", nDIndexFullPick.getAxis());
        mxOpParams.addParam("keepdims", true);
        mxOpParams.add("mode", "wrap");
        return this.manager.invoke("pick", new NDList(new NDArray[]{mo10from, nDIndexFullPick.getIndices()}), mxOpParams).singletonOrThrow();
    }

    public NDArray get(NDArray nDArray, NDIndexFullSlice nDIndexFullSlice) {
        NDArray mo10from = this.manager.mo10from(nDArray);
        MxOpParams mxOpParams = new MxOpParams();
        mxOpParams.addTupleParam("begin", nDIndexFullSlice.getMin());
        mxOpParams.addTupleParam("end", nDIndexFullSlice.getMax());
        mxOpParams.addTupleParam("step", nDIndexFullSlice.getStep());
        NDArray invoke = this.manager.invoke("_npi_slice", mo10from, mxOpParams);
        int[] toSqueeze = nDIndexFullSlice.getToSqueeze();
        if (toSqueeze.length > 0) {
            invoke = invoke.squeeze(toSqueeze);
            invoke.close();
        }
        return invoke;
    }

    public void set(NDArray nDArray, NDIndexFullSlice nDIndexFullSlice, NDArray nDArray2) {
        Shape shape;
        NDArray mo10from = this.manager.mo10from(nDArray);
        MxOpParams mxOpParams = new MxOpParams();
        mxOpParams.addTupleParam("begin", nDIndexFullSlice.getMin());
        mxOpParams.addTupleParam("end", nDIndexFullSlice.getMax());
        mxOpParams.addTupleParam("step", nDIndexFullSlice.getStep());
        Stack stack = new Stack();
        stack.add(nDArray2);
        stack.add(((NDArray) stack.peek()).toDevice(mo10from.getDevice(), false));
        Shape shape2 = nDIndexFullSlice.getShape();
        while (true) {
            shape = shape2;
            if (shape.size() <= nDArray2.size()) {
                break;
            } else {
                shape2 = shape.slice(1);
            }
        }
        stack.add(((NDArray) stack.peek()).reshape(shape));
        stack.add(((NDArray) stack.peek()).broadcast(nDIndexFullSlice.getShape()));
        this.manager.invoke("_npi_slice_assign", new NDArray[]{mo10from, (NDArray) stack.peek()}, new NDArray[]{mo10from}, mxOpParams);
        Iterator it = stack.iterator();
        while (it.hasNext()) {
            NDArray nDArray3 = (NDArray) it.next();
            if (nDArray3 != nDArray2) {
                nDArray3.close();
            }
        }
    }

    public void set(NDArray nDArray, NDIndexFullSlice nDIndexFullSlice, Number number) {
        NDArray mo10from = this.manager.mo10from(nDArray);
        MxOpParams mxOpParams = new MxOpParams();
        mxOpParams.addTupleParam("begin", nDIndexFullSlice.getMin());
        mxOpParams.addTupleParam("end", nDIndexFullSlice.getMax());
        mxOpParams.addTupleParam("step", nDIndexFullSlice.getStep());
        mxOpParams.addParam("scalar", number);
        this.manager.invoke("_npi_slice_assign_scalar", new NDArray[]{mo10from}, new NDArray[]{mo10from}, mxOpParams);
    }
}
