package org.tensorflow;

import com.google.protobuf.InvalidProtocolBufferException;
import com.google.protobuf.Message;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import org.bytedeco.javacpp.BytePointer;
import org.bytedeco.javacpp.Pointer;
import org.bytedeco.javacpp.PointerPointer;
import org.bytedeco.javacpp.PointerScope;
import org.tensorflow.Graph;
import org.tensorflow.exceptions.TensorFlowException;
import org.tensorflow.internal.c_api.TF_Buffer;
import org.tensorflow.internal.c_api.TF_Graph;
import org.tensorflow.internal.c_api.TF_Operation;
import org.tensorflow.internal.c_api.TF_Output;
import org.tensorflow.internal.c_api.TF_Session;
import org.tensorflow.internal.c_api.TF_SessionOptions;
import org.tensorflow.internal.c_api.TF_Status;
import org.tensorflow.internal.c_api.TF_Tensor;
import org.tensorflow.internal.c_api.global.tensorflow;
import org.tensorflow.op.Op;
import org.tensorflow.proto.framework.ConfigProto;
import org.tensorflow.proto.framework.RunMetadata;
import org.tensorflow.proto.framework.RunOptions;
import org.tensorflow.proto.util.SaverDef;
import org.tensorflow.types.TString;

/* loaded from: input_file:org/tensorflow/Session.class */
public final class Session implements AutoCloseable {
    private final Graph graph;
    private final Graph.Reference graphRef;
    private final Object nativeHandleLock;
    private TF_Session nativeHandle;
    private int numActiveRuns;

    /* loaded from: input_file:org/tensorflow/Session$Run.class */
    public static final class Run {
        public List<Tensor> outputs;
        public RunMetadata metadata;
    }

    /* loaded from: input_file:org/tensorflow/Session$Runner.class */
    public final class Runner {
        private final ArrayList<Output<?>> inputs = new ArrayList<>();
        private final ArrayList<Tensor> inputTensors = new ArrayList<>();
        private final ArrayList<Output<?>> outputs = new ArrayList<>();
        private final ArrayList<GraphOperation> targets = new ArrayList<>();
        private RunOptions runOptions = null;

        /* JADX INFO: Access modifiers changed from: private */
        /* loaded from: input_file:org/tensorflow/Session$Runner$Reference.class */
        public class Reference implements AutoCloseable {
            public Reference() {
                synchronized (Session.this.nativeHandleLock) {
                    if (Session.this.nativeHandle == null || Session.this.nativeHandle.isNull()) {
                        throw new IllegalStateException("run() cannot be called on the Session after close()");
                    }
                    Session.access$304(Session.this);
                }
            }

            @Override // java.lang.AutoCloseable
            public void close() {
                synchronized (Session.this.nativeHandleLock) {
                    if (Session.this.nativeHandle == null || Session.this.nativeHandle.isNull()) {
                        return;
                    }
                    if (Session.access$306(Session.this) == 0) {
                        Session.this.nativeHandleLock.notifyAll();
                    }
                }
            }
        }

        public Runner() {
        }

        public Runner feed(String str, Tensor tensor) {
            return feed(parseOutput(str), tensor);
        }

        public Runner feed(String str, int i, Tensor tensor) {
            GraphOperation operationByName = operationByName(str);
            if (operationByName != null) {
                this.inputs.add(operationByName.output(i));
                this.inputTensors.add(tensor);
            }
            return this;
        }

        public Runner feed(Operand<?> operand, Tensor tensor) {
            this.inputs.add(operand.asOutput());
            this.inputTensors.add(tensor);
            return this;
        }

        public Runner fetch(String str) {
            return fetch(parseOutput(str));
        }

        public Runner fetch(String str, int i) {
            GraphOperation operationByName = operationByName(str);
            if (operationByName != null) {
                this.outputs.add(operationByName.output(i));
            }
            return this;
        }

        public Runner fetch(Output<?> output) {
            this.outputs.add(output);
            return this;
        }

        public Runner fetch(Operand<?> operand) {
            return fetch(operand.asOutput());
        }

        public Runner addTarget(String str) {
            GraphOperation operationByName = operationByName(str);
            if (operationByName != null) {
                this.targets.add(operationByName);
            }
            return this;
        }

        public Runner addTarget(Operation operation) {
            if (!(operation instanceof GraphOperation)) {
                throw new IllegalArgumentException("Operation of type " + operation.getClass().getName() + " is not supported in graph sessions");
            }
            this.targets.add((GraphOperation) operation);
            return this;
        }

        public Runner addTarget(Op op) {
            return addTarget(op.op());
        }

        public Runner setOptions(RunOptions runOptions) {
            this.runOptions = runOptions;
            return this;
        }

        public List<Tensor> run() {
            return runHelper(false).outputs;
        }

        public Run runAndFetchMetadata() {
            return runHelper(true);
        }

        private Run runHelper(boolean z) {
            TF_Tensor[] tF_TensorArr = new TF_Tensor[this.inputTensors.size()];
            TF_Operation[] tF_OperationArr = new TF_Operation[this.inputs.size()];
            int[] iArr = new int[this.inputs.size()];
            TF_Operation[] tF_OperationArr2 = new TF_Operation[this.outputs.size()];
            int[] iArr2 = new int[this.outputs.size()];
            TF_Operation[] tF_OperationArr3 = new TF_Operation[this.targets.size()];
            int i = 0;
            Iterator<Tensor> it = this.inputTensors.iterator();
            while (it.hasNext()) {
                int i2 = i;
                i++;
                tF_TensorArr[i2] = it.next().asRawTensor().nativeHandle();
            }
            int i3 = 0;
            Iterator<Output<?>> it2 = this.inputs.iterator();
            while (it2.hasNext()) {
                Output<?> next = it2.next();
                tF_OperationArr[i3] = (TF_Operation) next.getUnsafeNativeHandle();
                iArr[i3] = next.index();
                i3++;
            }
            int i4 = 0;
            Iterator<Output<?>> it3 = this.outputs.iterator();
            while (it3.hasNext()) {
                Output<?> next2 = it3.next();
                tF_OperationArr2[i4] = (TF_Operation) next2.getUnsafeNativeHandle();
                iArr2[i4] = next2.index();
                i4++;
            }
            int i5 = 0;
            Iterator<GraphOperation> it4 = this.targets.iterator();
            while (it4.hasNext()) {
                int i6 = i5;
                i5++;
                tF_OperationArr3[i6] = it4.next().getUnsafeNativeHandle();
            }
            Reference reference = new Reference();
            ArrayList arrayList = new ArrayList();
            try {
                try {
                    RunMetadata run = Session.run(Session.this.nativeHandle, this.runOptions, tF_TensorArr, tF_OperationArr, iArr, tF_OperationArr2, iArr2, tF_OperationArr3, z, arrayList);
                    reference.close();
                    Run run2 = new Run();
                    run2.outputs = arrayList;
                    run2.metadata = run;
                    return run2;
                } catch (Exception e) {
                    Iterator it5 = arrayList.iterator();
                    while (it5.hasNext()) {
                        ((Tensor) it5.next()).close();
                    }
                    arrayList.clear();
                    throw e;
                }
            } catch (Throwable th) {
                reference.close();
                throw th;
            }
        }

        private GraphOperation operationByName(String str) {
            GraphOperation operation = Session.this.graph.operation(str);
            if (operation == null) {
                throw new IllegalArgumentException("No Operation named [" + str + "] in the Graph");
            }
            return operation;
        }

        private Output<?> parseOutput(String str) {
            int lastIndexOf = str.lastIndexOf(58);
            if (lastIndexOf == -1 || lastIndexOf == str.length() - 1) {
                return new Output<>(operationByName(str), 0);
            }
            try {
                String substring = str.substring(0, lastIndexOf);
                return new Output<>(operationByName(substring), Integer.parseInt(str.substring(lastIndexOf + 1)));
            } catch (NumberFormatException e) {
                return new Output<>(operationByName(str), 0);
            }
        }
    }

    public Session(Graph graph) {
        this(graph, (ConfigProto) null);
    }

    public Session(Graph graph, ConfigProto configProto) {
        this.nativeHandleLock = new Object();
        this.graph = graph;
        Graph.Reference ref = graph.ref();
        try {
            this.nativeHandle = allocate(ref.nativeHandle(), null, configProto);
            this.graphRef = graph.ref();
            ref.close();
        } catch (Throwable th) {
            ref.close();
            throw th;
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public Session(Graph graph, TF_Session tF_Session) {
        this.nativeHandleLock = new Object();
        this.graph = graph;
        this.nativeHandle = tF_Session;
        this.graphRef = graph.ref();
    }

    @Override // java.lang.AutoCloseable
    public void close() {
        this.graphRef.close();
        synchronized (this.nativeHandleLock) {
            if (this.nativeHandle == null || this.nativeHandle.isNull()) {
                return;
            }
            while (this.numActiveRuns > 0) {
                try {
                    this.nativeHandleLock.wait();
                } catch (InterruptedException e) {
                    Thread.currentThread().interrupt();
                    return;
                }
            }
            delete(this.nativeHandle);
            this.nativeHandle = null;
        }
    }

    public Runner runner() {
        return new Runner();
    }

    public void run(String str) {
        GraphOperation operation = this.graph.operation(str);
        if (operation == null) {
            throw new IllegalArgumentException("Operation named '" + str + "' cannot be found in the graph");
        }
        runner().addTarget(operation).run();
    }

    public void run(Op op) {
        runner().addTarget(op.op()).run();
    }

    public void runInit() {
        Runner runner = runner();
        List<Op> initializers = this.graph.initializers();
        runner.getClass();
        initializers.forEach(runner::addTarget);
        runner.run();
    }

    public void save(String str) {
        SaverDef saverDef = this.graph.saverDef();
        runner().addTarget(saverDef.getSaveTensorName()).feed(saverDef.getFilenameTensorName(), TString.scalarOf(str)).run();
    }

    public void restore(String str) {
        SaverDef saverDef = this.graph.saverDef();
        runner().addTarget(saverDef.getRestoreOpName()).feed(saverDef.getFilenameTensorName(), TString.scalarOf(str)).run();
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public Graph graph() {
        return this.graph;
    }

    private static void requireHandle(Pointer pointer) {
        if (pointer == null || pointer.isNull()) {
            throw new IllegalStateException("close() has been called on the Session");
        }
    }

    private static void resolveHandles(String str, Pointer[] pointerArr, PointerPointer pointerPointer, int i) {
        if (pointerArr.length != i) {
            throw new IllegalArgumentException("expected " + i + ", got " + pointerArr.length + " " + str);
        }
        for (int i2 = 0; i2 < i; i2++) {
            if (pointerArr[i2] == null || pointerArr[i2].isNull()) {
                throw new IllegalStateException("invalid " + str + " (#" + i2 + " of " + i + ")");
            }
            pointerPointer.put(i2, pointerArr[i2]);
        }
    }

    private static TF_Session allocate(TF_Graph tF_Graph, String str, ConfigProto configProto) {
        if (tF_Graph == null || tF_Graph.isNull()) {
            throw new IllegalStateException("Graph has been close()d");
        }
        PointerScope pointerScope = new PointerScope(new Class[0]);
        Throwable th = null;
        try {
            try {
                TF_Status newStatus = TF_Status.newStatus();
                TF_SessionOptions newSessionOptions = TF_SessionOptions.newSessionOptions();
                if (configProto != null) {
                    BytePointer bytePointer = new BytePointer(configProto.toByteArray());
                    tensorflow.TF_SetConfig(newSessionOptions, bytePointer, bytePointer.capacity(), newStatus);
                    newStatus.throwExceptionIfNotOK();
                }
                TF_Session newSession = TF_Session.newSession(tF_Graph, newSessionOptions, newStatus);
                newStatus.throwExceptionIfNotOK();
                TF_Session tF_Session = (TF_Session) newSession.retainReference();
                if (pointerScope != null) {
                    if (0 != 0) {
                        try {
                            pointerScope.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    } else {
                        pointerScope.close();
                    }
                }
                return tF_Session;
            } finally {
            }
        } catch (Throwable th3) {
            if (pointerScope != null) {
                if (th != null) {
                    try {
                        pointerScope.close();
                    } catch (Throwable th4) {
                        th.addSuppressed(th4);
                    }
                } else {
                    pointerScope.close();
                }
            }
            throw th3;
        }
    }

    private static void delete(TF_Session tF_Session) {
        requireHandle(tF_Session);
        tF_Session.releaseReference();
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static RunMetadata run(TF_Session tF_Session, RunOptions runOptions, TF_Tensor[] tF_TensorArr, TF_Operation[] tF_OperationArr, int[] iArr, TF_Operation[] tF_OperationArr2, int[] iArr2, TF_Operation[] tF_OperationArr3, boolean z, List<Tensor> list) {
        RunMetadata parseFrom;
        requireHandle(tF_Session);
        int length = tF_TensorArr.length;
        int length2 = tF_OperationArr2.length;
        int length3 = tF_OperationArr3.length;
        PointerScope pointerScope = new PointerScope(new Class[0]);
        Throwable th = null;
        try {
            TF_Output tF_Output = new TF_Output(length);
            PointerPointer pointerPointer = new PointerPointer(length);
            TF_Output tF_Output2 = new TF_Output(length2);
            PointerPointer pointerPointer2 = new PointerPointer(length2);
            PointerPointer pointerPointer3 = new PointerPointer(length3);
            TF_Buffer newBuffer = z ? TF_Buffer.newBuffer() : null;
            resolveHandles("input Tensors", tF_TensorArr, pointerPointer, length);
            Graph.resolveOutputs("input", tF_OperationArr, iArr, tF_Output, length);
            Graph.resolveOutputs("output", tF_OperationArr2, iArr2, tF_Output2, length2);
            resolveHandles("target Operations", tF_OperationArr3, pointerPointer3, length3);
            TF_Status newStatus = TF_Status.newStatus();
            tensorflow.TF_SessionRun(tF_Session, TF_Buffer.newBufferFromString((Message) runOptions), tF_Output, pointerPointer, length, tF_Output2, pointerPointer2, length2, pointerPointer3, length3, newBuffer, newStatus);
            newStatus.throwExceptionIfNotOK();
            for (int i = 0; i < length2; i++) {
                list.add(RawTensor.fromHandle(((TF_Tensor) pointerPointer2.get(TF_Tensor.class, i)).withDeallocator()).asTypedTensor());
            }
            if (newBuffer != null) {
                try {
                    parseFrom = RunMetadata.parseFrom(newBuffer.dataAsByteBuffer());
                } catch (InvalidProtocolBufferException e) {
                    throw new TensorFlowException("Cannot parse RunMetadata protocol buffer", e);
                }
            } else {
                parseFrom = null;
            }
            return parseFrom;
        } finally {
            if (pointerScope != null) {
                if (0 != 0) {
                    try {
                        pointerScope.close();
                    } catch (Throwable th2) {
                        th.addSuppressed(th2);
                    }
                } else {
                    pointerScope.close();
                }
            }
        }
    }

    static /* synthetic */ int access$304(Session session) {
        int i = session.numActiveRuns + 1;
        session.numActiveRuns = i;
        return i;
    }

    static /* synthetic */ int access$306(Session session) {
        int i = session.numActiveRuns - 1;
        session.numActiveRuns = i;
        return i;
    }
}
