package cc.mallet.classify;

import cc.mallet.optimize.ConjugateGradient;
import cc.mallet.optimize.LimitedMemoryBFGS;
import cc.mallet.optimize.Optimizable;
import cc.mallet.types.Alphabet;
import cc.mallet.types.FeatureSelection;
import cc.mallet.types.FeatureVector;
import cc.mallet.types.FeatureVectorSequence;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.Label;
import cc.mallet.types.Labels;
import cc.mallet.types.MatrixOps;
import cc.mallet.util.MalletLogger;
import cc.mallet.util.MalletProgressMessageLogger;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.util.Arrays;
import java.util.Iterator;
import java.util.logging.Logger;

/* loaded from: input_file:WEB-INF/lib/mallet-2.0.7.jar:cc/mallet/classify/RankMaxEntTrainer.class */
public class RankMaxEntTrainer extends MaxEntTrainer {
    private static Logger logger = MalletLogger.getLogger(RankMaxEntTrainer.class.getName());
    private static Logger progressLogger = MalletProgressMessageLogger.getLogger(RankMaxEntTrainer.class.getName() + "-pl");
    private static final long serialVersionUID = 1;
    private static final int CURRENT_SERIAL_VERSION = 1;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:WEB-INF/lib/mallet-2.0.7.jar:cc/mallet/classify/RankMaxEntTrainer$MaximizableTrainer.class */
    public class MaximizableTrainer implements Optimizable.ByGradientValue {
        double[] parameters;
        double[] constraints;
        double[] cachedGradient;
        RankMaxEnt theClassifier;
        InstanceList trainingList;
        double cachedValue;
        boolean cachedValueStale;
        boolean cachedGradientStale;
        int numLabels;
        int numFeatures;
        int defaultFeatureIndex;
        FeatureSelection featureSelection;
        FeatureSelection[] perLabelFeatureSelection;
        static final /* synthetic */ boolean $assertionsDisabled;

        public MaximizableTrainer() {
        }

        public MaximizableTrainer(InstanceList instanceList, RankMaxEnt rankMaxEnt) {
            this.trainingList = instanceList;
            Alphabet dataAlphabet = instanceList.getDataAlphabet();
            this.numLabels = 2;
            this.numFeatures = dataAlphabet.size() + 1;
            this.defaultFeatureIndex = this.numFeatures - 1;
            this.parameters = new double[this.numLabels * this.numFeatures];
            this.constraints = new double[this.numLabels * this.numFeatures];
            this.cachedGradient = new double[this.numLabels * this.numFeatures];
            Arrays.fill(this.parameters, 0.0d);
            Arrays.fill(this.constraints, 0.0d);
            Arrays.fill(this.cachedGradient, 0.0d);
            this.featureSelection = instanceList.getFeatureSelection();
            this.perLabelFeatureSelection = instanceList.getPerLabelFeatureSelection();
            if (this.featureSelection != null) {
                this.featureSelection.add(this.defaultFeatureIndex);
            }
            if (this.perLabelFeatureSelection != null) {
                for (int i = 0; i < this.perLabelFeatureSelection.length; i++) {
                    this.perLabelFeatureSelection[i].add(this.defaultFeatureIndex);
                }
            }
            if (!$assertionsDisabled && this.featureSelection != null && this.perLabelFeatureSelection != null) {
                throw new AssertionError();
            }
            if (rankMaxEnt != null) {
                this.theClassifier = rankMaxEnt;
                this.parameters = this.theClassifier.parameters;
                this.featureSelection = this.theClassifier.featureSelection;
                this.perLabelFeatureSelection = this.theClassifier.perClassFeatureSelection;
                this.defaultFeatureIndex = this.theClassifier.defaultFeatureIndex;
                if (!$assertionsDisabled && rankMaxEnt.getInstancePipe() != instanceList.getPipe()) {
                    throw new AssertionError();
                }
            } else if (this.theClassifier == null) {
                this.theClassifier = new RankMaxEnt(instanceList.getPipe(), this.parameters, this.featureSelection, this.perLabelFeatureSelection);
            }
            this.cachedValueStale = true;
            this.cachedGradientStale = true;
            Iterator<Instance> it = this.trainingList.iterator();
            RankMaxEntTrainer.logger.fine("Number of instances in training list = " + this.trainingList.size());
            while (it.hasNext()) {
                Instance next = it.next();
                double instanceWeight = this.trainingList.getInstanceWeight(next);
                FeatureVectorSequence featureVectorSequence = (FeatureVectorSequence) next.getData();
                Object target = next.getTarget();
                int intValue = Integer.valueOf((target instanceof Labels ? ((Labels) target).get(0) : (Label) target).getBestLabel().getEntry().toString()).intValue();
                if (intValue == -1) {
                    RankMaxEntTrainer.logger.warning("True label is -1. Skipping...");
                } else {
                    FeatureVector featureVector = featureVectorSequence.get(intValue);
                    Alphabet alphabet = featureVector.getAlphabet();
                    if (!$assertionsDisabled && featureVector.getAlphabet() != dataAlphabet) {
                        throw new AssertionError();
                    }
                    MatrixOps.rowPlusEquals(this.constraints, this.numFeatures, 0, featureVector, instanceWeight);
                    if (!$assertionsDisabled && Double.isNaN(instanceWeight)) {
                        throw new AssertionError("instanceWeight is NaN");
                    }
                    boolean z = false;
                    for (int i2 = 0; i2 < featureVector.numLocations(); i2++) {
                        if (Double.isNaN(featureVector.valueAtLocation(i2))) {
                            RankMaxEntTrainer.logger.info("NaN for feature " + alphabet.lookupObject(featureVector.indexAtLocation(i2)).toString());
                            z = true;
                        }
                    }
                    if (z) {
                        RankMaxEntTrainer.logger.info("NaN in instance: " + next.getName());
                    }
                    double[] dArr = this.constraints;
                    int i3 = (0 * this.numFeatures) + this.defaultFeatureIndex;
                    dArr[i3] = dArr[i3] + (1.0d * instanceWeight);
                }
            }
        }

        public RankMaxEnt getClassifier() {
            return this.theClassifier;
        }

        @Override // cc.mallet.optimize.Optimizable
        public double getParameter(int i) {
            return this.parameters[i];
        }

        @Override // cc.mallet.optimize.Optimizable
        public void setParameter(int i, double d) {
            this.cachedValueStale = true;
            this.cachedGradientStale = true;
            this.parameters[i] = d;
        }

        @Override // cc.mallet.optimize.Optimizable
        public int getNumParameters() {
            return this.parameters.length;
        }

        @Override // cc.mallet.optimize.Optimizable
        public void getParameters(double[] dArr) {
            if (dArr == null || dArr.length != this.parameters.length) {
                dArr = new double[this.parameters.length];
            }
            System.arraycopy(this.parameters, 0, dArr, 0, this.parameters.length);
        }

        @Override // cc.mallet.optimize.Optimizable
        public void setParameters(double[] dArr) {
            if (!$assertionsDisabled && dArr == null) {
                throw new AssertionError();
            }
            this.cachedValueStale = true;
            this.cachedGradientStale = true;
            if (dArr.length != this.parameters.length) {
                this.parameters = new double[dArr.length];
            }
            System.arraycopy(dArr, 0, this.parameters, 0, dArr.length);
        }

        @Override // cc.mallet.optimize.Optimizable.ByGradientValue
        public double getValue() {
            if (this.cachedValueStale) {
                this.cachedValue = 0.0d;
                this.cachedGradientStale = true;
                MatrixOps.setAll(this.cachedGradient, 0.0d);
                Iterator<Instance> it = this.trainingList.iterator();
                int i = 0;
                while (it.hasNext()) {
                    i++;
                    Instance next = it.next();
                    FeatureVectorSequence featureVectorSequence = (FeatureVectorSequence) next.getData();
                    double[] dArr = new double[featureVectorSequence.size()];
                    double instanceWeight = this.trainingList.getInstanceWeight(next);
                    Object target = next.getTarget();
                    int i2 = -1;
                    if (target instanceof Label) {
                        i2 = Integer.valueOf(((Label) target).toString()).intValue();
                        if (i2 == -1) {
                            continue;
                        } else {
                            if (!$assertionsDisabled && (i2 < 0 || i2 >= featureVectorSequence.size())) {
                                throw new AssertionError();
                            }
                            this.theClassifier.getClassificationScores(next, dArr);
                        }
                    } else if (target instanceof Labels) {
                        Labels labels = (Labels) target;
                        int[] iArr = new int[labels.size()];
                        for (int i3 = 0; i3 < labels.size(); i3++) {
                            iArr[i3] = Integer.valueOf(labels.get(i3).toString()).intValue();
                        }
                        i2 = iArr[0];
                        this.theClassifier.getClassificationScoresForTies(next, dArr, iArr);
                    }
                    double d = -(instanceWeight * Math.log(dArr[i2]));
                    if (Double.isNaN(d)) {
                        RankMaxEntTrainer.logger.fine("MaxEntTrainer: Instance " + next.getName() + "has NaN value. log(scores)= " + Math.log(dArr[i2]) + " scores = " + dArr[i2] + " has instance weight = " + instanceWeight);
                    }
                    if (Double.isInfinite(d)) {
                        RankMaxEntTrainer.logger.warning("Instance " + next.getSource() + " has infinite value; skipping value and gradient");
                        this.cachedValue -= d;
                        this.cachedValueStale = false;
                        return -d;
                    }
                    this.cachedValue += d;
                    double d2 = dArr[i2];
                    for (int i4 = 0; i4 < featureVectorSequence.size(); i4++) {
                        if (dArr[i4] != 0.0d) {
                            if (!$assertionsDisabled && Double.isInfinite(dArr[i4])) {
                                throw new AssertionError();
                            }
                            MatrixOps.rowPlusEquals(this.cachedGradient, this.numFeatures, 0, featureVectorSequence.get(i4), (-instanceWeight) * dArr[i4]);
                            double[] dArr2 = this.cachedGradient;
                            int i5 = (this.numFeatures * 0) + this.defaultFeatureIndex;
                            dArr2[i5] = dArr2[i5] + ((-instanceWeight) * dArr[i4]);
                        }
                    }
                }
                for (int i6 = 0; i6 < this.numLabels; i6++) {
                    for (int i7 = 0; i7 < this.numFeatures; i7++) {
                        double d3 = this.parameters[(i6 * this.numFeatures) + i7];
                        this.cachedValue += (d3 * d3) / (2.0d * RankMaxEntTrainer.this.gaussianPriorVariance);
                    }
                }
                this.cachedValue *= -1.0d;
                this.cachedValueStale = false;
                RankMaxEntTrainer.progressLogger.info("Value (loglikelihood) = " + this.cachedValue);
            }
            return this.cachedValue;
        }

        @Override // cc.mallet.optimize.Optimizable.ByGradientValue
        public void getValueGradient(double[] dArr) {
            if (this.cachedGradientStale) {
                if (this.cachedValueStale) {
                    getValue();
                }
                MatrixOps.plusEquals(this.cachedGradient, this.constraints);
                MatrixOps.plusEquals(this.cachedGradient, this.parameters, (-1.0d) / RankMaxEntTrainer.this.gaussianPriorVariance);
                MatrixOps.substitute(this.cachedGradient, Double.NEGATIVE_INFINITY, 0.0d);
                if (this.perLabelFeatureSelection == null) {
                    for (int i = 0; i < this.numLabels; i++) {
                        MatrixOps.rowSetAll(this.cachedGradient, this.numFeatures, i, 0.0d, this.featureSelection, false);
                    }
                } else {
                    for (int i2 = 0; i2 < this.numLabels; i2++) {
                        MatrixOps.rowSetAll(this.cachedGradient, this.numFeatures, i2, 0.0d, this.perLabelFeatureSelection[i2], false);
                    }
                }
                this.cachedGradientStale = false;
            }
            if (!$assertionsDisabled && (dArr == null || dArr.length != this.parameters.length)) {
                throw new AssertionError();
            }
            System.arraycopy(this.cachedGradient, 0, dArr, 0, this.cachedGradient.length);
        }

        static {
            $assertionsDisabled = !RankMaxEntTrainer.class.desiredAssertionStatus();
        }
    }

    public RankMaxEntTrainer() {
    }

    public RankMaxEntTrainer(double d) {
        super(d);
    }

    public Optimizable.ByGradientValue getMaximizableTrainer(InstanceList instanceList) {
        return instanceList == null ? new MaximizableTrainer() : new MaximizableTrainer(instanceList, null);
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // cc.mallet.classify.MaxEntTrainer, cc.mallet.classify.ClassifierTrainer
    public MaxEnt train(InstanceList instanceList) {
        boolean z;
        logger.fine("trainingSet.size() = " + instanceList.size());
        MaximizableTrainer maximizableTrainer = new MaximizableTrainer(instanceList, (RankMaxEnt) this.initialClassifier);
        LimitedMemoryBFGS limitedMemoryBFGS = new LimitedMemoryBFGS(maximizableTrainer);
        for (int i = 0; i < this.numIterations; i++) {
            try {
                z = limitedMemoryBFGS.optimize(1);
            } catch (IllegalArgumentException e) {
                e.printStackTrace();
                logger.info("Catching exception; saying converged.");
                z = true;
            }
            if (z) {
                break;
            }
        }
        if (this.numIterations == Integer.MAX_VALUE) {
            this.optimizer = new ConjugateGradient(maximizableTrainer);
            try {
                this.optimizer.optimize();
            } catch (IllegalArgumentException e2) {
                e2.printStackTrace();
                logger.info("Catching exception; saying converged.");
            }
        }
        progressLogger.info("\n");
        return maximizableTrainer.getClassifier();
    }

    @Override // cc.mallet.classify.MaxEntTrainer
    public String toString() {
        return "RankMaxEntTrainer,numIterations=" + this.numIterations + ",gaussianPriorVariance=" + this.gaussianPriorVariance;
    }

    private void writeObject(ObjectOutputStream objectOutputStream) throws IOException {
        objectOutputStream.defaultWriteObject();
        objectOutputStream.writeInt(1);
    }

    private void readObject(ObjectInputStream objectInputStream) throws IOException, ClassNotFoundException {
        objectInputStream.defaultReadObject();
        objectInputStream.readInt();
    }
}
