package org.nd4j.autodiff.samediff;

import com.google.common.collect.HashBasedTable;
import com.google.common.collect.Table;
import com.google.common.primitives.Ints;
import com.google.flatbuffers.FlatBufferBuilder;
import com.rits.cloning.Cloner;
import com.rits.cloning.IFastCloner;
import java.io.BufferedInputStream;
import java.io.BufferedOutputStream;
import java.io.DataOutputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.lang.reflect.Method;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.IdentityHashMap;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.zip.ZipEntry;
import java.util.zip.ZipFile;
import java.util.zip.ZipOutputStream;
import lombok.NonNull;
import org.apache.commons.io.IOUtils;
import org.apache.commons.io.output.CloseShieldOutputStream;
import org.apache.commons.lang3.ArrayUtils;
import org.nd4j.autodiff.execution.conf.ExecutionMode;
import org.nd4j.autodiff.execution.conf.ExecutorConfiguration;
import org.nd4j.autodiff.execution.conf.OutputMode;
import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.functions.DifferentialFunctionFactory;
import org.nd4j.autodiff.samediff.internal.AbstractSession;
import org.nd4j.autodiff.samediff.internal.DataTypesSession;
import org.nd4j.autodiff.samediff.internal.InferenceSession;
import org.nd4j.autodiff.samediff.internal.SameDiffOp;
import org.nd4j.autodiff.samediff.internal.Variable;
import org.nd4j.autodiff.samediff.ops.SDBaseOps;
import org.nd4j.autodiff.samediff.ops.SDCNN;
import org.nd4j.autodiff.samediff.ops.SDLoss;
import org.nd4j.autodiff.samediff.ops.SDMath;
import org.nd4j.autodiff.samediff.ops.SDNN;
import org.nd4j.autodiff.samediff.ops.SDRNN;
import org.nd4j.autodiff.samediff.ops.SDRandom;
import org.nd4j.autodiff.samediff.serde.FlatBuffersMapper;
import org.nd4j.autodiff.util.cloner.DataBufferFastCloner;
import org.nd4j.autodiff.util.cloner.INDArrayFastCloner;
import org.nd4j.base.Preconditions;
import org.nd4j.evaluation.IEvaluation;
import org.nd4j.graph.FlatArray;
import org.nd4j.graph.FlatGraph;
import org.nd4j.graph.FlatNode;
import org.nd4j.graph.FlatVariable;
import org.nd4j.graph.IntPair;
import org.nd4j.jackson.objectmapper.holder.ObjectMapperHolder;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.buffer.factory.DataBufferFactory;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.BaseOp;
import org.nd4j.linalg.api.ops.CustomOp;
import org.nd4j.linalg.api.ops.CustomOpDescriptor;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.IndexAccumulation;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.api.ops.ReduceOp;
import org.nd4j.linalg.api.ops.ScalarOp;
import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
import org.nd4j.linalg.api.ops.impl.controlflow.If;
import org.nd4j.linalg.api.ops.impl.controlflow.While;
import org.nd4j.linalg.api.ops.impl.controlflow.compat.Enter;
import org.nd4j.linalg.api.ops.impl.controlflow.compat.Switch;
import org.nd4j.linalg.api.ops.impl.layers.ExternalErrorsFunction;
import org.nd4j.linalg.api.ops.impl.shape.tensorops.TensorArray;
import org.nd4j.linalg.api.ops.impl.transforms.Assert;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.GradientBackwardsMarker;
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.collection.IntArrayKeyMap;
import org.nd4j.linalg.compression.CompressedDataBuffer;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.adapter.MultiDataSetIteratorAdapter;
import org.nd4j.linalg.dataset.adapter.SingletonMultiDataSetIterator;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.exception.ND4UnresolvedOutputVariables;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.learning.GradientUpdater;
import org.nd4j.linalg.learning.regularization.Regularization;
import org.nd4j.linalg.primitives.AtomicBoolean;
import org.nd4j.linalg.primitives.Pair;
import org.nd4j.linalg.util.ArrayUtil;
import org.nd4j.linalg.util.DeviceLocalNDArray;
import org.nd4j.weightinit.WeightInitScheme;
import org.nd4j.weightinit.impl.ConstantInitScheme;
import org.nd4j.weightinit.impl.NDArraySupplierInitScheme;
import org.nd4j.weightinit.impl.ZeroInitScheme;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/nd4j/autodiff/samediff/SameDiff.class */
public class SameDiff extends SDBaseOps {
    private final Map<String, Variable> variables;
    private final Map<String, SameDiffOp> ops;
    private final Map<Long, InferenceSession> sessions;
    private final Map<String, DeviceLocalNDArray> constantArrays;
    private final Map<String, DeviceLocalNDArray> variablesArrays;
    private final Map<Long, Map<String, INDArray>> placeholdersPerThread;
    private final List<String> lossVariables;
    private TrainingConfig trainingConfig;
    private boolean initializedTraining;
    private INDArray updaterState;
    private Map<String, INDArray> updaterViews;
    private Map<String, GradientUpdater> updaterMap;
    private Map<String, String> baseNameForFunctionInstanceId;
    private DifferentialFunctionFactory functionFactory;

    @Deprecated
    private Map<String, long[]> variableNameToShape;

    @Deprecated
    private Map<String, SDVariable> forwardVarForGrad;
    private int variableId;
    public final SDMath math;
    public final SDRandom random;
    public final SDNN nn;
    public final SDCNN cnn;
    public final SDRNN rnn;
    public final SDLoss loss;
    private Map<String, List<String>> propertiesToResolve;
    private Map<String, Map<String, Object>> propertiesForFunction;

    @Deprecated
    private Map<String, long[]> placeHolderOriginalShapes;
    private Map<String, SameDiffFunctionDefinition> sameDiffFunctionDefinitionMap;
    private Map<String, SameDiff> sameDiffFunctionInstances;
    private Set<String> placeHolderFunctions;
    private Table<String, String, String> fieldVariableResolutionMapping;
    private transient AtomicBoolean wasRegistered;
    private boolean debugMode;
    private Map<int[], Op> opsForResult;
    private boolean resolvedVariables;
    boolean logExecution;
    private SameDiff parent;
    private SameDiff child;
    public static final String TRAINING_CONFIG_JSON_ZIP_ENTRY_NAME = "trainingConfig.json";
    public static final String SAMEDIFF_FILE_ENTRY_NAME = "samediff.fb";
    private static final Logger log = LoggerFactory.getLogger(SameDiff.class);
    private static Cloner cloner = newCloner();
    private static Map<String, Method> opMethods = new HashMap();

    /* loaded from: input_file:org/nd4j/autodiff/samediff/SameDiff$SameDiffBuilder.class */
    public static class SameDiffBuilder {
        private TrainingConfig trainingConfig;
        private boolean initializedTraining;
        private INDArray updaterState;
        private Map<String, INDArray> updaterViews;
        private Map<String, GradientUpdater> updaterMap;
        private Map<String, String> baseNameForFunctionInstanceId;
        private DifferentialFunctionFactory functionFactory;
        private Map<String, long[]> variableNameToShape;
        private Map<String, SDVariable> forwardVarForGrad;
        private int variableId;
        private Map<String, List<String>> propertiesToResolve;
        private Map<String, Map<String, Object>> propertiesForFunction;
        private Map<String, long[]> placeHolderOriginalShapes;
        private Map<String, SameDiffFunctionDefinition> sameDiffFunctionDefinitionMap;
        private Map<String, SameDiff> sameDiffFunctionInstances;
        private Set<String> placeHolderFunctions;
        private Table<String, String, String> fieldVariableResolutionMapping;
        private AtomicBoolean wasRegistered;
        private boolean debugMode;
        private Map<int[], Op> opsForResult;
        private boolean resolvedVariables;
        private boolean logExecution;
        private SameDiff parent;
        private SameDiff child;

        SameDiffBuilder() {
        }

        public SameDiffBuilder trainingConfig(TrainingConfig trainingConfig) {
            this.trainingConfig = trainingConfig;
            return this;
        }

        public SameDiffBuilder initializedTraining(boolean z) {
            this.initializedTraining = z;
            return this;
        }

        public SameDiffBuilder updaterState(INDArray iNDArray) {
            this.updaterState = iNDArray;
            return this;
        }

        public SameDiffBuilder updaterViews(Map<String, INDArray> map) {
            this.updaterViews = map;
            return this;
        }

        public SameDiffBuilder updaterMap(Map<String, GradientUpdater> map) {
            this.updaterMap = map;
            return this;
        }

        public SameDiffBuilder baseNameForFunctionInstanceId(Map<String, String> map) {
            this.baseNameForFunctionInstanceId = map;
            return this;
        }

        public SameDiffBuilder functionFactory(DifferentialFunctionFactory differentialFunctionFactory) {
            this.functionFactory = differentialFunctionFactory;
            return this;
        }

        @Deprecated
        public SameDiffBuilder variableNameToShape(Map<String, long[]> map) {
            this.variableNameToShape = map;
            return this;
        }

        @Deprecated
        public SameDiffBuilder forwardVarForGrad(Map<String, SDVariable> map) {
            this.forwardVarForGrad = map;
            return this;
        }

        public SameDiffBuilder variableId(int i) {
            this.variableId = i;
            return this;
        }

        public SameDiffBuilder propertiesToResolve(Map<String, List<String>> map) {
            this.propertiesToResolve = map;
            return this;
        }

        public SameDiffBuilder propertiesForFunction(Map<String, Map<String, Object>> map) {
            this.propertiesForFunction = map;
            return this;
        }

        @Deprecated
        public SameDiffBuilder placeHolderOriginalShapes(Map<String, long[]> map) {
            this.placeHolderOriginalShapes = map;
            return this;
        }

        public SameDiffBuilder sameDiffFunctionDefinitionMap(Map<String, SameDiffFunctionDefinition> map) {
            this.sameDiffFunctionDefinitionMap = map;
            return this;
        }

        public SameDiffBuilder sameDiffFunctionInstances(Map<String, SameDiff> map) {
            this.sameDiffFunctionInstances = map;
            return this;
        }

        public SameDiffBuilder placeHolderFunctions(Set<String> set) {
            this.placeHolderFunctions = set;
            return this;
        }

        public SameDiffBuilder fieldVariableResolutionMapping(Table<String, String, String> table) {
            this.fieldVariableResolutionMapping = table;
            return this;
        }

        public SameDiffBuilder wasRegistered(AtomicBoolean atomicBoolean) {
            this.wasRegistered = atomicBoolean;
            return this;
        }

        public SameDiffBuilder debugMode(boolean z) {
            this.debugMode = z;
            return this;
        }

        public SameDiffBuilder opsForResult(Map<int[], Op> map) {
            this.opsForResult = map;
            return this;
        }

        public SameDiffBuilder resolvedVariables(boolean z) {
            this.resolvedVariables = z;
            return this;
        }

        public SameDiffBuilder logExecution(boolean z) {
            this.logExecution = z;
            return this;
        }

        public SameDiffBuilder parent(SameDiff sameDiff) {
            this.parent = sameDiff;
            return this;
        }

        public SameDiffBuilder child(SameDiff sameDiff) {
            this.child = sameDiff;
            return this;
        }

        public SameDiff build() {
            return new SameDiff(this.trainingConfig, this.initializedTraining, this.updaterState, this.updaterViews, this.updaterMap, this.baseNameForFunctionInstanceId, this.functionFactory, this.variableNameToShape, this.forwardVarForGrad, this.variableId, this.propertiesToResolve, this.propertiesForFunction, this.placeHolderOriginalShapes, this.sameDiffFunctionDefinitionMap, this.sameDiffFunctionInstances, this.placeHolderFunctions, this.fieldVariableResolutionMapping, this.wasRegistered, this.debugMode, this.opsForResult, this.resolvedVariables, this.logExecution, this.parent, this.child);
        }

        public String toString() {
            return "SameDiff.SameDiffBuilder(trainingConfig=" + this.trainingConfig + ", initializedTraining=" + this.initializedTraining + ", updaterState=" + this.updaterState + ", updaterViews=" + this.updaterViews + ", updaterMap=" + this.updaterMap + ", baseNameForFunctionInstanceId=" + this.baseNameForFunctionInstanceId + ", functionFactory=" + this.functionFactory + ", variableNameToShape=" + this.variableNameToShape + ", forwardVarForGrad=" + this.forwardVarForGrad + ", variableId=" + this.variableId + ", propertiesToResolve=" + this.propertiesToResolve + ", propertiesForFunction=" + this.propertiesForFunction + ", placeHolderOriginalShapes=" + this.placeHolderOriginalShapes + ", sameDiffFunctionDefinitionMap=" + this.sameDiffFunctionDefinitionMap + ", sameDiffFunctionInstances=" + this.sameDiffFunctionInstances + ", placeHolderFunctions=" + this.placeHolderFunctions + ", fieldVariableResolutionMapping=" + this.fieldVariableResolutionMapping + ", wasRegistered=" + this.wasRegistered + ", debugMode=" + this.debugMode + ", opsForResult=" + this.opsForResult + ", resolvedVariables=" + this.resolvedVariables + ", logExecution=" + this.logExecution + ", parent=" + this.parent + ", child=" + this.child + ")";
        }
    }

    public SDMath math() {
        return this.math;
    }

    public SDRandom random() {
        return this.random;
    }

    public SDNN nn() {
        return this.nn;
    }

    public SDCNN cnn() {
        return this.cnn;
    }

    public SDRNN rnn() {
        return this.rnn;
    }

    public SDLoss loss() {
        return this.loss;
    }

    public static Cloner newCloner() {
        Cloner cloner2 = new Cloner();
        cloner2.registerFastCloner(Nd4j.getBackend().getNDArrayClass(), new INDArrayFastCloner());
        DataBufferFastCloner dataBufferFastCloner = new DataBufferFastCloner();
        DataBufferFactory dataBufferFactory = Nd4j.getDataBufferFactory();
        doReg(cloner2, dataBufferFastCloner, dataBufferFactory.intBufferClass());
        doReg(cloner2, dataBufferFastCloner, dataBufferFactory.longBufferClass());
        doReg(cloner2, dataBufferFastCloner, dataBufferFactory.halfBufferClass());
        doReg(cloner2, dataBufferFastCloner, dataBufferFactory.floatBufferClass());
        doReg(cloner2, dataBufferFastCloner, dataBufferFactory.doubleBufferClass());
        doReg(cloner2, dataBufferFastCloner, CompressedDataBuffer.class);
        return cloner2;
    }

    private static void doReg(Cloner cloner2, IFastCloner iFastCloner, Class<?> cls) {
        if (cls != null) {
            cloner2.registerFastCloner(cls, iFastCloner);
        }
    }

    public void updateVariableName(String str, String str2) {
        SDVariable variable = getVariable(str);
        Variable remove = this.variables.remove(str);
        variable.setVarName(str2);
        remove.setName(str2);
        this.variables.put(str2, remove);
        for (SameDiffOp sameDiffOp : this.ops.values()) {
            List<String> outputsOfOp = sameDiffOp.getOutputsOfOp();
            if (outputsOfOp != null && !outputsOfOp.isEmpty()) {
                for (int i = 0; i < outputsOfOp.size(); i++) {
                    if (outputsOfOp.get(i).equals(str)) {
                        outputsOfOp.set(i, str2);
                    }
                }
            }
            List<String> inputsToOp = sameDiffOp.getInputsToOp();
            if (inputsToOp != null && !inputsToOp.isEmpty()) {
                for (int i2 = 0; i2 < inputsToOp.size(); i2++) {
                    if (inputsToOp.get(i2).equals(str)) {
                        inputsToOp.set(i2, str2);
                    }
                }
            }
        }
        if (this.variableNameToShape.containsKey(str)) {
            this.variableNameToShape.put(str2, this.variableNameToShape.remove(str));
        }
        if (this.forwardVarForGrad.containsKey(str)) {
            this.forwardVarForGrad.put(str2, this.forwardVarForGrad.remove(str));
        }
        if (remove.getInputsForOp() != null) {
            Iterator<String> it = remove.getInputsForOp().iterator();
            while (it.hasNext()) {
                DifferentialFunction op = this.ops.get(it.next()).getOp();
                if (op instanceof BaseOp) {
                    BaseOp baseOp = (BaseOp) op;
                    if (baseOp.getXVertexId() != null && baseOp.getXVertexId().equals(str)) {
                        baseOp.setXVertexId(str2);
                    }
                    if (baseOp.getYVertexId() != null && baseOp.getYVertexId().equals(str)) {
                        baseOp.setYVertexId(str2);
                    }
                    if (baseOp.getZVertexId() != null && baseOp.getZVertexId().equals(str)) {
                        baseOp.setZVertexId(str2);
                    }
                }
            }
        }
        if (remove.getOutputOfOp() != null) {
            DifferentialFunction op2 = this.ops.get(remove.getOutputOfOp()).getOp();
            if (op2 instanceof BaseOp) {
                BaseOp baseOp2 = (BaseOp) op2;
                if (baseOp2.getXVertexId() != null && baseOp2.getXVertexId().equals(str)) {
                    baseOp2.setXVertexId(str2);
                }
                if (baseOp2.getYVertexId() != null && baseOp2.getYVertexId().equals(str)) {
                    baseOp2.setYVertexId(str2);
                }
                if (baseOp2.getZVertexId() == null || !baseOp2.getZVertexId().equals(str)) {
                    return;
                }
                baseOp2.setZVertexId(str2);
            }
        }
    }

    public SameDiff disableDebugging() {
        this.debugMode = false;
        return this;
    }

    public SameDiff enableDebugMode() {
        this.debugMode = true;
        return this;
    }

    @Override // org.nd4j.autodiff.samediff.ops.SDBaseOps
    public DifferentialFunctionFactory f() {
        return this.functionFactory;
    }

    public SDVariable invokeGraphOn(SameDiff sameDiff) {
        HashMap hashMap = new HashMap();
        int i = 1;
        for (SDVariable sDVariable : variables()) {
            SDVariable sDVariable2 = (SDVariable) cloner.deepCloneDontCloneInstances(sDVariable, new Object[]{sDVariable.getSameDiff()});
            SDVariable var = sameDiff.var(sDVariable2);
            if (sDVariable.getArr() != null && sDVariable.getVariableType() != VariableType.ARRAY) {
                sameDiff.associateArrayWithVariable(sDVariable.getArr(), var);
            }
            hashMap.put(Integer.valueOf(i), Integer.valueOf(i));
            sDVariable2.setSameDiff(sameDiff);
            i++;
        }
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        for (SameDiffOp sameDiffOp : this.ops.values()) {
            DifferentialFunction op = sameDiffOp.getOp();
            if (!(op instanceof SDVariable)) {
                DifferentialFunction differentialFunction = (DifferentialFunction) cloner.deepCloneDontCloneInstances(op, new Object[]{op.getSameDiff()});
                differentialFunction.setSameDiff(sameDiff);
                differentialFunction.setOwnName(op.getOwnName());
                if (sameDiff.functionExists(op.getOwnName())) {
                    sameDiff.putFunctionForId(op.getOwnName(), op);
                }
                linkedHashMap.put(op.getOwnName(), differentialFunction);
                SDVariable[] args = op.args();
                SDVariable[] outputVariables = op.outputVariables();
                sameDiff.addArgsFor(args, differentialFunction);
                sameDiff.addOutgoingFor(outputVariables, op);
                for (SDVariable sDVariable3 : differentialFunction.args()) {
                    sDVariable3.setSameDiff(sameDiff);
                }
                for (SDVariable sDVariable4 : differentialFunction.outputVariables()) {
                    sDVariable4.setSameDiff(sameDiff);
                }
                sameDiff.ops.put(op.getOwnName(), sameDiffOp);
            }
        }
        return sameDiff.variables().get(sameDiff.variables().size() - 1);
    }

    public boolean functionExists(String str) {
        return this.ops.containsKey(str);
    }

    public DifferentialFunction functionOutputFor(String str) {
        String outputOfOp;
        if (this.variables.get(str).getOutputOfOp() == null || (outputOfOp = this.variables.get(str).getOutputOfOp()) == null) {
            return null;
        }
        return this.ops.get(outputOfOp).getOp();
    }

    public DifferentialFunction getFunctionById(@NonNull String str) {
        if (str == null) {
            throw new NullPointerException("id is marked @NonNull but is null");
        }
        if (this.ops.containsKey(str)) {
            return this.ops.get(str).getOp();
        }
        throw new ND4JIllegalStateException("No function with id " + str + " found!");
    }

    public void putFunctionForId(String str, DifferentialFunction differentialFunction) {
        if (this.ops.containsKey(str) && this.ops.get(str).getOp() == null) {
            throw new ND4JIllegalStateException("Function by id already exists!");
        }
        if (differentialFunction instanceof SDVariable) {
            throw new ND4JIllegalStateException("Function must not be a variable!");
        }
        if (this.ops.containsKey(str)) {
            return;
        }
        this.ops.put(str, SameDiffOp.builder().name(str).op(differentialFunction).build());
    }

    public String[] getInputsForFunction(DifferentialFunction differentialFunction) {
        if (!this.ops.containsKey(differentialFunction.getOwnName())) {
            throw new ND4JIllegalStateException("Illegal function instance id found " + differentialFunction.getOwnName());
        }
        List<String> inputsToOp = this.ops.get(differentialFunction.getOwnName()).getInputsToOp();
        if (inputsToOp == null) {
            return null;
        }
        return (String[]) inputsToOp.toArray(new String[inputsToOp.size()]);
    }

    public String[] getOutputsForFunction(DifferentialFunction differentialFunction) {
        if (!this.ops.containsKey(differentialFunction.getOwnName())) {
            throw new ND4JIllegalStateException("Illegal function instance id found " + differentialFunction.getOwnName());
        }
        List<String> outputsOfOp = this.ops.get(differentialFunction.getOwnName()).getOutputsOfOp();
        if (outputsOfOp == null) {
            return null;
        }
        return (String[]) outputsOfOp.toArray(new String[outputsOfOp.size()]);
    }

    public SDVariable[] getOutputVariablesForFunction(DifferentialFunction differentialFunction) {
        String[] outputsForFunction = getOutputsForFunction(differentialFunction);
        if (outputsForFunction == null) {
            throw new ND4JIllegalStateException("No inputs found for function " + differentialFunction);
        }
        SDVariable[] sDVariableArr = new SDVariable[outputsForFunction.length];
        for (int i = 0; i < outputsForFunction.length; i++) {
            sDVariableArr[i] = getVariable(outputsForFunction[i]);
        }
        return sDVariableArr;
    }

    public SDVariable[] getInputVariablesForFunction(DifferentialFunction differentialFunction) {
        String[] inputsForFunction = getInputsForFunction(differentialFunction);
        if (inputsForFunction == null) {
            throw new ND4JIllegalStateException("No inputs found for function " + differentialFunction);
        }
        SDVariable[] sDVariableArr = new SDVariable[inputsForFunction.length];
        for (int i = 0; i < inputsForFunction.length; i++) {
            sDVariableArr[i] = getVariable(inputsForFunction[i]);
            if (sDVariableArr[i] == null) {
                throw new ND4JIllegalStateException("Found null variable at index " + i);
            }
        }
        return sDVariableArr;
    }

    public void setArrayForVariable(@NonNull String str, @NonNull INDArray iNDArray) {
        if (str == null) {
            throw new NullPointerException("varName is marked @NonNull but is null");
        }
        if (iNDArray == null) {
            throw new NullPointerException("arr is marked @NonNull but is null");
        }
        Preconditions.checkState(this.variables.containsKey(str), "No variable with name \"%s\" exists", str);
        SDVariable variable = getVariable(str);
        if (variable.isConstant()) {
            this.constantArrays.put(str, new DeviceLocalNDArray(iNDArray));
            return;
        }
        if (variable.getVariableType() == VariableType.VARIABLE) {
            this.variablesArrays.put(str, new DeviceLocalNDArray(iNDArray));
        } else {
            if (!variable.isPlaceHolder()) {
                throw new UnsupportedOperationException("Cannot set variable of type " + variable.getVariableType() + " using this method");
            }
            long id = Thread.currentThread().getId();
            if (!this.placeholdersPerThread.containsKey(Long.valueOf(id))) {
                this.placeholdersPerThread.put(Long.valueOf(id), new HashMap());
            }
            this.placeholdersPerThread.get(Long.valueOf(id)).put(str, iNDArray);
        }
    }

    public long[] getShapeForVarName(String str) {
        return arrayAlreadyExistsForVarName(str) ? getVariable(str).getArr().shape() : this.variableNameToShape.get(str);
    }

    public LongShapeDescriptor getShapeDescriptorForVarName(String str) {
        return getVariable(str).getArr() != null ? getVariable(str).getArr().shapeDescriptor() : LongShapeDescriptor.fromShape(this.variableNameToShape.get(str), Nd4j.dataType());
    }

    @Deprecated
    public void putShapeForVarName(String str, long[] jArr) {
        if (jArr == null) {
            throw new ND4JIllegalStateException("Shape must not be null!");
        }
        if (this.variableNameToShape.containsKey(str)) {
            throw new ND4JIllegalStateException("Shape for " + str + " already exists!");
        }
        this.variableNameToShape.put(str, jArr);
    }

    public void putShapeForVarName(String str, LongShapeDescriptor longShapeDescriptor) {
        SDVariable variable = getVariable(str);
        putShapeForVarName(str, longShapeDescriptor.getShape());
        variable.setDataType(longShapeDescriptor.dataType());
    }

    @Deprecated
    public void putOrUpdateShapeForVarName(String str, long[] jArr, boolean z) {
        Preconditions.checkNotNull(jArr, "Cannot put null shape for variable: %s", str);
        if (this.variableNameToShape.containsKey(str)) {
            return;
        }
        putShapeForVarName(str, jArr);
    }

    public boolean shapeAlreadyExistsForVarName(String str) {
        return this.variableNameToShape.containsKey(str) || arrayAlreadyExistsForVarName(str);
    }

    public boolean arrayAlreadyExistsForVarName(String str) {
        SDVariable variable = getVariable(str);
        switch (variable.getVariableType()) {
            case VARIABLE:
                return this.variablesArrays.containsKey(str);
            case ARRAY:
                long id = Thread.currentThread().getId();
                return this.sessions.containsKey(Long.valueOf(id)) && this.sessions.get(Long.valueOf(id)).contains(str, AbstractSession.OUTER_FRAME, 0, null);
            case CONSTANT:
                return this.constantArrays.containsKey(str);
            case PLACEHOLDER:
                return this.placeholdersPerThread.containsKey(Long.valueOf(Thread.currentThread().getId())) && this.placeholdersPerThread.get(Long.valueOf(Thread.currentThread().getId())).containsKey(str);
            default:
                throw new RuntimeException("Unknown variable type: " + variable.getVariableType());
        }
    }

    public INDArray getArrForVarName(@NonNull String str) {
        if (str == null) {
            throw new NullPointerException("varName is marked @NonNull but is null");
        }
        Preconditions.checkState(this.variables.containsKey(str), "No variable found with name \"%s\"", str);
        SDVariable variable = this.variables.get(str).getVariable();
        switch (variable.getVariableType()) {
            case VARIABLE:
                if (!this.variablesArrays.containsKey(str)) {
                    variable.storeAndAllocateNewArray();
                }
                return this.variablesArrays.get(str).get();
            case ARRAY:
                InferenceSession inferenceSession = this.sessions.get(Long.valueOf(Thread.currentThread().getId()));
                if (inferenceSession == null) {
                    return null;
                }
                return inferenceSession.get(str, AbstractSession.OUTER_FRAME, 0, null, false);
            case CONSTANT:
                if (this.constantArrays.containsKey(str)) {
                    return this.constantArrays.get(str).get();
                }
                return null;
            case PLACEHOLDER:
                long id = Thread.currentThread().getId();
                if (this.placeholdersPerThread.get(Long.valueOf(id)) == null || !this.placeholdersPerThread.get(Long.valueOf(id)).containsKey(str)) {
                    return null;
                }
                return this.placeholdersPerThread.get(Long.valueOf(id)).get(str);
            default:
                throw new RuntimeException("Unknown variable type: " + variable.getVariableType());
        }
    }

    public void associateArrayWithVariable(INDArray iNDArray, @NonNull String str) {
        if (str == null) {
            throw new NullPointerException("variable is marked @NonNull but is null");
        }
        Preconditions.checkState(this.variables.containsKey(str), "Cannot associate array with variable \"%s\": variable \"%s\" does not exist in this SameDiff instance", str, str);
        associateArrayWithVariable(iNDArray, getVariable(str));
    }

    /* JADX WARN: Can't fix incorrect switch cases order, some code will duplicate */
    /* JADX WARN: Failed to find 'out' block for switch in B:39:0x0104. Please report as an issue. */
    /* JADX WARN: Removed duplicated region for block: B:43:0x01fd  */
    /* JADX WARN: Removed duplicated region for block: B:58:0x0256 A[ADDED_TO_REGION, ORIG_RETURN, RETURN] */
    /* JADX WARN: Removed duplicated region for block: B:63:0x01a3  */
    /*
        Code decompiled incorrectly, please refer to instructions dump.
        To view partially-correct add '--show-bad-code' argument
    */
    public void associateArrayWithVariable(org.nd4j.linalg.api.ndarray.INDArray r7, org.nd4j.autodiff.samediff.SDVariable r8) {
        /*
            Method dump skipped, instructions count: 599
            To view this dump add '--comments-level debug' option
        */
        throw new UnsupportedOperationException("Method not decompiled: org.nd4j.autodiff.samediff.SameDiff.associateArrayWithVariable(org.nd4j.linalg.api.ndarray.INDArray, org.nd4j.autodiff.samediff.SDVariable):void");
    }

    public void putSubFunction(String str, SameDiff sameDiff) {
        if (this.sameDiffFunctionInstances.containsKey(str) && this.sameDiffFunctionInstances.get(str) != sameDiff) {
            throw new ND4JIllegalStateException("Unable to replace samediff namespace. Please choose another opName");
        }
        this.sameDiffFunctionInstances.put(str, sameDiff);
    }

    public Map<String, SDVariable> variableMap() {
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        for (Variable variable : this.variables.values()) {
            linkedHashMap.put(variable.getName(), variable.getVariable());
        }
        return linkedHashMap;
    }

    @Deprecated
    public SDVariable invoke(Op op, SDVariable sDVariable, SDVariable sDVariable2) {
        if (!opMethods.containsKey(op.opName())) {
            throw new ND4JIllegalStateException("Illegal method opName " + op.opName());
        }
        if (sDVariable == null || sDVariable2 == null) {
            try {
                return (SDVariable) opMethods.get(op.opName()).invoke(this, sDVariable);
            } catch (Exception e) {
            }
        } else {
            try {
                return (SDVariable) opMethods.get(op.opName()).invoke(this, sDVariable, sDVariable2);
            } catch (Exception e2) {
            }
        }
        throw new ND4JIllegalStateException("Illegal method opName " + op.opName());
    }

    public Collection<String> definedFunctionNames() {
        return this.sameDiffFunctionInstances.keySet();
    }

    public SDVariable invoke(Op op, SDVariable sDVariable) {
        return invoke(op, sDVariable, null);
    }

    private SameDiff() {
        this.variables = new LinkedHashMap();
        this.ops = new LinkedHashMap();
        this.sessions = new ConcurrentHashMap();
        this.constantArrays = new ConcurrentHashMap();
        this.variablesArrays = new ConcurrentHashMap();
        this.placeholdersPerThread = new ConcurrentHashMap();
        this.lossVariables = new ArrayList();
        this.variableId = 0;
        this.math = new SDMath(this);
        this.random = new SDRandom(this);
        this.nn = new SDNN(this);
        this.cnn = new SDCNN(this);
        this.rnn = new SDRNN(this);
        this.loss = new SDLoss(this);
        this.wasRegistered = new AtomicBoolean(false);
        this.resolvedVariables = false;
        this.logExecution = true;
        this.functionFactory = new DifferentialFunctionFactory(this);
        this.sameDiffFunctionDefinitionMap = new LinkedHashMap();
        this.sameDiffFunctionInstances = new LinkedHashMap();
        this.forwardVarForGrad = new LinkedHashMap();
        this.opsForResult = new IntArrayKeyMap();
        this.variableNameToShape = new LinkedHashMap();
        this.placeHolderOriginalShapes = new LinkedHashMap();
        this.placeHolderFunctions = new LinkedHashSet();
        this.baseNameForFunctionInstanceId = new LinkedHashMap();
        this.propertiesToResolve = new LinkedHashMap();
        this.propertiesForFunction = new LinkedHashMap();
        this.fieldVariableResolutionMapping = HashBasedTable.create();
    }

    public void addPropertyToResolve(DifferentialFunction differentialFunction, String str) {
        if (this.propertiesToResolve.containsKey(differentialFunction.getOwnName())) {
            this.propertiesToResolve.get(differentialFunction.getOwnName()).add(str);
            return;
        }
        ArrayList arrayList = new ArrayList();
        arrayList.add(str);
        this.propertiesToResolve.put(differentialFunction.getOwnName(), arrayList);
    }

    public List<String> propertiesToResolveForFunction(DifferentialFunction differentialFunction) {
        return !this.propertiesToResolve.containsKey(differentialFunction.getOwnName()) ? Collections.emptyList() : this.propertiesToResolve.get(differentialFunction.getOwnName());
    }

    private void addPropertyForFunction(DifferentialFunction differentialFunction, String str, Object obj) {
        if (!this.propertiesForFunction.containsKey(differentialFunction.getOwnName())) {
            LinkedHashMap linkedHashMap = new LinkedHashMap();
            linkedHashMap.put(str, obj);
            this.propertiesForFunction.put(differentialFunction.getOwnName(), linkedHashMap);
        } else {
            Map<String, Object> map = this.propertiesForFunction.get(differentialFunction.getOwnName());
            if (map.containsKey(str)) {
                throw new ND4JIllegalStateException("Attempting to override property " + str);
            }
            map.put(str, obj);
        }
    }

    public void addVariableMappingForField(DifferentialFunction differentialFunction, String str, String str2) {
        this.fieldVariableResolutionMapping.put(differentialFunction.getOwnName(), str, str2);
    }

    public String getVarNameForFieldAndFunction(DifferentialFunction differentialFunction, String str) {
        return (String) this.fieldVariableResolutionMapping.get(differentialFunction.getOwnName(), str);
    }

    public void setBaseNameForFunctionInstanceId(String str, DifferentialFunction differentialFunction) {
        this.baseNameForFunctionInstanceId.put(differentialFunction.getOwnName(), str);
    }

    public String getBaseNameForFunction(DifferentialFunction differentialFunction) {
        return this.baseNameForFunctionInstanceId.get(differentialFunction.getOwnName());
    }

    public <X extends SDVariable> X setupFunction(X x) {
        Preconditions.checkNotNull(x, "Passed in function must not be null!");
        if (!(x instanceof SDVariable)) {
            return x;
        }
        if (x.getSameDiff() != this) {
            x.setSameDiff(this);
        }
        return x;
    }

    public void addOutgoingFor(SDVariable[] sDVariableArr, DifferentialFunction differentialFunction) {
        String[] strArr = new String[sDVariableArr.length];
        for (int i = 0; i < strArr.length; i++) {
            strArr[i] = sDVariableArr[i].getVarName();
        }
        addOutgoingFor(strArr, differentialFunction);
    }

    public void addOutgoingFor(String[] strArr, DifferentialFunction differentialFunction) {
        if (differentialFunction.getOwnName() == null) {
            throw new ND4JIllegalStateException("Instance id can not be null. Function not initialized properly");
        }
        if (this.ops.get(differentialFunction.getOwnName()).getOutputsOfOp() != null && !this.ops.get(differentialFunction.getOwnName()).getOutputsOfOp().isEmpty()) {
            throw new ND4JIllegalStateException("Outgoing arguments already declared for " + differentialFunction);
        }
        if (strArr == null) {
            throw new ND4JIllegalStateException("Var names can not be null!");
        }
        for (String str : strArr) {
            if (str == null) {
                throw new ND4JIllegalStateException("Variable name elements can not be null!");
            }
        }
        this.ops.get(differentialFunction.getOwnName()).setOutputsOfOp(Arrays.asList(strArr));
        for (String str2 : strArr) {
            this.variables.get(str2).setOutputOfOp(differentialFunction.getOwnName());
        }
    }

    public void addArgsFor(String[] strArr, DifferentialFunction differentialFunction) {
        if (differentialFunction.getOwnName() == null) {
            throw new ND4JIllegalStateException("Instance id can not be null. Function not initialized properly");
        }
        for (String str : strArr) {
            if (isPlaceHolder(str)) {
                this.placeHolderFunctions.add(differentialFunction.getOwnName());
            }
        }
        if (!this.ops.containsKey(differentialFunction.getOwnName())) {
            this.ops.put(differentialFunction.getOwnName(), SameDiffOp.builder().name(differentialFunction.getOwnName()).op(differentialFunction).build());
        }
        this.ops.get(differentialFunction.getOwnName()).setInputsToOp(Arrays.asList(strArr));
        for (String str2 : strArr) {
            List<String> inputsForOp = this.variables.get(str2).getInputsForOp();
            if (inputsForOp == null) {
                inputsForOp = new ArrayList();
                this.variables.get(str2).setInputsForOp(inputsForOp);
            }
            if (!inputsForOp.contains(differentialFunction.getOwnName())) {
                inputsForOp.add(differentialFunction.getOwnName());
            }
        }
    }

    public void addArgsFor(SDVariable[] sDVariableArr, DifferentialFunction differentialFunction) {
        String[] strArr = new String[sDVariableArr.length];
        for (int i = 0; i < strArr.length; i++) {
            if (sDVariableArr[i] == null) {
                throw new ND4JIllegalStateException("Found null variable at index " + i);
            }
            strArr[i] = sDVariableArr[i].getVarName();
        }
        addArgsFor(strArr, differentialFunction);
    }

    public DifferentialFunction getVariableOutputFunction(String str) {
        Preconditions.checkState(this.variables.containsKey(str), "No variable with name \"%s\" found in graph", str);
        if (this.variables.get(str).getOutputOfOp() == null) {
            return null;
        }
        return this.ops.get(this.variables.get(str).getOutputOfOp()).getOp();
    }

    public boolean hasArgs(DifferentialFunction differentialFunction) {
        List<String> inputsToOp = this.ops.get(differentialFunction.getOwnName()).getInputsToOp();
        return inputsToOp != null && inputsToOp.size() > 0;
    }

    public DifferentialFunction[] functions() {
        ArrayList arrayList = new ArrayList(this.ops.size());
        Iterator<SameDiffOp> it = this.ops.values().iterator();
        while (it.hasNext()) {
            arrayList.add(it.next().getOp());
        }
        return (DifferentialFunction[]) arrayList.toArray(new DifferentialFunction[arrayList.size()]);
    }

    public int hashCode() {
        return (31 * super.hashCode()) + (this.variables != null ? this.variables.hashCode() : 0);
    }

    public static SameDiff create(SameDiff sameDiff) {
        SameDiff build = builder().sameDiffFunctionInstances(sameDiff.sameDiffFunctionInstances).build();
        build.variables.putAll(sameDiff.variables);
        build.functionFactory = new DifferentialFunctionFactory(build);
        return build;
    }

    public boolean equals(Object obj) {
        if (this == obj) {
            return true;
        }
        if (obj == null || getClass() != obj.getClass()) {
            return false;
        }
        SameDiff sameDiff = (SameDiff) obj;
        if (this.variables != null) {
            if (!this.variables.equals(sameDiff.variables)) {
                return false;
            }
        } else if (sameDiff.variables != null) {
            return false;
        }
        if (this.sameDiffFunctionDefinitionMap != null) {
            if (!this.sameDiffFunctionDefinitionMap.equals(sameDiff.sameDiffFunctionDefinitionMap)) {
                return false;
            }
        } else if (sameDiff.sameDiffFunctionDefinitionMap != null) {
            return false;
        }
        return this.sameDiffFunctionInstances != null ? this.sameDiffFunctionInstances.equals(sameDiff.sameDiffFunctionInstances) : sameDiff.sameDiffFunctionInstances == null;
    }

    public static SameDiff create() {
        return new SameDiff();
    }

    public SameDiff dup() {
        SameDiff sameDiff = (SameDiff) newCloner().deepClone(this);
        sameDiff.sessions.clear();
        return sameDiff;
    }

    public long numElements() {
        long j = 0;
        Iterator<SDVariable> it = variables().iterator();
        while (it.hasNext()) {
            if (it.next().getShape() != null) {
                j += ArrayUtil.prod(r0);
            }
        }
        return j;
    }

    public List<String> inputs() {
        ArrayList arrayList = new ArrayList();
        for (String str : this.variables.keySet()) {
            if (isPlaceHolder(str)) {
                arrayList.add(str);
            }
        }
        return arrayList;
    }

    public List<String> outputs() {
        ArrayList arrayList = new ArrayList();
        for (Variable variable : this.variables.values()) {
            if (!variable.getVariable().isConstant() && !variable.getVariable().isPlaceHolder() && (variable.getInputsForOp() == null || variable.getInputsForOp().isEmpty())) {
                if (variable.getControlDepsForOp() == null || variable.getControlDepsForOp().isEmpty()) {
                    if (variable.getControlDepsForVar() == null || variable.getControlDepsForVar().isEmpty()) {
                        if (variable.getOutputOfOp() != null) {
                            SameDiffOp sameDiffOp = this.ops.get(variable.getOutputOfOp());
                            if (!(sameDiffOp.getOp() instanceof Assert) && !(sameDiffOp.getOp() instanceof Switch)) {
                            }
                        }
                        arrayList.add(variable.getName());
                    }
                }
            }
        }
        return arrayList;
    }

    public List<SDVariable> variables() {
        return new ArrayList(variableMap().values());
    }

    public List<String> getLossVariables() {
        return Collections.unmodifiableList(this.lossVariables);
    }

    public void setLossVariables(String... strArr) {
        this.lossVariables.clear();
        for (String str : strArr) {
            addLossVariable(str);
        }
        this.sameDiffFunctionInstances.remove("grad");
    }

    public void addLossVariable(@NonNull String str) {
        if (str == null) {
            throw new NullPointerException("variableName is marked @NonNull but is null");
        }
        Preconditions.checkState(hasVariable(str), "No variable with name \"%s\" exists", str);
        SDVariable variable = getVariable(str);
        Preconditions.checkState(variable.dataType().isFPType(), "Only floating point type variables can be marked as losses to be minimized. SDVariable \"%s\" has datatype %s", str, variable.dataType());
        Preconditions.checkState(variable.getVariableType() == VariableType.ARRAY, "Only ARRAY type SDVariables can be marked as losses to be minimized. SDVariable \"%s\" has variable type %s", str, variable.getVariableType());
        if (this.lossVariables.contains(str)) {
            return;
        }
        this.lossVariables.add(str);
    }

    public void setTrainingConfig(TrainingConfig trainingConfig) {
        this.trainingConfig = trainingConfig;
    }

    public void fit(DataSet dataSet) {
        fit(new SingletonMultiDataSetIterator(dataSet.toMultiDataSet()), 1, false);
    }

    public void fit(MultiDataSet multiDataSet) {
        fit(new SingletonMultiDataSetIterator(multiDataSet), 1, false);
    }

    public void fit(DataSetIterator dataSetIterator, int i) {
        fit(new MultiDataSetIteratorAdapter(dataSetIterator), i, true);
    }

    public void fit(MultiDataSetIterator multiDataSetIterator, int i) {
        fit(multiDataSetIterator, i, true);
    }

    protected synchronized void fit(MultiDataSetIterator multiDataSetIterator, int i, boolean z) {
        Preconditions.checkNotNull(multiDataSetIterator, "Iterator must not be null");
        Preconditions.checkState(i > 0, "Number of training epochs must be a positive number. Got: %s", i);
        Preconditions.checkState(this.trainingConfig != null, "No training configuration has been set. A training configuration must be set before training. Use setTrainingConfig(TrainingConfig)");
        Preconditions.checkState(i == 1 || multiDataSetIterator.resetSupported(), "Cannot train for multiple epochs on an iterator that does not support resetting");
        if (!multiDataSetIterator.hasNext() && multiDataSetIterator.resetSupported()) {
            multiDataSetIterator.reset();
        }
        boolean z2 = false;
        for (int i2 = 0; i2 < i; i2++) {
            while (multiDataSetIterator.hasNext()) {
                MultiDataSet next = multiDataSetIterator.next();
                if (!z2) {
                    Preconditions.checkState(this.trainingConfig.getDataSetFeatureMapping().size() == next.numFeatureArrays(), "The number of dataset feature mapping variables set in the training configuration (%s) must match the number of dataset feature arrays (%s)", this.trainingConfig.getDataSetFeatureMapping().size(), next.numFeatureArrays());
                    List<String> dataSetLabelMapping = this.trainingConfig.getDataSetLabelMapping();
                    int size = dataSetLabelMapping == null ? 0 : dataSetLabelMapping.size();
                    Preconditions.checkState(size == next.numLabelsArrays(), "The number of dataset label mapping variables set in the training configuration (%s) must match the number of dataset label arrays (%s)", size, next.numLabelsArrays());
                    z2 = true;
                }
                Map<String, INDArray> placeholderMap = toPlaceholderMap(next);
                Preconditions.checkState(placeholderMap.size() > 0, "No placeholder variables were set for training");
                resolveVariablesWith(placeholderMap);
                execBackwards(placeholderMap);
                if (!this.initializedTraining) {
                    initializeTraining();
                }
                int iterationCount = this.trainingConfig.getIterationCount();
                int epochCount = this.trainingConfig.getEpochCount();
                for (String str : this.trainingConfig.getTrainableParams()) {
                    INDArray arr = this.variables.get(str).getVariable().getArr();
                    SDVariable gradient = this.variables.get(str).getVariable().getGradient();
                    if (gradient != null) {
                        INDArray arr2 = gradient.getArr();
                        List<Regularization> regularization = this.trainingConfig.getRegularization();
                        int iterationCount2 = this.trainingConfig.getIterationCount();
                        int epochCount2 = this.trainingConfig.getEpochCount();
                        double learningRate = this.trainingConfig.getUpdater().hasLearningRate() ? this.trainingConfig.getUpdater().getLearningRate(iterationCount, epochCount2) : 1.0d;
                        if (regularization != null && regularization.size() > 0) {
                            for (Regularization regularization2 : regularization) {
                                if (regularization2.applyStep() == Regularization.ApplyStep.BEFORE_UPDATER) {
                                    regularization2.apply(arr, arr2, learningRate, iterationCount2, epochCount2);
                                }
                            }
                        }
                        INDArray newShapeNoCopy = Shape.newShapeNoCopy(arr2, new long[]{1, arr2.length()}, arr2.ordering() == 'f');
                        Preconditions.checkState(newShapeNoCopy != null, "Error reshaping array for parameter \"%s\": array is a view?", str);
                        GradientUpdater gradientUpdater = this.updaterMap.get(str);
                        try {
                            gradientUpdater.applyUpdater(newShapeNoCopy, iterationCount, epochCount);
                            if (regularization != null && regularization.size() > 0) {
                                for (Regularization regularization3 : regularization) {
                                    if (regularization3.applyStep() == Regularization.ApplyStep.POST_UPDATER) {
                                        regularization3.apply(arr, arr2, learningRate, iterationCount2, epochCount2);
                                    }
                                }
                            }
                            if (this.trainingConfig.isMinimize()) {
                                arr.subi(arr2);
                            } else {
                                arr.addi(arr2);
                            }
                        } catch (Throwable th) {
                            throw new RuntimeException("Error applying updater " + gradientUpdater.getClass().getSimpleName() + " to parameter \"" + str + "\": either parameter size is inconsistent between iterations, or \"" + str + "\" should not be a trainable parameter?", th);
                        }
                    }
                }
                this.trainingConfig.incrementIterationCount();
            }
            if (i2 < i - 1) {
                multiDataSetIterator.reset();
            }
            if (z) {
                this.trainingConfig.incrementEpochCount();
            }
        }
    }

    public double calcRegularizationScore() {
        Preconditions.checkState(this.trainingConfig != null, "No training configuration has been set. A training configuration must be set before calculating the L2 loss. Use setTrainingConfig(TrainingConfig)");
        if (this.trainingConfig.getRegularization() == null || this.trainingConfig.getRegularization().isEmpty()) {
            return 0.0d;
        }
        if (this.trainingConfig.getTrainableParams() == null || this.trainingConfig.getTrainableParams().isEmpty()) {
            initializeTraining();
        }
        List<Regularization> regularization = this.trainingConfig.getRegularization();
        double d = 0.0d;
        for (String str : this.trainingConfig.getTrainableParams()) {
            Iterator<Regularization> it = regularization.iterator();
            while (it.hasNext()) {
                d += it.next().score(getVariable(str).getArr(), this.trainingConfig.getIterationCount(), this.trainingConfig.getEpochCount());
            }
        }
        return d;
    }

    protected void initializeTraining() {
        if (this.initializedTraining) {
            return;
        }
        if (this.trainingConfig == null) {
            throw new ND4JIllegalStateException("Please specify a training config with setTrainingConfig");
        }
        if (this.trainingConfig.getTrainableParams() == null || this.trainingConfig.getTrainableParams().size() == 0) {
            ArrayList arrayList = new ArrayList();
            Iterator<Variable> it = this.variables.values().iterator();
            while (it.hasNext()) {
                String varName = it.next().getVariable().getVarName();
                if (this.variables.get(varName).getOutputOfOp() == null && !isPlaceHolder(varName) && !this.variables.get(varName).getVariable().isConstant() && (this.trainingConfig.getDataSetFeatureMapping() == null || !this.trainingConfig.getDataSetFeatureMapping().contains(varName))) {
                    if (this.trainingConfig.getDataSetLabelMapping() == null || !this.trainingConfig.getDataSetLabelMapping().contains(varName)) {
                        if (this.trainingConfig.getDataSetFeatureMaskMapping() == null || !this.trainingConfig.getDataSetFeatureMaskMapping().contains(varName)) {
                            if (this.trainingConfig.getDataSetLabelMaskMapping() == null || !this.trainingConfig.getDataSetLabelMaskMapping().contains(varName)) {
                                arrayList.add(varName);
                            }
                        }
                    }
                }
            }
            this.trainingConfig.setTrainableParams(arrayList);
            log.info("Inferred trainable variables: {}", arrayList);
        }
        long j = 0;
        DataType dataType = null;
        for (String str : this.trainingConfig.getTrainableParams()) {
            SDVariable variable = this.variables.get(str).getVariable();
            Preconditions.checkState(variable != null, "No variable found for trainable parameter name \"%s\"", str);
            INDArray arr = variable.getArr();
            Preconditions.checkState(arr != null, "No array found for trainable parameter \"%s\"", str);
            j += arr.length();
            if (dataType == null) {
                dataType = arr.dataType();
            }
        }
        long stateSize = this.trainingConfig.getUpdater().stateSize(j);
        if (stateSize > 0) {
            MemoryWorkspace scopeOutOfWorkspaces = Nd4j.getMemoryManager().scopeOutOfWorkspaces();
            Throwable th = null;
            try {
                this.updaterState = Nd4j.createUninitialized(dataType, 1, stateSize);
                if (scopeOutOfWorkspaces != null) {
                    if (0 != 0) {
                        try {
                            scopeOutOfWorkspaces.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    } else {
                        scopeOutOfWorkspaces.close();
                    }
                }
            } catch (Throwable th3) {
                if (scopeOutOfWorkspaces != null) {
                    if (0 != 0) {
                        try {
                            scopeOutOfWorkspaces.close();
                        } catch (Throwable th4) {
                            th.addSuppressed(th4);
                        }
                    } else {
                        scopeOutOfWorkspaces.close();
                    }
                }
                throw th3;
            }
        }
        long j2 = 0;
        this.updaterViews = new HashMap();
        this.updaterMap = new HashMap();
        for (String str2 : this.trainingConfig.getTrainableParams()) {
            long stateSize2 = this.trainingConfig.getUpdater().stateSize(this.variables.get(str2).getVariable().getArr().length());
            INDArray iNDArray = (stateSize == 0 || stateSize2 == 0) ? null : this.updaterState.get(NDArrayIndex.interval(0, 1), NDArrayIndex.interval(j2, j2 + stateSize2));
            this.updaterViews.put(str2, iNDArray);
            this.updaterMap.put(str2, this.trainingConfig.getUpdater().instantiate(iNDArray, true));
            j2 += stateSize2;
        }
        this.initializedTraining = true;
    }

    private Map<String, INDArray> toPlaceholderMap(MultiDataSet multiDataSet) {
        HashMap hashMap = new HashMap();
        int i = 0;
        Iterator<String> it = this.trainingConfig.getDataSetFeatureMapping().iterator();
        while (it.hasNext()) {
            int i2 = i;
            i++;
            hashMap.put(it.next(), multiDataSet.getFeatures(i2));
        }
        int i3 = 0;
        if (this.trainingConfig.getDataSetLabelMapping() != null) {
            Iterator<String> it2 = this.trainingConfig.getDataSetLabelMapping().iterator();
            while (it2.hasNext()) {
                int i4 = i3;
                i3++;
                hashMap.put(it2.next(), multiDataSet.getLabels(i4));
            }
        }
        if (this.trainingConfig.getDataSetFeatureMaskMapping() != null && this.trainingConfig.getDataSetFeatureMaskMapping().size() > 0) {
            int i5 = 0;
            for (String str : this.trainingConfig.getDataSetFeatureMaskMapping()) {
                if (str == null) {
                    i5++;
                } else {
                    int i6 = i5;
                    i5++;
                    hashMap.put(str, multiDataSet.getFeaturesMaskArray(i6));
                }
            }
        }
        if (this.trainingConfig.getDataSetLabelMaskMapping() != null && this.trainingConfig.getDataSetLabelMaskMapping().size() > 0) {
            int i7 = 0;
            for (String str2 : this.trainingConfig.getDataSetLabelMaskMapping()) {
                if (str2 == null) {
                    i7++;
                } else {
                    int i8 = i7;
                    i7++;
                    hashMap.put(str2, multiDataSet.getLabelsMaskArray(i8));
                }
            }
        }
        return hashMap;
    }

    public void evaluate(DataSetIterator dataSetIterator, String str, IEvaluation... iEvaluationArr) {
        Preconditions.checkArgument(iEvaluationArr != null && iEvaluationArr.length > 0, "No evaluations were passed to the evaluate method");
        evaluate(new MultiDataSetIteratorAdapter(dataSetIterator), Collections.singletonMap(str, Arrays.asList(iEvaluationArr)), Collections.singletonMap(str, 0));
    }

    public void evaluate(DataSetIterator dataSetIterator, Map<String, IEvaluation> map) {
        HashMap hashMap = new HashMap();
        HashMap hashMap2 = new HashMap();
        for (String str : map.keySet()) {
            hashMap.put(str, 0);
            hashMap2.put(str, Collections.singletonList(map.get(str)));
        }
        evaluate(new MultiDataSetIteratorAdapter(dataSetIterator), hashMap2, hashMap);
    }

    public void evaluateMultiple(DataSetIterator dataSetIterator, Map<String, List<IEvaluation>> map) {
        HashMap hashMap = new HashMap();
        Iterator<String> it = map.keySet().iterator();
        while (it.hasNext()) {
            hashMap.put(it.next(), 0);
        }
        evaluate(new MultiDataSetIteratorAdapter(dataSetIterator), map, hashMap);
    }

    public void evaluate(MultiDataSetIterator multiDataSetIterator, String str, int i, IEvaluation... iEvaluationArr) {
        Preconditions.checkArgument(iEvaluationArr != null && iEvaluationArr.length > 0, "No evaluations were passed to the evaluate method");
        evaluate(multiDataSetIterator, Collections.singletonMap(str, Arrays.asList(iEvaluationArr)), Collections.singletonMap(str, Integer.valueOf(i)));
    }

    public void evaluate(MultiDataSetIterator multiDataSetIterator, Map<String, List<IEvaluation>> map, Map<String, Integer> map2) {
        Preconditions.checkState(this.trainingConfig != null, "Training config has not been set");
        Preconditions.checkState(map.keySet().equals(map2.keySet()), "Keysets for variable evaluations and for the prediction label mapping must be equal. Keys for variables to evaluate: %s vs. keys for label mapping: %s", map.keySet(), map2.keySet());
        if (!multiDataSetIterator.hasNext() && multiDataSetIterator.resetSupported()) {
            multiDataSetIterator.reset();
        }
        ArrayList arrayList = new ArrayList(map.keySet());
        while (multiDataSetIterator.hasNext()) {
            MultiDataSet next = multiDataSetIterator.next();
            Map<String, INDArray> exec = exec(toPlaceholderMap(next), arrayList);
            for (Map.Entry<String, List<IEvaluation>> entry : map.entrySet()) {
                INDArray iNDArray = exec.get(entry.getKey());
                Iterator<IEvaluation> it = entry.getValue().iterator();
                while (it.hasNext()) {
                    it.next().eval(next.getLabels(map2.get(entry.getKey()).intValue()), iNDArray);
                }
            }
        }
    }

    public Map<String, INDArray> output(DataSet dataSet, String... strArr) {
        return output(new SingletonMultiDataSetIterator(dataSet.toMultiDataSet()), strArr).get(0);
    }

    public List<Map<String, INDArray>> output(DataSetIterator dataSetIterator, String... strArr) {
        return output(new MultiDataSetIteratorAdapter(dataSetIterator), strArr);
    }

    public List<Map<String, INDArray>> output(MultiDataSetIterator multiDataSetIterator, String... strArr) {
        Preconditions.checkState(this.trainingConfig != null, "Training config has not been set");
        List<String> asList = strArr != null ? Arrays.asList(strArr) : outputs();
        ArrayList arrayList = new ArrayList();
        if (!multiDataSetIterator.hasNext() && multiDataSetIterator.resetSupported()) {
            multiDataSetIterator.reset();
        }
        while (multiDataSetIterator.hasNext()) {
            arrayList.add(exec(toPlaceholderMap(multiDataSetIterator.next()), asList));
        }
        return arrayList;
    }

    public SDVariable one(String str, int... iArr) {
        return one(str, Nd4j.defaultFloatingPointType(), iArr);
    }

    public SDVariable one(String str, long... jArr) {
        return one(str, Nd4j.defaultFloatingPointType(), jArr);
    }

    public SDVariable one(String str, DataType dataType, int... iArr) {
        return var(str, new ConstantInitScheme('f', 1.0d), dataType, ArrayUtil.toLongArray(iArr));
    }

    public SDVariable one(String str, DataType dataType, long... jArr) {
        return var(str, new ConstantInitScheme('f', 1.0d), dataType, jArr);
    }

    public SDVariable zero(String str, long... jArr) {
        return zero(str, Nd4j.defaultFloatingPointType(), jArr);
    }

    public SDVariable zero(String str, int... iArr) {
        return zero(str, Nd4j.defaultFloatingPointType(), iArr);
    }

    public SDVariable zero(String str, DataType dataType, long... jArr) {
        return var(str, new ZeroInitScheme(), dataType, jArr);
    }

    public SDVariable zero(String str, DataType dataType, int... iArr) {
        return var(str, new ZeroInitScheme(), dataType, ArrayUtil.toLongArray(iArr));
    }

    public SDVariable constant(@NonNull INDArray iNDArray) {
        if (iNDArray == null) {
            throw new NullPointerException("constant is marked @NonNull but is null");
        }
        return constant(getNewVarName(), iNDArray);
    }

    public SDVariable constant(String str, @NonNull INDArray iNDArray) {
        if (iNDArray == null) {
            throw new NullPointerException("constant is marked @NonNull but is null");
        }
        Preconditions.checkState(!this.variables.containsKey(str), "Variable with name \"%s\" already exists", str);
        if (str == null || str.length() < 1) {
            str = getNewVarName();
        }
        SDVariable sDVariable = new SDVariable(str, VariableType.CONSTANT, this, iNDArray.shape(), iNDArray.dataType(), null);
        this.variables.put(str, Variable.builder().name(str).variable(sDVariable).build());
        this.constantArrays.put(str, new DeviceLocalNDArray(iNDArray));
        return sDVariable;
    }

    @Deprecated
    public SDVariable constant(SDVariable sDVariable, long... jArr) {
        return constant((String) null, sDVariable, jArr);
    }

    @Deprecated
    public SDVariable constant(String str, SDVariable sDVariable, long... jArr) {
        return updateVariableNameAndReference(f().constant(sDVariable, jArr), str);
    }

    public SDVariable placeHolder(String str, DataType dataType, long... jArr) {
        SDVariable sDVariable = new SDVariable(str, VariableType.PLACEHOLDER, this, jArr, dataType, null);
        this.variables.put(str, Variable.builder().name(str).variable(sDVariable).build());
        return sDVariable;
    }

    public SDVariable var(@NonNull String str, @NonNull WeightInitScheme weightInitScheme, @NonNull DataType dataType, @NonNull long... jArr) {
        if (str == null) {
            throw new NullPointerException("name is marked @NonNull but is null");
        }
        if (weightInitScheme == null) {
            throw new NullPointerException("weightInitScheme is marked @NonNull but is null");
        }
        if (dataType == null) {
            throw new NullPointerException("dataType is marked @NonNull but is null");
        }
        if (jArr == null) {
            throw new NullPointerException("shape is marked @NonNull but is null");
        }
        return var(str, VariableType.VARIABLE, weightInitScheme, dataType, jArr);
    }

    public SDVariable var(@NonNull String str, @NonNull VariableType variableType, WeightInitScheme weightInitScheme, DataType dataType, long... jArr) {
        if (str == null) {
            throw new NullPointerException("name is marked @NonNull but is null");
        }
        if (variableType == null) {
            throw new NullPointerException("variableType is marked @NonNull but is null");
        }
        if (this.variables.containsKey(str) && this.variables.get(str).getVariable().getArr() != null) {
            throw new IllegalArgumentException("Another variable with the name " + str + " already exists.");
        }
        if (str == null || str.length() < 1) {
            str = getNewVarName();
        }
        SDVariable sDVariable = new SDVariable(str, variableType, this, jArr, dataType, weightInitScheme);
        addVariable(sDVariable);
        if (variableType == VariableType.PLACEHOLDER) {
            setOriginalPlaceHolderShape(str, jArr);
            putShapeForVarName(str, jArr);
        }
        return sDVariable;
    }

    public SDVariable var(@NonNull String str, @NonNull LongShapeDescriptor longShapeDescriptor, WeightInitScheme weightInitScheme) {
        if (str == null) {
            throw new NullPointerException("name is marked @NonNull but is null");
        }
        if (longShapeDescriptor == null) {
            throw new NullPointerException("shape is marked @NonNull but is null");
        }
        return var(str, weightInitScheme, longShapeDescriptor.dataType(), longShapeDescriptor.getShape());
    }

    public SDVariable var(String str, DataType dataType, long... jArr) {
        Preconditions.checkNotNull(Boolean.valueOf(jArr != null), "Invalid shape: shape may not be null");
        return Shape.isPlaceholderShape(jArr) ? placeHolder(str, dataType, jArr) : var(str, new ZeroInitScheme(), dataType, jArr);
    }

    public SDVariable var(String str, LongShapeDescriptor longShapeDescriptor) {
        Preconditions.checkNotNull(Boolean.valueOf(longShapeDescriptor != null), "Invalid shape: shape may not be null");
        return var(str, longShapeDescriptor, new ZeroInitScheme());
    }

    public SDVariable var(String str, int... iArr) {
        return var(str, Nd4j.defaultFloatingPointType(), iArr);
    }

    public SDVariable var(String str, long... jArr) {
        return var(str, Nd4j.defaultFloatingPointType(), jArr);
    }

    public SDVariable var(String str, DataType dataType, int... iArr) {
        Preconditions.checkNotNull(iArr, "Invalid shape: shape may not be null");
        return Shape.isPlaceholderShape(iArr) ? placeHolder(str, dataType, ArrayUtil.toLongArray(iArr)) : var(str, new ZeroInitScheme(), dataType, ArrayUtil.toLongArray(iArr));
    }

    public SDVariable var(@NonNull SDVariable sDVariable) {
        if (sDVariable == null) {
            throw new NullPointerException("v is marked @NonNull but is null");
        }
        if (this.variables.containsKey(sDVariable.getVarName()) && this.variables.get(sDVariable.getVarName()).getVariable().getArr() != null) {
            return this.variables.get(sDVariable.getVarName()).getVariable();
        }
        if (sDVariable.getVarName() == null || sDVariable.getVarName().length() < 1) {
            throw new IllegalArgumentException("Name for variable must be defined");
        }
        VariableType variableType = sDVariable.getVariableType();
        NDArraySupplierInitScheme nDArraySupplierInitScheme = null;
        switch (variableType) {
            case VARIABLE:
                nDArraySupplierInitScheme = new NDArraySupplierInitScheme(sDVariable.getArr());
                break;
            case ARRAY:
                break;
            case CONSTANT:
                return constant(sDVariable.getVarName(), sDVariable.getArr());
            case PLACEHOLDER:
                return placeHolder(sDVariable.getVarName(), sDVariable.dataType(), sDVariable.placeholderShape());
            default:
                throw new RuntimeException("Unknown/not supported variable type: " + variableType);
        }
        return addVariable(new SDVariable(sDVariable.getVarName(), sDVariable.getVariableType(), this, sDVariable.getShape(), sDVariable.dataType(), nDArraySupplierInitScheme));
    }

    private String getNewVarName() {
        String str = "sd_var_" + String.valueOf(this.variableId);
        while (true) {
            String str2 = str;
            if (!this.variables.containsKey(str2)) {
                return str2;
            }
            this.variableId++;
            str = "sd_var_" + String.valueOf(this.variableId);
        }
    }

    public SDVariable var(DataType dataType, int... iArr) {
        return var(getNewVarName(), dataType, iArr);
    }

    public SDVariable var(DataType dataType, long... jArr) {
        return var(getNewVarName(), dataType, jArr);
    }

    public SDVariable var(WeightInitScheme weightInitScheme, DataType dataType, long... jArr) {
        return var(getNewVarName(), weightInitScheme, dataType, jArr);
    }

    public SDVariable var(INDArray iNDArray) {
        return var(getNewVarName(), iNDArray);
    }

    public SDVariable var(String str, @NonNull INDArray iNDArray) {
        if (iNDArray == null) {
            throw new NullPointerException("arr is marked @NonNull but is null");
        }
        if (this.variables.containsKey(str) && this.variables.get(str).getVariable().getArr() != null) {
            throw new IllegalArgumentException("Another variable with the name " + str + " already exists.");
        }
        Preconditions.checkState(iNDArray.dataType().isFPType(), "Cannot create variable with non-floating point type: provided array has datatype %s. Variables must be floating point type to be trainable by backpropagation.\nFor non floating point types, these should be created as placeholders or constants instead.", iNDArray.dataType());
        if (str == null || str.length() < 1) {
            str = getNewVarName();
        }
        boolean z = false;
        if (iNDArray.isAttached()) {
            iNDArray = iNDArray.detach();
            z = true;
        }
        if (iNDArray.isView()) {
            iNDArray = iNDArray.dup();
            z = true;
        }
        if (!z) {
            Iterator<DeviceLocalNDArray> it = this.variablesArrays.values().iterator();
            while (true) {
                if (!it.hasNext()) {
                    break;
                }
                if (it.next().get() == iNDArray) {
                    iNDArray = iNDArray.dup();
                    break;
                }
            }
        }
        SDVariable sDVariable = new SDVariable(str, VariableType.VARIABLE, this, iNDArray.shape(), iNDArray.dataType(), new NDArraySupplierInitScheme(iNDArray));
        associateArrayWithVariable(iNDArray, sDVariable);
        if (ArrayUtil.prod(iNDArray.shape()) == 1) {
            MemoryWorkspace scopeOutOfWorkspaces = Nd4j.getMemoryManager().scopeOutOfWorkspaces();
            Throwable th = null;
            try {
                try {
                    sDVariable.setScalarValue(Nd4j.scalar(iNDArray.getDouble(0L)));
                    if (scopeOutOfWorkspaces != null) {
                        if (0 != 0) {
                            try {
                                scopeOutOfWorkspaces.close();
                            } catch (Throwable th2) {
                                th.addSuppressed(th2);
                            }
                        } else {
                            scopeOutOfWorkspaces.close();
                        }
                    }
                } finally {
                }
            } catch (Throwable th3) {
                if (scopeOutOfWorkspaces != null) {
                    if (th != null) {
                        try {
                            scopeOutOfWorkspaces.close();
                        } catch (Throwable th4) {
                            th.addSuppressed(th4);
                        }
                    } else {
                        scopeOutOfWorkspaces.close();
                    }
                }
                throw th3;
            }
        }
        addVariable(sDVariable);
        if (getShapeForVarName(str) == null) {
            putShapeForVarName(str, iNDArray.shape());
        }
        return sDVariable;
    }

    public SDVariable convertToConstant(@NonNull SDVariable sDVariable) {
        if (sDVariable == null) {
            throw new NullPointerException("variable is marked @NonNull but is null");
        }
        convertToConstants(Collections.singletonList(sDVariable));
        return sDVariable;
    }

    public void convertToConstants(List<SDVariable> list) {
        if (list.size() == 0) {
            return;
        }
        boolean z = true;
        for (SDVariable sDVariable : list) {
            if (sDVariable.getVariableType() != VariableType.CONSTANT) {
                z = false;
                Preconditions.checkState(sDVariable.getVariableType() != VariableType.ARRAY, "Cannot convert variable of type ARRAY to a constant: %s", sDVariable);
            }
        }
        if (z) {
            return;
        }
        this.sessions.clear();
        this.sameDiffFunctionInstances.remove("grad");
        for (SDVariable sDVariable2 : list) {
            String varName = sDVariable2.getVarName();
            INDArray arr = sDVariable2.getArr();
            Preconditions.checkNotNull(arr, "Could not get array for variable %s: if this is a placeholder, use SDVariable.setArray before converting", sDVariable2);
            this.constantArrays.put(varName, new DeviceLocalNDArray(arr));
            this.variablesArrays.remove(varName);
            if (!this.placeholdersPerThread.isEmpty()) {
                Iterator<Map<String, INDArray>> it = this.placeholdersPerThread.values().iterator();
                while (it.hasNext()) {
                    it.next().remove(varName);
                }
            }
            sDVariable2.setVariableType(VariableType.CONSTANT);
        }
        if (this.trainingConfig != null) {
            HashSet hashSet = new HashSet();
            boolean z2 = false;
            List<String> trainableParams = this.trainingConfig.getTrainableParams();
            for (SDVariable sDVariable3 : list) {
                hashSet.add(sDVariable3.getVarName());
                if (!z2 && trainableParams.contains(sDVariable3.getVarName())) {
                    z2 = true;
                }
            }
            if (z2) {
                ArrayList arrayList = new ArrayList();
                for (String str : trainableParams) {
                    if (!hashSet.contains(str)) {
                        arrayList.add(str);
                    }
                }
                this.trainingConfig.setTrainableParams(arrayList);
            }
            if (this.initializedTraining) {
                ArrayList arrayList2 = new ArrayList();
                for (String str2 : trainableParams) {
                    INDArray iNDArray = this.updaterViews.get(str2);
                    if (!hashSet.contains(str2)) {
                        arrayList2.add(iNDArray);
                    }
                }
                this.updaterState = arrayList2.isEmpty() ? null : Nd4j.concat(0, (INDArray[]) arrayList2.toArray(new INDArray[arrayList2.size()]));
                long j = 0;
                this.updaterViews = new HashMap();
                this.updaterMap = new HashMap();
                for (String str3 : this.trainingConfig.getTrainableParams()) {
                    long stateSize = this.trainingConfig.getUpdater().stateSize(this.variables.get(str3).getVariable().getArr().length());
                    INDArray iNDArray2 = (this.updaterState == null || stateSize == 0) ? null : this.updaterState.get(NDArrayIndex.interval(0, 1), NDArrayIndex.interval(j, j + stateSize));
                    this.updaterViews.put(str3, iNDArray2);
                    this.updaterMap.put(str3, this.trainingConfig.getUpdater().instantiate(iNDArray2, false));
                    j += stateSize;
                }
            }
        }
    }

    public SDVariable convertToVariable(@NonNull SDVariable sDVariable) {
        if (sDVariable == null) {
            throw new NullPointerException("constant is marked @NonNull but is null");
        }
        Preconditions.checkState(sDVariable.dataType().isFPType(), "Only floating point SDVariables can be converted to variables, datatype of %s is %s", sDVariable.getVarName(), sDVariable.dataType());
        convertToVariables(Collections.singletonList(sDVariable));
        return sDVariable;
    }

    public void convertToVariables(@NonNull List<SDVariable> list) {
        if (list == null) {
            throw new NullPointerException("constants is marked @NonNull but is null");
        }
        if (list.size() == 0) {
            return;
        }
        boolean z = true;
        for (SDVariable sDVariable : list) {
            if (sDVariable.getVariableType() != VariableType.VARIABLE) {
                z = false;
            }
            Preconditions.checkState(sDVariable.getVariableType() != VariableType.ARRAY, "Cannot convert variable of type ARRAY to a variable: %s", sDVariable);
        }
        if (z) {
            return;
        }
        this.sessions.clear();
        this.sameDiffFunctionInstances.remove("grad");
        for (SDVariable sDVariable2 : list) {
            String varName = sDVariable2.getVarName();
            INDArray arr = sDVariable2.getArr();
            Preconditions.checkNotNull(arr, "Could not get array for variable %s: if this is a placeholder, use SDVariable.setArray before converting", sDVariable2);
            this.variablesArrays.put(varName, new DeviceLocalNDArray(arr));
            this.constantArrays.remove(varName);
            if (!this.placeholdersPerThread.isEmpty()) {
                Iterator<Map<String, INDArray>> it = this.placeholdersPerThread.values().iterator();
                while (it.hasNext()) {
                    it.next().remove(varName);
                }
            }
            sDVariable2.setVariableType(VariableType.VARIABLE);
        }
        if (this.trainingConfig != null) {
            ArrayList arrayList = new ArrayList(this.trainingConfig.getTrainableParams());
            ArrayList arrayList2 = new ArrayList();
            for (SDVariable sDVariable3 : list) {
                arrayList.add(sDVariable3.getVarName());
                arrayList2.add(sDVariable3.getVarName());
            }
            this.trainingConfig.setTrainableParams(arrayList);
            if (this.initializedTraining) {
                long j = 0;
                Iterator it2 = arrayList2.iterator();
                while (it2.hasNext()) {
                    j += this.trainingConfig.getUpdater().stateSize(getVariable((String) it2.next()).getArr().length());
                }
                if (j > 0) {
                    INDArray createUninitialized = Nd4j.createUninitialized(this.updaterState.dataType(), 1, j);
                    this.updaterState = this.updaterState == null ? createUninitialized : Nd4j.concat(1, this.updaterState, createUninitialized);
                    long j2 = 0;
                    this.updaterViews = new HashMap();
                    this.updaterMap = new HashMap();
                    for (String str : this.trainingConfig.getTrainableParams()) {
                        long stateSize = this.trainingConfig.getUpdater().stateSize(this.variables.get(str).getVariable().getArr().length());
                        INDArray iNDArray = (this.updaterState == null || stateSize == 0) ? null : this.updaterState.get(NDArrayIndex.interval(0, 1), NDArrayIndex.interval(j2, j2 + stateSize));
                        this.updaterViews.put(str, iNDArray);
                        this.updaterMap.put(str, this.trainingConfig.getUpdater().instantiate(iNDArray, arrayList2.contains(str)));
                        j2 += stateSize;
                    }
                }
            }
        }
    }

    public void removeArgFromFunction(String str, DifferentialFunction differentialFunction) {
        SDVariable[] args = differentialFunction.args();
        for (SDVariable sDVariable : args) {
            if (sDVariable.getVarName().equals(str)) {
                List<String> inputsToOp = this.ops.get(differentialFunction.getOwnName()).getInputsToOp();
                ArrayList arrayList = new ArrayList(args.length - 1);
                for (int i = 0; i < args.length; i++) {
                    if (!inputsToOp.get(i).equals(str)) {
                        arrayList.add(inputsToOp.get(i));
                    }
                }
                this.ops.get(differentialFunction.getOwnName()).setInputsToOp(arrayList);
                return;
            }
        }
    }

    public SDVariable getVariable(String str) {
        Variable variable = this.variables.get(str);
        if (variable == null) {
            return null;
        }
        return variable.getVariable();
    }

    public boolean hasVariable(String str) {
        return this.variables.containsKey(str);
    }

    public SDVariable getGradForVariable(String str) {
        Preconditions.checkState(this.variables.containsKey(str), "No variable with name \"%s\" exists", str);
        SDVariable variable = getVariable(str);
        Preconditions.checkState(variable.dataType().isFPType(), "Cannot get gradient of %s variable \"%s\": only floating point variables have gradients", str, variable.dataType());
        if (this.variables.containsKey(str) && this.variables.get(str).getGradient() != null) {
            return this.variables.get(str).getGradient();
        }
        if (this.sameDiffFunctionInstances.containsKey("grad") && this.sameDiffFunctionInstances.get("grad").variables.containsKey(str)) {
            return this.sameDiffFunctionInstances.get("grad").variables.get(str).getGradient();
        }
        return null;
    }

    public boolean variableHasGradient(String str) {
        Preconditions.checkState(this.variables.containsKey(str), "No variable with name \"%s\" exists", str);
        SDVariable variable = getVariable(str);
        return (!variable.dataType().isFPType() || variable.isConstant() || getGradForVariable(str) == null) ? false : true;
    }

    public void setGradientForVariableName(String str, SDVariable sDVariable) {
        Preconditions.checkState(this.variables.containsKey(str), "No variable exists with name \"%s\"", str);
        if (sDVariable == null) {
            throw new ND4JIllegalStateException("Unable to set null gradient for variable name " + str);
        }
        this.variables.get(str).setGradient(sDVariable);
    }

    public void setForwardVariableForVarName(String str, SDVariable sDVariable) {
        this.forwardVarForGrad.put(str, sDVariable);
    }

    public SDVariable grad(String str) {
        if (!this.sameDiffFunctionInstances.containsKey("grad")) {
            throw new IllegalStateException("Unable to obtain gradient. Please run execBackwards() first.");
        }
        return getFunction("grad").getGradForVariable(getFunction("grad").getVariable(str).getVarName());
    }

    public SDVariable scalar(String str, double d) {
        MemoryWorkspace scopeOutOfWorkspaces = Nd4j.getMemoryManager().scopeOutOfWorkspaces();
        Throwable th = null;
        try {
            try {
                SDVariable var = var(str, Nd4j.scalar(d));
                if (scopeOutOfWorkspaces != null) {
                    if (0 != 0) {
                        try {
                            scopeOutOfWorkspaces.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    } else {
                        scopeOutOfWorkspaces.close();
                    }
                }
                return var;
            } finally {
            }
        } catch (Throwable th3) {
            if (scopeOutOfWorkspaces != null) {
                if (th != null) {
                    try {
                        scopeOutOfWorkspaces.close();
                    } catch (Throwable th4) {
                        th.addSuppressed(th4);
                    }
                } else {
                    scopeOutOfWorkspaces.close();
                }
            }
            throw th3;
        }
    }

    public SDVariable scalar(String str, float f) {
        MemoryWorkspace scopeOutOfWorkspaces = Nd4j.getMemoryManager().scopeOutOfWorkspaces();
        Throwable th = null;
        try {
            try {
                SDVariable var = var(str, Nd4j.scalar(f));
                if (scopeOutOfWorkspaces != null) {
                    if (0 != 0) {
                        try {
                            scopeOutOfWorkspaces.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    } else {
                        scopeOutOfWorkspaces.close();
                    }
                }
                return var;
            } finally {
            }
        } catch (Throwable th3) {
            if (scopeOutOfWorkspaces != null) {
                if (th != null) {
                    try {
                        scopeOutOfWorkspaces.close();
                    } catch (Throwable th4) {
                        th.addSuppressed(th4);
                    }
                } else {
                    scopeOutOfWorkspaces.close();
                }
            }
            throw th3;
        }
    }

    public SDVariable scalar(String str, int i) {
        MemoryWorkspace scopeOutOfWorkspaces = Nd4j.getMemoryManager().scopeOutOfWorkspaces();
        Throwable th = null;
        try {
            try {
                SDVariable var = var(str, Nd4j.scalar(i));
                if (scopeOutOfWorkspaces != null) {
                    if (0 != 0) {
                        try {
                            scopeOutOfWorkspaces.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    } else {
                        scopeOutOfWorkspaces.close();
                    }
                }
                return var;
            } finally {
            }
        } catch (Throwable th3) {
            if (scopeOutOfWorkspaces != null) {
                if (th != null) {
                    try {
                        scopeOutOfWorkspaces.close();
                    } catch (Throwable th4) {
                        th.addSuppressed(th4);
                    }
                } else {
                    scopeOutOfWorkspaces.close();
                }
            }
            throw th3;
        }
    }

    public SDVariable scalar(String str, long j) {
        MemoryWorkspace scopeOutOfWorkspaces = Nd4j.getMemoryManager().scopeOutOfWorkspaces();
        Throwable th = null;
        try {
            try {
                SDVariable var = var(str, Nd4j.scalar(j));
                if (scopeOutOfWorkspaces != null) {
                    if (0 != 0) {
                        try {
                            scopeOutOfWorkspaces.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    } else {
                        scopeOutOfWorkspaces.close();
                    }
                }
                return var;
            } finally {
            }
        } catch (Throwable th3) {
            if (scopeOutOfWorkspaces != null) {
                if (th != null) {
                    try {
                        scopeOutOfWorkspaces.close();
                    } catch (Throwable th4) {
                        th.addSuppressed(th4);
                    }
                } else {
                    scopeOutOfWorkspaces.close();
                }
            }
            throw th3;
        }
    }

    public SDVariable scalar(String str, DataType dataType, Number number) {
        MemoryWorkspace scopeOutOfWorkspaces = Nd4j.getMemoryManager().scopeOutOfWorkspaces();
        Throwable th = null;
        try {
            try {
                SDVariable var = var(str, Nd4j.scalar(dataType, number));
                if (scopeOutOfWorkspaces != null) {
                    if (0 != 0) {
                        try {
                            scopeOutOfWorkspaces.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    } else {
                        scopeOutOfWorkspaces.close();
                    }
                }
                return var;
            } finally {
            }
        } catch (Throwable th3) {
            if (scopeOutOfWorkspaces != null) {
                if (th != null) {
                    try {
                        scopeOutOfWorkspaces.close();
                    } catch (Throwable th4) {
                        th.addSuppressed(th4);
                    }
                } else {
                    scopeOutOfWorkspaces.close();
                }
            }
            throw th3;
        }
    }

    public SDVariable constant(double d) {
        return constant((String) null, d);
    }

    public SDVariable constant(String str, double d) {
        MemoryWorkspace scopeOutOfWorkspaces = Nd4j.getMemoryManager().scopeOutOfWorkspaces();
        Throwable th = null;
        try {
            try {
                SDVariable constant = constant(str, Nd4j.scalar(d));
                if (scopeOutOfWorkspaces != null) {
                    if (0 != 0) {
                        try {
                            scopeOutOfWorkspaces.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    } else {
                        scopeOutOfWorkspaces.close();
                    }
                }
                return constant;
            } finally {
            }
        } catch (Throwable th3) {
            if (scopeOutOfWorkspaces != null) {
                if (th != null) {
                    try {
                        scopeOutOfWorkspaces.close();
                    } catch (Throwable th4) {
                        th.addSuppressed(th4);
                    }
                } else {
                    scopeOutOfWorkspaces.close();
                }
            }
            throw th3;
        }
    }

    public SDVariable constant(float f) {
        return constant((String) null, f);
    }

    public SDVariable constant(String str, float f) {
        MemoryWorkspace scopeOutOfWorkspaces = Nd4j.getMemoryManager().scopeOutOfWorkspaces();
        Throwable th = null;
        try {
            try {
                SDVariable constant = constant(str, Nd4j.scalar(f));
                if (scopeOutOfWorkspaces != null) {
                    if (0 != 0) {
                        try {
                            scopeOutOfWorkspaces.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    } else {
                        scopeOutOfWorkspaces.close();
                    }
                }
                return constant;
            } finally {
            }
        } catch (Throwable th3) {
            if (scopeOutOfWorkspaces != null) {
                if (th != null) {
                    try {
                        scopeOutOfWorkspaces.close();
                    } catch (Throwable th4) {
                        th.addSuppressed(th4);
                    }
                } else {
                    scopeOutOfWorkspaces.close();
                }
            }
            throw th3;
        }
    }

    public SDVariable constant(int i) {
        return constant((String) null, i);
    }

    public SDVariable constant(String str, int i) {
        MemoryWorkspace scopeOutOfWorkspaces = Nd4j.getMemoryManager().scopeOutOfWorkspaces();
        Throwable th = null;
        try {
            try {
                SDVariable constant = constant(str, Nd4j.scalar(i));
                if (scopeOutOfWorkspaces != null) {
                    if (0 != 0) {
                        try {
                            scopeOutOfWorkspaces.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    } else {
                        scopeOutOfWorkspaces.close();
                    }
                }
                return constant;
            } finally {
            }
        } catch (Throwable th3) {
            if (scopeOutOfWorkspaces != null) {
                if (th != null) {
                    try {
                        scopeOutOfWorkspaces.close();
                    } catch (Throwable th4) {
                        th.addSuppressed(th4);
                    }
                } else {
                    scopeOutOfWorkspaces.close();
                }
            }
            throw th3;
        }
    }

    public SDVariable constant(long j) {
        return constant((String) null, j);
    }

    public SDVariable constant(String str, long j) {
        MemoryWorkspace scopeOutOfWorkspaces = Nd4j.getMemoryManager().scopeOutOfWorkspaces();
        Throwable th = null;
        try {
            try {
                SDVariable constant = constant(str, Nd4j.scalar(j));
                if (scopeOutOfWorkspaces != null) {
                    if (0 != 0) {
                        try {
                            scopeOutOfWorkspaces.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    } else {
                        scopeOutOfWorkspaces.close();
                    }
                }
                return constant;
            } finally {
            }
        } catch (Throwable th3) {
            if (scopeOutOfWorkspaces != null) {
                if (th != null) {
                    try {
                        scopeOutOfWorkspaces.close();
                    } catch (Throwable th4) {
                        th.addSuppressed(th4);
                    }
                } else {
                    scopeOutOfWorkspaces.close();
                }
            }
            throw th3;
        }
    }

    public SDVariable constant(String str, DataType dataType, Number number) {
        MemoryWorkspace scopeOutOfWorkspaces = Nd4j.getMemoryManager().scopeOutOfWorkspaces();
        Throwable th = null;
        try {
            try {
                SDVariable constant = constant(str, Nd4j.scalar(dataType, number));
                if (scopeOutOfWorkspaces != null) {
                    if (0 != 0) {
                        try {
                            scopeOutOfWorkspaces.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    } else {
                        scopeOutOfWorkspaces.close();
                    }
                }
                return constant;
            } finally {
            }
        } catch (Throwable th3) {
            if (scopeOutOfWorkspaces != null) {
                if (th != null) {
                    try {
                        scopeOutOfWorkspaces.close();
                    } catch (Throwable th4) {
                        th.addSuppressed(th4);
                    }
                } else {
                    scopeOutOfWorkspaces.close();
                }
            }
            throw th3;
        }
    }

    public SDVariable addVariable(SDVariable sDVariable) {
        Preconditions.checkState(sDVariable.getSameDiff() == this, "Samediff instance must be the same.");
        if (this.variables.containsKey(sDVariable.getVarName()) && !this.variables.get(sDVariable.getVarName()).getVariable().equals(sDVariable)) {
            throw new IllegalArgumentException("Variable already found with variable opName " + sDVariable.getVarName());
        }
        Preconditions.checkState(sDVariable.getSameDiff() == this, "Same diff instance for variable must be the same!");
        this.variables.put(sDVariable.getVarName(), Variable.builder().name(sDVariable.getVarName()).variable(sDVariable).build());
        return sDVariable;
    }

    @Override // org.nd4j.autodiff.samediff.ops.SDBaseOps
    public String generateNewVarName(String str, int i) {
        String str2;
        if (!this.variables.containsKey(str) && i == 0) {
            return str;
        }
        int i2 = 0;
        String str3 = str + (0 == 0 ? "" : "_0") + (i > 0 ? ":" + i : "");
        while (true) {
            str2 = str3;
            if (getVariable(str2) == null) {
                break;
            }
            i2++;
            str3 = str + "_" + i2 + (i > 0 ? ":" + i : "");
        }
        if (getVariable(str2) != null) {
            throw new ND4JIllegalStateException("Converged on already generated variable!");
        }
        return str2;
    }

    /* JADX WARN: Multi-variable type inference failed */
    public SDVariable[] generateOutputVariableForOp(DifferentialFunction differentialFunction, String str, boolean z) {
        if (str == null || (str.isEmpty() && getBaseNameForFunction(differentialFunction) != null)) {
            str = getBaseNameForFunction(differentialFunction);
        }
        if (str == null) {
            str = differentialFunction.opName();
        }
        List<DataType> list = null;
        if (!z) {
            ArrayList arrayList = new ArrayList();
            List<String> inputsToOp = this.ops.get(differentialFunction.getOwnName()).getInputsToOp();
            if (inputsToOp != null) {
                Iterator<String> it = inputsToOp.iterator();
                while (it.hasNext()) {
                    arrayList.add(this.variables.get(it.next()).getVariable().dataType());
                }
            }
            list = differentialFunction.calculateOutputDataTypes(arrayList);
        }
        List<LongShapeDescriptor> calculateOutputShape = differentialFunction.calculateOutputShape();
        if (calculateOutputShape == null || calculateOutputShape.isEmpty()) {
            if (differentialFunction instanceof CustomOp) {
                CustomOp customOp = (CustomOp) differentialFunction;
                int numOutputs = differentialFunction.getNumOutputs();
                if (numOutputs <= 0) {
                    CustomOpDescriptor descriptor = customOp.getDescriptor();
                    if (descriptor != null) {
                        numOutputs = descriptor.getNumOutputs();
                    }
                    if (numOutputs <= 0) {
                        throw new ND4UnresolvedOutputVariables("Could not determine number of output variables for op " + differentialFunction.getOwnName() + " - " + differentialFunction.getClass().getSimpleName() + ". Ops can override getNumOutputs() to specify number of outputs if required");
                    }
                }
                SDVariable[] args = differentialFunction.args();
                if (args != null && args.length > 0 && args[0].getArr() != null) {
                    differentialFunction.args()[0].getArr().ordering();
                }
                SDVariable[] sDVariableArr = new SDVariable[numOutputs];
                Preconditions.checkState(z || numOutputs == 0 || (list != null && list.size() == numOutputs), "Incorrect number of output datatypes: got %s but expected datatypes for %s outputs - %s (op: %s)", list == null ? null : Integer.valueOf(list.size()), Integer.valueOf(numOutputs), list, differentialFunction.getClass().getSimpleName());
                int i = 0;
                while (i < sDVariableArr.length) {
                    SDVariable variable = i == 0 ? getVariable(str) : getVariable(str + ":" + i);
                    if (variable == null) {
                        variable = var(generateNewVarName(str, i), VariableType.ARRAY, null, z ? null : list.get(i), (long[]) null);
                    }
                    variable.setOutputIndex(i);
                    variable.setCreator(differentialFunction);
                    sDVariableArr[i] = variable;
                    i++;
                }
                if (getOutputsForFunction(differentialFunction) == null) {
                    addOutgoingFor(sDVariableArr, differentialFunction);
                }
                return sDVariableArr;
            }
            if ((differentialFunction instanceof BaseOp) && calculateOutputShape.isEmpty()) {
                SDVariable[] sDVariableArr2 = new SDVariable[1];
                SDVariable variable2 = getVariable(str);
                SDVariable[] args2 = differentialFunction.args();
                if (args2 != null && args2.length > 0 && differentialFunction.args()[0].getArr() != null) {
                    differentialFunction.args()[0].getArr().ordering();
                }
                if (variable2 == null) {
                    variable2 = var(str, VariableType.ARRAY, null, list.get(0), (long[]) null);
                }
                if (variable2 == null) {
                    variable2 = var(str, VariableType.ARRAY, null, list.get(0), (long[]) null);
                }
                variable2.setOutputIndex(0);
                variable2.setCreator(differentialFunction);
                sDVariableArr2[0] = variable2;
                if (getOutputsForFunction(differentialFunction) == null) {
                    addOutgoingFor(sDVariableArr2, differentialFunction);
                }
                return sDVariableArr2;
            }
        }
        if (!z) {
            for (int i2 = 0; i2 < calculateOutputShape.size(); i2++) {
                DataType dataType = calculateOutputShape.get(i2).dataType();
                DataType dataType2 = list.get(i2);
                Preconditions.checkState(dataType2 == dataType, "Calculated output data types do not match for shape calculation vs. datatype calculation: %s vs %s for op %s output %s", dataType, dataType2, differentialFunction.getClass().getName(), Integer.valueOf(i2));
            }
        }
        char c = 'c';
        if (differentialFunction.args() != null && differentialFunction.args().length > 0 && differentialFunction.args()[0].getArr() != null) {
            c = differentialFunction.args()[0].getArr().ordering();
        }
        SDVariable[] sDVariableArr3 = new SDVariable[calculateOutputShape.size()];
        differentialFunction.getOwnName();
        String str2 = str;
        int i3 = 0;
        while (i3 < sDVariableArr3.length) {
            LongShapeDescriptor longShapeDescriptor = calculateOutputShape.get(i3);
            String str3 = str2 + (i3 > 0 ? ":" + i3 : "");
            SDVariable variable3 = getVariable(str3);
            if (variable3 == null) {
                variable3 = var(str3, VariableType.ARRAY, null, longShapeDescriptor.dataType(), longShapeDescriptor.getShape());
            } else if (longShapeDescriptor != null && !shapeAlreadyExistsForVarName(variable3.getVarName())) {
                putShapeForVarName(variable3.getVarName(), longShapeDescriptor);
            } else if (longShapeDescriptor == null || shapeAlreadyExistsForVarName(variable3.getVarName())) {
            }
            if (variable3 == null) {
                variable3 = var(str3 + (i3 > 0 ? ":" + i3 : ""), new ZeroInitScheme(c), DataType.FLOAT, longShapeDescriptor.getShape());
            }
            variable3.setOutputIndex(i3);
            variable3.setCreator(differentialFunction);
            sDVariableArr3[i3] = variable3;
            i3++;
        }
        return sDVariableArr3;
    }

    public SDVariable[] generateOutputVariableForOp(DifferentialFunction differentialFunction) {
        return generateOutputVariableForOp(differentialFunction, differentialFunction.opName(), false);
    }

    public SameDiff getFunction(String str) {
        return this.sameDiffFunctionInstances.get(str);
    }

    public While whileStatement(SameDiffConditional sameDiffConditional, SameDiffFunctionDefinition sameDiffFunctionDefinition, SameDiffFunctionDefinition sameDiffFunctionDefinition2, SDVariable[] sDVariableArr) {
        return While.builder().inputVars(sDVariableArr).condition(sameDiffFunctionDefinition).predicate(sameDiffConditional).trueBody(sameDiffFunctionDefinition2).parent(this).blockName("while-" + UUID.randomUUID().toString()).build();
    }

    public If ifStatement(SameDiffConditional sameDiffConditional, SameDiffFunctionDefinition sameDiffFunctionDefinition, SameDiffFunctionDefinition sameDiffFunctionDefinition2, SameDiffFunctionDefinition sameDiffFunctionDefinition3, SDVariable[] sDVariableArr) {
        return If.builder().conditionBody(sameDiffFunctionDefinition).falseBody(sameDiffFunctionDefinition3).trueBody(sameDiffFunctionDefinition2).predicate(sameDiffConditional).inputVars(sDVariableArr).parent(this).blockName("if-" + UUID.randomUUID().toString()).build();
    }

    public TensorArray tensorArray(DataType dataType) {
        TensorArray tensorArray = new TensorArray(this, dataType);
        tensorArray.outputVariables();
        return tensorArray;
    }

    public SDVariable invokeFunctionOn(String str, SameDiff sameDiff) {
        return this.sameDiffFunctionInstances.get(str).invokeGraphOn(sameDiff);
    }

    public SameDiff defineFunction(String str, SameDiffFunctionDefinition sameDiffFunctionDefinition, SDVariable[] sDVariableArr) {
        if (!this.sameDiffFunctionInstances.containsKey(str)) {
            SameDiff create = create();
            this.child = create;
            create.parent = this;
            SDVariable[] sDVariableArr2 = new SDVariable[sDVariableArr.length];
            for (int i = 0; i < sDVariableArr2.length; i++) {
                sDVariableArr2[i] = create.var(sDVariableArr[i]);
            }
            sameDiffFunctionDefinition.define(create, null, sDVariableArr2);
            this.sameDiffFunctionInstances.put(str, create);
        }
        this.child = null;
        return this.sameDiffFunctionInstances.get(str);
    }

    public void defineFunction(String str, SameDiffFunctionDefinition sameDiffFunctionDefinition) {
        defineFunction(str, sameDiffFunctionDefinition, new LinkedHashMap());
    }

    public void defineFunction(String str, SameDiffFunctionDefinition sameDiffFunctionDefinition, Map<String, INDArray> map) {
        if (this.sameDiffFunctionInstances.containsKey(str)) {
            return;
        }
        SameDiff create = create();
        sameDiffFunctionDefinition.define(create, map, null);
        this.sameDiffFunctionInstances.put(str, create);
    }

    @Deprecated
    public INDArray execAndEndResult() {
        List<String> outputs = outputs();
        Preconditions.checkState(outputs.size() == 1, "Method can only be used with SameDiff instances with a single output");
        return execSingle(this.placeholdersPerThread.get(Long.valueOf(Thread.currentThread().getId())), outputs.get(0));
    }

    public void execBackwards(Map<String, INDArray> map) {
        SDVariable gradient;
        if (getFunction("grad") == null) {
            createGradFunction();
        }
        HashSet hashSet = new HashSet();
        for (Variable variable : this.variables.values()) {
            if (variable.getVariable().getVariableType() == VariableType.VARIABLE && (gradient = variable.getVariable().gradient()) != null) {
                hashSet.add(gradient.getVarName());
            }
        }
        if (hashSet.isEmpty()) {
            log.warn("Skipping gradient execution (backward pass) - no variables to be calculated (graph does not contain any VARIABLE type SDVariables).\nIf gradients for other variables (such as placeholders) are required, use execBackwards(Map, List) instead");
        } else {
            execBackwards(map, new ArrayList(hashSet));
        }
    }

    public void execBackwards(Map<String, INDArray> map, String... strArr) {
        execBackwards(map, Arrays.asList(strArr));
    }

    public void execBackwards(Map<String, INDArray> map, List<String> list) {
        if (getFunction("grad") == null) {
            createGradFunction();
        }
        log.trace("About to execute backward function");
        if (list.isEmpty()) {
            log.warn("Skipping gradient calculation (backward pass) - no variables to be calculated (variableGradNamesList is empty)");
        } else {
            this.sameDiffFunctionInstances.get("grad").exec(map, list);
        }
    }

    public void createGradFunction() {
        if (this.lossVariables.isEmpty()) {
            if (this.trainingConfig == null || this.trainingConfig.getLossVariables() == null || this.trainingConfig.getLossVariables().isEmpty()) {
                List<String> outputs = outputs();
                if (outputs.size() == 1) {
                    String outputOfOp = this.variables.get(outputs.get(0)).getOutputOfOp();
                    if (outputOfOp == null || !(this.ops.get(outputOfOp).getOp() instanceof ExternalErrorsFunction)) {
                        log.info("Inferring output \"{}\" as loss variable as none were previously set. Use SameDiff.setLossVariables() to override", outputs.get(0));
                    }
                    this.lossVariables.add(outputs.get(0));
                }
            } else {
                this.lossVariables.addAll(this.trainingConfig.getLossVariables());
            }
        }
        Preconditions.checkState(!this.lossVariables.isEmpty(), "Cannot create gradient function: No loss variables (variables to minimize) have been specified. Loss variables are the variables that represent the loss/cost/score to be minimized during training, and that all gradients are calculated with respect to.\n Losses can be specified either in TrainingConfiguration (Builder.minimize(...)) or via SameDiff.setLossVariables()/addLossVariable()");
        if (log.isTraceEnabled()) {
            log.trace("Defining function \"grad\"");
        }
        defineFunction("grad", new SameDiffFunctionDefinition() { // from class: org.nd4j.autodiff.samediff.SameDiff.1
            @Override // org.nd4j.autodiff.samediff.SameDiffFunctionDefinition
            public SDVariable[] define(SameDiff sameDiff, Map<String, INDArray> map, SDVariable[] sDVariableArr) {
                List<String> inputsToOp;
                List<String> outputsOfOp;
                List<String> outputsOfOp2;
                List<String> inputsToOp2;
                if (SameDiff.this.debugMode) {
                    sameDiff.enableDebugMode();
                }
                this.invokeGraphOn(sameDiff);
                if (SameDiff.this.debugMode) {
                    Preconditions.checkState(sameDiff.ops.keySet().equals(SameDiff.this.ops.keySet()), "ops keysets not equal");
                }
                ArrayList arrayList = new ArrayList(sameDiff.ops.values());
                if (arrayList.isEmpty()) {
                    throw new ND4JIllegalStateException("No ops found!");
                }
                Iterator it = arrayList.iterator();
                while (it.hasNext()) {
                    DifferentialFunction op = ((SameDiffOp) it.next()).getOp();
                    if (!(op instanceof SDVariable)) {
                        for (SDVariable sDVariable : op.args()) {
                            sDVariable.setSameDiff(sameDiff);
                        }
                        for (SDVariable sDVariable2 : op.outputVariables()) {
                            sDVariable2.setSameDiff(sameDiff);
                        }
                        op.setSameDiff(sameDiff);
                    }
                }
                ArrayList arrayList2 = new ArrayList(SameDiff.this.lossVariables.size());
                SDVariable var = sameDiff.var("one-var", Nd4j.scalar(1.0f));
                for (String str : SameDiff.this.lossVariables) {
                    Preconditions.checkNotNull(str, "Encountered null value in loss variables. Null loss variables are not allowed. Use SameDiff.setLossVariables with non-null array names to fix");
                    Preconditions.checkState(SameDiff.this.variables.containsKey(str), "Specified loss function variable \"%s\" does not exist", str);
                    SDVariable variable = ((Variable) SameDiff.this.variables.get(str)).getVariable();
                    Preconditions.checkState(variable.dataType().isFPType(), "Specified loss function variable \"%s\" is not a floatingpoint variable (datatype: %s). Only floating point variables may be used as loss function variable", str, variable.dataType());
                    SDVariable sum = variable.sum(new int[0]);
                    if (sum.dataType() == var.dataType()) {
                        sameDiff.setGradientForVariableName(sum.getVarName(), var);
                    } else {
                        sameDiff.setGradientForVariableName(sum.getVarName(), var.castTo(sum.dataType()));
                    }
                    if (arrayList2.contains(sum)) {
                        SameDiff.log.warn("Loss function variable \"{}\" appears multiple times in list of loss variables - using only first instance", str);
                    } else {
                        arrayList2.add(sum);
                    }
                }
                if (SameDiff.log.isTraceEnabled()) {
                    String[] outputVariablesNames = ((SameDiffOp) arrayList.get(arrayList.size() - 1)).getOp().outputVariablesNames();
                    SameDiff.log.trace("Defining backward function: initial outputs {}", outputVariablesNames == null ? "null" : Arrays.toString(outputVariablesNames));
                }
                HashSet<String> hashSet = new HashSet();
                LinkedList linkedList = new LinkedList();
                for (String str2 : SameDiff.this.lossVariables) {
                    if (!linkedList.contains(str2)) {
                        linkedList.add(str2);
                    }
                }
                while (!linkedList.isEmpty()) {
                    String str3 = (String) linkedList.remove();
                    if (!hashSet.contains(str3)) {
                        Variable variable2 = (Variable) SameDiff.this.variables.get(str3);
                        if (variable2.getVariable().dataType().isFPType()) {
                            hashSet.add(variable2.getName());
                            if (variable2.getOutputOfOp() != null && (inputsToOp2 = ((SameDiffOp) SameDiff.this.ops.get(variable2.getOutputOfOp())).getInputsToOp()) != null) {
                                for (String str4 : inputsToOp2) {
                                    if (((Variable) SameDiff.this.variables.get(str4)).getVariable().dataType().isFPType()) {
                                        linkedList.add(str4);
                                    }
                                }
                            }
                        }
                    }
                }
                HashSet<String> hashSet2 = new HashSet(hashSet);
                LinkedList linkedList2 = new LinkedList();
                for (String str5 : hashSet) {
                    Variable variable3 = (Variable) SameDiff.this.variables.get(str5);
                    if (variable3.getVariable().getVariableType() == VariableType.ARRAY) {
                        List<String> inputsToOp3 = ((SameDiffOp) SameDiff.this.ops.get(variable3.getOutputOfOp())).getInputsToOp();
                        boolean z = false;
                        if (inputsToOp3 != null) {
                            Iterator<String> it2 = inputsToOp3.iterator();
                            while (true) {
                                if (!it2.hasNext()) {
                                    break;
                                }
                                if (hashSet.contains(it2.next())) {
                                    z = true;
                                    break;
                                }
                            }
                        }
                        if (!z) {
                            linkedList2.add(str5);
                        }
                    }
                    if (variable3.getVariable().getVariableType() == VariableType.CONSTANT || variable3.getVariable().getVariableType() == VariableType.PLACEHOLDER) {
                        linkedList2.add(str5);
                    }
                }
                while (!linkedList2.isEmpty()) {
                    String str6 = (String) linkedList2.remove();
                    Variable variable4 = (Variable) SameDiff.this.variables.get(str6);
                    hashSet2.remove(str6);
                    List<String> inputsForOp = variable4.getInputsForOp();
                    if (inputsForOp != null && !inputsForOp.isEmpty()) {
                        Iterator<String> it3 = inputsForOp.iterator();
                        while (it3.hasNext()) {
                            SameDiffOp sameDiffOp = (SameDiffOp) SameDiff.this.ops.get(it3.next());
                            boolean z2 = false;
                            Iterator<String> it4 = sameDiffOp.getInputsToOp().iterator();
                            while (true) {
                                if (!it4.hasNext()) {
                                    break;
                                }
                                if (hashSet2.contains(it4.next())) {
                                    z2 = true;
                                    break;
                                }
                            }
                            if (!z2 && (outputsOfOp2 = sameDiffOp.getOutputsOfOp()) != null) {
                                for (String str7 : outputsOfOp2) {
                                    if (!linkedList2.contains(str7)) {
                                        linkedList2.add(str7);
                                    }
                                }
                            }
                        }
                    }
                }
                Preconditions.checkState(!hashSet2.isEmpty(), "Cannot differentiate graph relative to the specified loss function variables %s: graph does not contain any trainable SDVariables (floating point VARIABLE type SDVariables) that the loss function depend on.", SameDiff.this.lossVariables);
                LinkedList linkedList3 = new LinkedList();
                Iterator it5 = arrayList2.iterator();
                while (it5.hasNext()) {
                    Variable variable5 = (Variable) sameDiff.variables.get(((SDVariable) it5.next()).getVarName());
                    if (variable5.getOutputOfOp() != null) {
                        linkedList3.add(variable5.getOutputOfOp());
                    }
                }
                HashMap hashMap = new HashMap();
                Iterator it6 = hashSet2.iterator();
                while (it6.hasNext()) {
                    Variable variable6 = (Variable) SameDiff.this.variables.get((String) it6.next());
                    List<String> inputsForOp2 = variable6.getInputsForOp();
                    if (inputsForOp2 != null) {
                        ArrayList arrayList3 = new ArrayList();
                        for (String str8 : inputsForOp2) {
                            List<String> outputsOfOp3 = ((SameDiffOp) SameDiff.this.ops.get(str8)).getOutputsOfOp();
                            boolean z3 = false;
                            if (outputsOfOp3 != null) {
                                Iterator<String> it7 = outputsOfOp3.iterator();
                                while (true) {
                                    if (!it7.hasNext()) {
                                        break;
                                    }
                                    if (hashSet2.contains(it7.next())) {
                                        z3 = true;
                                        break;
                                    }
                                }
                            }
                            if (z3) {
                                arrayList3.add(str8);
                            }
                        }
                        hashMap.put(variable6.getName(), arrayList3);
                    }
                }
                HashSet hashSet3 = new HashSet();
                while (!linkedList3.isEmpty()) {
                    DifferentialFunction op2 = ((SameDiffOp) sameDiff.ops.get((String) linkedList3.remove())).getOp();
                    if (op2 instanceof GradientBackwardsMarker) {
                        inputsToOp = ((SameDiffOp) sameDiff.ops.get(op2.getOwnName())).getInputsToOp();
                        outputsOfOp = Collections.emptyList();
                    } else {
                        inputsToOp = ((SameDiffOp) sameDiff.ops.get(op2.getOwnName())).getInputsToOp();
                        outputsOfOp = ((SameDiffOp) sameDiff.ops.get(op2.getOwnName())).getOutputsOfOp();
                    }
                    ArrayList arrayList4 = new ArrayList();
                    Iterator<String> it8 = outputsOfOp.iterator();
                    while (it8.hasNext()) {
                        SDVariable variable7 = sameDiff.getVariable(it8.next());
                        SDVariable gradient = variable7.hasGradient() ? variable7.gradient() : null;
                        if (gradient != null) {
                            arrayList4.add(gradient);
                        } else if (variable7.dataType().isFPType()) {
                            arrayList4.add(sameDiff.zerosLike(variable7));
                        } else {
                            arrayList4.add(null);
                        }
                    }
                    op2.diff(arrayList4);
                    hashSet3.add(op2.getOwnName());
                    Iterator<String> it9 = inputsToOp.iterator();
                    while (it9.hasNext()) {
                        String outputOfOp2 = ((Variable) sameDiff.variables.get(it9.next())).getOutputOfOp();
                        if (outputOfOp2 != null && !hashSet3.contains(outputOfOp2)) {
                            boolean z4 = false;
                            SameDiffOp sameDiffOp2 = (SameDiffOp) SameDiff.this.ops.get(outputOfOp2);
                            if (sameDiffOp2.getInputsToOp() != null) {
                                boolean z5 = false;
                                Iterator<String> it10 = sameDiffOp2.getInputsToOp().iterator();
                                while (true) {
                                    if (!it10.hasNext()) {
                                        break;
                                    }
                                    if (hashSet2.contains(it10.next())) {
                                        z5 = true;
                                        break;
                                    }
                                }
                                if (z5 && !hashSet3.contains(sameDiffOp2.getName())) {
                                    z4 = true;
                                }
                            }
                            if (z4) {
                                boolean z6 = true;
                                SameDiffOp sameDiffOp3 = (SameDiffOp) sameDiff.ops.get(outputOfOp2);
                                Iterator<String> it11 = sameDiffOp3.getOutputsOfOp().iterator();
                                while (true) {
                                    if (!it11.hasNext()) {
                                        break;
                                    }
                                    Variable variable8 = (Variable) SameDiff.this.variables.get(it11.next());
                                    if (variable8.getVariable().dataType().isFPType() && hashSet2.contains(variable8.getName())) {
                                        if (variable8.getVariable().gradient() == null) {
                                            z6 = false;
                                            break;
                                        }
                                        List list = (List) hashMap.get(variable8.getName());
                                        if (list != null) {
                                            z6 &= hashSet3.containsAll(list);
                                            if (!z6) {
                                                break;
                                            }
                                        } else {
                                            continue;
                                        }
                                    }
                                }
                                if (z6 && !linkedList3.contains(sameDiffOp3.getOp().getOwnName())) {
                                    linkedList3.add(sameDiffOp3.getOp().getOwnName());
                                }
                            }
                        }
                    }
                }
                for (String str9 : hashSet2) {
                    if (!SameDiff.this.lossVariables.contains(str9) && ((Variable) SameDiff.this.variables.get(str9)).getVariable().gradient() == null) {
                        throw new IllegalStateException("Error encountered during differentiation: no gradient for required variable \"" + str9 + "\" was calculated");
                    }
                }
                return new SDVariable[]{sameDiff.var("grad", DataType.FLOAT, 1)};
            }
        });
        associateSameDiffWithOpsAndVariables();
    }

    public void setOriginalPlaceHolderShape(String str, long[] jArr) {
        if (!isPlaceHolder(str)) {
            throw new ND4JIllegalStateException("Vertex id " + str + " does not appear to be a place holder. Did you forget to call addPlaceHolder?");
        }
        if (jArr == null) {
            throw new ND4JIllegalStateException("Null and 0 length shape arrays not allowed");
        }
        if (this.placeHolderOriginalShapes.containsKey(str) && !Arrays.equals(this.placeHolderOriginalShapes.get(str), jArr)) {
            throw new ND4JIllegalStateException("Unable to add a new shape for vertex id " + str);
        }
        this.placeHolderOriginalShapes.put(str, jArr);
    }

    @Deprecated
    public long[] getOriginalShapeForPlaceHolder(String str) {
        return this.placeHolderOriginalShapes.get(str);
    }

    public boolean isPlaceHolder(String str) {
        Preconditions.checkState(this.variables.containsKey(str), "No variable present in SameDiff instance with name \"%s\"", str);
        return this.variables.get(str).getVariable().isPlaceHolder();
    }

    public void resolveVariablesWith(Map<String, INDArray> map) {
        for (Map.Entry<String, INDArray> entry : map.entrySet()) {
            SDVariable variable = getVariable(entry.getKey());
            if (variable == null) {
                throw new ND4JIllegalStateException("No variable name found for " + entry.getKey());
            }
            this.variables.get(entry.getKey());
            if (variable.getVariableType() == VariableType.PLACEHOLDER) {
                long[] placeholderShape = variable.placeholderShape();
                long[] shape = entry.getValue().shape();
                Preconditions.checkState(placeholderShape.length == shape.length, "Placeholder shape not compatible (mismatched rank): placeholder \"%s\" shape %s, got incompatible shape %s", entry.getKey(), placeholderShape, shape);
            }
        }
        for (Map.Entry<String, INDArray> entry2 : map.entrySet()) {
            if (!this.variables.get(entry2.getKey()).getVariable().isPlaceHolder()) {
                throw new ND4JIllegalStateException("Illegal variable " + entry2.getKey() + " passed in. Variable found not to be a place holder variable");
            }
            long[] originalShapeForPlaceHolder = getOriginalShapeForPlaceHolder(entry2.getKey());
            if (!Shape.isPlaceholderShape(originalShapeForPlaceHolder) && !Shape.shapeEquals(originalShapeForPlaceHolder, entry2.getValue().shape())) {
                throw new ND4JIllegalStateException("Place holder shape specified was " + Arrays.toString(originalShapeForPlaceHolder) + " but array shape was " + Arrays.toString(entry2.getValue().shape()));
            }
            associateArrayWithVariable(entry2.getValue(), getVariable(entry2.getKey()));
            setArrayForVariable(entry2.getKey(), entry2.getValue());
        }
        this.resolvedVariables = true;
    }

    @Override // org.nd4j.autodiff.samediff.ops.SDBaseOps
    public SDVariable updateVariableNameAndReference(SDVariable sDVariable, String str) {
        if (sDVariable == null) {
            throw new NullPointerException("Null input: No variable found for updating!");
        }
        if (str != null && this.variables.containsKey(str) && sDVariable != this.variables.get(str).getVariable()) {
            throw new IllegalStateException("Variable name \"" + str + "\" already exists for a different SDVariable");
        }
        if (str == null && this.variables.containsKey(sDVariable.getVarName())) {
            str = generateNewVarName(sDVariable.getVarName(), 0);
        }
        if (str == null || sDVariable.getVarName().equals(str)) {
            return sDVariable;
        }
        String varName = sDVariable.getVarName();
        sDVariable.setVarName(str);
        updateVariableName(varName, str);
        return sDVariable;
    }

    @Override // org.nd4j.autodiff.samediff.ops.SDBaseOps
    protected SameDiff sd() {
        return this;
    }

    @Override // org.nd4j.autodiff.samediff.ops.SDBaseOps
    public SDVariable[] updateVariableNamesAndReferences(SDVariable[] sDVariableArr, String[] strArr) {
        int length = sDVariableArr.length;
        SDVariable[] sDVariableArr2 = new SDVariable[length];
        for (int i = 0; i < length; i++) {
            sDVariableArr2[i] = updateVariableNameAndReference(sDVariableArr[i], strArr == null ? null : strArr[i]);
        }
        return sDVariableArr2;
    }

    protected void associateSameDiffWithOpsAndVariables() {
        Iterator<SDVariable> it = variableMap().values().iterator();
        while (it.hasNext()) {
            it.next().setSameDiff(this);
        }
        Iterator<SameDiffOp> it2 = this.ops.values().iterator();
        while (it2.hasNext()) {
            DifferentialFunction op = it2.next().getOp();
            op.setSameDiff(this);
            SDVariable[] args = op.args();
            if (args != null) {
                for (SDVariable sDVariable : args) {
                    sDVariable.setSameDiff(this);
                }
            }
            SDVariable[] outputVariables = op.outputVariables();
            if (outputVariables != null) {
                for (SDVariable sDVariable2 : outputVariables) {
                    sDVariable2.setSameDiff(this);
                }
            }
        }
    }

    public Map<String, INDArray> execAll(Map<String, INDArray> map) {
        ArrayList arrayList = new ArrayList();
        Iterator<Variable> it = this.variables.values().iterator();
        while (it.hasNext()) {
            arrayList.add(it.next().getName());
        }
        return exec(map, (String[]) arrayList.toArray(new String[arrayList.size()]));
    }

    public INDArray execSingle(Map<String, INDArray> map, String str) {
        return exec(map, str).get(str);
    }

    public Map<String, INDArray> exec(Map<String, INDArray> map, List<String> list) {
        return exec(map, (String[]) list.toArray(new String[list.size()]));
    }

    public Map<String, INDArray> exec(Map<String, INDArray> map, String... strArr) {
        Preconditions.checkState(strArr != null && strArr.length > 0, "No outputs were specified");
        long id = Thread.currentThread().getId();
        if (!this.sessions.containsKey(Long.valueOf(id))) {
            log.info("Creating new InferenceSession for thread {}", Long.valueOf(id));
            this.sessions.put(Long.valueOf(id), new InferenceSession(this));
        }
        List<String> inputs = inputs();
        if (map == null && inputs != null) {
            map = this.placeholdersPerThread.get(Long.valueOf(Thread.currentThread().getId()));
        }
        if (inputs != null && inputs.size() > 0) {
            Preconditions.checkNotNull(map, "No placeholders were provided. Network has placeholders: %s", inputs);
            for (String str : inputs) {
                Preconditions.checkState(map.containsKey(str), "No placeholder variable was provided for variable \"%s\". Cannot execute without all placeholders set", str);
            }
        }
        return this.sessions.get(Long.valueOf(id)).output(Arrays.asList(strArr), map);
    }

    protected int asFlatNode(String str, @NonNull SameDiff sameDiff, @NonNull FlatBufferBuilder flatBufferBuilder) {
        if (sameDiff == null) {
            throw new NullPointerException("scope is marked @NonNull but is null");
        }
        if (flatBufferBuilder == null) {
            throw new NullPointerException("bufferBuilder is marked @NonNull but is null");
        }
        int createString = flatBufferBuilder.createString(str);
        return FlatNode.createFlatNode(flatBufferBuilder, createString, createString, (byte) 119, 10L, 0, 0, 0, 0, 0, 0, 0, 0, -1, 0, 0, 0, 0, 0, 0);
    }

    public static Pair<String, Integer> parseVariable(@NonNull String str) {
        if (str == null) {
            throw new NullPointerException("varName is marked @NonNull but is null");
        }
        if (!str.contains(":")) {
            return Pair.pairOf(str, 0);
        }
        String[] split = str.split(":");
        Integer valueOf = Integer.valueOf(split[split.length - 1]);
        if (split.length == 2) {
            return Pair.pairOf(split[0], valueOf);
        }
        StringBuilder sb = new StringBuilder();
        for (int i = 0; i < split.length - 1; i++) {
            sb.append(split[i]);
            if (i < split.length - 2) {
                sb.append(":");
            }
        }
        return Pair.pairOf(sb.toString(), valueOf);
    }

    /* JADX WARN: Multi-variable type inference failed */
    protected int asFlatNode(@NonNull DifferentialFunction differentialFunction, @NonNull FlatBufferBuilder flatBufferBuilder, List<SDVariable> list, Map<String, Integer> map, Map<String, Integer> map2, Map<String, Integer> map3, AtomicInteger atomicInteger, Integer num) {
        double[] dArr;
        long[] jArr;
        int[] iArr;
        int[] dimensions;
        INDArray scalar;
        if (differentialFunction == 0) {
            throw new NullPointerException("node is marked @NonNull but is null");
        }
        if (flatBufferBuilder == null) {
            throw new NullPointerException("bufferBuilder is marked @NonNull but is null");
        }
        String opName = differentialFunction.opName();
        long opNum = FlatBuffersMapper.getOpNum(differentialFunction.opName(), differentialFunction.opType());
        if (differentialFunction.opType() == Op.Type.CUSTOM) {
            dArr = ((CustomOp) differentialFunction).tArgs();
        } else {
            dArr = differentialFunction.getExtraArgs() != null ? new double[differentialFunction.getExtraArgs().length] : new double[0];
            for (int i = 0; i < dArr.length; i++) {
                dArr[i] = ((Number) differentialFunction.getExtraArgs()[i]).doubleValue();
            }
        }
        boolean[] zArr = null;
        if (differentialFunction.opType() == Op.Type.CUSTOM) {
            DynamicCustomOp dynamicCustomOp = (DynamicCustomOp) differentialFunction;
            jArr = dynamicCustomOp.iArgs();
            zArr = dynamicCustomOp.bArgs();
        } else if (differentialFunction instanceof Enter) {
            String frameName = ((Enter) differentialFunction).getFrameName();
            if (!map3.containsKey(frameName)) {
                map3.put(frameName, Integer.valueOf(atomicInteger.incrementAndGet()));
            }
            jArr = new long[]{map3.get(frameName).intValue()};
        } else {
            jArr = new long[0];
        }
        if (differentialFunction.opType() == Op.Type.REDUCE_BOOL || differentialFunction.opType() == Op.Type.REDUCE_SAME || differentialFunction.opType() == Op.Type.REDUCE_FLOAT || differentialFunction.opType() == Op.Type.REDUCE_LONG) {
            zArr = new boolean[]{((ReduceOp) differentialFunction).isKeepDims(), true};
        } else if (differentialFunction.opType() == Op.Type.INDEXREDUCE) {
            zArr = new boolean[]{((IndexAccumulation) differentialFunction).isKeepDims(), true};
        }
        ArrayList arrayList = new ArrayList();
        try {
            SDVariable[] outputVariables = differentialFunction.outputVariables();
            iArr = new int[outputVariables.length];
            for (int i2 = 0; i2 < iArr.length; i2++) {
                iArr[i2] = list.indexOf(outputVariables[i2]);
            }
        } catch (ND4UnresolvedOutputVariables e) {
            iArr = new int[0];
        } catch (Exception e2) {
            throw new ND4JIllegalStateException(e2);
        }
        for (SDVariable sDVariable : differentialFunction.args()) {
            String varName = sDVariable.getVarName();
            int indexOf = this.variables.get(varName).getOutputOfOp() != null ? this.ops.get(this.ops.get(this.variables.get(varName).getOutputOfOp()).getOp().getOwnName()).getOutputsOfOp().indexOf(varName) : 0;
            if (!map.containsKey(varName)) {
                if (!varName.contains("NextIteration")) {
                    throw new ND4JIllegalStateException("Unknown variable used in input: [" + varName + "]");
                }
                int incrementAndGet = atomicInteger.incrementAndGet();
                map2.put(varName, Integer.valueOf(incrementAndGet));
                map.put(varName, Integer.valueOf(incrementAndGet));
            }
            arrayList.add(Integer.valueOf(IntPair.createIntPair(flatBufferBuilder, map.get(varName).intValue(), indexOf)));
        }
        log.trace("Own Name: {}", differentialFunction.getOwnName());
        int intValue = num != null ? num.intValue() : atomicInteger.incrementAndGet();
        for (String str : differentialFunction.outputVariablesNames()) {
            if (!map.containsKey(str)) {
                map.put(str, Integer.valueOf(intValue));
            }
        }
        if (differentialFunction.opType() == Op.Type.REDUCE_FLOAT || differentialFunction.opType() == Op.Type.REDUCE_SAME || differentialFunction.opType() == Op.Type.REDUCE_BOOL || differentialFunction.opType() == Op.Type.REDUCE_LONG || differentialFunction.opType() == Op.Type.INDEXREDUCE || differentialFunction.opType() == Op.Type.REDUCE3) {
            dimensions = differentialFunction.getDimensions();
            if (dimensions == null) {
                dimensions = new int[0];
            }
        } else {
            dimensions = new int[0];
        }
        int createPropertiesVector = FlatNode.createPropertiesVector(flatBufferBuilder, FlatBuffersMapper.mapFunctionPropertiesToFlatProperties(flatBufferBuilder, differentialFunction.propertiesForFunction()));
        int createInputVector = FlatNode.createInputVector(flatBufferBuilder, new int[0]);
        int createInputPairedVector = FlatNode.createInputPairedVector(flatBufferBuilder, Ints.toArray(arrayList));
        int createOutputVector = FlatNode.createOutputVector(flatBufferBuilder, iArr);
        int createExtraParamsVector = FlatNode.createExtraParamsVector(flatBufferBuilder, dArr);
        int createExtraIntegerVector = FlatNode.createExtraIntegerVector(flatBufferBuilder, jArr);
        int createExtraBoolsVector = FlatNode.createExtraBoolsVector(flatBufferBuilder, zArr != null ? zArr : new boolean[0]);
        int createDimensionsVector = FlatNode.createDimensionsVector(flatBufferBuilder, dimensions);
        int createString = flatBufferBuilder.createString(differentialFunction.getOwnName());
        int createString2 = flatBufferBuilder.createString("");
        int i3 = 0;
        if ((differentialFunction instanceof ScalarOp) && (scalar = ((ScalarOp) differentialFunction).scalar()) != null) {
            i3 = scalar.toFlatArray(flatBufferBuilder);
        }
        if (differentialFunction.opType() == null) {
            log.warn("Null-op node: {}", differentialFunction);
        }
        List<String> outputsOfOp = differentialFunction.getSameDiff().ops.get(differentialFunction.getOwnName()).getOutputsOfOp();
        int[] iArr2 = new int[outputsOfOp == null ? 0 : outputsOfOp.size()];
        for (int i4 = 0; i4 < iArr2.length; i4++) {
            iArr2[i4] = flatBufferBuilder.createString(outputsOfOp.get(i4));
        }
        int createOutputNamesVector = FlatNode.createOutputNamesVector(flatBufferBuilder, iArr2);
        int createString3 = flatBufferBuilder.createString(opName);
        byte[] bArr = new byte[outputsOfOp.size()];
        int i5 = 0;
        Iterator<String> it = outputsOfOp.iterator();
        while (it.hasNext()) {
            int i6 = i5;
            i5++;
            bArr[i6] = FlatBuffersMapper.getDataTypeAsByte(getVariable(it.next()).dataType());
        }
        return FlatNode.createFlatNode(flatBufferBuilder, intValue, createString, FlatBuffersMapper.getFlatOpType(differentialFunction.opType()), opNum, createPropertiesVector, createInputVector, createInputPairedVector, createOutputVector, createExtraParamsVector, createExtraIntegerVector, createExtraBoolsVector, createDimensionsVector, -1, 0, createString2, createOutputNamesVector, createString3, FlatNode.createOutputTypesVector(flatBufferBuilder, bArr), i3);
    }

    public ByteBuffer asFlatBuffers(@NonNull ExecutorConfiguration executorConfiguration) {
        if (executorConfiguration == null) {
            throw new NullPointerException("configuration is marked @NonNull but is null");
        }
        return asFlatBuffers(0L, executorConfiguration);
    }

    public ByteBuffer asFlatBuffers(long j, @NonNull ExecutorConfiguration executorConfiguration) {
        int incrementAndGet;
        int i;
        if (executorConfiguration == null) {
            throw new NullPointerException("configuration is marked @NonNull but is null");
        }
        Nd4j.getExecutioner().commit();
        FlatBufferBuilder flatBufferBuilder = new FlatBufferBuilder(1024);
        AtomicInteger atomicInteger = new AtomicInteger(0);
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        ArrayList arrayList3 = new ArrayList();
        List<SDVariable> arrayList4 = new ArrayList<>(variables());
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        Map<String, Integer> linkedHashMap2 = new LinkedHashMap<>();
        Map<String, Integer> linkedHashMap3 = new LinkedHashMap<>();
        int i2 = 0;
        IdentityHashMap identityHashMap = new IdentityHashMap();
        for (SDVariable sDVariable : variables()) {
            INDArray arr = sDVariable.getArr();
            log.trace("Exporting variable: [{}]", sDVariable.getVarName());
            String varName = sDVariable.getVarName();
            if (this.variables.get(varName).getOutputOfOp() != null) {
                DifferentialFunction op = this.ops.get(this.variables.get(varName).getOutputOfOp()).getOp();
                if (identityHashMap.containsKey(op)) {
                    incrementAndGet = ((Integer) identityHashMap.get(op)).intValue();
                } else {
                    incrementAndGet = atomicInteger.incrementAndGet();
                    identityHashMap.put(op, Integer.valueOf(incrementAndGet));
                }
                String[] outputVariablesNames = op.outputVariablesNames();
                i = ArrayUtils.indexOf(outputVariablesNames, varName);
                Preconditions.checkState(i >= 0, "Variable name \"%s\" not found in list of outputs: %s", varName, outputVariablesNames);
            } else {
                incrementAndGet = atomicInteger.incrementAndGet();
                i = 0;
            }
            linkedHashMap.put(sDVariable.getVarName(), Integer.valueOf(incrementAndGet));
            log.trace("Adding [{}] as [{}]", sDVariable.getVarName(), Integer.valueOf(incrementAndGet));
            int i3 = 0;
            int createString = flatBufferBuilder.createString(sDVariable.getVarName());
            int flatArray = arr == null ? 0 : arr.toFlatArray(flatBufferBuilder);
            int createIntPair = IntPair.createIntPair(flatBufferBuilder, incrementAndGet, i);
            byte ordinal = (byte) sDVariable.getVariableType().ordinal();
            if (sDVariable.getVariableType() == VariableType.PLACEHOLDER) {
                i3 = FlatVariable.createShapeVector(flatBufferBuilder, sDVariable.getShape());
            }
            arrayList.add(Integer.valueOf(FlatVariable.createFlatVariable(flatBufferBuilder, createIntPair, createString, FlatBuffersMapper.getDataTypeAsByte(sDVariable.dataType()), i3, flatArray, -1, ordinal)));
        }
        Iterator<SameDiffOp> it = this.ops.values().iterator();
        while (it.hasNext()) {
            DifferentialFunction op2 = it.next().getOp();
            arrayList3.add(Integer.valueOf(asFlatNode(op2, flatBufferBuilder, arrayList4, linkedHashMap, linkedHashMap2, linkedHashMap3, atomicInteger, (Integer) identityHashMap.get(op2))));
        }
        for (Map.Entry<String, SameDiff> entry : this.sameDiffFunctionInstances.entrySet()) {
            if (!entry.getKey().equalsIgnoreCase("grad")) {
                arrayList3.add(Integer.valueOf(asFlatNode(entry.getKey(), entry.getValue(), flatBufferBuilder)));
                List<SDVariable> arrayList5 = new ArrayList<>(entry.getValue().variables());
                for (SDVariable sDVariable2 : entry.getValue().variables()) {
                    INDArray arr2 = sDVariable2.getArr();
                    if (arr2 != null) {
                        int createString2 = flatBufferBuilder.createString(sDVariable2.getVarName());
                        int flatArray2 = arr2.toFlatArray(flatBufferBuilder);
                        i2++;
                        int createIntPair2 = IntPair.createIntPair(flatBufferBuilder, i2, 0);
                        Pair<String, Integer> parseVariable = parseVariable(sDVariable2.getVarName());
                        linkedHashMap.put(parseVariable.getFirst(), Integer.valueOf(i2));
                        log.trace("Adding [{}] as [{}]", parseVariable.getFirst(), Integer.valueOf(i2));
                        arrayList.add(Integer.valueOf(FlatVariable.createFlatVariable(flatBufferBuilder, createIntPair2, createString2, FlatBuffersMapper.getDataTypeAsByte(arr2.dataType()), 0, flatArray2, -1, (byte) sDVariable2.getVariableType().ordinal())));
                    }
                }
                Iterator<SameDiffOp> it2 = entry.getValue().ops.values().iterator();
                while (it2.hasNext()) {
                    arrayList3.add(Integer.valueOf(asFlatNode(it2.next().getOp(), flatBufferBuilder, arrayList5, linkedHashMap, linkedHashMap2, linkedHashMap3, atomicInteger, null)));
                }
            }
        }
        int createVariablesVector = FlatGraph.createVariablesVector(flatBufferBuilder, Ints.toArray(arrayList2));
        int createVariablesVector2 = FlatGraph.createVariablesVector(flatBufferBuilder, Ints.toArray(arrayList));
        int createNodesVector = FlatGraph.createNodesVector(flatBufferBuilder, Ints.toArray(arrayList3));
        int i4 = 0;
        Iterator<SDVariable> it3 = variables().iterator();
        while (it3.hasNext()) {
            if (it3.next().isPlaceHolder()) {
                i4++;
            }
        }
        int[] iArr = new int[i4];
        if (i4 > 0) {
            int i5 = 0;
            for (SDVariable sDVariable3 : variables()) {
                if (sDVariable3.isPlaceHolder()) {
                    int i6 = i5;
                    i5++;
                    iArr[i6] = flatBufferBuilder.createString(sDVariable3.getVarName());
                }
            }
        }
        int createPlaceholdersVector = FlatGraph.createPlaceholdersVector(flatBufferBuilder, iArr);
        List<String> lossVariables = getLossVariables();
        int[] iArr2 = new int[lossVariables == null ? 0 : lossVariables.size()];
        for (int i7 = 0; i7 < iArr2.length; i7++) {
            iArr2[i7] = flatBufferBuilder.createString(lossVariables.get(i7));
        }
        flatBufferBuilder.finish(FlatGraph.createFlatGraph(flatBufferBuilder, j, createVariablesVector2, createNodesVector, createVariablesVector, executorConfiguration.getFlatConfiguration(flatBufferBuilder), createPlaceholdersVector, FlatGraph.createLossVariablesVector(flatBufferBuilder, iArr2)));
        synchronized (this) {
            for (Map.Entry entry2 : linkedHashMap.entrySet()) {
                this.variables.get(entry2.getKey()).setVariableIndex(((Integer) entry2.getValue()).intValue());
            }
        }
        return flatBufferBuilder.dataBuffer();
    }

    public FlatGraph asFlatGraph() {
        return FlatGraph.getRootAsFlatGraph(asFlatBuffers());
    }

    public FlatGraph asFlatGraph(long j, ExecutorConfiguration executorConfiguration) {
        return FlatGraph.getRootAsFlatGraph(asFlatBuffers(j, executorConfiguration));
    }

    public ByteBuffer asFlatBuffers() {
        return asFlatBuffers(ExecutorConfiguration.builder().outputMode(OutputMode.VARIABLE_SPACE).executionMode(ExecutionMode.SEQUENTIAL).profilingMode(OpExecutioner.ProfilingMode.DISABLED).gatherTimings(true).build());
    }

    public void saveWithTrainingConfig(OutputStream outputStream) throws IOException {
        if (this.trainingConfig == null) {
            throw new IllegalStateException("No training configuration found!");
        }
        saveWithTrainingConfig(this.trainingConfig, outputStream);
    }

    public void saveWithTrainingConfig(File file) throws IOException {
        if (this.trainingConfig == null) {
            throw new IllegalStateException("No training configuration found!");
        }
        BufferedOutputStream bufferedOutputStream = new BufferedOutputStream(new FileOutputStream(file));
        Throwable th = null;
        try {
            try {
                saveWithTrainingConfig(this.trainingConfig, bufferedOutputStream);
                bufferedOutputStream.flush();
                if (bufferedOutputStream != null) {
                    if (0 == 0) {
                        bufferedOutputStream.close();
                        return;
                    }
                    try {
                        bufferedOutputStream.close();
                    } catch (Throwable th2) {
                        th.addSuppressed(th2);
                    }
                }
            } catch (Throwable th3) {
                th = th3;
                throw th3;
            }
        } catch (Throwable th4) {
            if (bufferedOutputStream != null) {
                if (th != null) {
                    try {
                        bufferedOutputStream.close();
                    } catch (Throwable th5) {
                        th.addSuppressed(th5);
                    }
                } else {
                    bufferedOutputStream.close();
                }
            }
            throw th4;
        }
    }

    public void saveWithTrainingConfig(TrainingConfig trainingConfig, OutputStream outputStream) throws IOException {
        String writeValueAsString = ObjectMapperHolder.getJsonMapper().writeValueAsString(trainingConfig);
        ZipOutputStream zipOutputStream = new ZipOutputStream(new CloseShieldOutputStream(outputStream));
        zipOutputStream.putNextEntry(new ZipEntry(TRAINING_CONFIG_JSON_ZIP_ENTRY_NAME));
        zipOutputStream.write(writeValueAsString.getBytes());
        zipOutputStream.putNextEntry(new ZipEntry(SAMEDIFF_FILE_ENTRY_NAME));
        ByteBuffer asFlatBuffers = asFlatBuffers();
        int position = asFlatBuffers.position();
        byte[] array = asFlatBuffers.array();
        BufferedOutputStream bufferedOutputStream = new BufferedOutputStream(zipOutputStream);
        Throwable th = null;
        try {
            DataOutputStream dataOutputStream = new DataOutputStream(bufferedOutputStream);
            Throwable th2 = null;
            try {
                try {
                    dataOutputStream.write(array, position, array.length - position);
                    if (dataOutputStream != null) {
                        if (0 != 0) {
                            try {
                                dataOutputStream.close();
                            } catch (Throwable th3) {
                                th2.addSuppressed(th3);
                            }
                        } else {
                            dataOutputStream.close();
                        }
                    }
                    if (bufferedOutputStream != null) {
                        if (0 == 0) {
                            bufferedOutputStream.close();
                            return;
                        }
                        try {
                            bufferedOutputStream.close();
                        } catch (Throwable th4) {
                            th.addSuppressed(th4);
                        }
                    }
                } catch (Throwable th5) {
                    th2 = th5;
                    throw th5;
                }
            } catch (Throwable th6) {
                if (dataOutputStream != null) {
                    if (th2 != null) {
                        try {
                            dataOutputStream.close();
                        } catch (Throwable th7) {
                            th2.addSuppressed(th7);
                        }
                    } else {
                        dataOutputStream.close();
                    }
                }
                throw th6;
            }
        } catch (Throwable th8) {
            if (bufferedOutputStream != null) {
                if (0 != 0) {
                    try {
                        bufferedOutputStream.close();
                    } catch (Throwable th9) {
                        th.addSuppressed(th9);
                    }
                } else {
                    bufferedOutputStream.close();
                }
            }
            throw th8;
        }
    }

    public static SameDiff restoreFromTrainingConfigZip(File file) throws IOException {
        ZipFile zipFile = new ZipFile(file);
        InputStream inputStream = zipFile.getInputStream(zipFile.getEntry(TRAINING_CONFIG_JSON_ZIP_ENTRY_NAME));
        Throwable th = null;
        try {
            TrainingConfig trainingConfig = (TrainingConfig) ObjectMapperHolder.getJsonMapper().readValue(IOUtils.toByteArray(inputStream), TrainingConfig.class);
            if (inputStream != null) {
                if (0 != 0) {
                    try {
                        inputStream.close();
                    } catch (Throwable th2) {
                        th.addSuppressed(th2);
                    }
                } else {
                    inputStream.close();
                }
            }
            InputStream inputStream2 = zipFile.getInputStream(zipFile.getEntry(SAMEDIFF_FILE_ENTRY_NAME));
            Throwable th3 = null;
            try {
                SameDiff fromFlatBuffers = fromFlatBuffers(ByteBuffer.wrap(IOUtils.toByteArray(inputStream2)));
                if (inputStream2 != null) {
                    if (0 != 0) {
                        try {
                            inputStream2.close();
                        } catch (Throwable th4) {
                            th3.addSuppressed(th4);
                        }
                    } else {
                        inputStream2.close();
                    }
                }
                fromFlatBuffers.setTrainingConfig(trainingConfig);
                fromFlatBuffers.initializeTraining();
                return fromFlatBuffers;
            } catch (Throwable th5) {
                if (inputStream2 != null) {
                    if (0 != 0) {
                        try {
                            inputStream2.close();
                        } catch (Throwable th6) {
                            th3.addSuppressed(th6);
                        }
                    } else {
                        inputStream2.close();
                    }
                }
                throw th5;
            }
        } catch (Throwable th7) {
            if (inputStream != null) {
                if (0 != 0) {
                    try {
                        inputStream.close();
                    } catch (Throwable th8) {
                        th.addSuppressed(th8);
                    }
                } else {
                    inputStream.close();
                }
            }
            throw th7;
        }
    }

    /* JADX WARN: Failed to calculate best type for var: r13v1 ??
    java.lang.NullPointerException: Cannot invoke "jadx.core.dex.instructions.args.InsnArg.getType()" because "changeArg" is null
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.moveListener(TypeUpdate.java:439)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.runListeners(TypeUpdate.java:232)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.requestUpdate(TypeUpdate.java:212)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.updateTypeForSsaVar(TypeUpdate.java:183)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.updateTypeChecked(TypeUpdate.java:112)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.apply(TypeUpdate.java:83)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.apply(TypeUpdate.java:56)
    	at jadx.core.dex.visitors.typeinference.FixTypesVisitor.calculateFromBounds(FixTypesVisitor.java:156)
    	at jadx.core.dex.visitors.typeinference.FixTypesVisitor.setBestType(FixTypesVisitor.java:133)
    	at jadx.core.dex.visitors.typeinference.FixTypesVisitor.deduceType(FixTypesVisitor.java:238)
    	at jadx.core.dex.visitors.typeinference.FixTypesVisitor.tryDeduceTypes(FixTypesVisitor.java:221)
    	at jadx.core.dex.visitors.typeinference.FixTypesVisitor.visit(FixTypesVisitor.java:91)
     */
    /* JADX WARN: Failed to calculate best type for var: r13v1 ??
    java.lang.NullPointerException: Cannot invoke "jadx.core.dex.instructions.args.InsnArg.getType()" because "changeArg" is null
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.moveListener(TypeUpdate.java:439)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.runListeners(TypeUpdate.java:232)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.requestUpdate(TypeUpdate.java:212)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.updateTypeForSsaVar(TypeUpdate.java:183)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.updateTypeChecked(TypeUpdate.java:112)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.apply(TypeUpdate.java:83)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.apply(TypeUpdate.java:56)
    	at jadx.core.dex.visitors.typeinference.TypeInferenceVisitor.calculateFromBounds(TypeInferenceVisitor.java:145)
    	at jadx.core.dex.visitors.typeinference.TypeInferenceVisitor.setBestType(TypeInferenceVisitor.java:123)
    	at jadx.core.dex.visitors.typeinference.TypeInferenceVisitor.lambda$runTypePropagation$2(TypeInferenceVisitor.java:101)
    	at java.base/java.util.ArrayList.forEach(ArrayList.java:1596)
    	at jadx.core.dex.visitors.typeinference.TypeInferenceVisitor.runTypePropagation(TypeInferenceVisitor.java:101)
    	at jadx.core.dex.visitors.typeinference.TypeInferenceVisitor.visit(TypeInferenceVisitor.java:75)
     */
    /* JADX WARN: Failed to calculate best type for var: r14v0 ??
    java.lang.NullPointerException: Cannot invoke "jadx.core.dex.instructions.args.InsnArg.getType()" because "changeArg" is null
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.moveListener(TypeUpdate.java:439)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.runListeners(TypeUpdate.java:232)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.requestUpdate(TypeUpdate.java:212)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.updateTypeForSsaVar(TypeUpdate.java:183)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.updateTypeChecked(TypeUpdate.java:112)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.apply(TypeUpdate.java:83)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.apply(TypeUpdate.java:56)
    	at jadx.core.dex.visitors.typeinference.FixTypesVisitor.calculateFromBounds(FixTypesVisitor.java:156)
    	at jadx.core.dex.visitors.typeinference.FixTypesVisitor.setBestType(FixTypesVisitor.java:133)
    	at jadx.core.dex.visitors.typeinference.FixTypesVisitor.deduceType(FixTypesVisitor.java:238)
    	at jadx.core.dex.visitors.typeinference.FixTypesVisitor.tryDeduceTypes(FixTypesVisitor.java:221)
    	at jadx.core.dex.visitors.typeinference.FixTypesVisitor.visit(FixTypesVisitor.java:91)
     */
    /* JADX WARN: Failed to calculate best type for var: r14v0 ??
    java.lang.NullPointerException: Cannot invoke "jadx.core.dex.instructions.args.InsnArg.getType()" because "changeArg" is null
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.moveListener(TypeUpdate.java:439)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.runListeners(TypeUpdate.java:232)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.requestUpdate(TypeUpdate.java:212)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.updateTypeForSsaVar(TypeUpdate.java:183)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.updateTypeChecked(TypeUpdate.java:112)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.apply(TypeUpdate.java:83)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.apply(TypeUpdate.java:56)
    	at jadx.core.dex.visitors.typeinference.TypeInferenceVisitor.calculateFromBounds(TypeInferenceVisitor.java:145)
    	at jadx.core.dex.visitors.typeinference.TypeInferenceVisitor.setBestType(TypeInferenceVisitor.java:123)
    	at jadx.core.dex.visitors.typeinference.TypeInferenceVisitor.lambda$runTypePropagation$2(TypeInferenceVisitor.java:101)
    	at java.base/java.util.ArrayList.forEach(ArrayList.java:1596)
    	at jadx.core.dex.visitors.typeinference.TypeInferenceVisitor.runTypePropagation(TypeInferenceVisitor.java:101)
    	at jadx.core.dex.visitors.typeinference.TypeInferenceVisitor.visit(TypeInferenceVisitor.java:75)
     */
    /* JADX WARN: Multi-variable type inference failed. Error: java.lang.NullPointerException: Cannot invoke "jadx.core.dex.instructions.args.RegisterArg.getSVar()" because the return value of "jadx.core.dex.nodes.InsnNode.getResult()" is null
    	at jadx.core.dex.visitors.typeinference.AbstractTypeConstraint.collectRelatedVars(AbstractTypeConstraint.java:31)
    	at jadx.core.dex.visitors.typeinference.AbstractTypeConstraint.<init>(AbstractTypeConstraint.java:19)
    	at jadx.core.dex.visitors.typeinference.TypeSearch$1.<init>(TypeSearch.java:376)
    	at jadx.core.dex.visitors.typeinference.TypeSearch.makeMoveConstraint(TypeSearch.java:376)
    	at jadx.core.dex.visitors.typeinference.TypeSearch.makeConstraint(TypeSearch.java:361)
    	at jadx.core.dex.visitors.typeinference.TypeSearch.collectConstraints(TypeSearch.java:341)
    	at java.base/java.util.ArrayList.forEach(ArrayList.java:1596)
    	at jadx.core.dex.visitors.typeinference.TypeSearch.run(TypeSearch.java:60)
    	at jadx.core.dex.visitors.typeinference.FixTypesVisitor.runMultiVariableSearch(FixTypesVisitor.java:116)
    	at jadx.core.dex.visitors.typeinference.FixTypesVisitor.visit(FixTypesVisitor.java:91)
     */
    /* JADX WARN: Not initialized variable reg: 13, insn: 0x00dd: MOVE (r0 I:??[int, float, boolean, short, byte, char, OBJECT, ARRAY]) = (r13 I:??[int, float, boolean, short, byte, char, OBJECT, ARRAY]) A[TRY_LEAVE], block:B:81:0x00dd */
    /* JADX WARN: Not initialized variable reg: 14, insn: 0x00e2: MOVE (r0 I:??[int, float, boolean, short, byte, char, OBJECT, ARRAY]) = (r14 I:??[int, float, boolean, short, byte, char, OBJECT, ARRAY]), block:B:83:0x00e2 */
    /* JADX WARN: Type inference failed for: r13v1, types: [java.io.BufferedOutputStream] */
    /* JADX WARN: Type inference failed for: r14v0, types: [java.lang.Throwable] */
    public void asFlatFile(@NonNull File file) throws IOException {
        ?? r13;
        ?? r14;
        if (file == null) {
            throw new NullPointerException("file is marked @NonNull but is null");
        }
        ByteBuffer asFlatBuffers = asFlatBuffers();
        int position = asFlatBuffers.position();
        byte[] array = asFlatBuffers.array();
        FileOutputStream fileOutputStream = new FileOutputStream(file);
        Throwable th = null;
        try {
            try {
                BufferedOutputStream bufferedOutputStream = new BufferedOutputStream(fileOutputStream);
                Throwable th2 = null;
                DataOutputStream dataOutputStream = new DataOutputStream(bufferedOutputStream);
                Throwable th3 = null;
                try {
                    try {
                        dataOutputStream.write(array, position, array.length - position);
                        if (dataOutputStream != null) {
                            if (0 != 0) {
                                try {
                                    dataOutputStream.close();
                                } catch (Throwable th4) {
                                    th3.addSuppressed(th4);
                                }
                            } else {
                                dataOutputStream.close();
                            }
                        }
                        if (bufferedOutputStream != null) {
                            if (0 != 0) {
                                try {
                                    bufferedOutputStream.close();
                                } catch (Throwable th5) {
                                    th2.addSuppressed(th5);
                                }
                            } else {
                                bufferedOutputStream.close();
                            }
                        }
                        if (fileOutputStream != null) {
                            if (0 == 0) {
                                fileOutputStream.close();
                                return;
                            }
                            try {
                                fileOutputStream.close();
                            } catch (Throwable th6) {
                                th.addSuppressed(th6);
                            }
                        }
                    } catch (Throwable th7) {
                        th3 = th7;
                        throw th7;
                    }
                } catch (Throwable th8) {
                    if (dataOutputStream != null) {
                        if (th3 != null) {
                            try {
                                dataOutputStream.close();
                            } catch (Throwable th9) {
                                th3.addSuppressed(th9);
                            }
                        } else {
                            dataOutputStream.close();
                        }
                    }
                    throw th8;
                }
            } catch (Throwable th10) {
                if (fileOutputStream != null) {
                    if (0 != 0) {
                        try {
                            fileOutputStream.close();
                        } catch (Throwable th11) {
                            th.addSuppressed(th11);
                        }
                    } else {
                        fileOutputStream.close();
                    }
                }
                throw th10;
            }
        } catch (Throwable th12) {
            if (r13 != 0) {
                if (r14 != 0) {
                    try {
                        r13.close();
                    } catch (Throwable th13) {
                        r14.addSuppressed(th13);
                    }
                } else {
                    r13.close();
                }
            }
            throw th12;
        }
    }

    /* JADX WARN: Failed to calculate best type for var: r14v0 ??
    java.lang.NullPointerException: Cannot invoke "jadx.core.dex.instructions.args.InsnArg.getType()" because "changeArg" is null
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.moveListener(TypeUpdate.java:439)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.runListeners(TypeUpdate.java:232)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.requestUpdate(TypeUpdate.java:212)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.updateTypeForSsaVar(TypeUpdate.java:183)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.updateTypeChecked(TypeUpdate.java:112)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.apply(TypeUpdate.java:83)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.apply(TypeUpdate.java:56)
    	at jadx.core.dex.visitors.typeinference.FixTypesVisitor.calculateFromBounds(FixTypesVisitor.java:156)
    	at jadx.core.dex.visitors.typeinference.FixTypesVisitor.setBestType(FixTypesVisitor.java:133)
    	at jadx.core.dex.visitors.typeinference.FixTypesVisitor.deduceType(FixTypesVisitor.java:238)
    	at jadx.core.dex.visitors.typeinference.FixTypesVisitor.tryDeduceTypes(FixTypesVisitor.java:221)
    	at jadx.core.dex.visitors.typeinference.FixTypesVisitor.visit(FixTypesVisitor.java:91)
     */
    /* JADX WARN: Failed to calculate best type for var: r14v0 ??
    java.lang.NullPointerException: Cannot invoke "jadx.core.dex.instructions.args.InsnArg.getType()" because "changeArg" is null
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.moveListener(TypeUpdate.java:439)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.runListeners(TypeUpdate.java:232)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.requestUpdate(TypeUpdate.java:212)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.updateTypeForSsaVar(TypeUpdate.java:183)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.updateTypeChecked(TypeUpdate.java:112)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.apply(TypeUpdate.java:83)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.apply(TypeUpdate.java:56)
    	at jadx.core.dex.visitors.typeinference.TypeInferenceVisitor.calculateFromBounds(TypeInferenceVisitor.java:145)
    	at jadx.core.dex.visitors.typeinference.TypeInferenceVisitor.setBestType(TypeInferenceVisitor.java:123)
    	at jadx.core.dex.visitors.typeinference.TypeInferenceVisitor.lambda$runTypePropagation$2(TypeInferenceVisitor.java:101)
    	at java.base/java.util.ArrayList.forEach(ArrayList.java:1596)
    	at jadx.core.dex.visitors.typeinference.TypeInferenceVisitor.runTypePropagation(TypeInferenceVisitor.java:101)
    	at jadx.core.dex.visitors.typeinference.TypeInferenceVisitor.visit(TypeInferenceVisitor.java:75)
     */
    /* JADX WARN: Failed to calculate best type for var: r15v0 ??
    java.lang.NullPointerException: Cannot invoke "jadx.core.dex.instructions.args.InsnArg.getType()" because "changeArg" is null
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.moveListener(TypeUpdate.java:439)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.runListeners(TypeUpdate.java:232)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.requestUpdate(TypeUpdate.java:212)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.updateTypeForSsaVar(TypeUpdate.java:183)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.updateTypeChecked(TypeUpdate.java:112)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.apply(TypeUpdate.java:83)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.apply(TypeUpdate.java:56)
    	at jadx.core.dex.visitors.typeinference.FixTypesVisitor.calculateFromBounds(FixTypesVisitor.java:156)
    	at jadx.core.dex.visitors.typeinference.FixTypesVisitor.setBestType(FixTypesVisitor.java:133)
    	at jadx.core.dex.visitors.typeinference.FixTypesVisitor.deduceType(FixTypesVisitor.java:238)
    	at jadx.core.dex.visitors.typeinference.FixTypesVisitor.tryDeduceTypes(FixTypesVisitor.java:221)
    	at jadx.core.dex.visitors.typeinference.FixTypesVisitor.visit(FixTypesVisitor.java:91)
     */
    /* JADX WARN: Failed to calculate best type for var: r15v0 ??
    java.lang.NullPointerException: Cannot invoke "jadx.core.dex.instructions.args.InsnArg.getType()" because "changeArg" is null
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.moveListener(TypeUpdate.java:439)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.runListeners(TypeUpdate.java:232)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.requestUpdate(TypeUpdate.java:212)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.updateTypeForSsaVar(TypeUpdate.java:183)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.updateTypeChecked(TypeUpdate.java:112)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.apply(TypeUpdate.java:83)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.apply(TypeUpdate.java:56)
    	at jadx.core.dex.visitors.typeinference.TypeInferenceVisitor.calculateFromBounds(TypeInferenceVisitor.java:145)
    	at jadx.core.dex.visitors.typeinference.TypeInferenceVisitor.setBestType(TypeInferenceVisitor.java:123)
    	at jadx.core.dex.visitors.typeinference.TypeInferenceVisitor.lambda$runTypePropagation$2(TypeInferenceVisitor.java:101)
    	at java.base/java.util.ArrayList.forEach(ArrayList.java:1596)
    	at jadx.core.dex.visitors.typeinference.TypeInferenceVisitor.runTypePropagation(TypeInferenceVisitor.java:101)
    	at jadx.core.dex.visitors.typeinference.TypeInferenceVisitor.visit(TypeInferenceVisitor.java:75)
     */
    /* JADX WARN: Multi-variable type inference failed. Error: java.lang.NullPointerException: Cannot invoke "jadx.core.dex.instructions.args.RegisterArg.getSVar()" because the return value of "jadx.core.dex.nodes.InsnNode.getResult()" is null
    	at jadx.core.dex.visitors.typeinference.AbstractTypeConstraint.collectRelatedVars(AbstractTypeConstraint.java:31)
    	at jadx.core.dex.visitors.typeinference.AbstractTypeConstraint.<init>(AbstractTypeConstraint.java:19)
    	at jadx.core.dex.visitors.typeinference.TypeSearch$1.<init>(TypeSearch.java:376)
    	at jadx.core.dex.visitors.typeinference.TypeSearch.makeMoveConstraint(TypeSearch.java:376)
    	at jadx.core.dex.visitors.typeinference.TypeSearch.makeConstraint(TypeSearch.java:361)
    	at jadx.core.dex.visitors.typeinference.TypeSearch.collectConstraints(TypeSearch.java:341)
    	at java.base/java.util.ArrayList.forEach(ArrayList.java:1596)
    	at jadx.core.dex.visitors.typeinference.TypeSearch.run(TypeSearch.java:60)
    	at jadx.core.dex.visitors.typeinference.FixTypesVisitor.runMultiVariableSearch(FixTypesVisitor.java:116)
    	at jadx.core.dex.visitors.typeinference.FixTypesVisitor.visit(FixTypesVisitor.java:91)
     */
    /* JADX WARN: Not initialized variable reg: 14, insn: 0x00f0: MOVE (r0 I:??[int, float, boolean, short, byte, char, OBJECT, ARRAY]) = (r14 I:??[int, float, boolean, short, byte, char, OBJECT, ARRAY]) A[TRY_LEAVE], block:B:76:0x00f0 */
    /* JADX WARN: Not initialized variable reg: 15, insn: 0x00f5: MOVE (r0 I:??[int, float, boolean, short, byte, char, OBJECT, ARRAY]) = (r15 I:??[int, float, boolean, short, byte, char, OBJECT, ARRAY]), block:B:78:0x00f5 */
    /* JADX WARN: Type inference failed for: r14v0, types: [java.io.BufferedOutputStream] */
    /* JADX WARN: Type inference failed for: r15v0, types: [java.lang.Throwable] */
    public void asFlatFile(@NonNull File file, @NonNull ExecutorConfiguration executorConfiguration) throws IOException {
        ?? r14;
        ?? r15;
        if (file == null) {
            throw new NullPointerException("file is marked @NonNull but is null");
        }
        if (executorConfiguration == null) {
            throw new NullPointerException("configuration is marked @NonNull but is null");
        }
        ByteBuffer asFlatBuffers = asFlatBuffers(executorConfiguration);
        int position = asFlatBuffers.position();
        byte[] array = asFlatBuffers.array();
        FileOutputStream fileOutputStream = new FileOutputStream(file);
        Throwable th = null;
        try {
            try {
                BufferedOutputStream bufferedOutputStream = new BufferedOutputStream(fileOutputStream);
                Throwable th2 = null;
                DataOutputStream dataOutputStream = new DataOutputStream(bufferedOutputStream);
                Throwable th3 = null;
                try {
                    try {
                        dataOutputStream.write(array, position, array.length - position);
                        if (dataOutputStream != null) {
                            if (0 != 0) {
                                try {
                                    dataOutputStream.close();
                                } catch (Throwable th4) {
                                    th3.addSuppressed(th4);
                                }
                            } else {
                                dataOutputStream.close();
                            }
                        }
                        if (bufferedOutputStream != null) {
                            if (0 != 0) {
                                try {
                                    bufferedOutputStream.close();
                                } catch (Throwable th5) {
                                    th2.addSuppressed(th5);
                                }
                            } else {
                                bufferedOutputStream.close();
                            }
                        }
                        if (fileOutputStream != null) {
                            if (0 == 0) {
                                fileOutputStream.close();
                                return;
                            }
                            try {
                                fileOutputStream.close();
                            } catch (Throwable th6) {
                                th.addSuppressed(th6);
                            }
                        }
                    } catch (Throwable th7) {
                        th3 = th7;
                        throw th7;
                    }
                } catch (Throwable th8) {
                    if (dataOutputStream != null) {
                        if (th3 != null) {
                            try {
                                dataOutputStream.close();
                            } catch (Throwable th9) {
                                th3.addSuppressed(th9);
                            }
                        } else {
                            dataOutputStream.close();
                        }
                    }
                    throw th8;
                }
            } catch (Throwable th10) {
                if (r14 != 0) {
                    if (r15 != 0) {
                        try {
                            r14.close();
                        } catch (Throwable th11) {
                            r15.addSuppressed(th11);
                        }
                    } else {
                        r14.close();
                    }
                }
                throw th10;
            }
        } catch (Throwable th12) {
            if (fileOutputStream != null) {
                if (0 != 0) {
                    try {
                        fileOutputStream.close();
                    } catch (Throwable th13) {
                        th.addSuppressed(th13);
                    }
                } else {
                    fileOutputStream.close();
                }
            }
            throw th12;
        }
    }

    public static SameDiff fromFlatFile(@NonNull File file) throws IOException {
        if (file == null) {
            throw new NullPointerException("file is marked @NonNull but is null");
        }
        BufferedInputStream bufferedInputStream = new BufferedInputStream(new FileInputStream(file));
        Throwable th = null;
        try {
            try {
                byte[] byteArray = IOUtils.toByteArray(bufferedInputStream);
                if (bufferedInputStream != null) {
                    if (0 != 0) {
                        try {
                            bufferedInputStream.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    } else {
                        bufferedInputStream.close();
                    }
                }
                return fromFlatBuffers(ByteBuffer.wrap(byteArray));
            } finally {
            }
        } catch (Throwable th3) {
            if (bufferedInputStream != null) {
                if (th != null) {
                    try {
                        bufferedInputStream.close();
                    } catch (Throwable th4) {
                        th.addSuppressed(th4);
                    }
                } else {
                    bufferedInputStream.close();
                }
            }
            throw th3;
        }
    }

    public static SameDiff fromFlatBuffers(ByteBuffer byteBuffer) throws IOException {
        String[] strArr;
        FlatGraph rootAsFlatGraph = FlatGraph.getRootAsFlatGraph(byteBuffer);
        int nodesLength = rootAsFlatGraph.nodesLength();
        int variablesLength = rootAsFlatGraph.variablesLength();
        ArrayList<FlatNode> arrayList = new ArrayList(nodesLength);
        for (int i = 0; i < nodesLength; i++) {
            arrayList.add(rootAsFlatGraph.nodes(i));
        }
        ArrayList<FlatVariable> arrayList2 = new ArrayList(variablesLength);
        for (int i2 = 0; i2 < variablesLength; i2++) {
            arrayList2.add(rootAsFlatGraph.variables(i2));
        }
        rootAsFlatGraph.configuration();
        SameDiff create = create();
        int placeholdersLength = rootAsFlatGraph.placeholdersLength();
        LinkedHashSet linkedHashSet = new LinkedHashSet();
        for (int i3 = 0; i3 < placeholdersLength; i3++) {
            linkedHashSet.add(rootAsFlatGraph.placeholders(i3));
        }
        new HashMap();
        HashMap hashMap = new HashMap();
        HashMap hashMap2 = new HashMap();
        for (FlatVariable flatVariable : arrayList2) {
            int shapeLength = flatVariable.shapeLength();
            long[] jArr = new long[shapeLength];
            for (int i4 = 0; i4 < shapeLength; i4++) {
                jArr[i4] = flatVariable.shape(i4);
            }
            String name = flatVariable.name();
            DataType dataTypeFromByte = FlatBuffersMapper.getDataTypeFromByte(flatVariable.dtype());
            VariableType variableType = VariableType.values()[flatVariable.variabletype()];
            SDVariable sDVariable = new SDVariable(name, variableType, create, jArr, dataTypeFromByte, null);
            create.variables.put(name, Variable.builder().name(name).variable(sDVariable).build());
            create.variableNameToShape.put(name, jArr);
            FlatArray ndarray = flatVariable.ndarray();
            if (ndarray != null && variableType != VariableType.ARRAY) {
                MemoryWorkspace scopeOutOfWorkspaces = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces();
                Throwable th = null;
                try {
                    try {
                        INDArray createFromFlatArray = Nd4j.createFromFlatArray(ndarray);
                        if (scopeOutOfWorkspaces != null) {
                            if (0 != 0) {
                                try {
                                    scopeOutOfWorkspaces.close();
                                } catch (Throwable th2) {
                                    th.addSuppressed(th2);
                                }
                            } else {
                                scopeOutOfWorkspaces.close();
                            }
                        }
                        create.setArrayForVariable(name, createFromFlatArray);
                    } finally {
                    }
                } catch (Throwable th3) {
                    if (scopeOutOfWorkspaces != null) {
                        if (th != null) {
                            try {
                                scopeOutOfWorkspaces.close();
                            } catch (Throwable th4) {
                                th.addSuppressed(th4);
                            }
                        } else {
                            scopeOutOfWorkspaces.close();
                        }
                    }
                    throw th3;
                }
            }
            IntPair id = flatVariable.id();
            hashMap.put(new Pair(Integer.valueOf(id.first()), Integer.valueOf(id.second())), sDVariable);
            if (!hashMap2.containsKey(name)) {
                hashMap2.put(name, new ArrayList());
            }
            ((List) hashMap2.get(name)).add(sDVariable);
        }
        for (FlatNode flatNode : arrayList) {
            DifferentialFunction fromFlatNode = FlatBuffersMapper.fromFlatNode(flatNode);
            String name2 = flatNode.name();
            fromFlatNode.setSameDiff(create);
            fromFlatNode.setOwnName(name2);
            if (create.ops.containsKey(name2)) {
                create.ops.get(name2).setOp(fromFlatNode);
            } else {
                create.ops.put(name2, SameDiffOp.builder().name(name2).op(fromFlatNode).build());
            }
            int outputLength = flatNode.outputLength();
            int[] iArr = new int[outputLength];
            for (int i5 = 0; i5 < outputLength; i5++) {
                iArr[i5] = flatNode.output(i5);
            }
            int id2 = flatNode.id();
            int[] iArr2 = new int[flatNode.outputLength()];
            for (int i6 = 0; i6 < iArr2.length; i6++) {
                iArr2[i6] = flatNode.output(i6);
            }
            int[] iArr3 = new int[flatNode.inputLength()];
            for (int i7 = 0; i7 < iArr3.length; i7++) {
                iArr3[i7] = flatNode.input(i7);
            }
            IntPair[] intPairArr = new IntPair[flatNode.inputPairedLength()];
            ArrayList arrayList3 = new ArrayList();
            for (int i8 = 0; i8 < intPairArr.length; i8++) {
                intPairArr[i8] = flatNode.inputPaired(i8);
                arrayList3.add(new Pair(Integer.valueOf(intPairArr[i8].first()), Integer.valueOf(intPairArr[i8].second())));
            }
            String[] strArr2 = new String[intPairArr.length];
            for (int i9 = 0; i9 < intPairArr.length; i9++) {
                SDVariable sDVariable2 = (SDVariable) hashMap.get(new Pair(Integer.valueOf(intPairArr[i9].first()), Integer.valueOf(intPairArr[i9].second())));
                if (sDVariable2 == null) {
                }
                strArr2[i9] = sDVariable2.getVarName();
            }
            create.ops.get(fromFlatNode.getOwnName()).setInputsToOp(Arrays.asList(strArr2));
            for (String str : strArr2) {
                Variable variable = create.getVariables().get(str);
                if (variable.getInputsForOp() == null) {
                    variable.setInputsForOp(new ArrayList());
                }
                if (!variable.getInputsForOp().contains(fromFlatNode.getOwnName())) {
                    variable.getInputsForOp().add(fromFlatNode.getOwnName());
                }
            }
            List list = (List) hashMap2.get(name2);
            int numOutputs = fromFlatNode.getNumOutputs();
            if (numOutputs <= 0) {
                numOutputs = flatNode.outputLength();
            }
            if (list == null || list.size() != numOutputs) {
                int outputNamesLength = flatNode.outputNamesLength();
                strArr = new String[outputNamesLength];
                for (int i10 = 0; i10 < outputNamesLength; i10++) {
                    String outputNames = flatNode.outputNames(i10);
                    strArr[i10] = outputNames;
                    if (!create.variables.containsKey(outputNames)) {
                        SDVariable sDVariable3 = new SDVariable(outputNames, VariableType.VARIABLE, create, null, null, null);
                        create.variables.put(outputNames, Variable.builder().name(outputNames).variable(sDVariable3).build());
                        hashMap.put(new Pair(Integer.valueOf(id2), Integer.valueOf(i10)), sDVariable3);
                    }
                    create.getVariables().get(strArr[i10]).setOutputOfOp(fromFlatNode.getOwnName());
                }
                create.ops.get(fromFlatNode.getOwnName()).setOutputsOfOp(Arrays.asList(strArr));
            } else {
                strArr = new String[list.size()];
                for (int i11 = 0; i11 < strArr.length; i11++) {
                    strArr[i11] = ((SDVariable) list.get(i11)).getVarName();
                    create.getVariables().get(strArr[i11]).setOutputOfOp(fromFlatNode.getOwnName());
                }
                create.ops.get(fromFlatNode.getOwnName()).setOutputsOfOp(Arrays.asList(strArr));
            }
            for (int i12 = 0; i12 < strArr.length; i12++) {
                Pair pair = new Pair(Integer.valueOf(id2), Integer.valueOf(i12));
                if (!hashMap.containsKey(pair)) {
                    hashMap.put(pair, create.getVariable(strArr[i12]));
                }
            }
        }
        if (rootAsFlatGraph.lossVariablesLength() > 0) {
            for (int i13 = 0; i13 < rootAsFlatGraph.lossVariablesLength(); i13++) {
                create.addLossVariable(rootAsFlatGraph.lossVariables(i13));
            }
        }
        return create;
    }

    public String asFlatPrint() {
        StringBuilder sb = new StringBuilder();
        FlatGraph rootAsFlatGraph = FlatGraph.getRootAsFlatGraph(asFlatBuffers());
        sb.append("\nExternal variables:\n\n");
        for (int i = 0; i < rootAsFlatGraph.variablesLength(); i++) {
            FlatVariable variables = rootAsFlatGraph.variables(i);
            MemoryWorkspace scopeOutOfWorkspaces = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces();
            Throwable th = null;
            try {
                try {
                    FlatArray ndarray = variables.ndarray();
                    INDArray createFromFlatArray = ndarray != null ? Nd4j.createFromFlatArray(ndarray) : null;
                    if (scopeOutOfWorkspaces != null) {
                        if (0 != 0) {
                            try {
                                scopeOutOfWorkspaces.close();
                            } catch (Throwable th2) {
                                th.addSuppressed(th2);
                            }
                        } else {
                            scopeOutOfWorkspaces.close();
                        }
                    }
                    sb.append(variables.id().first()).append(":<").append(variables.name()).append("> ");
                    if (createFromFlatArray == null) {
                        sb.append("<no array>").append("; Values: ").append("<no array>").append(";\n");
                    } else {
                        sb.append(Arrays.toString(createFromFlatArray.shapeInfoDataBuffer().asInt())).append("; Values: ");
                        if (createFromFlatArray.data() == null) {
                            sb.append("<empty array>");
                        } else if (createFromFlatArray.dataType() == DataType.UTF8) {
                            sb.append("<string array>");
                        } else if (createFromFlatArray.length() < 50) {
                            sb.append(Arrays.toString(createFromFlatArray.data().asFloat()).replaceAll(" ", ""));
                        } else {
                            sb.append("[");
                            for (int i2 = 0; i2 < 50; i2++) {
                                if (i2 > 0) {
                                    sb.append(",");
                                }
                                sb.append(createFromFlatArray.data().getFloat(i2));
                            }
                            sb.append("]");
                        }
                        sb.append(";\n");
                    }
                } finally {
                }
            } catch (Throwable th3) {
                if (scopeOutOfWorkspaces != null) {
                    if (th != null) {
                        try {
                            scopeOutOfWorkspaces.close();
                        } catch (Throwable th4) {
                            th.addSuppressed(th4);
                        }
                    } else {
                        scopeOutOfWorkspaces.close();
                    }
                }
                throw th3;
            }
        }
        Map<String, CustomOpDescriptor> customOperations = Nd4j.getExecutioner().getCustomOperations();
        sb.append("\nOps sequence:\n\n");
        for (int i3 = 0; i3 < rootAsFlatGraph.nodesLength(); i3++) {
            FlatNode nodes = rootAsFlatGraph.nodes(i3);
            log.info("{}:<{}>", Integer.valueOf(nodes.id()), nodes.name());
            sb.append(nodes.id()).append(":<").append(nodes.name()).append("> ").append(FlatBuffersMapper.getTypeFromByte(nodes.opType()));
            if (FlatBuffersMapper.getTypeFromByte(nodes.opType()) != Op.Type.CUSTOM) {
                sb.append(": ").append(nodes.opNum());
            } else {
                String str = null;
                for (String str2 : customOperations.keySet()) {
                    if (customOperations.get(str2).getHash() == nodes.opNum()) {
                        str = str2;
                    }
                }
                if (str == null) {
                    str = "unknown";
                }
                sb.append(": ").append(str);
            }
            sb.append("; Inputs: {");
            for (int i4 = 0; i4 < nodes.inputPairedLength(); i4++) {
                IntPair inputPaired = nodes.inputPaired(i4);
                sb.append("[").append(inputPaired.first()).append(":").append(inputPaired.second()).append("]");
                if (i4 < nodes.inputPairedLength() - 1) {
                    sb.append(", ");
                }
            }
            sb.append("};");
            sb.append(" OpNum: {").append(nodes.opNum()).append("};");
            sb.append("\n");
        }
        return sb.toString();
    }

    public String summary() {
        String str;
        Map<String, SDVariable> variableMap = variableMap();
        DifferentialFunction[] functions = functions();
        int i = 0;
        Iterator<String> it = variableMap.keySet().iterator();
        while (it.hasNext()) {
            if (getArrForVarName(it.next()) != null) {
                i++;
            }
        }
        StringBuilder sb = new StringBuilder();
        sb.append("--- Summary ---\n");
        sb.append(String.format("%-25s%-20s", "Variables:", Integer.valueOf(variableMap.size()))).append(" (").append(i).append(" with arrays)").append("\n").append(String.format("%-25s%-20s", "Functions:", Integer.valueOf(functions.length))).append("\n").append(String.format("%-25s%-20s", "SameDiff Function Defs:", Integer.valueOf(this.sameDiffFunctionInstances.size()))).append("\n").append("Loss function variables: ").append(getLossVariables()).append("\n\n");
        sb.append("--- Variables ---\n");
        HashMap hashMap = new HashMap();
        int i2 = 22;
        int i3 = 8;
        for (String str2 : variableMap.keySet()) {
            String str3 = null;
            Iterator<SameDiffOp> it2 = this.ops.values().iterator();
            while (true) {
                if (!it2.hasNext()) {
                    break;
                }
                SameDiffOp next = it2.next();
                List<String> outputsOfOp = next.getOutputsOfOp();
                if (outputsOfOp != null && outputsOfOp.contains(str2)) {
                    str3 = next.getName();
                    break;
                }
            }
            if (str3 == null) {
                str = "<none>";
            } else {
                DifferentialFunction functionById = getFunctionById(str3);
                str = functionById.getOwnName() + "(" + functionById.opName() + ")";
            }
            String str4 = str;
            hashMap.put(str2, str4);
            i2 = Math.max(i2, str4.length());
            i3 = Math.max(i3, str2.length());
        }
        String str5 = "%-" + (i3 + 2) + "s%-20s%-20s%-20s%-" + (i2 + 2) + "s%-20s";
        sb.append(String.format(str5, "- Name -", "- Array Shape -", "- Variable Type -", "- Data Type-", "- Output Of Function -", "- Inputs To Functions -")).append("\n");
        for (String str6 : variableMap.keySet()) {
            INDArray arrForVarName = getArrForVarName(str6);
            String arrays = arrForVarName != null ? Arrays.toString(arrForVarName.shape()) : "-";
            String variableType = getVariable(str6).getVariableType().toString();
            String dataType = getVariable(str6).dataType().toString();
            List<String> inputsForOp = this.variables.get(str6).getInputsForOp();
            String str7 = "";
            if (inputsForOp != null) {
                str7 = inputsForOp.toString();
            }
            sb.append(String.format(str5, str6, arrays, variableType, dataType, (String) hashMap.get(str6), str7)).append("\n");
        }
        sb.append("\n\n--- Functions ---\n");
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        int i4 = 10;
        int i5 = 11;
        int i6 = 17;
        int i7 = 10;
        for (DifferentialFunction differentialFunction : functions) {
            String[] argNames = differentialFunction.argNames();
            String[] outputVariablesNames = differentialFunction.outputVariablesNames();
            String arrays2 = Arrays.toString(argNames);
            String arrays3 = Arrays.toString(outputVariablesNames);
            i4 = Math.max(i4, arrays2.length());
            i5 = Math.max(i5, arrays3.length());
            arrayList.add(arrays2);
            arrayList2.add(arrays3);
            i6 = Math.max(i6, (differentialFunction.getOwnName() == null ? differentialFunction.opName() : differentialFunction.getOwnName()).length());
            i7 = Math.max(i7, differentialFunction.getClass().getSimpleName().length());
        }
        String str8 = "%-5s%-" + (i6 + 2) + "s%-" + (i7 + 2) + "s%-" + (i4 + 2) + "s%-" + (i5 + 2) + "s";
        sb.append(String.format(str8, "", "- Function Name -", "- Op -", "- Inputs -", "- Outputs -")).append("\n");
        for (int i8 = 0; i8 < functions.length; i8++) {
            DifferentialFunction differentialFunction2 = functions[i8];
            sb.append(String.format(str8, String.valueOf(i8), differentialFunction2.getOwnName() == null ? differentialFunction2.opName() : differentialFunction2.getOwnName(), differentialFunction2.getClass().getSimpleName(), arrayList.get(i8), arrayList2.get(i8))).append("\n");
        }
        if (this.sameDiffFunctionInstances.size() > 0) {
            sb.append("\n\n--- SameDiff Defined Functions ---\n");
            sb.append(String.format("%-20s%-15s%-15s%-15s", "- Name -", "- Variables -", "- Functions -", "- Fn Defs -")).append("\n");
            for (Map.Entry<String, SameDiff> entry : this.sameDiffFunctionInstances.entrySet()) {
                SameDiff value = entry.getValue();
                sb.append(String.format("%-20s%-15s%-15s%-15s", entry.getKey(), String.valueOf(value.variableMap().size()), String.valueOf(value.functions() == null ? 0 : value.functions().length), String.valueOf(value.definedFunctionNames().size()))).append("\n");
            }
        }
        return sb.toString();
    }

    public Map<String, DataType> calculateOutputDataTypes() {
        ArrayList arrayList = new ArrayList(this.variables.keySet());
        DataTypesSession dataTypesSession = new DataTypesSession(this);
        HashMap hashMap = new HashMap();
        for (Variable variable : this.variables.values()) {
            if (variable.getVariable().isPlaceHolder()) {
                DataType dataType = variable.getVariable().dataType();
                Preconditions.checkNotNull(dataType, "Placeholder variable %s has null datatype", variable.getName());
                hashMap.put(variable.getName(), dataType);
            }
        }
        return dataTypesSession.output(arrayList, hashMap);
    }

    public static SameDiffBuilder builder() {
        return new SameDiffBuilder();
    }

    public SameDiff(TrainingConfig trainingConfig, boolean z, INDArray iNDArray, Map<String, INDArray> map, Map<String, GradientUpdater> map2, Map<String, String> map3, DifferentialFunctionFactory differentialFunctionFactory, Map<String, long[]> map4, Map<String, SDVariable> map5, int i, Map<String, List<String>> map6, Map<String, Map<String, Object>> map7, Map<String, long[]> map8, Map<String, SameDiffFunctionDefinition> map9, Map<String, SameDiff> map10, Set<String> set, Table<String, String, String> table, AtomicBoolean atomicBoolean, boolean z2, Map<int[], Op> map11, boolean z3, boolean z4, SameDiff sameDiff, SameDiff sameDiff2) {
        this.variables = new LinkedHashMap();
        this.ops = new LinkedHashMap();
        this.sessions = new ConcurrentHashMap();
        this.constantArrays = new ConcurrentHashMap();
        this.variablesArrays = new ConcurrentHashMap();
        this.placeholdersPerThread = new ConcurrentHashMap();
        this.lossVariables = new ArrayList();
        this.variableId = 0;
        this.math = new SDMath(this);
        this.random = new SDRandom(this);
        this.nn = new SDNN(this);
        this.cnn = new SDCNN(this);
        this.rnn = new SDRNN(this);
        this.loss = new SDLoss(this);
        this.wasRegistered = new AtomicBoolean(false);
        this.resolvedVariables = false;
        this.logExecution = true;
        this.trainingConfig = trainingConfig;
        this.initializedTraining = z;
        this.updaterState = iNDArray;
        this.updaterViews = map;
        this.updaterMap = map2;
        this.baseNameForFunctionInstanceId = map3;
        this.functionFactory = differentialFunctionFactory;
        this.variableNameToShape = map4;
        this.forwardVarForGrad = map5;
        this.variableId = i;
        this.propertiesToResolve = map6;
        this.propertiesForFunction = map7;
        this.placeHolderOriginalShapes = map8;
        this.sameDiffFunctionDefinitionMap = map9;
        this.sameDiffFunctionInstances = map10;
        this.placeHolderFunctions = set;
        this.fieldVariableResolutionMapping = table;
        this.wasRegistered = atomicBoolean;
        this.debugMode = z2;
        this.opsForResult = map11;
        this.resolvedVariables = z3;
        this.logExecution = z4;
        this.parent = sameDiff;
        this.child = sameDiff2;
    }

    public Map<String, Variable> getVariables() {
        return this.variables;
    }

    public Map<String, SameDiffOp> getOps() {
        return this.ops;
    }

    public Map<Long, InferenceSession> getSessions() {
        return this.sessions;
    }

    public TrainingConfig getTrainingConfig() {
        return this.trainingConfig;
    }

    public boolean isInitializedTraining() {
        return this.initializedTraining;
    }

    public INDArray getUpdaterState() {
        return this.updaterState;
    }

    public Map<String, INDArray> getUpdaterViews() {
        return this.updaterViews;
    }

    public Map<String, GradientUpdater> getUpdaterMap() {
        return this.updaterMap;
    }

    public boolean isDebugMode() {
        return this.debugMode;
    }

    public boolean isLogExecution() {
        return this.logExecution;
    }

    public void setLogExecution(boolean z) {
        this.logExecution = z;
    }

    public SameDiff getParent() {
        return this.parent;
    }

    public SameDiff getChild() {
        return this.child;
    }

    static {
        for (Method method : SameDiff.class.getDeclaredMethods()) {
            if (method.getReturnType().equals(SDVariable.class)) {
                opMethods.put(method.getName(), method);
            }
        }
    }
}
