package org.nd4j.autodiff.samediff.internal;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.internal.AbstractSession;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;

/* loaded from: input_file:org/nd4j/autodiff/samediff/internal/DataTypesSession.class */
public class DataTypesSession extends AbstractSession<DataType, DataTypeCalc> {

    /* JADX INFO: Access modifiers changed from: protected */
    /* loaded from: input_file:org/nd4j/autodiff/samediff/internal/DataTypesSession$DataTypeCalc.class */
    public static class DataTypeCalc {
        protected final DifferentialFunction fn;
        protected final List<DataType> inputTypes;

        public DataTypeCalc(DifferentialFunction differentialFunction, List<DataType> list) {
            this.fn = differentialFunction;
            this.inputTypes = list;
        }

        public DifferentialFunction getFn() {
            return this.fn;
        }

        public List<DataType> getInputTypes() {
            return this.inputTypes;
        }

        public boolean equals(Object obj) {
            if (obj == this) {
                return true;
            }
            if (!(obj instanceof DataTypeCalc)) {
                return false;
            }
            DataTypeCalc dataTypeCalc = (DataTypeCalc) obj;
            if (!dataTypeCalc.canEqual(this)) {
                return false;
            }
            DifferentialFunction fn = getFn();
            DifferentialFunction fn2 = dataTypeCalc.getFn();
            if (fn == null) {
                if (fn2 != null) {
                    return false;
                }
            } else if (!fn.equals(fn2)) {
                return false;
            }
            List<DataType> inputTypes = getInputTypes();
            List<DataType> inputTypes2 = dataTypeCalc.getInputTypes();
            return inputTypes == null ? inputTypes2 == null : inputTypes.equals(inputTypes2);
        }

        protected boolean canEqual(Object obj) {
            return obj instanceof DataTypeCalc;
        }

        public int hashCode() {
            DifferentialFunction fn = getFn();
            int hashCode = (1 * 59) + (fn == null ? 43 : fn.hashCode());
            List<DataType> inputTypes = getInputTypes();
            return (hashCode * 59) + (inputTypes == null ? 43 : inputTypes.hashCode());
        }

        public String toString() {
            return "DataTypesSession.DataTypeCalc(fn=" + getFn() + ", inputTypes=" + getInputTypes() + ")";
        }
    }

    public DataTypesSession(SameDiff sameDiff) {
        super(sameDiff);
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.nd4j.autodiff.samediff.internal.AbstractSession
    public DataType getConstantOrVariable(String str) {
        DataType dataType = this.sameDiff.getVariable(str).dataType();
        Preconditions.checkNotNull(dataType, "No datatype available for variable %s", str);
        return dataType;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.nd4j.autodiff.samediff.internal.AbstractSession
    public DataTypeCalc getAndParameterizeOp(String str, AbstractSession.FrameIter frameIter, Set<AbstractSession.VarId> set, Set<AbstractSession.VarId> set2, Set<String> set3, Map<String, DataType> map) {
        DifferentialFunction functionById = this.sameDiff.getFunctionById(str);
        ArrayList arrayList = new ArrayList();
        for (SDVariable sDVariable : functionById.args()) {
            DataType dataType = sDVariable.dataType();
            if (dataType != null) {
                arrayList.add(dataType);
            } else {
                String varName = sDVariable.getVarName();
                for (AbstractSession.VarId varId : set) {
                    if (varId.getVariable().equals(varName)) {
                        DataType dataType2 = (DataType) this.nodeOutputs.get(varId);
                        Preconditions.checkNotNull(dataType2, "No datatype for %s", varId);
                        arrayList.add(dataType2);
                    }
                }
            }
        }
        return new DataTypeCalc(functionById, arrayList);
    }

    /* renamed from: getOutputs, reason: avoid collision after fix types in other method */
    public DataType[] getOutputs2(DataTypeCalc dataTypeCalc, AbstractSession.FrameIter frameIter, Set<AbstractSession.VarId> set, Set<AbstractSession.VarId> set2, Set<String> set3) {
        List<DataType> calculateOutputDataTypes = dataTypeCalc.getFn().calculateOutputDataTypes(dataTypeCalc.getInputTypes());
        return (DataType[]) calculateOutputDataTypes.toArray(new DataType[calculateOutputDataTypes.size()]);
    }

    @Override // org.nd4j.autodiff.samediff.internal.AbstractSession
    public /* bridge */ /* synthetic */ DataType[] getOutputs(DataTypeCalc dataTypeCalc, AbstractSession.FrameIter frameIter, Set set, Set set2, Set set3) {
        return getOutputs2(dataTypeCalc, frameIter, (Set<AbstractSession.VarId>) set, (Set<AbstractSession.VarId>) set2, (Set<String>) set3);
    }

    @Override // org.nd4j.autodiff.samediff.internal.AbstractSession
    public /* bridge */ /* synthetic */ DataTypeCalc getAndParameterizeOp(String str, AbstractSession.FrameIter frameIter, Set set, Set set2, Set set3, Map<String, DataType> map) {
        return getAndParameterizeOp(str, frameIter, (Set<AbstractSession.VarId>) set, (Set<AbstractSession.VarId>) set2, (Set<String>) set3, map);
    }
}
