package cc.mallet.grmm.inference;

import cc.mallet.grmm.types.Assignment;
import cc.mallet.grmm.types.AssignmentIterator;
import cc.mallet.grmm.types.DiscreteFactor;
import cc.mallet.grmm.types.Factor;
import cc.mallet.grmm.types.FactorGraph;
import cc.mallet.grmm.types.LogTableFactor;
import cc.mallet.grmm.types.Variable;
import cc.mallet.util.Randoms;
import cc.mallet.util.Timing;
import java.util.Iterator;
import java.util.List;

/* loaded from: input_file:WEB-INF/lib/mallet-2.0.7.jar:cc/mallet/grmm/inference/GibbsSampler.class */
public class GibbsSampler implements Sampler {
    private int burnin;
    private Factor[] allCpts;
    private Randoms r;

    public GibbsSampler() {
        this.r = new Randoms(324231);
    }

    public GibbsSampler(int i) {
        this.r = new Randoms(324231);
        this.burnin = i;
    }

    public GibbsSampler(Randoms randoms, int i) {
        this.r = new Randoms(324231);
        this.burnin = i;
        this.r = randoms;
    }

    public void setBurnin(int i) {
        this.burnin = i;
    }

    @Override // cc.mallet.grmm.inference.Sampler
    public void setRandom(Randoms randoms) {
        this.r = randoms;
    }

    @Override // cc.mallet.grmm.inference.Sampler
    public Assignment sample(FactorGraph factorGraph, int i) {
        Assignment initialAssignment = initialAssignment(factorGraph);
        if (initialAssignment == null) {
            throw new IllegalArgumentException("GibbsSampler: Could not find feasible assignment for model " + factorGraph);
        }
        Timing timing = new Timing();
        for (int i2 = 0; i2 < this.burnin; i2++) {
            initialAssignment = doOnePass(factorGraph, initialAssignment);
        }
        timing.tick("Burnin");
        Assignment assignment = new Assignment();
        for (int i3 = 0; i3 < i; i3++) {
            initialAssignment = doOnePass(factorGraph, initialAssignment);
            assignment.addRow(initialAssignment);
        }
        timing.tick("Sampling");
        return assignment;
    }

    private Assignment initialAssignment(FactorGraph factorGraph) {
        Assignment assignment = new Assignment(factorGraph, new int[factorGraph.numVariables()]);
        return factorGraph.logValue(assignment) > Double.NEGATIVE_INFINITY ? assignment : initialAssignmentRec(factorGraph, new Assignment(), 0);
    }

    private Assignment initialAssignmentRec(FactorGraph factorGraph, Assignment assignment, int i) {
        Assignment initialAssignmentRec;
        if (i >= factorGraph.factors().size()) {
            return assignment;
        }
        Factor factor = factorGraph.getFactor(i);
        Factor slice = factor.slice(assignment);
        if (slice.varSet().isEmpty()) {
            if (factor.value(assignment) > 1.0E-50d) {
                return initialAssignmentRec(factorGraph, assignment, i + 1);
            }
            return null;
        }
        AssignmentIterator assignmentIterator = slice.assignmentIterator();
        while (assignmentIterator.hasNext()) {
            if (slice.value(assignmentIterator) > 1.0E-50d && (initialAssignmentRec = initialAssignmentRec(factorGraph, Assignment.union(assignment, assignmentIterator.assignment()), i + 1)) != null) {
                return initialAssignmentRec;
            }
            assignmentIterator.advance();
        }
        return null;
    }

    private Assignment doOnePass(FactorGraph factorGraph, Assignment assignment) {
        Assignment assignment2 = (Assignment) assignment.duplicate();
        for (int i = 0; i < assignment2.size(); i++) {
            Variable variable = factorGraph.get(i);
            assignment2.setValue(variable, constructConditionalCpt(factorGraph, variable, assignment2).sampleLocation(this.r));
        }
        return assignment2;
    }

    private DiscreteFactor constructConditionalCpt(FactorGraph factorGraph, Variable variable, Assignment assignment) {
        List allFactorsContaining = factorGraph.allFactorsContaining(variable);
        LogTableFactor logTableFactor = new LogTableFactor(variable);
        AssignmentIterator assignmentIterator = logTableFactor.assignmentIterator();
        while (assignmentIterator.hasNext()) {
            Assignment assignment2 = assignmentIterator.assignment();
            assignment.setValue(variable, assignment2.get(variable));
            logTableFactor.setRawValue(assignment2, sumValues(allFactorsContaining, assignment));
            assignmentIterator.advance();
        }
        logTableFactor.normalize();
        return logTableFactor;
    }

    private double sumValues(List list, Assignment assignment) {
        double d = 0.0d;
        Iterator it = list.iterator();
        while (it.hasNext()) {
            d += ((Factor) it.next()).logValue(assignment);
        }
        return d;
    }
}
