package cc.mallet.grmm.types;

import cc.mallet.grmm.util.Flops;
import cc.mallet.types.Matrix;
import cc.mallet.types.Matrixn;
import cc.mallet.types.SparseMatrixn;
import cc.mallet.util.Maths;
import java.util.Collection;
import java.util.Iterator;

/* loaded from: input_file:WEB-INF/lib/mallet-2.0.7.jar:cc/mallet/grmm/types/LogTableFactor.class */
public class LogTableFactor extends AbstractTableFactor {
    public LogTableFactor(AbstractTableFactor abstractTableFactor) {
        super(abstractTableFactor);
        this.probs = (Matrix) abstractTableFactor.getLogValueMatrix().cloneMatrix();
    }

    public LogTableFactor(Variable variable) {
        super(variable);
    }

    public LogTableFactor(Variable[] variableArr) {
        super(variableArr);
    }

    public LogTableFactor(Collection collection) {
        super(collection);
    }

    private LogTableFactor(Variable[] variableArr, double[] dArr) {
        super(variableArr, dArr);
    }

    private LogTableFactor(Variable[] variableArr, Matrix matrix) {
        super(variableArr, matrix);
    }

    public static LogTableFactor makeFromValues(Variable[] variableArr, double[] dArr) {
        double[] dArr2 = new double[dArr.length];
        for (int i = 0; i < dArr.length; i++) {
            dArr2[i] = Math.log(dArr[i]);
        }
        return makeFromLogValues(variableArr, dArr2);
    }

    public static LogTableFactor makeFromLogValues(Variable[] variableArr, double[] dArr) {
        return new LogTableFactor(variableArr, dArr);
    }

    @Override // cc.mallet.grmm.types.AbstractTableFactor
    void setAsIdentity() {
        setAll(0.0d);
    }

    @Override // cc.mallet.grmm.types.AbstractTableFactor, cc.mallet.grmm.types.Factor
    public Factor duplicate() {
        return new LogTableFactor(this);
    }

    @Override // cc.mallet.grmm.types.AbstractTableFactor
    protected AbstractTableFactor createBlankSubset(Variable[] variableArr) {
        return new LogTableFactor(variableArr);
    }

    @Override // cc.mallet.grmm.types.AbstractTableFactor, cc.mallet.grmm.types.Factor
    public Factor normalize() {
        double logspaceOneNorm = logspaceOneNorm();
        if (logspaceOneNorm < -500.0d) {
            System.err.println("Attempt to normalize all-0 factor " + dumpToString());
        }
        for (int i = 0; i < this.probs.numLocations(); i++) {
            this.probs.setValueAtLocation(i, this.probs.valueAtLocation(i) - logspaceOneNorm);
        }
        return this;
    }

    private double logspaceOneNorm() {
        double d = Double.NEGATIVE_INFINITY;
        for (int i = 0; i < this.probs.numLocations(); i++) {
            d = Maths.sumLogProb(d, this.probs.valueAtLocation(i));
        }
        Flops.sumLogProb(this.probs.numLocations());
        return d;
    }

    @Override // cc.mallet.grmm.types.AbstractTableFactor, cc.mallet.grmm.types.Factor
    public double sum() {
        Flops.exp();
        return Math.exp(logspaceOneNorm());
    }

    @Override // cc.mallet.grmm.types.AbstractTableFactor
    public double logsum() {
        return logspaceOneNorm();
    }

    @Override // cc.mallet.grmm.types.AbstractTableFactor
    protected void multiplyByInternal(DiscreteFactor discreteFactor) {
        int[] largeIdxToSmall = largeIdxToSmall(discreteFactor);
        int numLocations = this.probs.numLocations();
        for (int i = 0; i < numLocations; i++) {
            this.probs.setValueAtLocation(i, this.probs.valueAtLocation(i) + discreteFactor.logValue(largeIdxToSmall[i]));
        }
        Flops.increment(numLocations);
    }

    @Override // cc.mallet.grmm.types.AbstractTableFactor
    protected void divideByInternal(DiscreteFactor discreteFactor) {
        int[] largeIdxToSmall = largeIdxToSmall(discreteFactor);
        int numLocations = this.probs.numLocations();
        for (int i = 0; i < numLocations; i++) {
            int i2 = largeIdxToSmall[i];
            double valueAtLocation = this.probs.valueAtLocation(i);
            double logValue = discreteFactor.logValue(i2);
            double d = valueAtLocation - logValue;
            if (Double.isInfinite(logValue)) {
                d = Double.NEGATIVE_INFINITY;
            }
            this.probs.setValueAtLocation(i, d);
        }
        Flops.increment(numLocations);
    }

    @Override // cc.mallet.grmm.types.AbstractTableFactor
    protected void plusEqualsInternal(DiscreteFactor discreteFactor) {
        int[] largeIdxToSmall = largeIdxToSmall(discreteFactor);
        int numLocations = this.probs.numLocations();
        for (int i = 0; i < numLocations; i++) {
            this.probs.setValueAtLocation(i, Maths.sumLogProb(this.probs.valueAtLocation(i), discreteFactor.logValue(largeIdxToSmall[i])));
        }
        Flops.sumLogProb(numLocations);
    }

    @Override // cc.mallet.grmm.types.AbstractTableFactor, cc.mallet.grmm.types.Factor
    public double value(Assignment assignment) {
        Flops.exp();
        if (getNumVars() == 0) {
            return 1.0d;
        }
        return Math.exp(rawValue(assignment));
    }

    @Override // cc.mallet.grmm.types.Factor
    public double value(AssignmentIterator assignmentIterator) {
        Flops.exp();
        return Math.exp(rawValue(assignmentIterator.indexOfCurrentAssn()));
    }

    @Override // cc.mallet.grmm.types.DiscreteFactor
    public double value(int i) {
        Flops.exp();
        return Math.exp(rawValue(i));
    }

    @Override // cc.mallet.grmm.types.Factor
    public double logValue(AssignmentIterator assignmentIterator) {
        return rawValue(assignmentIterator.indexOfCurrentAssn());
    }

    @Override // cc.mallet.grmm.types.Factor
    public double logValue(int i) {
        return rawValue(i);
    }

    @Override // cc.mallet.grmm.types.Factor
    public double logValue(Assignment assignment) {
        return rawValue(assignment);
    }

    @Override // cc.mallet.grmm.types.AbstractTableFactor
    protected Factor marginalizeInternal(AbstractTableFactor abstractTableFactor) {
        abstractTableFactor.setAll(Double.NEGATIVE_INFINITY);
        int[] largeIdxToSmall = largeIdxToSmall(abstractTableFactor);
        int numLocations = this.probs.numLocations();
        for (int i = 0; i < numLocations; i++) {
            int i2 = largeIdxToSmall[i];
            abstractTableFactor.probs.setValueAtLocation(i2, Maths.sumLogProb(this.probs.valueAtLocation(i), abstractTableFactor.probs.singleValue(i2)));
        }
        Flops.sumLogProb(numLocations);
        return abstractTableFactor;
    }

    protected double rawValue(Assignment assignment) {
        int numVars = getNumVars();
        int[] iArr = new int[numVars];
        for (int i = 0; i < numVars; i++) {
            iArr[i] = assignment.get(getVariable(i));
        }
        return rawValue(iArr);
    }

    private double rawValue(int[] iArr) {
        return rawValue(this.probs.singleIndex(iArr));
    }

    @Override // cc.mallet.grmm.types.AbstractTableFactor
    protected double rawValue(int i) {
        int location = this.probs.location(i);
        if (location < 0) {
            return Double.NEGATIVE_INFINITY;
        }
        return this.probs.valueAtLocation(location);
    }

    @Override // cc.mallet.grmm.types.Factor
    public void exponentiate(double d) {
        Flops.increment(this.probs.numLocations());
        this.probs.timesEquals(d);
    }

    @Override // cc.mallet.grmm.types.AbstractTableFactor
    public void setLogValue(Assignment assignment, double d) {
        setRawValue(assignment, d);
    }

    @Override // cc.mallet.grmm.types.AbstractTableFactor
    public void setLogValue(AssignmentIterator assignmentIterator, double d) {
        setRawValue(assignmentIterator, d);
    }

    @Override // cc.mallet.grmm.types.AbstractTableFactor
    public void setValue(AssignmentIterator assignmentIterator, double d) {
        Flops.log();
        setRawValue(assignmentIterator, Math.log(d));
    }

    @Override // cc.mallet.grmm.types.AbstractTableFactor
    public void setLogValues(double[] dArr) {
        for (int i = 0; i < dArr.length; i++) {
            setRawValue(i, dArr[i]);
        }
    }

    @Override // cc.mallet.grmm.types.AbstractTableFactor
    public void setValues(double[] dArr) {
        Flops.log(dArr.length);
        for (int i = 0; i < dArr.length; i++) {
            setRawValue(i, Math.log(dArr[i]));
        }
    }

    @Override // cc.mallet.grmm.types.AbstractTableFactor
    public void timesEquals(double d) {
        timesEqualsLog(Math.log(d));
    }

    private void timesEqualsLog(double d) {
        Flops.increment(this.probs.numLocations());
        Matrix matrix = (Matrix) this.probs.cloneMatrix();
        matrix.setAll(d);
        this.probs.plusEquals(matrix);
    }

    @Override // cc.mallet.grmm.types.AbstractTableFactor
    protected void plusEqualsAtLocation(int i, double d) {
        Flops.log();
        Flops.sumLogProb(1);
        setRawValue(i, Maths.sumLogProb(logValue(i), Math.log(d)));
    }

    public static LogTableFactor makeFromValues(Variable variable, double[] dArr) {
        return makeFromValues(new Variable[]{variable}, dArr);
    }

    public static LogTableFactor makeFromMatrix(Variable[] variableArr, SparseMatrixn sparseMatrixn) {
        SparseMatrixn sparseMatrixn2 = (SparseMatrixn) sparseMatrixn.cloneMatrix();
        for (int i = 0; i < sparseMatrixn2.numLocations(); i++) {
            sparseMatrixn2.setValueAtLocation(i, Math.log(sparseMatrixn2.valueAtLocation(i)));
        }
        Flops.log(sparseMatrixn2.numLocations());
        return new LogTableFactor(variableArr, sparseMatrixn2);
    }

    public static LogTableFactor makeFromLogMatrix(Variable[] variableArr, Matrix matrix) {
        return new LogTableFactor(variableArr, (Matrix) matrix.cloneMatrix());
    }

    public static LogTableFactor makeFromLogValues(Variable variable, double[] dArr) {
        return makeFromLogValues(new Variable[]{variable}, dArr);
    }

    @Override // cc.mallet.grmm.types.AbstractTableFactor
    public Matrix getValueMatrix() {
        Matrix matrix = (Matrix) this.probs.cloneMatrix();
        for (int i = 0; i < this.probs.numLocations(); i++) {
            matrix.setValueAtLocation(i, Math.exp(matrix.valueAtLocation(i)));
        }
        Flops.exp(this.probs.numLocations());
        return matrix;
    }

    @Override // cc.mallet.grmm.types.AbstractTableFactor
    public Matrix getLogValueMatrix() {
        return this.probs;
    }

    @Override // cc.mallet.grmm.types.DiscreteFactor
    public double valueAtLocation(int i) {
        Flops.exp();
        return Math.exp(this.probs.valueAtLocation(i));
    }

    @Override // cc.mallet.grmm.types.AbstractTableFactor
    protected Factor slice_onevar(Variable variable, Assignment assignment) {
        Assignment assignment2 = (Assignment) assignment.duplicate();
        double[] dArr = new double[variable.getNumOutcomes()];
        for (int i = 0; i < variable.getNumOutcomes(); i++) {
            assignment2.setValue(variable, i);
            dArr[i] = logValue(assignment2);
        }
        return makeFromLogValues(variable, dArr);
    }

    @Override // cc.mallet.grmm.types.AbstractTableFactor
    protected Factor slice_twovar(Variable variable, Variable variable2, Assignment assignment) {
        Assignment assignment2 = (Assignment) assignment.duplicate();
        int numOutcomes = variable.getNumOutcomes();
        int numOutcomes2 = variable2.getNumOutcomes();
        int[] iArr = {numOutcomes, numOutcomes2};
        double[] dArr = new double[numOutcomes * numOutcomes2];
        for (int i = 0; i < numOutcomes; i++) {
            assignment2.setValue(variable, i);
            for (int i2 = 0; i2 < numOutcomes2; i2++) {
                assignment2.setValue(variable2, i2);
                dArr[Matrixn.singleIndex(iArr, new int[]{i, i2})] = logValue(assignment2);
            }
        }
        return makeFromLogValues(new Variable[]{variable, variable2}, dArr);
    }

    @Override // cc.mallet.grmm.types.AbstractTableFactor
    protected Factor slice_general(Variable[] variableArr, Assignment assignment) {
        HashVarSet hashVarSet = new HashVarSet(variableArr);
        hashVarSet.removeAll(assignment.varSet());
        double[] dArr = new double[hashVarSet.weight()];
        AssignmentIterator assignmentIterator = hashVarSet.assignmentIterator();
        while (assignmentIterator.hasNext()) {
            dArr[assignmentIterator.indexOfCurrentAssn()] = logValue(Assignment.union(assignment, assignmentIterator.assignment()));
            assignmentIterator.advance();
        }
        return makeFromLogValues(hashVarSet.toVariableArray(), dArr);
    }

    public static LogTableFactor multiplyAll(Collection collection) {
        HashVarSet hashVarSet = new HashVarSet();
        Iterator it = collection.iterator();
        while (it.hasNext()) {
            hashVarSet.addAll(((Factor) it.next()).varSet());
        }
        LogTableFactor logTableFactor = new LogTableFactor(hashVarSet);
        Iterator it2 = collection.iterator();
        while (it2.hasNext()) {
            logTableFactor.multiplyBy((Factor) it2.next());
        }
        return logTableFactor;
    }

    @Override // cc.mallet.grmm.types.AbstractTableFactor
    public AbstractTableFactor recenter() {
        timesEqualsLog(-this.probs.valueAtLocation(argmax()));
        return this;
    }
}
