package org.tensorflow;

import com.google.protobuf.InvalidProtocolBufferException;
import java.io.FileOutputStream;
import java.io.IOException;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;
import org.bytedeco.javacpp.BytePointer;
import org.bytedeco.javacpp.PointerScope;
import org.tensorflow.exceptions.TensorFlowException;
import org.tensorflow.internal.c_api.TF_Buffer;
import org.tensorflow.internal.c_api.TF_Graph;
import org.tensorflow.internal.c_api.TF_Session;
import org.tensorflow.internal.c_api.TF_SessionOptions;
import org.tensorflow.internal.c_api.TF_Status;
import org.tensorflow.internal.c_api.global.tensorflow;
import org.tensorflow.proto.framework.CollectionDef;
import org.tensorflow.proto.framework.ConfigProto;
import org.tensorflow.proto.framework.MetaGraphDef;
import org.tensorflow.proto.framework.RunOptions;
import org.tensorflow.proto.framework.SavedModel;
import org.tensorflow.proto.framework.SignatureDef;
import org.tensorflow.proto.framework.TensorInfo;
import org.tensorflow.proto.util.SaverDef;

/* loaded from: input_file:WEB-INF/lib/tensorflow-api-0.19.0.jar:org/tensorflow/SavedModelBundle.class */
public class SavedModelBundle implements AutoCloseable {
    public static final String DEFAULT_TAG = "serve";
    private static final String JAVA_INIT_OP_SIGNATURE_KEY = "__saved_model_java_init_op_tracker";
    private static final String INIT_OP_SIGNATURE_KEY = "__saved_model_init_op";
    private static final String MAIN_OP_COLLECTION_KEY = "saved_model_main_op";
    private static final String LEGACY_INIT_OP_COLLECTION_KEY = "legacy_init_op";
    private static final String TABLE_INITIALIZERS_COLLECTION_KEY = "table_initializer";
    private final Graph graph;
    private final Session session;
    private final MetaGraphDef metaGraphDef;
    private final Map<String, SessionFunction> functions;

    /* loaded from: input_file:WEB-INF/lib/tensorflow-api-0.19.0.jar:org/tensorflow/SavedModelBundle$Exporter.class */
    public static final class Exporter {
        private final String exportDir;
        private Session session;
        private String[] tags = {SavedModelBundle.DEFAULT_TAG};
        private final MetaGraphDef.Builder metaGraphDefBuilder = MetaGraphDef.newBuilder();
        private final Map<String, SessionFunction> functions = new LinkedHashMap();

        public Exporter withTags(String... strArr) {
            SavedModelBundle.validateTags(strArr);
            this.tags = strArr;
            return this;
        }

        public Exporter withSession(Session session) {
            if (this.session != null && this.session != session) {
                throw new IllegalStateException("This exporter already has a session that differs from the passed session");
            }
            this.session = session;
            return this;
        }

        public Exporter withFunction(SessionFunction sessionFunction) {
            Signature signature = sessionFunction.signature();
            if (this.functions.containsKey(signature.key())) {
                throw new IllegalArgumentException("Function \"" + signature.key() + "\" was already added to the model");
            }
            if (this.session != null && this.session != sessionFunction.session()) {
                throw new UnsupportedOperationException("This exporter already has a session that differs from the passed function's session");
            }
            this.session = sessionFunction.session();
            this.functions.put(signature.key(), sessionFunction);
            this.metaGraphDefBuilder.putSignatureDef(signature.key(), signature.asSignatureDef());
            return this;
        }

        public Exporter withFunctions(SessionFunction... sessionFunctionArr) {
            for (SessionFunction sessionFunction : sessionFunctionArr) {
                withFunction(sessionFunction);
            }
            return this;
        }

        public Exporter withSignature(Signature signature) {
            if (this.session == null) {
                throw new IllegalStateException("Session has not been set yet, you must call withSession or withFunction first.");
            }
            return withFunction(this.session.function(signature));
        }

        public Exporter withSignatures(Signature... signatureArr) {
            for (Signature signature : signatureArr) {
                withSignature(signature);
            }
            return this;
        }

        public void export() throws IOException {
            if (this.functions.isEmpty()) {
                throw new IllegalStateException("Model should contain at least one valid function");
            }
            Graph graph = this.session.graph();
            SaverDef saverDef = graph.saverDef();
            GraphOperation graphOperation = null;
            if (!this.functions.containsKey(SavedModelBundle.JAVA_INIT_OP_SIGNATURE_KEY)) {
                graphOperation = graph.addInitOp(true);
            }
            MetaGraphDef.Builder metaInfoDef = this.metaGraphDefBuilder.setSaverDef(saverDef).setGraphDef(graph.toGraphDef()).setMetaInfoDef(MetaGraphDef.MetaInfoDef.newBuilder().addAllTags(Arrays.asList(this.tags)));
            this.functions.forEach((str, sessionFunction) -> {
                metaInfoDef.putSignatureDef(str, sessionFunction.signature().asSignatureDef());
            });
            if (!this.functions.containsKey(SavedModelBundle.JAVA_INIT_OP_SIGNATURE_KEY)) {
                metaInfoDef.putSignatureDef(SavedModelBundle.JAVA_INIT_OP_SIGNATURE_KEY, SignatureDef.newBuilder().putOutputs(SavedModelBundle.JAVA_INIT_OP_SIGNATURE_KEY, TensorInfo.newBuilder().setName(graphOperation.name() + ":0").build()).build());
            }
            Path path = Paths.get(this.exportDir, "variables");
            path.toFile().mkdirs();
            this.session.save(path.resolve("variables").toString());
            SavedModel build = SavedModel.newBuilder().addMetaGraphs(metaInfoDef).build();
            FileOutputStream fileOutputStream = new FileOutputStream(Paths.get(this.exportDir, "saved_model.pb").toString());
            Throwable th = null;
            try {
                build.writeTo(fileOutputStream);
                if (fileOutputStream != null) {
                    if (0 == 0) {
                        fileOutputStream.close();
                        return;
                    }
                    try {
                        fileOutputStream.close();
                    } catch (Throwable th2) {
                        th.addSuppressed(th2);
                    }
                }
            } catch (Throwable th3) {
                if (fileOutputStream != null) {
                    if (0 != 0) {
                        try {
                            fileOutputStream.close();
                        } catch (Throwable th4) {
                            th.addSuppressed(th4);
                        }
                    } else {
                        fileOutputStream.close();
                    }
                }
                throw th3;
            }
        }

        Exporter(String str) {
            this.exportDir = str;
        }
    }

    /* loaded from: input_file:WEB-INF/lib/tensorflow-api-0.19.0.jar:org/tensorflow/SavedModelBundle$Loader.class */
    public static final class Loader {
        private String exportDir;
        private String[] tags;
        private ConfigProto configProto;
        private RunOptions runOptions;

        public SavedModelBundle load() {
            return SavedModelBundle.load(this.exportDir, this.tags, this.configProto, this.runOptions);
        }

        public Loader withRunOptions(RunOptions runOptions) {
            this.runOptions = runOptions;
            return this;
        }

        public Loader withConfigProto(ConfigProto configProto) {
            this.configProto = configProto;
            return this;
        }

        public Loader withTags(String... strArr) {
            SavedModelBundle.validateTags(strArr);
            this.tags = strArr;
            return this;
        }

        private Loader(String str) {
            this.exportDir = null;
            this.tags = new String[]{SavedModelBundle.DEFAULT_TAG};
            this.configProto = null;
            this.runOptions = null;
            this.exportDir = str;
        }
    }

    public static SavedModelBundle load(String str, String... strArr) {
        Loader loader = loader(str);
        if (strArr != null && strArr.length > 0) {
            loader.withTags(strArr);
        }
        return loader.load();
    }

    public static Loader loader(String str) {
        return new Loader(str);
    }

    public static Exporter exporter(String str) {
        return new Exporter(str);
    }

    public MetaGraphDef metaGraphDef() {
        return this.metaGraphDef;
    }

    public Graph graph() {
        return this.graph;
    }

    public Session session() {
        return this.session;
    }

    public List<Signature> signatures() {
        return (List) this.functions.values().stream().map((v0) -> {
            return v0.signature();
        }).filter(signature -> {
            return (signature.key().equals(INIT_OP_SIGNATURE_KEY) || signature.key().equals(JAVA_INIT_OP_SIGNATURE_KEY)) ? false : true;
        }).collect(Collectors.toList());
    }

    public SessionFunction function(String str) {
        SessionFunction sessionFunction = this.functions.get(str);
        if (sessionFunction == null) {
            throw new IllegalArgumentException(String.format("Function with signature [%s] not found", str));
        }
        return sessionFunction;
    }

    public List<SessionFunction> functions() {
        return new ArrayList(this.functions.values());
    }

    public Map<String, Tensor> call(Map<String, Tensor> map) {
        SessionFunction next = this.functions.size() == 1 ? this.functions.values().iterator().next() : this.functions.get(Signature.DEFAULT_KEY);
        if (next == null) {
            throw new IllegalArgumentException("Cannot elect a default function for this model");
        }
        return next.call(map);
    }

    @Override // java.lang.AutoCloseable
    public void close() {
        this.session.close();
        this.graph.close();
    }

    private SavedModelBundle(Graph graph, Session session, MetaGraphDef metaGraphDef, Map<String, Signature> map) {
        this.graph = graph;
        this.session = session;
        this.metaGraphDef = metaGraphDef;
        this.functions = (Map) map.entrySet().stream().collect(Collectors.toMap((v0) -> {
            return v0.getKey();
        }, entry -> {
            return new SessionFunction((Signature) entry.getValue(), session);
        }));
    }

    private static GraphOperation findInitOp(Graph graph, Map<String, Signature> map, Map<String, CollectionDef> map2) {
        Signature signature = map.get(INIT_OP_SIGNATURE_KEY);
        if (signature != null) {
            return (GraphOperation) graph.outputOrThrow(signature.getOutputs().get(INIT_OP_SIGNATURE_KEY).name).op();
        }
        CollectionDef collectionDef = map2.containsKey(MAIN_OP_COLLECTION_KEY) ? map2.get(MAIN_OP_COLLECTION_KEY) : map2.get(LEGACY_INIT_OP_COLLECTION_KEY);
        if (collectionDef == null) {
            return null;
        }
        CollectionDef.NodeList nodeList = collectionDef.getNodeList();
        if (nodeList.getValueCount() != 1) {
            throw new IllegalArgumentException("Expected exactly one main op in saved model.");
        }
        return (GraphOperation) graph.outputOrThrow(nodeList.getValue(0)).op();
    }

    private static SavedModelBundle fromHandle(TF_Graph tF_Graph, TF_Session tF_Session, MetaGraphDef metaGraphDef) {
        Graph graph = new Graph(tF_Graph, metaGraphDef.getSaverDef());
        Session session = new Session(graph, tF_Session);
        HashMap hashMap = new HashMap(metaGraphDef.getSignatureDefCount());
        metaGraphDef.getSignatureDefMap().forEach((str, signatureDef) -> {
            if (hashMap.containsKey(str)) {
                return;
            }
            hashMap.put(str, new Signature(str, signatureDef));
        });
        GraphOperation findInitOp = findInitOp(graph, hashMap, metaGraphDef.getCollectionDefMap());
        if (findInitOp != null) {
            graph.registerInitOp(findInitOp);
        }
        if (hashMap.containsKey(JAVA_INIT_OP_SIGNATURE_KEY)) {
            graph.registerInitOp(graph.outputOrThrow(((Signature) hashMap.get(JAVA_INIT_OP_SIGNATURE_KEY)).getOutputs().get(JAVA_INIT_OP_SIGNATURE_KEY).name).op());
        }
        session.setInitialized();
        if (metaGraphDef.containsCollectionDef(TABLE_INITIALIZERS_COLLECTION_KEY)) {
            metaGraphDef.getCollectionDefMap().get(TABLE_INITIALIZERS_COLLECTION_KEY).getNodeList().getValueList().forEach(str2 -> {
                graph.registerInitOp(graph.operationOrThrow(str2));
            });
        }
        return new SavedModelBundle(graph, session, metaGraphDef, hashMap);
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static SavedModelBundle load(String str, String[] strArr, ConfigProto configProto, RunOptions runOptions) {
        PointerScope pointerScope = new PointerScope(new Class[0]);
        Throwable th = null;
        try {
            TF_Status newStatus = TF_Status.newStatus();
            TF_SessionOptions newSessionOptions = TF_SessionOptions.newSessionOptions();
            if (configProto != null) {
                BytePointer bytePointer = new BytePointer(configProto.toByteArray());
                tensorflow.TF_SetConfig(newSessionOptions, bytePointer, bytePointer.capacity(), newStatus);
                newStatus.throwExceptionIfNotOK();
            }
            TF_Buffer newBufferFromString = TF_Buffer.newBufferFromString(runOptions);
            TF_Graph TF_NewGraph = tensorflow.TF_NewGraph();
            TF_Buffer newBuffer = TF_Buffer.newBuffer();
            TF_Session loadSessionFromSavedModel = TF_Session.loadSessionFromSavedModel(newSessionOptions, newBufferFromString, str, strArr, TF_NewGraph, newBuffer, newStatus);
            newStatus.throwExceptionIfNotOK();
            try {
                SavedModelBundle fromHandle = fromHandle(TF_NewGraph, loadSessionFromSavedModel, MetaGraphDef.parseFrom(newBuffer.dataAsByteBuffer()));
                TF_NewGraph.retainReference();
                loadSessionFromSavedModel.retainReference();
                fromHandle.session.initialize();
                return fromHandle;
            } catch (InvalidProtocolBufferException e) {
                throw new TensorFlowException("Cannot parse MetaGraphDef protocol buffer", e);
            }
        } finally {
            if (pointerScope != null) {
                if (0 != 0) {
                    try {
                        pointerScope.close();
                    } catch (Throwable th2) {
                        th.addSuppressed(th2);
                    }
                } else {
                    pointerScope.close();
                }
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static void validateTags(String[] strArr) {
        if (strArr == null || Arrays.stream(strArr).anyMatch((v0) -> {
            return Objects.isNull(v0);
        })) {
            throw new IllegalArgumentException("Invalid tags: " + Arrays.toString(strArr));
        }
    }

    static {
        try {
            Class.forName("org.tensorflow.TensorFlow");
        } catch (ClassNotFoundException e) {
            throw new RuntimeException(e);
        }
    }
}
