package de.dfki.km.perspecting.obie.model;

import cc.mallet.classify.Classification;
import cc.mallet.classify.MaxEnt;
import cc.mallet.classify.MaxEntTrainer;
import cc.mallet.classify.Trial;
import cc.mallet.classify.evaluate.ConfusionMatrix;
import cc.mallet.pipe.CharSequence2TokenSequence;
import cc.mallet.pipe.FeatureCountPipe;
import cc.mallet.pipe.FeatureSequence2FeatureVector;
import cc.mallet.pipe.Pipe;
import cc.mallet.pipe.SerialPipes;
import cc.mallet.pipe.Target2Label;
import cc.mallet.pipe.TokenSequence2FeatureSequence;
import cc.mallet.pipe.TokenSequenceLowercase;
import cc.mallet.pipe.TokenSequenceRemoveStopwords;
import cc.mallet.pipe.iterator.CsvIterator;
import cc.mallet.types.Alphabet;
import cc.mallet.types.InstanceList;
import cc.mallet.types.LabelAlphabet;
import cc.mallet.types.RankedFeatureVector;
import java.io.File;
import java.io.PrintWriter;
import java.io.Reader;
import java.io.StringReader;
import java.io.StringWriter;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.logging.Logger;
import java.util.regex.Pattern;

/* loaded from: input_file:de/dfki/km/perspecting/obie/model/EntityClassifier.class */
public class EntityClassifier {
    private static final Logger log = Logger.getLogger(EntityClassifier.class.getName());
    private MaxEnt classifier;

    public EntityClassifier(MaxEnt maxEnt) {
        this.classifier = maxEnt;
    }

    public void generateStoppwordLists(Reader reader, File file, File file2, int i, int i2) throws Exception {
        InstanceList instanceList = new InstanceList(buildFeatureSequencePipe());
        instanceList.addThruPipe(new CsvIterator(reader, Pattern.compile("^(\\S*)[\\s,]*(\\S*)[\\s,]*(.*)$"), 3, 2, 1));
        FeatureCountPipe featureCountPipe = new FeatureCountPipe(instanceList.getDataAlphabet(), instanceList.getTargetAlphabet());
        new InstanceList(featureCountPipe).addThruPipe(instanceList.iterator());
        featureCountPipe.writeCommonWords(file, i);
        featureCountPipe.writePrunedWords(file2, i2);
        reader.close();
    }

    private InstanceList createInstanceList(Reader reader, File file, File file2, Pipe pipe) throws Exception {
        InstanceList instanceList = new InstanceList(pipe);
        instanceList.addThruPipe(new CsvIterator(reader, Pattern.compile("^(\\S*)[\\s,]*(\\S*)[\\s,]*(.*)$"), 3, 2, 1));
        System.out.println("Read " + instanceList.size() + " instances");
        System.out.println("Feature Count: " + instanceList.getDataAlphabet().size());
        return instanceList;
    }

    public List<Classification> test(TokenSequence<String> tokenSequence, TokenSequence<Integer> tokenSequence2) throws Exception {
        StringWriter stringWriter = new StringWriter();
        System.out.println(this.classifier.getLabelAlphabet().lookupLabel(0).toString());
        return this.classifier.classify(createInstanceList(new StringReader(stringWriter.toString()), null, null, this.classifier.getInstancePipe()));
    }

    public MaxEnt train(Reader reader, File file, File file2) throws Exception {
        this.classifier = new MaxEntTrainer().train(createInstanceList(reader, file, file2, buildPipe(file, file2)));
        reader.close();
        return this.classifier;
    }

    public Map<String, Double[]> evaluate(Reader reader, File file, File file2, Reader reader2, boolean z, double d) throws Exception {
        Pipe buildPipe = buildPipe(file, file2);
        InstanceList createInstanceList = createInstanceList(reader, file, file2, buildPipe);
        InstanceList instanceList = null;
        if (reader2 != null) {
            instanceList = createInstanceList(reader2, null, null, buildPipe);
            reader2.close();
        }
        HashMap<String, Double[]> hashMap = new HashMap<>();
        if (instanceList != null) {
            this.classifier = evaluateWithTestData(createInstanceList, instanceList, hashMap, d);
        } else if (z) {
            this.classifier = evaluateCrossValidation(createInstanceList, hashMap);
        } else {
            this.classifier = evaluateWithTestData(createInstanceList, createInstanceList, hashMap, d);
        }
        return hashMap;
    }

    private MaxEnt evaluateWithTestData(InstanceList instanceList, InstanceList instanceList2, HashMap<String, Double[]> hashMap, double d) {
        InstanceList instanceList3 = instanceList.split(new double[]{d, 1.0d - d})[0];
        System.out.println("Read " + instanceList3.size() + " instances");
        System.out.println("Feature Count: " + instanceList3.getDataAlphabet().size());
        this.classifier = new MaxEntTrainer().train(instanceList3);
        System.out.println(new ConfusionMatrix(new Trial(this.classifier, instanceList2)).toString());
        for (Object obj : this.classifier.getLabelAlphabet().toArray()) {
            Double[] dArr = hashMap.get(obj);
            if (dArr == null) {
                hashMap.put(obj.toString(), dArr);
                dArr = new Double[]{Double.valueOf(0.0d), Double.valueOf(0.0d), Double.valueOf(0.0d), Double.valueOf(0.0d)};
            }
            Double[] dArr2 = dArr;
            dArr2[0] = Double.valueOf(dArr2[0].doubleValue() + this.classifier.getPrecision(instanceList2, obj));
            Double[] dArr3 = dArr;
            dArr3[1] = Double.valueOf(dArr3[1].doubleValue() + this.classifier.getRecall(instanceList2, obj));
            Double[] dArr4 = dArr;
            dArr4[2] = Double.valueOf(dArr4[2].doubleValue() + this.classifier.getF1(instanceList2, obj));
            Double[] dArr5 = dArr;
            dArr5[3] = Double.valueOf(dArr5[3].doubleValue() + this.classifier.getAccuracy(instanceList2));
        }
        return this.classifier;
    }

    private MaxEnt evaluateCrossValidation(InstanceList instanceList, HashMap<String, Double[]> hashMap) {
        InstanceList.CrossValidationIterator crossValidationIterator = instanceList.crossValidationIterator(10);
        while (crossValidationIterator.hasNext()) {
            InstanceList[] next = crossValidationIterator.next();
            System.out.println("Cross Training on " + next[0].size() + " instances");
            this.classifier = new MaxEntTrainer().train(next[0]);
            System.out.println("Cross Testing on " + next[1].size() + " instances");
            for (Object obj : this.classifier.getLabelAlphabet().toArray()) {
                Double[] dArr = hashMap.get(obj);
                if (dArr == null) {
                    dArr = new Double[4];
                    hashMap.put(obj.toString(), dArr);
                    dArr[0] = Double.valueOf(0.0d);
                    dArr[1] = Double.valueOf(0.0d);
                    dArr[2] = Double.valueOf(0.0d);
                }
                Double[] dArr2 = dArr;
                dArr2[0] = Double.valueOf(dArr2[0].doubleValue() + (this.classifier.getPrecision(next[1], obj) / 10));
                Double[] dArr3 = dArr;
                dArr3[1] = Double.valueOf(dArr3[1].doubleValue() + (this.classifier.getRecall(next[1], obj) / 10));
                Double[] dArr4 = dArr;
                dArr4[2] = Double.valueOf(dArr4[2].doubleValue() + (this.classifier.getF1(next[1], obj) / 10));
            }
        }
        return this.classifier;
    }

    public void printRank(PrintWriter printWriter, int i) {
        Alphabet alphabet = this.classifier.getAlphabet();
        LabelAlphabet labelAlphabet = this.classifier.getLabelAlphabet();
        int size = alphabet.size() + 1;
        int size2 = labelAlphabet.size();
        double[] dArr = new double[size - 1];
        for (int i2 = 0; i2 < size2; i2++) {
            printWriter.println();
            printWriter.println("FEATURES FOR CLASS " + labelAlphabet.lookupObject(i2) + " ");
            for (int i3 = 0; i3 < this.classifier.getDefaultFeatureIndex(); i3++) {
                dArr[i3] = this.classifier.getParameters()[(i2 * size) + i3];
            }
            printTopK(new RankedFeatureVector(alphabet, dArr), printWriter, i);
        }
        printWriter.println();
        printWriter.flush();
    }

    private void printTopK(RankedFeatureVector rankedFeatureVector, PrintWriter printWriter, int i) {
        int numLocations = rankedFeatureVector.numLocations();
        if (i > numLocations) {
            i = numLocations;
        }
        for (int i2 = 0; i2 < i; i2++) {
            printWriter.println(rankedFeatureVector.getAlphabet().lookupObject(rankedFeatureVector.getIndexAtRank(i2)) + "\t" + String.format("%f", Double.valueOf(rankedFeatureVector.getValueAtRank(i2))));
        }
    }

    private Pipe buildFeatureSequencePipe() {
        ArrayList arrayList = new ArrayList();
        arrayList.add(new CharSequence2TokenSequence("[\\S]+"));
        arrayList.add(new TokenSequenceLowercase());
        arrayList.add(new Target2Label());
        arrayList.add(new TokenSequence2FeatureSequence());
        return new SerialPipes(arrayList);
    }

    private Pipe buildPipe(File file, File file2) {
        ArrayList arrayList = new ArrayList();
        arrayList.add(new CharSequence2TokenSequence("[\\S]+"));
        arrayList.add(new TokenSequenceLowercase());
        arrayList.add(new Target2Label());
        if (file2 != null) {
            arrayList.add(new TokenSequenceRemoveStopwords(file2, "utf-8", false, true, false));
        }
        if (file != null) {
            arrayList.add(new TokenSequenceRemoveStopwords(file, "utf-8", false, true, false));
        }
        arrayList.add(new TokenSequence2FeatureSequence());
        arrayList.add(new FeatureSequence2FeatureVector());
        return new SerialPipes(arrayList);
    }

    public MaxEnt getClassifier() {
        return this.classifier;
    }
}
