package cc.mallet.classify;

import cc.mallet.pipe.Pipe;
import cc.mallet.types.Alphabet;
import cc.mallet.types.FeatureVector;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.LabelVector;
import cc.mallet.types.Labeling;
import cc.mallet.types.Multinomial;
import cc.mallet.types.RankedFeatureVector;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.Serializable;
import java.util.Arrays;
import org.apache.commons.lang3.StringUtils;

/* loaded from: input_file:WEB-INF/lib/mallet-2.0.7.jar:cc/mallet/classify/NaiveBayes.class */
public class NaiveBayes extends Classifier implements Serializable {
    Multinomial.Logged prior;
    Multinomial.Logged[] p;
    private static final long serialVersionUID = 1;
    private static final int CURRENT_SERIAL_VERSION = 1;
    static final /* synthetic */ boolean $assertionsDisabled;

    public NaiveBayes(Pipe pipe, Multinomial.Logged logged, Multinomial.Logged[] loggedArr) {
        super(pipe);
        this.prior = logged;
        this.p = loggedArr;
    }

    private static Multinomial.Logged[] logMultinomials(Multinomial[] multinomialArr) {
        Multinomial.Logged[] loggedArr = new Multinomial.Logged[multinomialArr.length];
        for (int i = 0; i < multinomialArr.length; i++) {
            loggedArr[i] = new Multinomial.Logged(multinomialArr[i]);
        }
        return loggedArr;
    }

    public NaiveBayes(Pipe pipe, Multinomial multinomial, Multinomial[] multinomialArr) {
        this(pipe, new Multinomial.Logged(multinomial), logMultinomials(multinomialArr));
    }

    public Multinomial.Logged[] getMultinomials() {
        return this.p;
    }

    public Multinomial.Logged getPriors() {
        return this.prior;
    }

    public void printWords(int i) {
        Alphabet dataAlphabet = this.instancePipe.getDataAlphabet();
        int size = dataAlphabet.size();
        int size2 = this.instancePipe.getTargetAlphabet().size();
        double[] dArr = new double[size];
        int min = Math.min(i, size);
        for (int i2 = 0; i2 < size2; i2++) {
            Arrays.fill(dArr, 0.0d);
            this.p[i2].addProbabilities(dArr);
            RankedFeatureVector rankedFeatureVector = new RankedFeatureVector(dataAlphabet, dArr);
            System.out.println("\nFeature probabilities " + this.instancePipe.getTargetAlphabet().lookupObject(i2));
            for (int i3 = 0; i3 < min; i3++) {
                System.out.println(rankedFeatureVector.getObjectAtRank(i3) + StringUtils.SPACE + rankedFeatureVector.getValueAtRank(i3));
            }
        }
    }

    @Override // cc.mallet.classify.Classifier
    public Classification classify(Instance instance) {
        int size = getLabelAlphabet().size();
        double[] dArr = new double[size];
        FeatureVector featureVector = (FeatureVector) instance.getData();
        if (!$assertionsDisabled && this.instancePipe != null && featureVector.getAlphabet() != this.instancePipe.getDataAlphabet()) {
            throw new AssertionError();
        }
        int numLocations = featureVector.numLocations();
        this.prior.addLogProbabilities(dArr);
        for (int i = 0; i < numLocations; i++) {
            int indexAtLocation = featureVector.indexAtLocation(i);
            for (int i2 = 0; i2 < size; i2++) {
                if (i2 < this.p.length && indexAtLocation < this.p[i2].size()) {
                    int i3 = i2;
                    dArr[i3] = dArr[i3] + (featureVector.valueAtLocation(i) * this.p[i2].logProbability(indexAtLocation));
                }
            }
        }
        double d = Double.NEGATIVE_INFINITY;
        for (int i4 = 0; i4 < size; i4++) {
            if (dArr[i4] > d) {
                d = dArr[i4];
            }
        }
        for (int i5 = 0; i5 < size; i5++) {
            int i6 = i5;
            dArr[i6] = dArr[i6] - d;
        }
        double d2 = 0.0d;
        for (int i7 = 0; i7 < size; i7++) {
            double exp = Math.exp(dArr[i7]);
            dArr[i7] = exp;
            d2 += exp;
        }
        for (int i8 = 0; i8 < size; i8++) {
            int i9 = i8;
            dArr[i9] = dArr[i9] / d2;
        }
        return new Classification(instance, this, new LabelVector(getLabelAlphabet(), dArr));
    }

    private double dataLogProbability(Instance instance, int i) {
        FeatureVector featureVector = (FeatureVector) instance.getData();
        int numLocations = featureVector.numLocations();
        double d = 0.0d;
        for (int i2 = 0; i2 < numLocations; i2++) {
            d += featureVector.valueAtLocation(i2) * this.p[i].logProbability(featureVector.indexAtLocation(i2));
        }
        return d;
    }

    public double dataLogLikelihood(InstanceList instanceList) {
        double d = 0.0d;
        for (int i = 0; i < instanceList.size(); i++) {
            double instanceWeight = instanceList.getInstanceWeight(i);
            Instance instance = instanceList.get(i);
            Labeling labeling = instance.getLabeling();
            if (labeling != null) {
                d += instanceWeight * dataLogProbability(instance, labeling.getBestIndex());
            } else {
                Labeling labeling2 = classify(instance).getLabeling();
                for (int i2 = 0; i2 < labeling2.numLocations(); i2++) {
                    int indexAtLocation = labeling2.indexAtLocation(i2);
                    double valueAtLocation = labeling2.valueAtLocation(i2);
                    if (valueAtLocation != 0.0d) {
                        d += instanceWeight * valueAtLocation * dataLogProbability(instance, indexAtLocation);
                    }
                }
            }
        }
        return d;
    }

    public double labelLogLikelihood(InstanceList instanceList) {
        double d = 0.0d;
        for (int i = 0; i < instanceList.size(); i++) {
            double instanceWeight = instanceList.getInstanceWeight(i);
            Instance instance = instanceList.get(i);
            Labeling labeling = instance.getLabeling();
            if (labeling != null) {
                Labeling labeling2 = classify(instance).getLabeling();
                if (labeling.numLocations() == 1) {
                    d += instanceWeight * Math.log(labeling2.value(labeling.getBestIndex()));
                } else {
                    for (int i2 = 0; i2 < labeling.numLocations(); i2++) {
                        int indexAtLocation = labeling.indexAtLocation(i2);
                        double valueAtLocation = labeling.valueAtLocation(i2);
                        if (valueAtLocation != 0.0d) {
                            d += instanceWeight * valueAtLocation * Math.log(labeling2.value(indexAtLocation));
                        }
                    }
                }
            }
        }
        return d;
    }

    private void writeObject(ObjectOutputStream objectOutputStream) throws IOException {
        objectOutputStream.writeInt(1);
        objectOutputStream.writeObject(getInstancePipe());
        objectOutputStream.writeObject(this.prior);
        objectOutputStream.writeObject(this.p);
    }

    private void readObject(ObjectInputStream objectInputStream) throws IOException, ClassNotFoundException {
        int readInt = objectInputStream.readInt();
        if (readInt != 1) {
            throw new ClassNotFoundException("Mismatched NaiveBayes versions: wanted 1, got " + readInt);
        }
        this.instancePipe = (Pipe) objectInputStream.readObject();
        this.prior = (Multinomial.Logged) objectInputStream.readObject();
        this.p = (Multinomial.Logged[]) objectInputStream.readObject();
    }

    static {
        $assertionsDisabled = !NaiveBayes.class.desiredAssertionStatus();
    }
}
