package edu.stanford.nlp.classify;

import edu.stanford.nlp.classify.LogPrior;
import edu.stanford.nlp.ling.Datum;
import edu.stanford.nlp.ling.RVFDatum;
import edu.stanford.nlp.objectbank.ObjectBank;
import edu.stanford.nlp.optimization.Minimizer;
import edu.stanford.nlp.optimization.QNMinimizer;
import edu.stanford.nlp.stats.ClassicCounter;
import edu.stanford.nlp.stats.Counter;
import edu.stanford.nlp.util.ErasureUtils;
import edu.stanford.nlp.util.Index;
import edu.stanford.nlp.util.ReflectionLoading;
import edu.stanford.nlp.util.StringUtils;
import java.io.File;
import java.io.Serializable;
import java.util.Arrays;
import java.util.Collection;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.Properties;
import org.apache.xpath.XPath;

/* loaded from: input_file:edu/stanford/nlp/classify/LogisticClassifier.class */
public class LogisticClassifier<L, F> implements Classifier<L, F>, Serializable, RVFClassifier<L, F> {
    private static final long serialVersionUID = 6672245467246897192L;
    private double[] weights;
    private Index<F> featureIndex;
    private L[] classes;

    @Deprecated
    private LogPrior prior;

    @Deprecated
    private boolean biased;

    public String toString() {
        if (this.featureIndex == null) {
            return "";
        }
        StringBuilder sb = new StringBuilder();
        for (F f : this.featureIndex) {
            sb.append(this.classes[1]).append(" / ").append(f).append(" = ").append(this.weights[this.featureIndex.indexOf(f)]);
        }
        return sb.toString();
    }

    public L getLabelForInternalPositiveClass() {
        return this.classes[1];
    }

    public L getLabelForInternalNegativeClass() {
        return this.classes[0];
    }

    public Counter<String> weightsAsCounter() {
        ClassicCounter classicCounter = new ClassicCounter();
        for (F f : this.featureIndex) {
            classicCounter.incrementCount(this.classes[1] + " / " + f, this.weights[this.featureIndex.indexOf(f)]);
        }
        return classicCounter;
    }

    public Counter<F> weightsAsGenericCounter() {
        ClassicCounter classicCounter = new ClassicCounter();
        for (F f : this.featureIndex) {
            double d = this.weights[this.featureIndex.indexOf(f)];
            if (d != XPath.MATCH_SCORE_QNAME) {
                classicCounter.setCount(f, d);
            }
        }
        return classicCounter;
    }

    public Index<F> getFeatureIndex() {
        return this.featureIndex;
    }

    public double[] getWeights() {
        return this.weights;
    }

    public LogisticClassifier(double[] dArr, Index<F> index, L[] lArr) {
        this.classes = (L[]) ErasureUtils.mkTArray(Object.class, 2);
        this.biased = false;
        this.weights = dArr;
        this.featureIndex = index;
        this.classes = lArr;
    }

    @Deprecated
    public LogisticClassifier(boolean z) {
        this(new LogPrior(LogPrior.LogPriorType.QUADRATIC), z);
    }

    @Deprecated
    public LogisticClassifier(LogPrior logPrior) {
        this.classes = (L[]) ErasureUtils.mkTArray(Object.class, 2);
        this.biased = false;
        this.prior = logPrior;
    }

    @Deprecated
    public LogisticClassifier(LogPrior logPrior, boolean z) {
        this.classes = (L[]) ErasureUtils.mkTArray(Object.class, 2);
        this.biased = false;
        this.prior = logPrior;
        this.biased = z;
    }

    @Override // edu.stanford.nlp.classify.Classifier
    public Collection<L> labels() {
        LinkedList linkedList = new LinkedList();
        linkedList.add(this.classes[0]);
        linkedList.add(this.classes[1]);
        return linkedList;
    }

    @Override // edu.stanford.nlp.classify.Classifier
    public L classOf(Datum<L, F> datum) {
        return datum instanceof RVFDatum ? classOfRVFDatum((RVFDatum) datum) : classOf(datum.asFeatures());
    }

    @Override // edu.stanford.nlp.classify.RVFClassifier
    @Deprecated
    public L classOf(RVFDatum<L, F> rVFDatum) {
        return classOf(rVFDatum.asFeaturesCounter());
    }

    private L classOfRVFDatum(RVFDatum<L, F> rVFDatum) {
        return classOf(rVFDatum.asFeaturesCounter());
    }

    public L classOf(Counter<F> counter) {
        return scoreOf(counter) > XPath.MATCH_SCORE_QNAME ? this.classes[1] : this.classes[0];
    }

    public L classOf(Collection<F> collection) {
        return scoreOf(collection) > XPath.MATCH_SCORE_QNAME ? this.classes[1] : this.classes[0];
    }

    public double scoreOf(Collection<F> collection) {
        double d = 0.0d;
        Iterator<F> it = collection.iterator();
        while (it.hasNext()) {
            int indexOf = this.featureIndex.indexOf(it.next());
            if (indexOf >= 0) {
                d += this.weights[indexOf];
            }
        }
        return d;
    }

    public double scoreOf(Counter<F> counter) {
        double d = 0.0d;
        for (F f : counter.keySet()) {
            int indexOf = this.featureIndex.indexOf(f);
            if (indexOf >= 0) {
                d += this.weights[indexOf] * counter.getCount(f);
            }
        }
        return d;
    }

    public Counter<F> justificationOf(Counter<F> counter) {
        ClassicCounter classicCounter = new ClassicCounter();
        for (F f : counter.keySet()) {
            int indexOf = this.featureIndex.indexOf(f);
            if (indexOf >= 0) {
                classicCounter.incrementCount(f, this.weights[indexOf] * counter.getCount(f));
            }
        }
        return classicCounter;
    }

    public Counter<F> justificationOf(Collection<F> collection) {
        ClassicCounter classicCounter = new ClassicCounter();
        for (F f : collection) {
            int indexOf = this.featureIndex.indexOf(f);
            if (indexOf >= 0) {
                classicCounter.incrementCount(f, this.weights[indexOf]);
            }
        }
        return classicCounter;
    }

    @Override // edu.stanford.nlp.classify.Classifier
    public Counter<L> scoresOf(Datum<L, F> datum) {
        if (datum instanceof RVFDatum) {
            return scoresOfRVFDatum((RVFDatum) datum);
        }
        double scoreOf = scoreOf(datum.asFeatures());
        ClassicCounter classicCounter = new ClassicCounter();
        classicCounter.setCount(this.classes[0], -scoreOf);
        classicCounter.setCount(this.classes[1], scoreOf);
        return classicCounter;
    }

    @Override // edu.stanford.nlp.classify.RVFClassifier
    @Deprecated
    public Counter<L> scoresOf(RVFDatum<L, F> rVFDatum) {
        return scoresOfRVFDatum(rVFDatum);
    }

    private Counter<L> scoresOfRVFDatum(RVFDatum<L, F> rVFDatum) {
        double scoreOf = scoreOf(rVFDatum.asFeaturesCounter());
        ClassicCounter classicCounter = new ClassicCounter();
        classicCounter.setCount(this.classes[0], -scoreOf);
        classicCounter.setCount(this.classes[1], scoreOf);
        return classicCounter;
    }

    public double probabilityOf(Datum<L, F> datum) {
        return datum instanceof RVFDatum ? probabilityOfRVFDatum((RVFDatum) datum) : probabilityOf(datum.asFeatures(), (Collection<F>) datum.label());
    }

    public double probabilityOf(Collection<F> collection, L l) {
        return 1.0d / (1.0d + Math.exp(((short) (l.equals(this.classes[0]) ? 1 : -1)) * scoreOf(collection)));
    }

    public double probabilityOf(RVFDatum<L, F> rVFDatum) {
        return probabilityOfRVFDatum(rVFDatum);
    }

    private double probabilityOfRVFDatum(RVFDatum<L, F> rVFDatum) {
        return probabilityOf(rVFDatum.asFeaturesCounter(), (Counter<F>) rVFDatum.label());
    }

    public double probabilityOf(Counter<F> counter, L l) {
        return 1.0d / (1.0d + Math.exp(((short) (l.equals(this.classes[0]) ? 1 : -1)) * scoreOf(counter)));
    }

    @Deprecated
    public void trainWeightedData(GeneralDataset<L, F> generalDataset, float[] fArr) {
        if (generalDataset.labelIndex.size() != 2) {
            throw new RuntimeException("LogisticClassifier is only for binary classification!");
        }
        LogisticObjectiveFunction logisticObjectiveFunction = null;
        if (generalDataset instanceof Dataset) {
            logisticObjectiveFunction = new LogisticObjectiveFunction(generalDataset.numFeatureTypes(), generalDataset.getDataArray(), generalDataset.getLabelsArray(), this.prior, fArr);
        } else if (generalDataset instanceof RVFDataset) {
            logisticObjectiveFunction = new LogisticObjectiveFunction(generalDataset.numFeatureTypes(), generalDataset.getDataArray(), generalDataset.getValuesArray(), generalDataset.getLabelsArray(), this.prior, fArr);
        }
        this.weights = new QNMinimizer(logisticObjectiveFunction).minimize((QNMinimizer) logisticObjectiveFunction, 1.0E-4d, new double[generalDataset.numFeatureTypes()]);
        this.featureIndex = generalDataset.featureIndex;
        this.classes[0] = generalDataset.labelIndex.get(0);
        this.classes[1] = generalDataset.labelIndex.get(1);
    }

    @Deprecated
    public void train(GeneralDataset<L, F> generalDataset) {
        train(generalDataset, XPath.MATCH_SCORE_QNAME, 1.0E-4d);
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v12, types: [edu.stanford.nlp.optimization.Minimizer] */
    /* JADX WARN: Type inference failed for: r0v15, types: [L[]] */
    /* JADX WARN: Type inference failed for: r0v17, types: [L[]] */
    /* JADX WARN: Type inference failed for: r0v30, types: [edu.stanford.nlp.optimization.Minimizer] */
    @Deprecated
    public void train(GeneralDataset<L, F> generalDataset, double d, double d2) {
        if (generalDataset.labelIndex.size() != 2) {
            throw new RuntimeException("LogisticClassifier is only for binary classification!");
        }
        if (this.biased) {
            BiasedLogisticObjectiveFunction biasedLogisticObjectiveFunction = new BiasedLogisticObjectiveFunction(generalDataset.numFeatureTypes(), generalDataset.getDataArray(), generalDataset.getLabelsArray(), this.prior);
            this.weights = (d > XPath.MATCH_SCORE_QNAME ? (Minimizer) ReflectionLoading.loadByReflection("edu.stanford.nlp.optimization.OWLQNMinimizer", Double.valueOf(d)) : new QNMinimizer(biasedLogisticObjectiveFunction)).minimize((QNMinimizer) biasedLogisticObjectiveFunction, d2, new double[generalDataset.numFeatureTypes()]);
        } else {
            LogisticObjectiveFunction logisticObjectiveFunction = null;
            if (generalDataset instanceof Dataset) {
                logisticObjectiveFunction = new LogisticObjectiveFunction(generalDataset.numFeatureTypes(), generalDataset.getDataArray(), generalDataset.getLabelsArray(), this.prior);
            } else if (generalDataset instanceof RVFDataset) {
                logisticObjectiveFunction = new LogisticObjectiveFunction(generalDataset.numFeatureTypes(), generalDataset.getDataArray(), generalDataset.getValuesArray(), generalDataset.getLabelsArray(), this.prior);
            }
            this.weights = (d > XPath.MATCH_SCORE_QNAME ? (Minimizer) ReflectionLoading.loadByReflection("edu.stanford.nlp.optimization.OWLQNMinimizer", Double.valueOf(d)) : new QNMinimizer(logisticObjectiveFunction)).minimize((QNMinimizer) logisticObjectiveFunction, d2, new double[generalDataset.numFeatureTypes()]);
        }
        this.featureIndex = generalDataset.featureIndex;
        this.classes[0] = generalDataset.labelIndex.get(0);
        this.classes[1] = generalDataset.labelIndex.get(1);
    }

    public static void main(String[] strArr) throws Exception {
        Properties argsToProperties = StringUtils.argsToProperties(strArr);
        double parseDouble = Double.parseDouble(argsToProperties.getProperty("l1reg", "0.0"));
        Dataset dataset = new Dataset();
        Iterator<String> it = ObjectBank.getLineIterator(new File(argsToProperties.getProperty("trainFile"))).iterator();
        while (it.hasNext()) {
            String[] split = it.next().split("\\s+");
            dataset.add((Collection) new LinkedList(Arrays.asList(split).subList(1, split.length)), (LinkedList) split[0]);
        }
        dataset.summaryStatistics();
        LogisticClassifier<L, F> trainClassifier = new LogisticClassifierFactory().trainClassifier(dataset, parseDouble, 1.0E-4d, argsToProperties.getProperty("biased", "false").equals("true"));
        Iterator<String> it2 = ObjectBank.getLineIterator(new File(argsToProperties.getProperty("testFile"))).iterator();
        while (it2.hasNext()) {
            String next = it2.next();
            String[] split2 = next.split("\\s+");
            System.out.println(((String) trainClassifier.classOf(new LinkedList(Arrays.asList(split2).subList(1, split2.length)))) + '\t' + next);
        }
    }
}
