package cc.mallet.fst.semi_supervised;

import cc.mallet.fst.CRF;
import cc.mallet.fst.CRFCacheStaleIndicator;
import cc.mallet.fst.CRFOptimizableByBatchLabelLikelihood;
import cc.mallet.fst.CRFOptimizableByGradientValues;
import cc.mallet.fst.CRFOptimizableByLabelLikelihood;
import cc.mallet.fst.CRFTrainerByLabelLikelihood;
import cc.mallet.fst.CRFTrainerByThreadedLabelLikelihood;
import cc.mallet.fst.ThreadedOptimizable;
import cc.mallet.fst.Transducer;
import cc.mallet.fst.TransducerTrainer;
import cc.mallet.fst.semi_supervised.constraints.GEConstraint;
import cc.mallet.optimize.LimitedMemoryBFGS;
import cc.mallet.optimize.Optimizable;
import cc.mallet.types.InstanceList;
import java.util.ArrayList;

/* loaded from: input_file:WEB-INF/lib/mallet-2.0.7.jar:cc/mallet/fst/semi_supervised/CRFTrainerByLikelihoodAndGE.class */
public class CRFTrainerByLikelihoodAndGE extends TransducerTrainer {
    private CRF crf;
    private ArrayList<GEConstraint> constraints;
    private StateLabelMap map;
    private int iteration = 0;
    private boolean converged = false;
    private double geWeight = 1.0d;
    private boolean initSupervised = false;
    private double gpv = 10.0d;
    private int numThreads = 1;
    private int supIterations = Integer.MAX_VALUE;

    public CRFTrainerByLikelihoodAndGE(CRF crf, ArrayList<GEConstraint> arrayList, StateLabelMap stateLabelMap) {
        this.crf = crf;
        this.constraints = arrayList;
        this.map = stateLabelMap;
    }

    public void setGEWeight(double d) {
        this.geWeight = d;
    }

    public void setGaussianPriorVariance(double d) {
        this.gpv = d;
    }

    public void setInitSupervised(boolean z) {
        this.initSupervised = z;
    }

    public void setSupervisedIterations(int i) {
        this.supIterations = i;
    }

    public void setNumThreads(int i) {
        this.numThreads = i;
    }

    @Override // cc.mallet.fst.TransducerTrainer
    public Transducer getTransducer() {
        return this.crf;
    }

    @Override // cc.mallet.fst.TransducerTrainer
    public int getIteration() {
        return this.iteration;
    }

    @Override // cc.mallet.fst.TransducerTrainer
    public boolean isFinishedTraining() {
        return this.converged;
    }

    public boolean train(InstanceList instanceList, InstanceList instanceList2, int i) {
        Optimizable.ByGradientValue threadedOptimizable;
        System.err.println(instanceList.size());
        System.err.println(instanceList2.size());
        if (this.initSupervised) {
            if (this.numThreads == 1) {
                CRFTrainerByLabelLikelihood cRFTrainerByLabelLikelihood = new CRFTrainerByLabelLikelihood(this.crf);
                cRFTrainerByLabelLikelihood.setAddNoFactors(true);
                cRFTrainerByLabelLikelihood.setGaussianPriorVariance(this.gpv);
                cRFTrainerByLabelLikelihood.train(instanceList, this.supIterations);
            } else {
                CRFTrainerByThreadedLabelLikelihood cRFTrainerByThreadedLabelLikelihood = new CRFTrainerByThreadedLabelLikelihood(this.crf, this.numThreads);
                cRFTrainerByThreadedLabelLikelihood.setAddNoFactors(true);
                cRFTrainerByThreadedLabelLikelihood.setGaussianPriorVariance(this.gpv);
                cRFTrainerByThreadedLabelLikelihood.train(instanceList, this.supIterations);
                cRFTrainerByThreadedLabelLikelihood.shutdown();
            }
            runEvaluators();
        }
        if (this.numThreads == 1) {
            threadedOptimizable = new CRFOptimizableByLabelLikelihood(this.crf, instanceList);
            ((CRFOptimizableByLabelLikelihood) threadedOptimizable).setGaussianPriorVariance(this.gpv);
        } else {
            CRFOptimizableByBatchLabelLikelihood cRFOptimizableByBatchLabelLikelihood = new CRFOptimizableByBatchLabelLikelihood(this.crf, instanceList, this.numThreads);
            threadedOptimizable = new ThreadedOptimizable(cRFOptimizableByBatchLabelLikelihood, instanceList, this.crf.getParameters().getNumFactors(), new CRFCacheStaleIndicator(this.crf));
            cRFOptimizableByBatchLabelLikelihood.setGaussianPriorVariance(this.gpv);
        }
        CRFOptimizableByGE cRFOptimizableByGE = new CRFOptimizableByGE(this.crf, this.constraints, instanceList2, this.map, this.numThreads, this.geWeight);
        cRFOptimizableByGE.setGaussianPriorVariance(Double.POSITIVE_INFINITY);
        LimitedMemoryBFGS limitedMemoryBFGS = new LimitedMemoryBFGS(new CRFOptimizableByGradientValues(this.crf, new Optimizable.ByGradientValue[]{threadedOptimizable, cRFOptimizableByGE}));
        try {
            this.converged = limitedMemoryBFGS.optimize(i);
        } catch (Exception e) {
            e.printStackTrace();
        }
        limitedMemoryBFGS.reset();
        try {
            this.converged = limitedMemoryBFGS.optimize(i);
        } catch (Exception e2) {
            e2.printStackTrace();
        }
        if (this.numThreads > 1) {
            ((ThreadedOptimizable) threadedOptimizable).shutdown();
            cRFOptimizableByGE.shutdown();
        }
        return this.converged;
    }

    @Override // cc.mallet.fst.TransducerTrainer
    public boolean train(InstanceList instanceList, int i) {
        throw new RuntimeException("Must use train(InstanceList trainingSet, InstanceList unlabeledSet, int numIterations) instead");
    }
}
