package cc.mallet.grmm.util;

import cc.mallet.grmm.inference.Inferencer;
import cc.mallet.grmm.inference.JunctionTree;
import cc.mallet.grmm.inference.JunctionTreeInferencer;
import cc.mallet.grmm.types.Assignment;
import cc.mallet.grmm.types.AssignmentIterator;
import cc.mallet.grmm.types.ConstantFactor;
import cc.mallet.grmm.types.Factor;
import cc.mallet.grmm.types.FactorGraph;
import cc.mallet.grmm.types.UndirectedModel;
import gnu.trove.THashSet;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.Map;

/* loaded from: input_file:WEB-INF/lib/mallet-2.0.7.jar:cc/mallet/grmm/util/Models.class */
public class Models {
    public static FactorGraph addEvidence(FactorGraph factorGraph, Assignment assignment) {
        return addEvidence(factorGraph, assignment, null);
    }

    public static FactorGraph addEvidence(FactorGraph factorGraph, Assignment assignment, Map map) {
        FactorGraph factorGraph2 = new FactorGraph(factorGraph.numVariables());
        addSlicedPotentials(factorGraph, factorGraph2, assignment, map);
        return factorGraph2;
    }

    public static UndirectedModel addEvidence(UndirectedModel undirectedModel, Assignment assignment) {
        UndirectedModel undirectedModel2 = new UndirectedModel(undirectedModel.numVariables());
        addSlicedPotentials(undirectedModel, undirectedModel2, assignment, null);
        return undirectedModel2;
    }

    private static void addSlicedPotentials(FactorGraph factorGraph, FactorGraph factorGraph2, Assignment assignment, Map map) {
        THashSet tHashSet = new THashSet(Arrays.asList(assignment.getVars()));
        THashSet tHashSet2 = new THashSet(factorGraph.variablesSet());
        tHashSet2.removeAll(tHashSet);
        Iterator factorsIterator = factorGraph.factorsIterator();
        while (factorsIterator.hasNext()) {
            Factor factor = (Factor) factorsIterator.next();
            new THashSet(factor.varSet()).retainAll(tHashSet2);
            Factor slice = factor.slice(assignment);
            factorGraph2.addFactor(slice);
            if (map != null) {
                map.put(factor, slice);
            }
        }
    }

    public static Assignment bestAssignment(FactorGraph factorGraph, Inferencer inferencer) {
        inferencer.computeMarginals(factorGraph);
        int[] iArr = new int[factorGraph.numVariables()];
        for (int i = 0; i < iArr.length; i++) {
            iArr[i] = inferencer.lookupMarginal(factorGraph.get(i)).argmax();
        }
        return new Assignment(factorGraph, iArr);
    }

    public static double entropy(FactorGraph factorGraph) {
        JunctionTreeInferencer junctionTreeInferencer = new JunctionTreeInferencer();
        junctionTreeInferencer.computeMarginals(factorGraph);
        return junctionTreeInferencer.lookupJunctionTree().entropy();
    }

    public static double KL(FactorGraph factorGraph, FactorGraph factorGraph2) {
        JunctionTreeInferencer junctionTreeInferencer = new JunctionTreeInferencer();
        junctionTreeInferencer.computeMarginals(factorGraph);
        JunctionTree lookupJunctionTree = junctionTreeInferencer.lookupJunctionTree();
        JunctionTreeInferencer junctionTreeInferencer2 = new JunctionTreeInferencer();
        junctionTreeInferencer2.computeMarginals(factorGraph2);
        JunctionTree lookupJunctionTree2 = junctionTreeInferencer2.lookupJunctionTree();
        double entropy = lookupJunctionTree.entropy();
        double d = 0.0d;
        for (Factor factor : lookupJunctionTree2.clusterPotentials()) {
            Factor lookupMarginal = junctionTreeInferencer.lookupMarginal(factor.varSet());
            AssignmentIterator assignmentIterator = factor.assignmentIterator();
            while (assignmentIterator.hasNext()) {
                d += lookupMarginal.value(assignmentIterator) * factor.logValue(assignmentIterator);
                assignmentIterator.advance();
            }
        }
        for (Factor factor2 : lookupJunctionTree2.sepsetPotentials()) {
            Factor lookupMarginal2 = junctionTreeInferencer.lookupMarginal(factor2.varSet());
            AssignmentIterator assignmentIterator2 = factor2.assignmentIterator();
            while (assignmentIterator2.hasNext()) {
                d -= lookupMarginal2.value(assignmentIterator2) * factor2.logValue(assignmentIterator2);
                assignmentIterator2.advance();
            }
        }
        return (-entropy) - d;
    }

    public static void removeConstantFactors(FactorGraph factorGraph) {
        for (Factor factor : new ArrayList(factorGraph.factors())) {
            if (factor instanceof ConstantFactor) {
                factorGraph.divideBy(factor);
            }
        }
    }
}
