package cc.mallet.grmm.types;

import cc.mallet.grmm.util.Matrices;
import cc.mallet.types.Matrix;
import cc.mallet.types.SparseMatrixn;
import cc.mallet.util.Randoms;
import com.hp.hpl.jena.sparql.sse.Tags;

/* loaded from: input_file:WEB-INF/lib/mallet-2.0.7.jar:cc/mallet/grmm/types/PottsTableFactor.class */
public class PottsTableFactor extends AbstractFactor implements ParameterizedFactor {
    private Variable alpha;
    private VarSet xs;

    public PottsTableFactor(VarSet varSet, Variable variable) {
        super(combineVariables(variable, varSet));
        this.alpha = variable;
        this.xs = varSet;
        if (!variable.isContinuous()) {
            throw new IllegalArgumentException("alpha must be continuous");
        }
    }

    public PottsTableFactor(Variable variable, Variable variable2, Variable variable3) {
        super(new HashVarSet(new Variable[]{variable, variable2, variable3}));
        this.alpha = variable3;
        this.xs = new HashVarSet(new Variable[]{variable, variable2});
        if (!variable3.isContinuous()) {
            throw new IllegalArgumentException("alpha must be continuous");
        }
    }

    private static VarSet combineVariables(Variable variable, VarSet varSet) {
        HashVarSet hashVarSet = new HashVarSet(varSet);
        hashVarSet.add(variable);
        return hashVarSet;
    }

    @Override // cc.mallet.grmm.types.AbstractFactor
    protected Factor extractMaxInternal(VarSet varSet) {
        throw new UnsupportedOperationException();
    }

    @Override // cc.mallet.grmm.types.AbstractFactor
    protected double lookupValueInternal(int i) {
        throw new UnsupportedOperationException();
    }

    @Override // cc.mallet.grmm.types.AbstractFactor
    protected Factor marginalizeInternal(VarSet varSet) {
        throw new UnsupportedOperationException();
    }

    @Override // cc.mallet.grmm.types.AbstractFactor, cc.mallet.grmm.types.Factor
    public double value(AssignmentIterator assignmentIterator) {
        Assignment assignment = assignmentIterator.assignment();
        return sliceForAlpha(assignment).value(assignment);
    }

    private Factor sliceForAlpha(Assignment assignment) {
        double d = assignment.getDouble(this.alpha);
        int[] sizesFromVarSet = sizesFromVarSet(this.xs);
        Matrix diag = Matrices.diag(sizesFromVarSet, d);
        Matrix constant = Matrices.constant(sizesFromVarSet, -d);
        constant.plusEquals(diag);
        return LogTableFactor.makeFromLogMatrix(this.xs.toVariableArray(), (SparseMatrixn) constant);
    }

    private int[] sizesFromVarSet(VarSet varSet) {
        int[] iArr = new int[varSet.size()];
        for (int i = 0; i < varSet.size(); i++) {
            iArr[i] = varSet.get(i).getNumOutcomes();
        }
        return iArr;
    }

    @Override // cc.mallet.grmm.types.Factor
    public Factor normalize() {
        throw new UnsupportedOperationException();
    }

    @Override // cc.mallet.grmm.types.AbstractFactor, cc.mallet.grmm.types.Factor
    public Assignment sample(Randoms randoms) {
        throw new UnsupportedOperationException();
    }

    @Override // cc.mallet.grmm.types.AbstractFactor, cc.mallet.grmm.types.Factor
    public double logValue(AssignmentIterator assignmentIterator) {
        return Math.log(value(assignmentIterator));
    }

    @Override // cc.mallet.grmm.types.Factor
    public Factor slice(Assignment assignment) {
        return sliceForAlpha(assignment).slice(assignment);
    }

    @Override // cc.mallet.grmm.types.Factor
    public String dumpToString() {
        StringBuffer stringBuffer = new StringBuffer();
        stringBuffer.append("[Potts: alpha:");
        stringBuffer.append(this.alpha);
        stringBuffer.append(" xs:");
        stringBuffer.append(this.xs);
        stringBuffer.append(Tags.RBRACKET);
        return stringBuffer.toString();
    }

    @Override // cc.mallet.grmm.types.ParameterizedFactor
    public double sumGradLog(Factor factor, Variable variable, Assignment assignment) {
        if (variable != this.alpha) {
            throw new IllegalArgumentException();
        }
        Factor marginalize = factor.marginalize(this.xs);
        double d = 0.0d;
        AssignmentIterator assignmentIterator = this.xs.assignmentIterator();
        while (assignmentIterator.hasNext()) {
            if (!isAllEqual(assignmentIterator.assignment())) {
                d += -marginalize.value(assignmentIterator);
            }
            assignmentIterator.advance();
        }
        return d;
    }

    public double secondDerivative(Factor factor, Variable variable, Assignment assignment) {
        double sumGradLog = sumGradLog(factor, variable, assignment);
        Factor marginalize = factor.marginalize(this.xs);
        double d = 0.0d;
        AssignmentIterator assignmentIterator = this.xs.assignmentIterator();
        while (assignmentIterator.hasNext()) {
            if (!isAllEqual(assignmentIterator.assignment())) {
                d += marginalize.value(assignmentIterator);
            }
            assignmentIterator.advance();
        }
        return d - (sumGradLog * sumGradLog);
    }

    private boolean isAllEqual(Assignment assignment) {
        Object object = assignment.getObject(this.xs.get(0));
        for (int i = 1; i < this.xs.size(); i++) {
            if (!object.equals(assignment.getObject(this.xs.get(i)))) {
                return false;
            }
        }
        return true;
    }

    @Override // cc.mallet.grmm.types.Factor
    public Factor duplicate() {
        return new PottsTableFactor(this.xs, this.alpha);
    }

    @Override // cc.mallet.grmm.types.Factor
    public boolean isNaN() {
        return false;
    }

    @Override // cc.mallet.grmm.types.Factor
    public boolean almostEquals(Factor factor, double d) {
        return equals(factor);
    }

    public boolean equals(Object obj) {
        if (this == obj) {
            return true;
        }
        if (obj == null || getClass() != obj.getClass()) {
            return false;
        }
        PottsTableFactor pottsTableFactor = (PottsTableFactor) obj;
        if (this.alpha != null) {
            if (!this.alpha.equals(pottsTableFactor.alpha)) {
                return false;
            }
        } else if (pottsTableFactor.alpha != null) {
            return false;
        }
        return this.xs != null ? this.xs.equals(pottsTableFactor.xs) : pottsTableFactor.xs == null;
    }

    public int hashCode() {
        return (29 * (this.alpha != null ? this.alpha.hashCode() : 0)) + (this.xs != null ? this.xs.hashCode() : 0);
    }
}
