package org.nd4j.linalg.api.ops.impl.layers.convolution;

import java.lang.reflect.Field;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import lombok.NonNull;
import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.common.base.Preconditions;
import org.nd4j.common.util.ArrayUtil;
import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.imports.converters.DifferentialFunctionClassHolder;
import org.nd4j.imports.descriptors.properties.AttributeAdapter;
import org.nd4j.imports.descriptors.properties.PropertyMapping;
import org.nd4j.imports.descriptors.properties.adapters.ConditionalFieldValueIntIndexArrayAdapter;
import org.nd4j.imports.descriptors.properties.adapters.NDArrayShapeAdapter;
import org.nd4j.imports.descriptors.properties.adapters.SizeThresholdIntArrayIntIndexAdpater;
import org.nd4j.imports.descriptors.properties.adapters.StringEqualsAdapter;
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef;
import org.tensorflow.framework.NodeDef;

/* loaded from: input_file:org/nd4j/linalg/api/ops/impl/layers/convolution/DepthwiseConv2D.class */
public class DepthwiseConv2D extends DynamicCustomOp {
    private static final Logger log = LoggerFactory.getLogger(DepthwiseConv2D.class);
    protected Conv2DConfig config;

    /* loaded from: input_file:org/nd4j/linalg/api/ops/impl/layers/convolution/DepthwiseConv2D$DepthwiseConv2DBuilder.class */
    public static class DepthwiseConv2DBuilder {
        private SameDiff sameDiff;
        private SDVariable[] inputFunctions;
        private Conv2DConfig config;

        DepthwiseConv2DBuilder() {
        }

        public DepthwiseConv2DBuilder sameDiff(SameDiff sameDiff) {
            this.sameDiff = sameDiff;
            return this;
        }

        public DepthwiseConv2DBuilder inputFunctions(SDVariable[] sDVariableArr) {
            this.inputFunctions = sDVariableArr;
            return this;
        }

        public DepthwiseConv2DBuilder config(Conv2DConfig conv2DConfig) {
            this.config = conv2DConfig;
            return this;
        }

        public DepthwiseConv2D build() {
            return new DepthwiseConv2D(this.sameDiff, this.inputFunctions, this.config);
        }

        public String toString() {
            return "DepthwiseConv2D.DepthwiseConv2DBuilder(sameDiff=" + this.sameDiff + ", inputFunctions=" + Arrays.deepToString(this.inputFunctions) + ", config=" + this.config + ")";
        }
    }

    public DepthwiseConv2D(@NonNull SameDiff sameDiff, @NonNull SDVariable sDVariable, @NonNull SDVariable sDVariable2, SDVariable sDVariable3, @NonNull Conv2DConfig conv2DConfig) {
        this(sameDiff, (SDVariable[]) wrapFilterNull(sDVariable, sDVariable2, sDVariable3), conv2DConfig);
        if (sameDiff == null) {
            throw new NullPointerException("sameDiff is marked non-null but is null");
        }
        if (sDVariable == null) {
            throw new NullPointerException("input is marked non-null but is null");
        }
        if (sDVariable2 == null) {
            throw new NullPointerException("weights is marked non-null but is null");
        }
        if (conv2DConfig == null) {
            throw new NullPointerException("conv2DConfig is marked non-null but is null");
        }
    }

    public DepthwiseConv2D(SameDiff sameDiff, SDVariable[] sDVariableArr, Conv2DConfig conv2DConfig) {
        super(sameDiff, sDVariableArr);
        this.config = conv2DConfig;
        addArgs();
    }

    public DepthwiseConv2D(INDArray[] iNDArrayArr, INDArray[] iNDArrayArr2, Conv2DConfig conv2DConfig) {
        super(iNDArrayArr, iNDArrayArr2);
        this.config = conv2DConfig;
        addArgs();
    }

    public DepthwiseConv2D(@NonNull INDArray iNDArray, @NonNull INDArray iNDArray2, INDArray iNDArray3, INDArray iNDArray4, @NonNull Conv2DConfig conv2DConfig) {
        this((INDArray[]) wrapFilterNull(iNDArray, iNDArray2, iNDArray3), wrapOrNull(iNDArray4), conv2DConfig);
        if (iNDArray == null) {
            throw new NullPointerException("input is marked non-null but is null");
        }
        if (iNDArray2 == null) {
            throw new NullPointerException("weights is marked non-null but is null");
        }
        if (conv2DConfig == null) {
            throw new NullPointerException("config is marked non-null but is null");
        }
    }

    public DepthwiseConv2D(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, Conv2DConfig conv2DConfig) {
        this(iNDArray, iNDArray2, iNDArray3, (INDArray) null, conv2DConfig);
    }

    public DepthwiseConv2D() {
    }

    @Override // org.nd4j.linalg.api.ops.DynamicCustomOp, org.nd4j.linalg.api.ops.CustomOp
    public long[] iArgs() {
        if (this.iArguments.size() == 0) {
            addArgs();
        }
        return super.iArgs();
    }

    protected void addArgs() {
        long[] jArr = new long[10];
        jArr[0] = this.config.getKH();
        jArr[1] = this.config.getKW();
        jArr[2] = this.config.getSH();
        jArr[3] = this.config.getSW();
        jArr[4] = this.config.getPH();
        jArr[5] = this.config.getPW();
        jArr[6] = this.config.getDH();
        jArr[7] = this.config.getDW();
        jArr[8] = ArrayUtil.fromBoolean(this.config.isSameMode());
        jArr[9] = this.config.getDataFormat().equalsIgnoreCase("NCHW") ? 0L : 1L;
        addIArgument(jArr);
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public Object getValue(Field field) {
        if (this.config == null) {
            this.config = Conv2DConfig.builder().build();
        }
        try {
            return this.config.getValue(field);
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public Map<String, Object> propertiesForFunction() {
        if (this.config == null && !this.iArguments.isEmpty()) {
            this.config = Conv2DConfig.builder().kH(this.iArguments.get(0).longValue()).kW(this.iArguments.get(1).longValue()).sH(this.iArguments.get(2).longValue()).sW(this.iArguments.get(3).longValue()).pH(this.iArguments.get(4).longValue()).pW(this.iArguments.get(5).longValue()).dH(this.iArguments.get(6).longValue()).dW(this.iArguments.get(7).longValue()).isSameMode(this.iArguments.get(8).longValue() == 1).dataFormat(this.iArguments.get(9).longValue() == 1 ? "NHWC" : "NCHW").build();
        }
        return this.config.toProperties();
    }

    @Override // org.nd4j.linalg.api.ops.DynamicCustomOp, org.nd4j.autodiff.functions.DifferentialFunction
    public void initFromTensorFlow(NodeDef nodeDef, SameDiff sameDiff, Map<String, AttrValue> map, GraphDef graphDef) {
        TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, map, nodeDef, graphDef);
        addArgs();
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public boolean isConfigProperties() {
        return true;
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public String configFieldName() {
        return "config";
    }

    @Override // org.nd4j.linalg.api.ops.DynamicCustomOp, org.nd4j.autodiff.functions.DifferentialFunction
    public void initFromOnnx(Onnx.NodeProto nodeProto, SameDiff sameDiff, Map<String, Onnx.AttributeProto> map, Onnx.GraphProto graphProto) {
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public Map<String, Map<String, AttributeAdapter>> attributeAdaptersForFunction() {
        HashMap hashMap = new HashMap();
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        Map<String, Field> fieldsForFunction = DifferentialFunctionClassHolder.getInstance().getFieldsForFunction(this);
        linkedHashMap.put("kH", new NDArrayShapeAdapter(0));
        linkedHashMap.put("kW", new NDArrayShapeAdapter(1));
        linkedHashMap.put("sH", new ConditionalFieldValueIntIndexArrayAdapter("NCHW", 2, 1, fieldsForFunction.get("dataFormat")));
        linkedHashMap.put("sW", new ConditionalFieldValueIntIndexArrayAdapter("NCHW", 3, 2, fieldsForFunction.get("dataFormat")));
        linkedHashMap.put("dH", new ConditionalFieldValueIntIndexArrayAdapter("NCHW", 2, 1, fieldsForFunction.get("dataFormat")));
        linkedHashMap.put("dW", new ConditionalFieldValueIntIndexArrayAdapter("NCHW", 3, 2, fieldsForFunction.get("dataFormat")));
        linkedHashMap.put("isSameMode", new StringEqualsAdapter("SAME"));
        HashMap hashMap2 = new HashMap();
        hashMap2.put("kH", new SizeThresholdIntArrayIntIndexAdpater(0, 2, 0));
        hashMap2.put("kW", new SizeThresholdIntArrayIntIndexAdpater(1, 2, 0));
        hashMap2.put("dH", new SizeThresholdIntArrayIntIndexAdpater(0, 2, 0));
        hashMap2.put("dW", new SizeThresholdIntArrayIntIndexAdpater(1, 2, 0));
        hashMap2.put("sH", new SizeThresholdIntArrayIntIndexAdpater(0, 2, 0));
        hashMap2.put("sW", new SizeThresholdIntArrayIntIndexAdpater(1, 2, 0));
        hashMap2.put("isSameMode", new StringEqualsAdapter("SAME"));
        try {
            hashMap.put(tensorflowName(), linkedHashMap);
        } catch (NoOpNameFoundException e) {
        }
        try {
            hashMap.put(onnxName(), hashMap2);
        } catch (NoOpNameFoundException e2) {
        }
        return hashMap;
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public Map<String, Map<String, PropertyMapping>> mappingsForFunction() {
        HashMap hashMap = new HashMap();
        HashMap hashMap2 = new HashMap();
        PropertyMapping build = PropertyMapping.builder().tfAttrName("strides").onnxAttrName("strides").propertyNames(new String[]{"sW", "sH"}).build();
        PropertyMapping build2 = PropertyMapping.builder().propertyNames(new String[]{"kH"}).tfInputPosition(1).shapePosition(0).onnxAttrName("kernel_shape").build();
        PropertyMapping build3 = PropertyMapping.builder().propertyNames(new String[]{"kW"}).tfInputPosition(1).shapePosition(1).onnxAttrName("kernel_shape").build();
        PropertyMapping build4 = PropertyMapping.builder().onnxAttrName("dilations").propertyNames(new String[]{"dW", "dH"}).tfAttrName("rates").build();
        PropertyMapping build5 = PropertyMapping.builder().onnxAttrName("data_format").tfAttrName("data_format").propertyNames(new String[]{"dataFormat"}).build();
        PropertyMapping.builder().onnxAttrName("data_format").tfAttrName("data_format").propertyNames(new String[]{"isNHWC"}).build();
        PropertyMapping build6 = PropertyMapping.builder().onnxAttrName("auto_pad").propertyNames(new String[]{"isSameMode"}).tfAttrName("padding").build();
        PropertyMapping build7 = PropertyMapping.builder().onnxAttrName("padding").propertyNames(new String[]{"pH", "pW"}).build();
        hashMap2.put("sW", build);
        hashMap2.put("sH", build);
        hashMap2.put("kH", build2);
        hashMap2.put("kW", build3);
        hashMap2.put("dW", build4);
        hashMap2.put("dH", build4);
        hashMap2.put("isSameMode", build6);
        hashMap2.put("pH", build7);
        hashMap2.put("pW", build7);
        hashMap2.put("dataFormat", build5);
        try {
            hashMap.put(onnxName(), hashMap2);
        } catch (NoOpNameFoundException e) {
        }
        try {
            hashMap.put(tensorflowName(), hashMap2);
        } catch (NoOpNameFoundException e2) {
        }
        return hashMap;
    }

    @Override // org.nd4j.linalg.api.ops.DynamicCustomOp, org.nd4j.autodiff.functions.DifferentialFunction, org.nd4j.linalg.api.ops.CustomOp
    public String opName() {
        return "depthwise_conv2d";
    }

    @Override // org.nd4j.linalg.api.ops.DynamicCustomOp, org.nd4j.autodiff.functions.DifferentialFunction
    public List<SDVariable> doDiff(List<SDVariable> list) {
        return Arrays.asList(new DepthwiseConv2DBp(this.sameDiff, arg(0), arg(1), args().length == 2 ? null : arg(2), list.get(0), this.config).outputVariables());
    }

    @Override // org.nd4j.linalg.api.ops.DynamicCustomOp, org.nd4j.autodiff.functions.DifferentialFunction
    public String onnxName() {
        return "depth_conv";
    }

    @Override // org.nd4j.linalg.api.ops.DynamicCustomOp, org.nd4j.autodiff.functions.DifferentialFunction
    public String tensorflowName() {
        return "DepthwiseConv2dNative";
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public List<DataType> calculateOutputDataTypes(List<DataType> list) {
        int length = args().length;
        Preconditions.checkState(list != null && list.size() == length, "Expected %s input data types for %s, got %s", Integer.valueOf(length), getClass(), list);
        return Collections.singletonList(list.get(0));
    }

    public static DepthwiseConv2DBuilder sameDiffBuilder() {
        return new DepthwiseConv2DBuilder();
    }

    public Conv2DConfig getConfig() {
        return this.config;
    }
}
