package org.nd4j.linalg.api.ops.impl.shape;

import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.common.base.Preconditions;
import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef;
import org.tensorflow.framework.NodeDef;

/* loaded from: input_file:org/nd4j/linalg/api/ops/impl/shape/Create.class */
public class Create extends DynamicCustomOp {
    private static final Logger log = LoggerFactory.getLogger(Create.class);
    protected boolean initialize;
    protected char order;
    protected DataType outputType;

    public Create() {
        this.initialize = false;
        this.order = 'c';
        this.outputType = DataType.FLOAT;
    }

    public Create(String str, SameDiff sameDiff, SDVariable sDVariable, boolean z) {
        this(str, sameDiff, sDVariable, 'c', z, sDVariable.dataType());
    }

    public Create(String str, SameDiff sameDiff, SDVariable sDVariable, char c, boolean z, DataType dataType) {
        super(str, sameDiff, new SDVariable[]{sDVariable}, false);
        this.initialize = false;
        this.order = 'c';
        this.outputType = DataType.FLOAT;
        this.outputType = dataType;
        this.initialize = z;
        this.order = c;
        addArgs();
    }

    public Create(INDArray iNDArray, DataType dataType) {
        this(iNDArray, 'c', false, dataType);
    }

    public Create(INDArray iNDArray, boolean z, DataType dataType) {
        this(iNDArray, 'c', z, dataType);
    }

    public Create(@NonNull INDArray iNDArray, char c, boolean z, DataType dataType) {
        super(new INDArray[]{iNDArray}, new INDArray[0]);
        this.initialize = false;
        this.order = 'c';
        this.outputType = DataType.FLOAT;
        if (iNDArray == null) {
            throw new NullPointerException("shape is marked non-null but is null");
        }
        this.order = c;
        this.initialize = z;
        this.outputType = dataType;
        addArgs();
    }

    protected void addArgs() {
        addBArgument(this.initialize);
        addIArgument(this.order, this.outputType.toInt());
    }

    @Override // org.nd4j.linalg.api.ops.DynamicCustomOp, org.nd4j.autodiff.functions.DifferentialFunction, org.nd4j.linalg.api.ops.CustomOp
    public String opName() {
        return "create";
    }

    @Override // org.nd4j.linalg.api.ops.DynamicCustomOp, org.nd4j.autodiff.functions.DifferentialFunction
    public String onnxName() {
        throw new NoOpNameFoundException("No op found for " + opName());
    }

    @Override // org.nd4j.linalg.api.ops.DynamicCustomOp, org.nd4j.autodiff.functions.DifferentialFunction
    public String tensorflowName() {
        return "Empty";
    }

    @Override // org.nd4j.linalg.api.ops.DynamicCustomOp, org.nd4j.autodiff.functions.DifferentialFunction
    public void initFromTensorFlow(NodeDef nodeDef, SameDiff sameDiff, Map<String, AttrValue> map, GraphDef graphDef) {
        if (map.containsKey(Nd4j.DTYPE)) {
            this.outputType = TFGraphMapper.convertType(map.get(Nd4j.DTYPE).getType());
        }
        if (map.containsKey("init")) {
            this.initialize = map.get("init").getB();
        }
        this.order = 'c';
        addArgs();
    }

    @Override // org.nd4j.linalg.api.ops.DynamicCustomOp, org.nd4j.autodiff.functions.DifferentialFunction
    public List<SDVariable> doDiff(List<SDVariable> list) {
        return Arrays.asList(this.sameDiff.zerosLike(outputVariables()[0]));
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public List<DataType> calculateOutputDataTypes(List<DataType> list) {
        Preconditions.checkState(list.size() == 1, "Expected list with exactly 1 datatype for %s, got %s", getClass(), list);
        return this.outputType != null ? Collections.singletonList(this.outputType) : list;
    }
}
