package cc.mallet.grmm.inference;

import cc.mallet.grmm.types.AbstractTableFactor;
import cc.mallet.grmm.types.Assignment;
import cc.mallet.grmm.types.BitVarSet;
import cc.mallet.grmm.types.ConstantFactor;
import cc.mallet.grmm.types.Factor;
import cc.mallet.grmm.types.FactorGraph;
import cc.mallet.grmm.types.LogTableFactor;
import cc.mallet.grmm.types.TableFactor;
import cc.mallet.grmm.types.VarSet;
import cc.mallet.grmm.types.Variable;
import cc.mallet.grmm.util.Graphs;
import cc.mallet.types.Alphabet;
import cc.mallet.util.MalletLogger;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.Iterator;
import java.util.List;
import java.util.TreeSet;
import java.util.logging.Level;
import java.util.logging.Logger;
import org._3pq.jgrapht.GraphHelper;
import org._3pq.jgrapht.UndirectedGraph;
import org._3pq.jgrapht.alg.ConnectivityInspector;
import org._3pq.jgrapht.graph.ListenableUndirectedGraph;
import org._3pq.jgrapht.graph.SimpleGraph;
import org._3pq.jgrapht.traverse.BreadthFirstIterator;

/* loaded from: input_file:WEB-INF/lib/mallet-2.0.7.jar:cc/mallet/grmm/inference/JunctionTreeInferencer.class */
public class JunctionTreeInferencer extends AbstractInferencer {
    private static Logger logger;
    private boolean inLogSpace;
    private JunctionTreePropagation propagator;
    protected transient JunctionTree jtCurrent;
    private transient ArrayList cliques;
    private static Comparator sepsetChooser;
    private transient int totalMessagesSent;
    private static final long serialVersionUID = 1;
    static final /* synthetic */ boolean $assertionsDisabled;

    public JunctionTreeInferencer() {
        this(JunctionTreePropagation.createSumProductInferencer());
    }

    public JunctionTreeInferencer(JunctionTreePropagation junctionTreePropagation) {
        this.totalMessagesSent = 0;
        this.propagator = junctionTreePropagation;
    }

    public static JunctionTreeInferencer createForMaxProduct() {
        return new JunctionTreeInferencer(JunctionTreePropagation.createMaxProductInferencer());
    }

    private boolean isAdjacent(UndirectedGraph undirectedGraph, Variable variable, Variable variable2) {
        return undirectedGraph.getEdge(variable, variable2) != null;
    }

    private int newEdgesRequired(UndirectedGraph undirectedGraph, Variable variable) {
        int i = 0;
        Iterator neighborsIterator = neighborsIterator(undirectedGraph, variable);
        while (neighborsIterator.hasNext()) {
            Variable variable2 = (Variable) neighborsIterator.next();
            Iterator neighborsIterator2 = neighborsIterator(undirectedGraph, variable);
            while (neighborsIterator2.hasNext()) {
                Variable variable3 = (Variable) neighborsIterator2.next();
                if (variable2 != variable3 && !isAdjacent(undirectedGraph, variable2, variable3)) {
                    i++;
                }
            }
        }
        return i;
    }

    private int weightRequired(UndirectedGraph undirectedGraph, Variable variable) {
        int i = 1;
        Iterator neighborsIterator = neighborsIterator(undirectedGraph, variable);
        while (neighborsIterator.hasNext()) {
            i *= ((Variable) neighborsIterator.next()).getNumOutcomes();
        }
        return i;
    }

    private void connectNeighbors(UndirectedGraph undirectedGraph, Variable variable) {
        Iterator neighborsIterator = neighborsIterator(undirectedGraph, variable);
        while (neighborsIterator.hasNext()) {
            Variable variable2 = (Variable) neighborsIterator.next();
            Iterator neighborsIterator2 = neighborsIterator(undirectedGraph, variable);
            while (neighborsIterator2.hasNext()) {
                Variable variable3 = (Variable) neighborsIterator2.next();
                if (variable2 != variable3 && !isAdjacent(undirectedGraph, variable2, variable3)) {
                    try {
                        undirectedGraph.addEdge(variable2, variable3);
                    } catch (Exception e) {
                        throw new RuntimeException(e);
                    }
                }
            }
        }
    }

    private boolean findSuperClique(List list, VarSet varSet) {
        Iterator it = list.iterator();
        while (it.hasNext()) {
            if (((VarSet) it.next()).containsAll(varSet)) {
                return true;
            }
        }
        return false;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static int cmp(int i, int i2) {
        if (i < i2) {
            return -1;
        }
        return i > i2 ? 1 : 0;
    }

    public Variable pickVertexToRemove(UndirectedGraph undirectedGraph, ArrayList arrayList) {
        int weightRequired;
        Iterator it = arrayList.iterator();
        Variable variable = (Variable) it.next();
        int newEdgesRequired = newEdgesRequired(undirectedGraph, variable);
        int weightRequired2 = weightRequired(undirectedGraph, variable);
        while (it.hasNext()) {
            Variable variable2 = (Variable) it.next();
            int newEdgesRequired2 = newEdgesRequired(undirectedGraph, variable2);
            if (newEdgesRequired2 < newEdgesRequired) {
                variable = variable2;
                newEdgesRequired = newEdgesRequired2;
                weightRequired2 = weightRequired(undirectedGraph, variable2);
            } else if (newEdgesRequired2 == newEdgesRequired && (weightRequired = weightRequired(undirectedGraph, variable2)) < weightRequired2) {
                variable = variable2;
                newEdgesRequired = newEdgesRequired2;
                weightRequired2 = weightRequired;
            }
        }
        return variable;
    }

    private void triangulate(UndirectedGraph undirectedGraph) {
        UndirectedGraph dupGraph = dupGraph(undirectedGraph);
        ArrayList arrayList = new ArrayList(undirectedGraph.vertexSet());
        makeVertexMap(arrayList);
        this.cliques = new ArrayList();
        if (logger.isLoggable(Level.FINER)) {
            logger.finer("Triangulating model: " + undirectedGraph);
            String str = "";
            for (int i = 0; i < arrayList.size(); i++) {
                str = str + ((Variable) arrayList.get(i)).toString() + "\n";
            }
            logger.finer(str);
        }
        while (!arrayList.isEmpty()) {
            Variable pickVertexToRemove = pickVertexToRemove(dupGraph, arrayList);
            logger.finer("Triangulating vertex " + pickVertexToRemove);
            BitVarSet bitVarSet = new BitVarSet(pickVertexToRemove.getUniverse(), GraphHelper.neighborListOf(dupGraph, pickVertexToRemove));
            bitVarSet.add(pickVertexToRemove);
            if (!findSuperClique(this.cliques, bitVarSet)) {
                this.cliques.add(bitVarSet);
                if (logger.isLoggable(Level.FINER)) {
                    logger.finer("  Elim clique " + bitVarSet + " size " + bitVarSet.size() + " weight " + bitVarSet.weight());
                }
            }
            connectNeighbors(dupGraph, pickVertexToRemove);
            arrayList.remove(pickVertexToRemove);
            dupGraph.removeVertex(pickVertexToRemove);
        }
        if (logger.isLoggable(Level.FINE)) {
            logger.fine("Triangulation done. Cliques are: ");
            int i2 = 0;
            int i3 = 0;
            int i4 = 0;
            int i5 = 0;
            Iterator it = this.cliques.iterator();
            while (it.hasNext()) {
                VarSet varSet = (VarSet) it.next();
                logger.finer(varSet.toString());
                i2 += varSet.size();
                i4 = Math.max(varSet.size(), i4);
                i3 += varSet.weight();
                i5 = Math.max(varSet.weight(), i5);
            }
            double size = this.cliques.size();
            logger.fine("Jt created " + size + " cliques. Size: avg " + (i2 / size) + " max " + i4 + " Weight: avg " + (i3 / size) + " max " + i5);
        }
    }

    private Alphabet makeVertexMap(ArrayList arrayList) {
        Alphabet alphabet = new Alphabet(arrayList.size(), Variable.class);
        alphabet.lookupIndices(arrayList.toArray(), true);
        return alphabet;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static int sepsetSize(BitVarSet[] bitVarSetArr) {
        if ($assertionsDisabled || bitVarSetArr.length == 2) {
            return bitVarSetArr[0].intersectionSize(bitVarSetArr[1]);
        }
        throw new AssertionError();
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static int sepsetCost(VarSet[] varSetArr) {
        if ($assertionsDisabled || varSetArr.length == 2) {
            return varSetArr[0].weight() + varSetArr[1].weight();
        }
        throw new AssertionError();
    }

    private JunctionTree graphToJt(UndirectedGraph undirectedGraph) {
        JunctionTree junctionTree = new JunctionTree(undirectedGraph.vertexSet().size());
        Object next = undirectedGraph.vertexSet().iterator().next();
        junctionTree.add(next);
        BreadthFirstIterator breadthFirstIterator = new BreadthFirstIterator(undirectedGraph, next);
        while (breadthFirstIterator.hasNext()) {
            Object next2 = breadthFirstIterator.next();
            for (Object obj : GraphHelper.neighborListOf(undirectedGraph, next2)) {
                if (junctionTree.getParent(next2) != obj) {
                    junctionTree.addNode(next2, obj);
                }
            }
        }
        return junctionTree;
    }

    private JunctionTree buildJtStructure() {
        BitVarSet bitVarSet;
        TreeSet treeSet = new TreeSet(sepsetChooser);
        Iterator it = this.cliques.iterator();
        while (it.hasNext()) {
            BitVarSet bitVarSet2 = (BitVarSet) it.next();
            Iterator it2 = this.cliques.iterator();
            while (it2.hasNext() && bitVarSet2 != (bitVarSet = (BitVarSet) it2.next())) {
                treeSet.add(new BitVarSet[]{bitVarSet2, bitVarSet});
            }
        }
        ListenableUndirectedGraph listenableUndirectedGraph = new ListenableUndirectedGraph(new SimpleGraph());
        Iterator it3 = this.cliques.iterator();
        while (it3.hasNext()) {
            listenableUndirectedGraph.addVertex((VarSet) it3.next());
        }
        ConnectivityInspector connectivityInspector = new ConnectivityInspector(listenableUndirectedGraph);
        listenableUndirectedGraph.addGraphListener(connectivityInspector);
        int size = this.cliques.size();
        int i = 0;
        while (i < size - 1) {
            VarSet[] varSetArr = (VarSet[]) treeSet.first();
            treeSet.remove(varSetArr);
            if (!connectivityInspector.pathExists(varSetArr[0], varSetArr[1])) {
                listenableUndirectedGraph.addEdge(varSetArr[0], varSetArr[1]);
                i++;
            }
        }
        JunctionTree graphToJt = graphToJt(listenableUndirectedGraph);
        if (logger.isLoggable(Level.FINER)) {
            logger.finer("  jt structure was " + graphToJt);
        }
        return graphToJt;
    }

    private void initJtCpts(FactorGraph factorGraph, JunctionTree junctionTree) {
        Iterator verticesIterator = junctionTree.getVerticesIterator();
        while (verticesIterator.hasNext()) {
            junctionTree.setCPF((VarSet) verticesIterator.next(), new ConstantFactor(1.0d));
        }
        for (Factor factor : factorGraph.factors()) {
            VarSet findParentCluster = junctionTree.findParentCluster(factor.varSet());
            if (!$assertionsDisabled && findParentCluster == null) {
                throw new AssertionError("Unable to find parent cluster for ptl " + factor + "in jt " + junctionTree);
            }
            junctionTree.setCPF(findParentCluster, junctionTree.getCPF(findParentCluster).multiply(factor));
        }
    }

    private AbstractTableFactor createBlankFactor(VarSet varSet) {
        return this.inLogSpace ? new LogTableFactor(varSet) : new TableFactor(varSet);
    }

    @Override // cc.mallet.grmm.inference.AbstractInferencer, cc.mallet.grmm.inference.Inferencer
    public void computeMarginals(FactorGraph factorGraph) {
        this.inLogSpace = factorGraph.getFactor(0) instanceof LogTableFactor;
        buildJunctionTree(factorGraph);
        this.propagator.computeMarginals(this.jtCurrent);
        this.totalMessagesSent += this.propagator.getTotalMessagesSent();
    }

    public void computeMarginals(JunctionTree junctionTree) {
        this.inLogSpace = false;
        this.jtCurrent = junctionTree;
        this.propagator.computeMarginals(this.jtCurrent);
        this.totalMessagesSent += this.propagator.getTotalMessagesSent();
    }

    public JunctionTree buildJunctionTree(FactorGraph factorGraph) {
        this.jtCurrent = (JunctionTree) factorGraph.getInferenceCache(JunctionTreeInferencer.class);
        if (this.jtCurrent != null) {
            this.jtCurrent.clearCPFs();
        } else {
            triangulate(Graphs.mdlToGraph(factorGraph));
            this.jtCurrent = buildJtStructure();
            factorGraph.setInferenceCache(JunctionTreeInferencer.class, this.jtCurrent);
        }
        initJtCpts(factorGraph, this.jtCurrent);
        return this.jtCurrent;
    }

    private UndirectedGraph dupGraph(UndirectedGraph undirectedGraph) {
        SimpleGraph simpleGraph = new SimpleGraph();
        GraphHelper.addGraph(simpleGraph, undirectedGraph);
        return simpleGraph;
    }

    @Override // cc.mallet.grmm.inference.AbstractInferencer, cc.mallet.grmm.inference.Inferencer
    public Factor lookupMarginal(Variable variable) {
        return this.propagator.lookupMarginal(this.jtCurrent, variable);
    }

    @Override // cc.mallet.grmm.inference.AbstractInferencer, cc.mallet.grmm.inference.Inferencer
    public Factor lookupMarginal(VarSet varSet) {
        return this.propagator.lookupMarginal(this.jtCurrent, varSet);
    }

    @Override // cc.mallet.grmm.inference.AbstractInferencer, cc.mallet.grmm.inference.Inferencer
    public double lookupLogJoint(Assignment assignment) {
        return this.jtCurrent.lookupLogJoint(assignment);
    }

    public double dumpLogJoint(Assignment assignment) {
        return this.jtCurrent.dumpLogJoint(assignment);
    }

    public JunctionTree lookupJunctionTree() {
        return this.jtCurrent;
    }

    private Iterator neighborsIterator(UndirectedGraph undirectedGraph, Variable variable) {
        return GraphHelper.neighborListOf(undirectedGraph, variable).iterator();
    }

    @Override // cc.mallet.grmm.inference.AbstractInferencer, cc.mallet.grmm.inference.Inferencer
    public void dump() {
        if (this.jtCurrent == null) {
            System.out.println("NO current junction tree");
        } else {
            System.out.println("Current junction tree");
            this.jtCurrent.dump();
        }
    }

    public int getTotalMessagesSent() {
        return this.totalMessagesSent;
    }

    private void writeObject(ObjectOutputStream objectOutputStream) throws IOException {
        objectOutputStream.defaultWriteObject();
    }

    private void readObject(ObjectInputStream objectInputStream) throws IOException, ClassNotFoundException {
        objectInputStream.defaultReadObject();
    }

    static {
        $assertionsDisabled = !JunctionTreeInferencer.class.desiredAssertionStatus();
        logger = MalletLogger.getLogger(JunctionTreeInferencer.class.getName());
        sepsetChooser = new Comparator() { // from class: cc.mallet.grmm.inference.JunctionTreeInferencer.1
            @Override // java.util.Comparator
            public int compare(Object obj, Object obj2) {
                if (obj == obj2) {
                    return 0;
                }
                BitVarSet[] bitVarSetArr = (BitVarSet[]) obj;
                BitVarSet[] bitVarSetArr2 = (BitVarSet[]) obj2;
                int i = -JunctionTreeInferencer.cmp(JunctionTreeInferencer.sepsetSize(bitVarSetArr), JunctionTreeInferencer.sepsetSize(bitVarSetArr2));
                if (i == 0) {
                    i = JunctionTreeInferencer.cmp(JunctionTreeInferencer.sepsetCost(bitVarSetArr), JunctionTreeInferencer.sepsetCost(bitVarSetArr2));
                    if (i == 0) {
                        i = JunctionTreeInferencer.cmp(obj.hashCode(), obj2.hashCode());
                    }
                }
                return i;
            }
        };
    }
}
