package cc.mallet.examples;

import cc.mallet.pipe.CharSequence2TokenSequence;
import cc.mallet.pipe.CharSequenceLowercase;
import cc.mallet.pipe.SerialPipes;
import cc.mallet.pipe.TokenSequence2FeatureSequence;
import cc.mallet.pipe.TokenSequenceRemoveStopwords;
import cc.mallet.pipe.iterator.CsvIterator;
import cc.mallet.topics.ParallelTopicModel;
import cc.mallet.types.Alphabet;
import cc.mallet.types.FeatureSequence;
import cc.mallet.types.IDSorter;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.LabelSequence;
import java.io.File;
import java.io.FileInputStream;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.Formatter;
import java.util.Iterator;
import java.util.Locale;
import java.util.TreeSet;
import java.util.regex.Pattern;
import org.apache.commons.lang3.StringUtils;

/* loaded from: input_file:WEB-INF/lib/mallet-2.0.7.jar:cc/mallet/examples/TopicModel.class */
public class TopicModel {
    public static void main(String[] strArr) throws Exception {
        ArrayList arrayList = new ArrayList();
        arrayList.add(new CharSequenceLowercase());
        arrayList.add(new CharSequence2TokenSequence(Pattern.compile("\\p{L}[\\p{L}\\p{P}]+\\p{L}")));
        arrayList.add(new TokenSequenceRemoveStopwords(new File("stoplists/en.txt"), "UTF-8", false, false, false));
        arrayList.add(new TokenSequence2FeatureSequence());
        InstanceList instanceList = new InstanceList(new SerialPipes(arrayList));
        instanceList.addThruPipe(new CsvIterator(new InputStreamReader(new FileInputStream(new File(strArr[0])), "UTF-8"), Pattern.compile("^(\\S*)[\\s,]*(\\S*)[\\s,]*(.*)$"), 3, 2, 1));
        ParallelTopicModel parallelTopicModel = new ParallelTopicModel(100, 1.0d, 0.01d);
        parallelTopicModel.addInstances(instanceList);
        parallelTopicModel.setNumThreads(2);
        parallelTopicModel.setNumIterations(50);
        parallelTopicModel.estimate();
        Alphabet dataAlphabet = instanceList.getDataAlphabet();
        FeatureSequence featureSequence = (FeatureSequence) parallelTopicModel.getData().get(0).instance.getData();
        LabelSequence labelSequence = parallelTopicModel.getData().get(0).topicSequence;
        Formatter formatter = new Formatter(new StringBuilder(), Locale.US);
        for (int i = 0; i < featureSequence.getLength(); i++) {
            formatter.format("%s-%d ", dataAlphabet.lookupObject(featureSequence.getIndexAtPosition(i)), Integer.valueOf(labelSequence.getIndexAtPosition(i)));
        }
        System.out.println(formatter);
        double[] topicProbabilities = parallelTopicModel.getTopicProbabilities(0);
        ArrayList<TreeSet<IDSorter>> sortedWords = parallelTopicModel.getSortedWords();
        for (int i2 = 0; i2 < 100; i2++) {
            Iterator<IDSorter> it = sortedWords.get(i2).iterator();
            Formatter formatter2 = new Formatter(new StringBuilder(), Locale.US);
            formatter2.format("%d\t%.3f\t", Integer.valueOf(i2), Double.valueOf(topicProbabilities[i2]));
            for (int i3 = 0; it.hasNext() && i3 < 5; i3++) {
                IDSorter next = it.next();
                formatter2.format("%s (%.0f) ", dataAlphabet.lookupObject(next.getID()), Double.valueOf(next.getWeight()));
            }
            System.out.println(formatter2);
        }
        StringBuilder sb = new StringBuilder();
        Iterator<IDSorter> it2 = sortedWords.get(0).iterator();
        for (int i4 = 0; it2.hasNext() && i4 < 5; i4++) {
            sb.append(dataAlphabet.lookupObject(it2.next().getID()) + StringUtils.SPACE);
        }
        InstanceList instanceList2 = new InstanceList(instanceList.getPipe());
        instanceList2.addThruPipe(new Instance(sb.toString(), null, "test instance", null));
        System.out.println("0\t" + parallelTopicModel.getInferencer().getSampledDistribution(instanceList2.get(0), 10, 1, 5)[0]);
    }
}
