package org.tensorflow;

import java.io.IOException;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import org.tensorflow.Session;

/* loaded from: input_file:WEB-INF/lib/tensorflow-api-0.19.0.jar:org/tensorflow/SessionFunction.class */
public class SessionFunction implements TensorFunction {
    private final Signature signature;
    private final Session session;

    public SessionFunction(Signature signature, Session session) {
        this.signature = signature;
        this.session = session;
        signature.getInputs().forEach((str, tensorDescription) -> {
            TensorFunction.validateDescription(tensorDescription, session.graph(), str, "Input");
        });
        signature.getInputs().forEach((str2, tensorDescription2) -> {
            TensorFunction.validateDescription(tensorDescription2, session.graph(), str2, "Output");
        });
    }

    public static SessionFunction create(Signature signature, Session session) {
        return new SessionFunction(signature, session);
    }

    public void save(String str) throws IOException {
        SavedModelBundle.exporter(str).withFunction(this).export();
    }

    @Override // org.tensorflow.TensorFunction
    public Signature signature() {
        return this.signature;
    }

    public Session session() {
        return this.session;
    }

    public SessionFunction withNewSession(Session session) {
        return new SessionFunction(this.signature, session);
    }

    @Override // org.tensorflow.TensorFunction
    public Map<String, Tensor> call(Map<String, Tensor> map) {
        Session.Runner runner = this.session.runner();
        this.signature.getInputs().forEach((str, tensorDescription) -> {
            if (!map.containsKey(str)) {
                throw new IllegalArgumentException("No argument found for parameter \"" + str + "\"");
            }
            Tensor tensor = (Tensor) map.get(str);
            if (tensor == null) {
                throw new IllegalArgumentException("Can't pass null as an argument to a function.  Argument \"" + str + "\" was null.");
            }
            runner.feed(tensorDescription.name, tensor);
        });
        this.signature.getOutputs().values().forEach(tensorDescription2 -> {
            runner.fetch(tensorDescription2.name);
        });
        List<Tensor> run = runner.run();
        LinkedHashMap linkedHashMap = new LinkedHashMap(run.size());
        int i = 0;
        Iterator<String> it = this.signature.outputNames().iterator();
        while (it.hasNext()) {
            linkedHashMap.put(it.next(), run.get(i));
            i++;
        }
        return linkedHashMap;
    }
}
