package cc.mallet.grmm.learning;

import cc.mallet.grmm.inference.AbstractBeliefPropagation;
import cc.mallet.grmm.inference.Inferencer;
import cc.mallet.grmm.inference.JunctionTreeInferencer;
import cc.mallet.grmm.inference.TRP;
import cc.mallet.grmm.types.AbstractTableFactor;
import cc.mallet.grmm.types.Assignment;
import cc.mallet.grmm.types.AssignmentIterator;
import cc.mallet.grmm.types.Factor;
import cc.mallet.grmm.types.FactorGraph;
import cc.mallet.grmm.types.Factors;
import cc.mallet.grmm.types.HashVarSet;
import cc.mallet.grmm.types.LogTableFactor;
import cc.mallet.grmm.types.UndirectedModel;
import cc.mallet.grmm.types.VarSet;
import cc.mallet.grmm.types.Variable;
import cc.mallet.grmm.util.LabelsAssignment;
import cc.mallet.grmm.util.Models;
import cc.mallet.optimize.Optimizable;
import cc.mallet.pipe.Pipe;
import cc.mallet.types.Alphabet;
import cc.mallet.types.FeatureVector;
import cc.mallet.types.FeatureVectorSequence;
import cc.mallet.types.HashedSparseVector;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.LabelAlphabet;
import cc.mallet.types.LabelsSequence;
import cc.mallet.types.Matrix;
import cc.mallet.types.Matrixn;
import cc.mallet.types.SparseVector;
import cc.mallet.util.ArrayUtils;
import cc.mallet.util.MalletLogger;
import gnu.trove.TDoubleArrayList;
import gnu.trove.THashMap;
import gnu.trove.TIntArrayList;
import gnu.trove.TObjectIntHashMap;
import java.io.File;
import java.io.FileOutputStream;
import java.io.FileWriter;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.OutputStream;
import java.io.PrintStream;
import java.io.PrintWriter;
import java.io.Reader;
import java.io.Serializable;
import java.io.Writer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.BitSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.logging.Logger;
import java.util.regex.Pattern;
import org.apache.jena.atlas.json.io.JSWriter;
import org.apache.xml.serializer.SerializerConstants;
import org.jdom.Element;
import org.jdom.JDOMException;
import org.jdom.input.SAXBuilder;

/* loaded from: input_file:WEB-INF/lib/mallet-2.0.7.jar:cc/mallet/grmm/learning/ACRF.class */
public class ACRF implements Serializable {
    private static transient Logger logger = MalletLogger.getLogger(ACRF.class.getName());
    Template[] templates;
    private GraphPostProcessor graphProcessor;
    Alphabet inputAlphabet;
    int defaultFeatureIndex;
    private Pipe inputPipe;
    private static final double DEFAULT_GAUSSIAN_PRIOR_VARIANCE = 10.0d;
    private static final long serialVersionUID = 2865175696692468236L;
    List fixedPtls = new ArrayList(0);
    private Inferencer globalInferencer = new TRP();
    private Inferencer viterbi = TRP.createForMaxProduct();
    private boolean cacheUnrolledGraphs = false;
    private transient Map graphCache = new THashMap();
    private double gaussianPriorVariance = 10.0d;
    private boolean doSizeScale = false;
    private transient File verboseOutputDirectory = null;

    /* loaded from: input_file:WEB-INF/lib/mallet-2.0.7.jar:cc/mallet/grmm/learning/ACRF$BigramTemplate.class */
    public static class BigramTemplate extends SequenceTemplate {
        int factor;
        private static final long serialVersionUID = 8944142287103225874L;
        static final /* synthetic */ boolean $assertionsDisabled;

        public BigramTemplate(int i) {
            this.factor = i;
        }

        @Override // cc.mallet.grmm.learning.ACRF.SequenceTemplate
        public void addInstantiatedCliques(UnrolledGraph unrolledGraph, FeatureVectorSequence featureVectorSequence, LabelsAssignment labelsAssignment) {
            for (int i = 0; i < labelsAssignment.maxTime() - 1; i++) {
                Variable varOfIndex = labelsAssignment.varOfIndex(i, this.factor);
                Variable varOfIndex2 = labelsAssignment.varOfIndex(i + 1, this.factor);
                FeatureVector featureVector = featureVectorSequence.getFeatureVector(i);
                Variable[] variableArr = {varOfIndex, varOfIndex2};
                if (!$assertionsDisabled && varOfIndex == null) {
                    throw new AssertionError("Couldn't get label factor " + this.factor + " time " + i);
                }
                if (!$assertionsDisabled && varOfIndex2 == null) {
                    throw new AssertionError("Couldn't get label factor " + this.factor + " time " + (i + 1));
                }
                unrolledGraph.addClique(new UnrolledVarSet(unrolledGraph, this, variableArr, featureVector));
            }
        }

        public String toString() {
            return "[BigramTemplate (" + this.factor + ")]";
        }

        public int getFactor() {
            return this.factor;
        }

        static {
            $assertionsDisabled = !ACRF.class.desiredAssertionStatus();
        }
    }

    /* loaded from: input_file:WEB-INF/lib/mallet-2.0.7.jar:cc/mallet/grmm/learning/ACRF$FixedFactorTemplate.class */
    public static abstract class FixedFactorTemplate extends Template {
        @Override // cc.mallet.grmm.learning.ACRF.Template
        public int initWeights(InstanceList instanceList) {
            return 0;
        }

        @Override // cc.mallet.grmm.learning.ACRF.Template
        public SparseVector[] getWeights() {
            return new SparseVector[0];
        }

        @Override // cc.mallet.grmm.learning.ACRF.Template
        public SparseVector getDefaultWeights() {
            return new SparseVector();
        }

        @Override // cc.mallet.grmm.learning.ACRF.Template
        public boolean isTrainable() {
            return false;
        }

        @Override // cc.mallet.grmm.learning.ACRF.Template
        public void setTrainable(boolean z) {
            if (z) {
                throw new IllegalArgumentException("This template is never trainable.");
            }
        }

        @Override // cc.mallet.grmm.learning.ACRF.Template
        public abstract AbstractTableFactor computeFactor(UnrolledVarSet unrolledVarSet);
    }

    /* loaded from: input_file:WEB-INF/lib/mallet-2.0.7.jar:cc/mallet/grmm/learning/ACRF$GraphPostProcessor.class */
    public interface GraphPostProcessor extends Serializable {
        void process(UnrolledGraph unrolledGraph, Instance instance);
    }

    /* loaded from: input_file:WEB-INF/lib/mallet-2.0.7.jar:cc/mallet/grmm/learning/ACRF$MaximizableACRF.class */
    public class MaximizableACRF implements Optimizable.ByGradientValue, Serializable {
        InstanceList trainData;
        double[] cachedGradient;
        boolean cachedValueStale;
        boolean cachedGradientStale;
        private int numParameters;
        private static final boolean printGradient = false;
        private transient UnrolledGraph graph;
        protected Inferencer inferencer;
        SparseVector[][] constraints;
        SparseVector[][] expectations;
        SparseVector[] defaultConstraints;
        SparseVector[] defaultExpectations;
        double cachedValue = -1.23456789E8d;
        protected BitSet infiniteValues = null;
        private int totalNodes = 0;
        private int gradCallNo = 0;

        private void initWeights(InstanceList instanceList) {
            for (int i = 0; i < ACRF.this.templates.length; i++) {
                this.numParameters += ACRF.this.templates[i].initWeights(instanceList);
            }
        }

        /* JADX WARN: Type inference failed for: r1v18, types: [cc.mallet.types.SparseVector[], cc.mallet.types.SparseVector[][]] */
        /* JADX WARN: Type inference failed for: r1v23, types: [cc.mallet.types.SparseVector[], cc.mallet.types.SparseVector[][]] */
        private void initConstraintsExpectations() {
            this.defaultConstraints = new SparseVector[ACRF.this.templates.length];
            this.defaultExpectations = new SparseVector[ACRF.this.templates.length];
            for (int i = 0; i < ACRF.this.templates.length; i++) {
                SparseVector defaultWeights = ACRF.this.templates[i].getDefaultWeights();
                this.defaultConstraints[i] = (SparseVector) defaultWeights.cloneMatrixZeroed();
                this.defaultExpectations[i] = (SparseVector) defaultWeights.cloneMatrixZeroed();
            }
            this.constraints = new SparseVector[ACRF.this.templates.length];
            this.expectations = new SparseVector[ACRF.this.templates.length];
            for (int i2 = 0; i2 < ACRF.this.templates.length; i2++) {
                SparseVector[] weights = ACRF.this.templates[i2].getWeights();
                this.constraints[i2] = new SparseVector[weights.length];
                this.expectations[i2] = new SparseVector[weights.length];
                for (int i3 = 0; i3 < weights.length; i3++) {
                    this.constraints[i2][i3] = (SparseVector) weights[i3].cloneMatrixZeroed();
                    this.expectations[i2][i3] = (SparseVector) weights[i3].cloneMatrixZeroed();
                }
            }
        }

        void resetExpectations() {
            for (int i = 0; i < this.expectations.length; i++) {
                this.defaultExpectations[i].setAll(0.0d);
                for (int i2 = 0; i2 < this.expectations[i].length; i2++) {
                    this.expectations[i][i2].setAll(0.0d);
                }
            }
        }

        protected MaximizableACRF(InstanceList instanceList) {
            this.inferencer = ACRF.this.globalInferencer.duplicate();
            ACRF.logger.finest("Initializing MaximizableACRF.");
            this.trainData = instanceList;
            initWeights(this.trainData);
            initConstraintsExpectations();
            int size = this.trainData.size();
            this.cachedGradient = new double[this.numParameters];
            this.cachedGradientStale = true;
            this.cachedValueStale = true;
            ACRF.logger.info("Number of training instances = " + size);
            ACRF.logger.info("Number of parameters = " + this.numParameters);
            ACRF.logger.info("Default feature index = " + ACRF.this.defaultFeatureIndex);
            describePrior();
            ACRF.logger.fine("Computing constraints");
            collectConstraints(this.trainData);
        }

        private void describePrior() {
            ACRF.logger.info("Using gaussian prior with variance " + ACRF.this.gaussianPriorVariance);
        }

        @Override // cc.mallet.optimize.Optimizable
        public int getNumParameters() {
            return this.numParameters;
        }

        @Override // cc.mallet.optimize.Optimizable
        public void getParameters(double[] dArr) {
            if (dArr.length != this.numParameters) {
                throw new IllegalArgumentException("Argument is not of the  correct dimensions");
            }
            int i = 0;
            for (int i2 = 0; i2 < ACRF.this.templates.length; i2++) {
                double[] values = ACRF.this.templates[i2].getDefaultWeights().getValues();
                System.arraycopy(values, 0, dArr, i, values.length);
                i += values.length;
            }
            for (int i3 = 0; i3 < ACRF.this.templates.length; i3++) {
                for (SparseVector sparseVector : ACRF.this.templates[i3].getWeights()) {
                    double[] values2 = sparseVector.getValues();
                    System.arraycopy(values2, 0, dArr, i, values2.length);
                    i += values2.length;
                }
            }
        }

        @Override // cc.mallet.optimize.Optimizable
        public void setParameters(double[] dArr) {
            if (dArr.length != this.numParameters) {
                throw new IllegalArgumentException("Argument is not of the  correct dimensions");
            }
            this.cachedGradientStale = true;
            this.cachedValueStale = true;
            int i = 0;
            for (int i2 = 0; i2 < ACRF.this.templates.length; i2++) {
                double[] values = ACRF.this.templates[i2].getDefaultWeights().getValues();
                System.arraycopy(dArr, i, values, 0, values.length);
                i += values.length;
            }
            for (int i3 = 0; i3 < ACRF.this.templates.length; i3++) {
                for (SparseVector sparseVector : ACRF.this.templates[i3].getWeights()) {
                    double[] values2 = sparseVector.getValues();
                    System.arraycopy(dArr, i, values2, 0, values2.length);
                    i += values2.length;
                }
            }
        }

        public SparseVector[] getExpectations(int i) {
            return this.expectations[i];
        }

        public SparseVector[] getConstraints(int i) {
            return this.constraints[i];
        }

        private void printParameters() {
            double[] dArr = new double[this.numParameters];
            getParameters(dArr);
            for (double d : dArr) {
                System.out.print(d + "\t");
            }
            System.out.println();
        }

        @Override // cc.mallet.optimize.Optimizable
        public double getParameter(int i) {
            return 0.0d;
        }

        @Override // cc.mallet.optimize.Optimizable
        public void setParameter(int i, double d) {
        }

        @Override // cc.mallet.optimize.Optimizable.ByGradientValue
        public double getValue() {
            if (this.cachedValueStale) {
                this.cachedValue = computeLogLikelihood();
                this.cachedValueStale = false;
                this.cachedGradientStale = true;
                ACRF.logger.info("getValue() (loglikelihood) = " + this.cachedValue);
            }
            if (Double.isNaN(this.cachedValue)) {
                ACRF.logger.warning("value is NaN");
                this.cachedValue = 0.0d;
            }
            return this.cachedValue;
        }

        protected double computeLogLikelihood() {
            double d = 0.0d;
            int size = this.trainData.size();
            long currentTimeMillis = System.currentTimeMillis();
            long j = 0;
            long j2 = 0;
            boolean z = false;
            if (this.infiniteValues == null) {
                this.infiniteValues = new BitSet();
                z = true;
            }
            resetExpectations();
            for (int i = 0; i < size; i++) {
                Instance instance = this.trainData.get(i);
                long currentTimeMillis2 = System.currentTimeMillis();
                UnrolledGraph unroll = ACRF.this.unroll(instance);
                long currentTimeMillis3 = System.currentTimeMillis();
                j += currentTimeMillis3 - currentTimeMillis2;
                if (unroll.numVariables() != 0) {
                    this.inferencer.computeMarginals(unroll);
                    j2 += System.currentTimeMillis() - currentTimeMillis3;
                    collectExpectations(unroll, this.inferencer);
                    double lookupLogJoint = this.inferencer.lookupLogJoint(unroll.getAssignment());
                    if (!Double.isInfinite(lookupLogJoint)) {
                        if (Double.isNaN(lookupLogJoint)) {
                            System.out.println("NaN on instance " + i + JSWriter.ObjectPairSep + instance.getName());
                            printDebugInfo(unroll);
                            ACRF.logger.warning("Value is NaN in ACRF.getValue() Instance " + i + JSWriter.ObjectPairSep + "returning -infinity... ");
                            return Double.NEGATIVE_INFINITY;
                        }
                        d += lookupLogJoint;
                    } else if (z) {
                        ACRF.logger.warning("Instance " + instance.getName() + " has infinite value; skipping.");
                        this.infiniteValues.set(i);
                    } else if (!this.infiniteValues.get(i)) {
                        ACRF.logger.warning("Infinite value on instance " + instance.getName() + "returning -infinity");
                        return Double.NEGATIVE_INFINITY;
                    }
                }
            }
            if (ACRF.this.doSizeScale) {
                d /= this.trainData.size();
            }
            double d2 = 2.0d * ACRF.this.gaussianPriorVariance;
            for (int i2 = 0; i2 < ACRF.this.templates.length; i2++) {
                SparseVector[] weights = ACRF.this.templates[i2].getWeights();
                for (int i3 = 0; i3 < weights.length; i3++) {
                    for (int i4 = 0; i4 < weights[i3].numLocations(); i4++) {
                        double valueAtLocation = weights[i3].valueAtLocation(i4);
                        if (weightValid(valueAtLocation, i2, i3)) {
                            d += ((-valueAtLocation) * valueAtLocation) / d2;
                        }
                    }
                }
            }
            if (ACRF.this.cacheUnrolledGraphs) {
                ACRF.this.reportOnGraphCache();
            }
            ACRF.logger.info("ACRF Inference time (ms) = " + (System.currentTimeMillis() - currentTimeMillis));
            ACRF.logger.info("ACRF marginals time (ms) = " + j2);
            ACRF.logger.info("ACRF unroll time (ms) = " + j);
            ACRF.logger.info("getValue (loglikelihood) = " + d);
            return d;
        }

        @Override // cc.mallet.optimize.Optimizable.ByGradientValue
        public void getValueGradient(double[] dArr) {
            if (this.cachedGradientStale) {
                if (this.cachedValueStale) {
                    getValue();
                }
                computeGradient();
                this.cachedGradientStale = false;
            }
            if (dArr.length != this.numParameters) {
                throw new IllegalArgumentException("Incorrect length buffer to getValueGradient(). Expected " + this.numParameters + ", received " + dArr.length);
            }
            System.arraycopy(this.cachedGradient, 0, dArr, 0, this.cachedGradient.length);
        }

        private void computeGradient() {
            double size;
            int i = 0;
            for (int i2 = 0; i2 < ACRF.this.templates.length; i2++) {
                SparseVector defaultWeights = ACRF.this.templates[i2].getDefaultWeights();
                SparseVector sparseVector = this.defaultConstraints[i2];
                SparseVector sparseVector2 = this.defaultExpectations[i2];
                for (int i3 = 0; i3 < defaultWeights.numLocations(); i3++) {
                    int i4 = i;
                    i++;
                    this.cachedGradient[i4] = ((ACRF.this.doSizeScale ? 1.0d / this.trainData.size() : 1.0d) * (sparseVector.valueAtLocation(i3) - sparseVector2.valueAtLocation(i3))) - (defaultWeights.valueAtLocation(i3) / ACRF.this.gaussianPriorVariance);
                }
            }
            for (int i5 = 0; i5 < ACRF.this.templates.length; i5++) {
                SparseVector[] weights = ACRF.this.templates[i5].getWeights();
                for (int i6 = 0; i6 < weights.length; i6++) {
                    SparseVector sparseVector3 = weights[i6];
                    SparseVector sparseVector4 = this.constraints[i5][i6];
                    SparseVector sparseVector5 = this.expectations[i5][i6];
                    for (int i7 = 0; i7 < sparseVector3.numLocations(); i7++) {
                        double valueAtLocation = sparseVector3.valueAtLocation(i7);
                        if (Double.isInfinite(valueAtLocation)) {
                            ACRF.logger.warning("Infinite weight for node index " + i6 + " feature " + ACRF.this.inputAlphabet.lookupObject(i7));
                            size = 0.0d;
                        } else {
                            size = ((ACRF.this.doSizeScale ? 1.0d / this.trainData.size() : 1.0d) * (sparseVector4.valueAtLocation(i7) - sparseVector5.valueAtLocation(i7))) - (valueAtLocation / ACRF.this.gaussianPriorVariance);
                        }
                        int i8 = i;
                        i++;
                        this.cachedGradient[i8] = size;
                    }
                }
            }
        }

        private void reportGradient() {
            if (ACRF.this.verboseOutputDirectory != null) {
                this.gradCallNo++;
                try {
                    PrintWriter printWriter = new PrintWriter(new FileWriter(new File(ACRF.this.verboseOutputDirectory, "acrf-grad-" + this.gradCallNo + ".txt")));
                    printWriter.println(ArrayUtils.toString(this.cachedGradient));
                    printWriter.close();
                    PrintWriter printWriter2 = new PrintWriter(new FileWriter(new File(ACRF.this.verboseOutputDirectory, "acrf-value-" + this.gradCallNo + ".txt")));
                    printWriter2.println(this.cachedValue);
                    printWriter2.close();
                    double[] dArr = new double[getNumParameters()];
                    getParameters(dArr);
                    PrintWriter printWriter3 = new PrintWriter(new FileWriter(new File(ACRF.this.verboseOutputDirectory, "acrf-weight-" + this.gradCallNo + ".txt")));
                    printWriter3.println(ArrayUtils.toString(dArr));
                    printWriter3.close();
                    printVecs(new File(ACRF.this.verboseOutputDirectory, "acrf-constraint-" + this.gradCallNo + ".txt"), this.defaultConstraints, this.constraints);
                    printVecs(new File(ACRF.this.verboseOutputDirectory, "acrf-exp-" + this.gradCallNo + ".txt"), this.defaultExpectations, this.expectations);
                    PrintWriter printWriter4 = new PrintWriter(new FileWriter(new File(ACRF.this.verboseOutputDirectory, "acrf-dumps-" + this.gradCallNo + ".txt")));
                    for (int i = 0; i < this.trainData.size(); i++) {
                        printWriter4.println(ACRF.this.unroll(this.trainData.get(i)));
                    }
                    printWriter4.close();
                } catch (IOException e) {
                    throw new RuntimeException(e);
                }
            }
        }

        private void printVecs(File file, SparseVector[] sparseVectorArr, SparseVector[][] sparseVectorArr2) throws IOException {
            PrintWriter printWriter = new PrintWriter(new FileWriter(file));
            for (SparseVector sparseVector : sparseVectorArr) {
                printWriter.println(sparseVector);
            }
            for (int i = 0; i < sparseVectorArr2.length; i++) {
                for (int i2 = 0; i2 < sparseVectorArr2[i].length; i2++) {
                    printWriter.println(sparseVectorArr2[i][i2]);
                }
            }
            printWriter.close();
        }

        private void collectExpectations(UnrolledGraph unrolledGraph, Inferencer inferencer) {
            Iterator unrolledVarSetIterator = unrolledGraph.unrolledVarSetIterator();
            while (unrolledVarSetIterator.hasNext()) {
                UnrolledVarSet unrolledVarSet = (UnrolledVarSet) unrolledVarSetIterator.next();
                int i = unrolledVarSet.tmpl.index;
                if (i != -1) {
                    Factor lookupMarginal = inferencer.lookupMarginal(unrolledVarSet);
                    AssignmentIterator assignmentIterator = lookupMarginal.assignmentIterator();
                    while (assignmentIterator.hasNext()) {
                        double value = lookupMarginal.value(assignmentIterator);
                        int indexOfCurrentAssn = assignmentIterator.indexOfCurrentAssn();
                        this.expectations[i][indexOfCurrentAssn].plusEqualsSparse(unrolledVarSet.fv, value);
                        if (this.defaultExpectations[i].location(indexOfCurrentAssn) != -1) {
                            this.defaultExpectations[i].incrementValue(indexOfCurrentAssn, value);
                        }
                        assignmentIterator.advance();
                    }
                }
            }
        }

        public void collectConstraints(InstanceList instanceList) {
            for (int i = 0; i < instanceList.size(); i++) {
                ACRF.logger.finest("*** Collecting constraints for instance " + i);
                UnrolledGraph unrolledGraph = new UnrolledGraph(instanceList.get(i), ACRF.this.templates, null, false);
                this.totalNodes = unrolledGraph.numVariables();
                Iterator unrolledVarSetIterator = unrolledGraph.unrolledVarSetIterator();
                while (unrolledVarSetIterator.hasNext()) {
                    UnrolledVarSet unrolledVarSet = (UnrolledVarSet) unrolledVarSetIterator.next();
                    int i2 = unrolledVarSet.tmpl.index;
                    if (i2 != -1) {
                        int lookupAssignmentNumber = unrolledVarSet.lookupAssignmentNumber();
                        this.constraints[i2][lookupAssignmentNumber].plusEqualsSparse(unrolledVarSet.fv);
                        if (this.defaultConstraints[i2].location(lookupAssignmentNumber) != -1) {
                            this.defaultConstraints[i2].incrementValue(lookupAssignmentNumber, 1.0d);
                        }
                    }
                }
            }
        }

        void dumpGradientToFile(String str) {
            try {
                PrintStream printStream = new PrintStream(new FileOutputStream(str));
                for (int i = 0; i < this.numParameters; i++) {
                    printStream.println(this.cachedGradient[i]);
                }
                printStream.close();
            } catch (IOException e) {
                System.err.println("Could not open output file.");
                e.printStackTrace();
            }
        }

        void dumpDefaults() {
            System.out.println("Default constraints");
            for (int i = 0; i < this.defaultConstraints.length; i++) {
                System.out.println("Template " + i);
                this.defaultConstraints[i].print();
            }
            System.out.println("Default expectations");
            for (int i2 = 0; i2 < this.defaultExpectations.length; i2++) {
                System.out.println("Template " + i2);
                this.defaultExpectations[i2].print();
            }
        }

        void printDebugInfo(UnrolledGraph unrolledGraph) {
            ACRF.this.print(System.err);
            Assignment assignment = unrolledGraph.getAssignment();
            Iterator unrolledVarSetIterator = unrolledGraph.unrolledVarSetIterator();
            while (unrolledVarSetIterator.hasNext()) {
                UnrolledVarSet unrolledVarSet = (UnrolledVarSet) unrolledVarSetIterator.next();
                System.out.println("Clique " + unrolledVarSet);
                dumpAssnForClique(assignment, unrolledVarSet);
                Factor factorOf = unrolledGraph.factorOf((VarSet) unrolledVarSet);
                System.out.println("Value = " + factorOf.value(assignment));
                System.out.println(factorOf);
            }
        }

        void dumpAssnForClique(Assignment assignment, UnrolledVarSet unrolledVarSet) {
            Iterator it = unrolledVarSet.iterator();
            while (it.hasNext()) {
                Variable variable = (Variable) it.next();
                System.out.println(variable + " ==> " + assignment.getObject(variable) + "  (" + assignment.get(variable) + ")");
            }
        }

        private boolean weightValid(double d, int i, int i2) {
            if (Double.isInfinite(d)) {
                ACRF.logger.warning("Weight is infinite for clique " + i + "assignment " + i2);
                return false;
            }
            if (!Double.isNaN(d)) {
                return true;
            }
            ACRF.logger.warning("Weight is Nan for clique " + i + "assignment " + i2);
            return false;
        }

        public void report() {
            int i = -1;
            if (this.inferencer instanceof AbstractBeliefPropagation) {
                i = AbstractBeliefPropagation.getTotalMessagesSent();
            } else if (this.inferencer instanceof JunctionTreeInferencer) {
                i = ((JunctionTreeInferencer) this.inferencer).getTotalMessagesSent();
            }
            if (i != -1) {
                ACRF.logger.info("Total messages sent = " + i);
            }
        }

        public void forceStale() {
            this.cachedGradientStale = true;
            this.cachedValueStale = true;
        }

        public int getTotalNodes() {
            return this.totalNodes;
        }
    }

    /* loaded from: input_file:WEB-INF/lib/mallet-2.0.7.jar:cc/mallet/grmm/learning/ACRF$PairwiseFactorTemplate.class */
    public static class PairwiseFactorTemplate extends SequenceTemplate {
        int factor0;
        int factor1;
        private static final long serialVersionUID = 1;
        static final /* synthetic */ boolean $assertionsDisabled;

        public PairwiseFactorTemplate(int i, int i2) {
            this.factor0 = i;
            this.factor1 = i2;
        }

        @Override // cc.mallet.grmm.learning.ACRF.SequenceTemplate
        public void addInstantiatedCliques(UnrolledGraph unrolledGraph, FeatureVectorSequence featureVectorSequence, LabelsAssignment labelsAssignment) {
            for (int i = 0; i < labelsAssignment.maxTime(); i++) {
                Variable varOfIndex = labelsAssignment.varOfIndex(i, this.factor0);
                Variable varOfIndex2 = labelsAssignment.varOfIndex(i, this.factor1);
                FeatureVector featureVector = featureVectorSequence.getFeatureVector(i);
                Variable[] variableArr = {varOfIndex, varOfIndex2};
                if (!$assertionsDisabled && varOfIndex == null) {
                    throw new AssertionError("Couldn't get label factor " + this.factor0 + " time " + i);
                }
                if (!$assertionsDisabled && varOfIndex2 == null) {
                    throw new AssertionError("Couldn't get label factor " + this.factor1 + " time " + i);
                }
                unrolledGraph.addClique(new UnrolledVarSet(unrolledGraph, this, variableArr, featureVector));
            }
        }

        public String toString() {
            return "[PairwiseFactorTemplate (" + this.factor0 + JSWriter.ArraySep + this.factor1 + ")]";
        }

        static {
            $assertionsDisabled = !ACRF.class.desiredAssertionStatus();
        }
    }

    /* loaded from: input_file:WEB-INF/lib/mallet-2.0.7.jar:cc/mallet/grmm/learning/ACRF$SequenceTemplate.class */
    public static abstract class SequenceTemplate extends Template {
        protected abstract void addInstantiatedCliques(UnrolledGraph unrolledGraph, FeatureVectorSequence featureVectorSequence, LabelsAssignment labelsAssignment);

        @Override // cc.mallet.grmm.learning.ACRF.Template
        public void addInstantiatedCliques(UnrolledGraph unrolledGraph, Instance instance) {
            addInstantiatedCliques(unrolledGraph, (FeatureVectorSequence) instance.getData(), (LabelsAssignment) instance.getTarget());
        }
    }

    /* loaded from: input_file:WEB-INF/lib/mallet-2.0.7.jar:cc/mallet/grmm/learning/ACRF$Template.class */
    public static abstract class Template implements Serializable {
        private static final double SOME_UNSUPPORTED_THRESHOLD = 0.1d;
        protected SparseVector[] weights;
        private BitSet assignmentsPresent;
        public int index;
        private SparseVector defaultWeights;
        private static final long serialVersionUID = -727618747254644076L;
        static final /* synthetic */ boolean $assertionsDisabled;
        private boolean unsupportedWeightsAdded = false;
        private boolean supportedOnly = true;
        private boolean trainable = true;

        public abstract void addInstantiatedCliques(UnrolledGraph unrolledGraph, Instance instance);

        protected void modifyPotential(UnrolledGraph unrolledGraph, UnrolledVarSet unrolledVarSet, AbstractTableFactor abstractTableFactor) {
        }

        protected boolean isSupportedOnly() {
            return this.supportedOnly;
        }

        void setSupportedOnly(boolean z) {
            this.supportedOnly = z;
        }

        public boolean isUnsupportedWeightsAdded() {
            return this.unsupportedWeightsAdded;
        }

        protected BitSet getAssignmentsPresent() {
            return this.assignmentsPresent;
        }

        public SparseVector[] getWeights() {
            return this.weights;
        }

        public void setWeights(SparseVector[] sparseVectorArr) {
            if (this.weights != null && sparseVectorArr.length != this.weights.length) {
                throw new IllegalArgumentException("Weights length changed; was " + this.weights.length + " now is " + sparseVectorArr.length);
            }
            this.weights = sparseVectorArr;
        }

        public int initWeights(InstanceList instanceList) {
            ACRF.logger.info("Template " + this + " : weights " + (this.supportedOnly ? "with NO" : "with ALL") + " unsupported features...");
            return this.supportedOnly ? initSparseWeights(instanceList) : initDenseWeights(instanceList);
        }

        private int initDenseWeights(InstanceList instanceList) {
            int size = instanceList.getDataAlphabet().size();
            int cliqueSizeFromInstance = cliqueSizeFromInstance(instanceList);
            int allocateDefaultWeights = 0 + allocateDefaultWeights(cliqueSizeFromInstance);
            SparseVector[] sparseVectorArr = new SparseVector[cliqueSizeFromInstance];
            for (int i = 0; i < cliqueSizeFromInstance; i++) {
                sparseVectorArr[i] = new SparseVector(new double[size], false);
                if (this.weights != null) {
                    sparseVectorArr[i].plusEqualsSparse(this.weights[i]);
                }
                allocateDefaultWeights += size;
                ACRF.logger.info("ACRF template " + this + " weights [" + i + "] num features " + size);
            }
            ACRF.logger.info("ACRF template " + this + " total num weights = " + allocateDefaultWeights);
            this.weights = sparseVectorArr;
            return allocateDefaultWeights;
        }

        private int initSparseWeights(InstanceList instanceList) {
            int cliqueSizeFromInstance = cliqueSizeFromInstance(instanceList);
            BitSet[] bitSetArr = new BitSet[cliqueSizeFromInstance];
            for (int i = 0; i < cliqueSizeFromInstance; i++) {
                bitSetArr[i] = new BitSet();
            }
            this.assignmentsPresent = new BitSet(cliqueSizeFromInstance);
            collectWeightsPresent(instanceList, bitSetArr);
            if (this.weights != null) {
                addInCurrentWeights(bitSetArr);
            }
            int allocateDefaultWeights = 0 + allocateDefaultWeights(cliqueSizeFromInstance);
            SparseVector[] sparseVectorArr = new SparseVector[cliqueSizeFromInstance];
            int allocateNewWeights = allocateDefaultWeights + allocateNewWeights(bitSetArr, sparseVectorArr);
            ACRF.logger.info("ACRF template " + this + " total num weights = " + allocateNewWeights);
            this.weights = sparseVectorArr;
            return allocateNewWeights;
        }

        private int allocateNewWeights(BitSet[] bitSetArr, SparseVector[] sparseVectorArr) {
            int i = 0;
            for (int i2 = 0; i2 < bitSetArr.length; i2++) {
                int cardinality = bitSetArr[i2].cardinality();
                int[] iArr = new int[cardinality];
                int i3 = 0;
                while (i3 < cardinality) {
                    iArr[i3] = bitSetArr[i2].nextSetBit(i3 == 0 ? 0 : iArr[i3 - 1] + 1);
                    i3++;
                }
                sparseVectorArr[i2] = new HashedSparseVector(iArr, new double[cardinality], cardinality, cardinality, false, false, false);
                if (this.weights != null) {
                    sparseVectorArr[i2].plusEqualsSparse(this.weights[i2]);
                }
                i += cardinality;
                if (cardinality != 0) {
                    ACRF.logger.info("ACRF template " + this + " weights [" + i2 + "] num features " + cardinality);
                }
            }
            return i;
        }

        public int addSomeUnsupportedWeights(InstanceList instanceList) {
            this.unsupportedWeightsAdded = true;
            int length = this.weights.length;
            BitSet[] bitSetArr = new BitSet[length];
            for (int i = 0; i < length; i++) {
                bitSetArr[i] = new BitSet();
            }
            collectSomeUnsupportedWeights(instanceList, bitSetArr);
            addInCurrentWeights(bitSetArr);
            SparseVector[] sparseVectorArr = new SparseVector[length];
            int allocateNewWeights = allocateNewWeights(bitSetArr, sparseVectorArr);
            ACRF.logger.info(this + " some supported weights added = " + allocateNewWeights);
            this.weights = sparseVectorArr;
            return allocateNewWeights;
        }

        private void collectSomeUnsupportedWeights(InstanceList instanceList, BitSet[] bitSetArr) {
            for (int i = 0; i < instanceList.size(); i++) {
                Iterator unrolledVarSetIterator = new UnrolledGraph(instanceList.get(i), new Template[]{this}, new ArrayList(), true).unrolledVarSetIterator();
                while (unrolledVarSetIterator.hasNext()) {
                    UnrolledVarSet unrolledVarSet = (UnrolledVarSet) unrolledVarSetIterator.next();
                    Factor normalize = unrolledVarSet.getFactor().normalize();
                    AssignmentIterator assignmentIterator = normalize.assignmentIterator();
                    while (assignmentIterator.hasNext()) {
                        if (normalize.value(assignmentIterator) > 0.1d) {
                            addPresentFeatures(bitSetArr[assignmentIterator.indexOfCurrentAssn()], unrolledVarSet.fv);
                        }
                        assignmentIterator.advance();
                    }
                }
            }
        }

        private int allocateDefaultWeights(int i) {
            SparseVector sparseVector = new SparseVector(new double[i], false);
            if (this.defaultWeights != null) {
                sparseVector.plusEqualsSparse(this.defaultWeights);
            }
            this.defaultWeights = sparseVector;
            return i;
        }

        private int cliqueSizeFromInstance(InstanceList instanceList) {
            int weight;
            int i = 0;
            for (int i2 = 0; i2 < instanceList.size(); i2++) {
                Iterator unrolledVarSetIterator = new UnrolledGraph(instanceList.get(i2), new Template[]{this}, null, false).unrolledVarSetIterator();
                while (unrolledVarSetIterator.hasNext()) {
                    UnrolledVarSet unrolledVarSet = (UnrolledVarSet) unrolledVarSetIterator.next();
                    if (unrolledVarSet.tmpl == this && (weight = unrolledVarSet.weight()) > i) {
                        i = weight;
                    }
                }
            }
            if (i == 0) {
                ACRF.logger.warning("***ACRF: Don't know size of " + this + ". Never needed in training data.");
            }
            return i;
        }

        private void checkCliqueSizeConsistent(InstanceList instanceList) {
            int i = -1;
            for (int i2 = 0; i2 < instanceList.size(); i2++) {
                Instance instance = instanceList.get(i2);
                Iterator unrolledVarSetIterator = new UnrolledGraph(instance, new Template[]{this}, null, false).unrolledVarSetIterator();
                while (unrolledVarSetIterator.hasNext()) {
                    UnrolledVarSet unrolledVarSet = (UnrolledVarSet) unrolledVarSetIterator.next();
                    if (unrolledVarSet.tmpl == this && i != unrolledVarSet.weight()) {
                        System.err.println("Weight change for clique " + unrolledVarSet + " template " + this + " old = " + i + " new " + unrolledVarSet.weight());
                        for (int i3 = 0; i3 < unrolledVarSet.size(); i3++) {
                            Variable variable = unrolledVarSet.get(i3);
                            System.err.println(variable + "\t" + variable.getNumOutcomes());
                        }
                        if (i != -1) {
                            throw new IllegalStateException("Error on instance " + instance + ": Template " + this + " clique " + unrolledVarSet + " error.  Strange weight: was " + i + " now is " + unrolledVarSet.weight());
                        }
                        i = unrolledVarSet.weight();
                    }
                }
            }
        }

        private void addInCurrentWeights(BitSet[] bitSetArr) {
            for (int i = 0; i < this.weights.length; i++) {
                for (int i2 = 0; i2 < this.weights[i].numLocations(); i2++) {
                    bitSetArr[i].set(this.weights[i].indexAtLocation(i2));
                }
            }
        }

        private void collectWeightsPresent(InstanceList instanceList, BitSet[] bitSetArr) {
            for (int i = 0; i < instanceList.size(); i++) {
                UnrolledGraph unrolledGraph = new UnrolledGraph(instanceList.get(i), new Template[]{this}, null, false);
                collectTransitionsPresentForGraph(unrolledGraph);
                collectWeightsPresentForGraph(unrolledGraph, bitSetArr);
            }
        }

        private void collectTransitionsPresentForGraph(UnrolledGraph unrolledGraph) {
            Iterator unrolledVarSetIterator = unrolledGraph.unrolledVarSetIterator();
            while (unrolledVarSetIterator.hasNext()) {
                UnrolledVarSet unrolledVarSet = (UnrolledVarSet) unrolledVarSetIterator.next();
                if (unrolledVarSet.tmpl == this) {
                    this.assignmentsPresent.set(unrolledVarSet.lookupAssignmentNumber());
                }
            }
        }

        private void collectWeightsPresentForGraph(UnrolledGraph unrolledGraph, BitSet[] bitSetArr) {
            Iterator unrolledVarSetIterator = unrolledGraph.unrolledVarSetIterator();
            while (unrolledVarSetIterator.hasNext()) {
                UnrolledVarSet unrolledVarSet = (UnrolledVarSet) unrolledVarSetIterator.next();
                if (unrolledVarSet.tmpl == this) {
                    addPresentFeatures(bitSetArr[unrolledVarSet.lookupAssignmentNumber()], unrolledVarSet.fv);
                }
            }
        }

        private void addPresentFeatures(BitSet bitSet, FeatureVector featureVector) {
            for (int i = 0; i < featureVector.numLocations(); i++) {
                bitSet.set(featureVector.indexAtLocation(i));
            }
        }

        public AbstractTableFactor computeFactor(UnrolledVarSet unrolledVarSet) {
            Matrix createFactorMatrix = createFactorMatrix(unrolledVarSet);
            SparseVector[] weights = getWeights();
            for (int i = 0; i < createFactorMatrix.numLocations(); i++) {
                int indexAtLocation = createFactorMatrix.indexAtLocation(i);
                if (!$assertionsDisabled && indexAtLocation >= weights.length) {
                    throw new AssertionError("Error: Instantiating " + this + " on " + unrolledVarSet + " : Clique has too many assignments.\n  # of weights = " + weights.length + " clique weight = " + unrolledVarSet.weight());
                }
                createFactorMatrix.setValueAtLocation(i, weights[indexAtLocation].dotProduct((SparseVector) unrolledVarSet.fv) + getDefaultWeight(indexAtLocation));
            }
            LogTableFactor logTableFactor = new LogTableFactor(unrolledVarSet);
            logTableFactor.setValues(createFactorMatrix);
            return logTableFactor;
        }

        protected Matrix createFactorMatrix(UnrolledVarSet unrolledVarSet) {
            return new Matrixn(unrolledVarSet.varDimensions());
        }

        public double getDefaultWeight(int i) {
            return this.defaultWeights.value(i);
        }

        public SparseVector getDefaultWeights() {
            return this.defaultWeights;
        }

        public void setDefaultWeights(SparseVector sparseVector) {
            this.defaultWeights = sparseVector;
        }

        public void setDefaultWeight(int i, double d) {
            this.defaultWeights.setValue(i, d);
        }

        public boolean isTrainable() {
            return this.trainable;
        }

        public void setTrainable(boolean z) {
            this.trainable = z;
        }

        private void readObject(ObjectInputStream objectInputStream) throws IOException, ClassNotFoundException {
            objectInputStream.defaultReadObject();
            if (this.assignmentsPresent == null) {
                this.assignmentsPresent = new BitSet(this.weights.length);
                this.assignmentsPresent.flip(0, this.assignmentsPresent.size());
            }
        }

        protected Assignment computeAssignment(Assignment assignment, VarSet varSet) {
            return (Assignment) assignment.marginalize(varSet);
        }

        static {
            $assertionsDisabled = !ACRF.class.desiredAssertionStatus();
        }
    }

    /* loaded from: input_file:WEB-INF/lib/mallet-2.0.7.jar:cc/mallet/grmm/learning/ACRF$UnigramTemplate.class */
    public static class UnigramTemplate extends SequenceTemplate {
        int factor;
        private static final long serialVersionUID = 1;
        static final /* synthetic */ boolean $assertionsDisabled;

        public UnigramTemplate(int i) {
            this.factor = i;
        }

        @Override // cc.mallet.grmm.learning.ACRF.SequenceTemplate
        public void addInstantiatedCliques(UnrolledGraph unrolledGraph, FeatureVectorSequence featureVectorSequence, LabelsAssignment labelsAssignment) {
            for (int i = 0; i < labelsAssignment.maxTime(); i++) {
                Variable varOfIndex = labelsAssignment.varOfIndex(i, this.factor);
                FeatureVector featureVector = featureVectorSequence.getFeatureVector(i);
                Variable[] variableArr = {varOfIndex};
                if (!$assertionsDisabled && varOfIndex == null) {
                    throw new AssertionError("Couldn't get label factor " + this.factor + " time " + i);
                }
                unrolledGraph.addClique(new UnrolledVarSet(unrolledGraph, this, variableArr, featureVector));
            }
        }

        public String toString() {
            return "[UnigramTemplate (" + this.factor + ")]";
        }

        static {
            $assertionsDisabled = !ACRF.class.desiredAssertionStatus();
        }
    }

    /* loaded from: input_file:WEB-INF/lib/mallet-2.0.7.jar:cc/mallet/grmm/learning/ACRF$UnrolledGraph.class */
    public static class UnrolledGraph extends UndirectedModel {
        List allVars;
        List cliques;
        int numSlices;
        boolean isCached;
        Instance instance;
        FeatureVectorSequence fvs;
        private Assignment assignment;
        LabelAlphabet[] outputAlphabets;
        ACRF acrf;
        List allTemplates;
        private boolean isFactorsAdded;
        private THashMap uvsMap;
        private double[] lastResids;
        TObjectIntHashMap observedVars;

        public UnrolledGraph(Instance instance, Template[] templateArr, Template[] templateArr2) {
            this(instance, templateArr, Arrays.asList(templateArr2));
        }

        UnrolledGraph(Instance instance, Template[] templateArr, List list) {
            this(instance, templateArr, list, true);
        }

        public UnrolledGraph(Instance instance, Template[] templateArr, List list, boolean z) {
            super(initialCapacity(instance));
            this.allVars = new ArrayList();
            this.cliques = new ArrayList();
            this.isCached = false;
            this.isFactorsAdded = false;
            this.uvsMap = new THashMap();
            this.observedVars = new TObjectIntHashMap();
            this.instance = instance;
            this.fvs = (FeatureVectorSequence) instance.getData();
            this.assignment = (Assignment) instance.getTarget();
            this.allTemplates = new ArrayList();
            if (list != null) {
                this.allTemplates.addAll(list);
            }
            this.allTemplates.addAll(Arrays.asList(templateArr));
            setupGraph();
            if (z) {
                computeCPFs();
            }
        }

        private static int initialCapacity(Instance instance) {
            if (instance.getData() == null) {
                return 8;
            }
            return 3 * ((FeatureVectorSequence) instance.getData()).size();
        }

        private void setupGraph() {
            Iterator it = this.allTemplates.iterator();
            while (it.hasNext()) {
                ((Template) it.next()).addInstantiatedCliques(this, this.instance);
            }
        }

        public void addClique(UnrolledVarSet unrolledVarSet) {
            this.cliques.add(unrolledVarSet);
        }

        private void computeCPFs() {
            this.isFactorsAdded = true;
            TDoubleArrayList tDoubleArrayList = new TDoubleArrayList();
            for (UnrolledVarSet unrolledVarSet : this.cliques) {
                AbstractTableFactor computeFactor = unrolledVarSet.tmpl.computeFactor(unrolledVarSet);
                addFactorInternal(unrolledVarSet, computeFactor);
                unrolledVarSet.tmpl.modifyPotential(this, unrolledVarSet, computeFactor);
                this.uvsMap.put(computeFactor, unrolledVarSet);
                tDoubleArrayList.add(Factors.distLinf(new LogTableFactor(unrolledVarSet), computeFactor));
            }
            this.lastResids = tDoubleArrayList.toNativeArray();
        }

        private void addFactorInternal(UnrolledVarSet unrolledVarSet, Factor factor) {
            unrolledVarSet.setFactor(factor);
            Factor factorOf = factorOf(factor.varSet());
            if (factorOf == null) {
                addFactor(factor);
            } else if (factorOf instanceof FactorGraph) {
                factorOf.multiplyBy(factor);
            } else {
                divideBy(factorOf);
                addFactor(new FactorGraph(new Factor[]{factor, factorOf}));
            }
        }

        /* JADX INFO: Access modifiers changed from: private */
        public void recomputeFactors() {
            this.lastResids = new double[factors().size()];
            for (UnrolledVarSet unrolledVarSet : this.cliques) {
                AbstractTableFactor abstractTableFactor = (AbstractTableFactor) unrolledVarSet.getFactor();
                AbstractTableFactor computeFactor = unrolledVarSet.tmpl.computeFactor(unrolledVarSet);
                this.lastResids[getIndex(abstractTableFactor)] = Factors.distLinf((AbstractTableFactor) abstractTableFactor.duplicate().normalize(), (AbstractTableFactor) computeFactor.duplicate().normalize());
                abstractTableFactor.setValues(computeFactor.getLogValueMatrix());
                unrolledVarSet.tmpl.modifyPotential(this, unrolledVarSet, abstractTableFactor);
            }
        }

        public double[] getLastResids() {
            return this.lastResids;
        }

        int getMaxTime() {
            return this.fvs.size();
        }

        int getNumFactors() {
            return this.outputAlphabets.length;
        }

        public Assignment getAssignment() {
            return this.assignment;
        }

        private boolean isObserved(Variable variable) {
            return this.observedVars.contains(variable);
        }

        public void setObserved(Variable variable, int i) {
            this.observedVars.put(variable, i);
        }

        public int observedValue(Variable variable) {
            return this.observedVars.get(variable);
        }

        public Iterator unrolledVarSetIterator() {
            return this.cliques.iterator();
        }

        public UnrolledVarSet getUnrolledVarSet(int i) {
            return (UnrolledVarSet) this.cliques.get(i);
        }

        public int getIndex(VarSet varSet) {
            return this.cliques.indexOf(varSet);
        }

        @Override // cc.mallet.grmm.types.FactorGraph
        public Variable get(int i) {
            return this.isFactorsAdded ? super.get(i) : (Variable) this.allVars.get(i);
        }

        @Override // cc.mallet.grmm.types.FactorGraph
        public int getIndex(Variable variable) {
            return this.isFactorsAdded ? super.getIndex(variable) : this.allVars.indexOf(variable);
        }

        public double getLogNumAssignments() {
            double d = 0.0d;
            for (int i = 0; i < numVariables(); i++) {
                d += Math.log(get(i).getNumOutcomes());
            }
            return d;
        }

        public Variable varOfIndex(int i, int i2) {
            return ((LabelsAssignment) this.instance.getTarget()).varOfIndex(i, i2);
        }

        public int numSlices() {
            return ((LabelsAssignment) this.instance.getTarget()).numSlices();
        }

        public double[] computeCurrentResids() {
            this.lastResids = new double[factors().size()];
            for (UnrolledVarSet unrolledVarSet : this.cliques) {
                AbstractTableFactor abstractTableFactor = (AbstractTableFactor) unrolledVarSet.getFactor();
                this.lastResids[getIndex(abstractTableFactor)] = Factors.distLinf(abstractTableFactor, unrolledVarSet.tmpl.computeFactor(unrolledVarSet));
            }
            return this.lastResids;
        }

        public UnrolledVarSet getUnrolledVarSet(Factor factor) {
            return (UnrolledVarSet) this.uvsMap.get(factor);
        }
    }

    /* loaded from: input_file:WEB-INF/lib/mallet-2.0.7.jar:cc/mallet/grmm/learning/ACRF$UnrolledVarSet.class */
    public static class UnrolledVarSet extends HashVarSet {
        Template tmpl;
        FeatureVector fv;
        Variable[] vars;
        Factor factor;
        UnrolledGraph graph;
        double lastChange;

        public int[] varDimensions() {
            int[] iArr = new int[size()];
            for (int i = 0; i < size(); i++) {
                iArr[i] = get(i).getNumOutcomes();
            }
            return iArr;
        }

        public UnrolledVarSet(UnrolledGraph unrolledGraph, Template template, Variable[] variableArr, FeatureVector featureVector) {
            super(variableArr);
            this.graph = unrolledGraph;
            this.vars = variableArr;
            this.tmpl = template;
            this.fv = featureVector;
        }

        Assignment getAssignmentByNumber(int i) {
            int[] varDimensions = varDimensions();
            int[] iArr = new int[varDimensions.length];
            Matrixn.singleToIndices(i, iArr, varDimensions);
            return new Assignment(this.vars, iArr);
        }

        public final int lookupAssignmentNumber() {
            return lookupAssignment().singleIndex();
        }

        public final Assignment lookupAssignment() {
            return this.tmpl.computeAssignment(this.graph.getAssignment(), this);
        }

        public int lookupNumberOfAssignment(Assignment assignment) {
            int[] varDimensions = varDimensions();
            int[] iArr = new int[varDimensions.length];
            for (int i = 0; i < iArr.length; i++) {
                iArr[i] = assignment.get(this.vars[i]);
            }
            return Matrixn.singleIndex(varDimensions, iArr);
        }

        public Template getTemplate() {
            return this.tmpl;
        }

        public FeatureVector getFv() {
            return this.fv;
        }

        public Factor getFactor() {
            return this.factor;
        }

        /* JADX INFO: Access modifiers changed from: private */
        public void setFactor(Factor factor) {
            if (this.factor != null) {
                this.lastChange = Factors.distLinf((AbstractTableFactor) factor, (AbstractTableFactor) this.factor);
            }
            this.factor = factor;
        }

        public double getLastChange() {
            return this.lastChange;
        }
    }

    public ACRF(Pipe pipe, Template[] templateArr) throws IllegalArgumentException {
        this.inputPipe = pipe;
        this.templates = templateArr;
        this.inputAlphabet = pipe.getDataAlphabet();
        this.defaultFeatureIndex = this.inputAlphabet.size();
        for (int i = 0; i < this.templates.length; i++) {
            this.templates[i].index = i;
        }
    }

    public Alphabet getInputAlphabet() {
        return this.inputAlphabet;
    }

    public int getDefaultFeatureIndex() {
        return this.defaultFeatureIndex;
    }

    public Inferencer getInferencer() {
        return this.globalInferencer;
    }

    public void setInferencer(Inferencer inferencer) {
        this.globalInferencer = inferencer;
    }

    public Inferencer getViterbiInferencer() {
        return this.viterbi;
    }

    public void setViterbiInferencer(Inferencer inferencer) {
        this.viterbi = inferencer;
    }

    public boolean isDoSizeScale() {
        return this.doSizeScale;
    }

    public void setDoSizeScale(boolean z) {
        this.doSizeScale = z;
    }

    public void setSupportedOnly(boolean z) {
        for (int i = 0; i < this.templates.length; i++) {
            this.templates[i].setSupportedOnly(z);
        }
    }

    public boolean isCacheUnrolledGraphs() {
        return this.cacheUnrolledGraphs;
    }

    public void setCacheUnrolledGraphs(boolean z) {
        this.cacheUnrolledGraphs = z;
    }

    public void setFixedPotentials(Template[] templateArr) {
        this.fixedPtls = Arrays.asList(templateArr);
        for (Template template : templateArr) {
            template.index = -1;
        }
    }

    public void addFixedPotentials(Template[] templateArr) {
        for (Template template : templateArr) {
            template.setTrainable(false);
            this.fixedPtls.add(template);
            template.index = -1;
        }
    }

    public Template[] getTemplates() {
        return this.templates;
    }

    public Pipe getInputPipe() {
        return this.inputPipe;
    }

    public Template[] getFixedTemplates() {
        return (Template[]) this.fixedPtls.toArray(new Template[this.fixedPtls.size()]);
    }

    public void addFixedPotential(Template template) {
        template.setTrainable(false);
        this.fixedPtls.add(template);
        template.index = -1;
    }

    public double getGaussianPriorVariance() {
        return this.gaussianPriorVariance;
    }

    public void setGaussianPriorVariance(double d) {
        this.gaussianPriorVariance = d;
    }

    public void setGraphProcessor(GraphPostProcessor graphPostProcessor) {
        this.graphProcessor = graphPostProcessor;
    }

    public Optimizable.ByGradientValue getMaximizable(InstanceList instanceList) {
        return new MaximizableACRF(instanceList);
    }

    public List bestAssignment(InstanceList instanceList) {
        ArrayList arrayList = new ArrayList(instanceList.size());
        for (int i = 0; i < instanceList.size(); i++) {
            arrayList.add(bestAssignment(instanceList.get(i)));
        }
        return arrayList;
    }

    public Assignment bestAssignment(Instance instance) {
        return Models.bestAssignment(unroll(instance), this.viterbi);
    }

    public List getBestLabels(InstanceList instanceList) {
        ArrayList arrayList = new ArrayList(instanceList.size());
        for (int i = 0; i < instanceList.size(); i++) {
            arrayList.add(getBestLabels(instanceList.get(i)));
        }
        return arrayList;
    }

    public LabelsSequence getBestLabels(Instance instance) {
        return ((LabelsAssignment) instance.getTarget()).toLabelsSequence(bestAssignment(instance));
    }

    public UnrolledGraph unroll(Instance instance) {
        UnrolledGraph unrolledGraph;
        if (this.cacheUnrolledGraphs && this.graphCache.containsKey(instance)) {
            unrolledGraph = (UnrolledGraph) this.graphCache.get(instance);
            unrolledGraph.recomputeFactors();
        } else {
            unrolledGraph = new UnrolledGraph(instance, this.templates, this.fixedPtls);
            if (this.graphProcessor != null) {
                this.graphProcessor.process(unrolledGraph, instance);
            }
        }
        if (this.cacheUnrolledGraphs) {
            this.graphCache.put(instance, unrolledGraph);
        }
        return unrolledGraph;
    }

    public UnrolledGraph unrollStructureOnly(Instance instance) {
        UnrolledGraph unrolledGraph;
        if (this.cacheUnrolledGraphs && this.graphCache.containsKey(instance)) {
            unrolledGraph = (UnrolledGraph) this.graphCache.get(instance);
            unrolledGraph.recomputeFactors();
        } else {
            unrolledGraph = new UnrolledGraph(instance, this.templates, this.fixedPtls, false);
            if (this.graphProcessor != null) {
                this.graphProcessor.process(unrolledGraph, instance);
            }
        }
        if (this.cacheUnrolledGraphs) {
            this.graphCache.put(instance, unrolledGraph);
        }
        return unrolledGraph;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public void reportOnGraphCache() {
        logger.info("Number of cached graphs = " + this.graphCache.size());
    }

    public void print(OutputStream outputStream) {
        PrintStream printStream = new PrintStream(outputStream);
        printStream.println("ACRF. Number of templates: == " + this.templates.length);
        printStream.println("Weights");
        for (int i = 0; i < this.templates.length; i++) {
            Template template = this.templates[i];
            printStream.println("TEMPLATE " + i + " == " + template);
            printStream.println("Default weights: ");
            SparseVector defaultWeights = template.getDefaultWeights();
            for (int i2 = 0; i2 < defaultWeights.numLocations(); i2++) {
                printStream.println(" [" + defaultWeights.indexAtLocation(i2) + "] = " + defaultWeights.valueAtLocation(i2));
            }
            SparseVector[] weights = template.getWeights();
            for (int i3 = 0; i3 < weights.length; i3++) {
                printStream.println("Assignment " + i3);
                SparseVector sparseVector = weights[i3];
                for (int i4 = 0; i4 < sparseVector.numLocations(); i4++) {
                    int indexAtLocation = sparseVector.indexAtLocation(i4);
                    if (indexAtLocation == this.defaultFeatureIndex) {
                        printStream.print("DEFAULT");
                    } else {
                        printStream.print(this.inputAlphabet.lookupObject(indexAtLocation));
                    }
                    printStream.println("  " + sparseVector.valueAtLocation(i4));
                }
            }
        }
    }

    private static void dumpValues(String str, SparseVector[][] sparseVectorArr) {
        for (int i = 0; i < sparseVectorArr.length; i++) {
            try {
                System.out.println(str + " Clique: " + i);
                writeCliqueValues(sparseVectorArr[i]);
            } catch (IOException e) {
                System.err.println("Error writing to file!");
                e.printStackTrace();
                return;
            }
        }
    }

    private static void writeCliqueValues(SparseVector[] sparseVectorArr) throws IOException {
        System.out.println("Num assignments = " + sparseVectorArr.length);
        for (int i = 0; i < sparseVectorArr.length; i++) {
            System.out.println("Num locations = " + sparseVectorArr[i].numLocations());
            for (int i2 = 0; i2 < sparseVectorArr[i].numLocations(); i2++) {
                System.out.print("sparse [" + i + "][" + sparseVectorArr[i].indexAtLocation(i2) + "] = ");
                System.out.println(sparseVectorArr[i].valueAtLocation(i2));
            }
        }
    }

    private void dumpOneGraph(UnrolledGraph unrolledGraph) {
        unrolledGraph.getAssignment();
        Iterator unrolledVarSetIterator = unrolledGraph.unrolledVarSetIterator();
        while (unrolledVarSetIterator.hasNext()) {
            UnrolledVarSet unrolledVarSet = (UnrolledVarSet) unrolledVarSetIterator.next();
            System.out.println("Clique " + unrolledVarSet);
            Factor factorOf = unrolledGraph.factorOf((VarSet) unrolledVarSet);
            if (factorOf != null) {
                System.out.println(factorOf);
            }
        }
    }

    public void dumpUnrolledGraphs(InstanceList instanceList) {
        for (int i = 0; i < instanceList.size(); i++) {
            Instance instance = instanceList.get(i);
            System.out.println("INSTANCE " + i + JSWriter.ObjectPairSep + instance.getName());
            dumpOneGraph(unroll(instance));
        }
    }

    public void readWeightsFromText(Reader reader) throws IOException {
        try {
            for (Element element : new SAXBuilder().build(reader).getRootElement().getChildren("TEMPLATE")) {
                String attributeValue = element.getAttributeValue("NAME");
                Template template = this.templates[Integer.parseInt(element.getAttributeValue("IDX"))];
                if (!template.getClass().getName().equals(attributeValue)) {
                    throw new RuntimeException("Expected template " + template + "; got " + attributeValue);
                }
                SparseVector readSparseVector = readSparseVector(element.getChild("DEFAULT_WEIGHTS").getText(), null);
                Element child = element.getChild("WEIGHTS");
                SparseVector[] sparseVectorArr = new SparseVector[Integer.parseInt(child.getAttributeValue("SIZE"))];
                for (Element element2 : child.getChildren("WEIGHT")) {
                    sparseVectorArr[Integer.parseInt(element2.getAttributeValue("IDX"))] = readSparseVector(element2.getText(), getInputAlphabet());
                }
                template.setDefaultWeights(readSparseVector);
                template.weights = sparseVectorArr;
            }
        } catch (JDOMException e) {
            throw new RuntimeException(e);
        }
    }

    private SparseVector readSparseVector(String str, Alphabet alphabet) throws IOException {
        TIntArrayList tIntArrayList = new TIntArrayList();
        TDoubleArrayList tDoubleArrayList = new TDoubleArrayList();
        for (String str2 : str.split("\n")) {
            if (!Pattern.matches("^\\s*$", str2)) {
                String[] split = str2.split("\t");
                int lookupIndex = alphabet != null ? alphabet.lookupIndex(split[0]) : Integer.parseInt(split[0]);
                double parseDouble = Double.parseDouble(split[1]);
                tIntArrayList.add(lookupIndex);
                tDoubleArrayList.add(parseDouble);
            }
        }
        return new SparseVector(tIntArrayList.toNativeArray(), tDoubleArrayList.toNativeArray());
    }

    public void writeWeightsText(Writer writer) {
        PrintWriter printWriter = new PrintWriter(writer);
        printWriter.println("<CRF>");
        for (int i = 0; i < this.templates.length; i++) {
            Template template = this.templates[i];
            printWriter.println("<TEMPLATE NAME=\"" + template.getClass().getName() + "\" IDX=\"" + i + "\" >");
            printWriter.println("<DEFAULT_WEIGHTS>");
            SparseVector defaultWeights = template.getDefaultWeights();
            for (int i2 = 0; i2 < defaultWeights.numLocations(); i2++) {
                printWriter.print(defaultWeights.indexAtLocation(i2));
                printWriter.print("\t");
                printWriter.println(defaultWeights.valueAtLocation(i2));
            }
            printWriter.println("</DEFAULT_WEIGHTS>");
            printWriter.println();
            SparseVector[] weights = template.getWeights();
            printWriter.println("<WEIGHTS SIZE=\"" + weights.length + "\">");
            for (int i3 = 0; i3 < weights.length; i3++) {
                printWriter.println("<WEIGHT IDX=\"" + i3 + "\">");
                writeWeightVector(printWriter, weights[i3]);
                printWriter.println();
                printWriter.println("</WEIGHT>");
            }
            printWriter.println("</WEIGHTS>");
            printWriter.println("</TEMPLATE>");
        }
        printWriter.println("</CRF>");
    }

    private void writeWeightVector(PrintWriter printWriter, SparseVector sparseVector) {
        printWriter.println(SerializerConstants.CDATA_DELIMITER_OPEN);
        Alphabet inputAlphabet = getInputAlphabet();
        for (int i = 0; i < sparseVector.numLocations(); i++) {
            int indexAtLocation = sparseVector.indexAtLocation(i);
            double valueAtLocation = sparseVector.valueAtLocation(i);
            if (indexAtLocation < inputAlphabet.size()) {
                printWriter.print(inputAlphabet.lookupObject(indexAtLocation));
            } else {
                printWriter.print("IDX" + indexAtLocation);
            }
            printWriter.print("\t");
            printWriter.println(valueAtLocation);
        }
        printWriter.println(SerializerConstants.CDATA_DELIMITER_CLOSE);
    }

    public static ACRF makeFactorial(Pipe pipe, int i) {
        ArrayList arrayList = new ArrayList();
        for (int i2 = 0; i2 < i; i2++) {
            arrayList.add(new BigramTemplate(i2));
            if (i2 + 1 < i) {
                arrayList.add(new PairwiseFactorTemplate(i2, i2 + 1));
            }
        }
        return new ACRF(pipe, (Template[]) arrayList.toArray(new Template[arrayList.size()]));
    }

    private void readObject(ObjectInputStream objectInputStream) throws IOException, ClassNotFoundException {
        objectInputStream.defaultReadObject();
        this.graphCache = new THashMap();
    }

    public void setVerboseOutputDirectory(File file) {
        this.verboseOutputDirectory = file;
    }
}
