package de.dfki.km.perspecting.obie.model;

import cc.mallet.fst.CRF;
import cc.mallet.fst.CRFTrainerByLabelLikelihood;
import cc.mallet.fst.MultiSegmentationEvaluator;
import cc.mallet.fst.TransducerTrainer;
import cc.mallet.pipe.Pipe;
import cc.mallet.pipe.SerialPipes;
import cc.mallet.pipe.TokenSequence2FeatureVectorSequence;
import cc.mallet.pipe.iterator.LineGroupIterator;
import cc.mallet.pipe.tsf.FeaturesInWindow;
import cc.mallet.pipe.tsf.TokenText;
import cc.mallet.types.FeatureVectorSequence;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.LabelSequence;
import cc.mallet.types.Sequence;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.FileReader;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.Reader;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import java.util.regex.Pattern;
import org.apache.log4j.Logger;

/* loaded from: input_file:de/dfki/km/perspecting/obie/model/NounPhraseChunker.class */
public class NounPhraseChunker {
    private static Logger log = Logger.getLogger(NounPhraseChunker.class);
    private String crfFile;
    private Pipe p;
    private InstanceList testData;
    private String testFile;
    private String trainFile;
    private InstanceList trainingData;
    private int windowSize;
    private CRF crf;

    public NounPhraseChunker(String str, String str2, String str3, int i) {
        this.crfFile = "";
        this.p = null;
        this.testData = null;
        this.testFile = "";
        this.trainFile = "";
        this.trainingData = null;
        this.windowSize = 3;
        this.trainFile = str;
        this.testFile = str2;
        this.crfFile = str3 != null ? str3 : str + "-" + Integer.toString(i) + ".crf";
        this.windowSize = i;
    }

    public NounPhraseChunker(String str) throws Exception {
        this.crfFile = "";
        this.p = null;
        this.testData = null;
        this.testFile = "";
        this.trainFile = "";
        this.trainingData = null;
        this.windowSize = 3;
        this.crfFile = str != null ? str : this.trainFile + "-" + Integer.toString(this.windowSize) + ".crf";
        ObjectInputStream objectInputStream = new ObjectInputStream(new FileInputStream(str));
        this.crf = (CRF) objectInputStream.readObject();
        objectInputStream.close();
        this.p = this.crf.getInputPipe();
    }

    public String getCrfFile() {
        return this.crfFile;
    }

    public InstanceList getTestData() {
        return this.testData;
    }

    public String getTestFile() {
        return this.testFile;
    }

    public String getTrainFile() {
        return this.trainFile;
    }

    public InstanceList getTrainingData() {
        return this.trainingData;
    }

    public int getWindowSize() {
        return this.windowSize;
    }

    public void setCrfFile(String str) {
        this.crfFile = str;
    }

    public void setTestFile(String str) {
        this.testFile = str;
    }

    public void setTrainFile(String str) {
        this.trainFile = str;
    }

    public void setWindowSize(int i) {
        this.windowSize = i;
    }

    public void test() throws Exception {
        ObjectInputStream objectInputStream = new ObjectInputStream(new FileInputStream(this.crfFile));
        CRF crf = (CRF) objectInputStream.readObject();
        objectInputStream.close();
        if (this.trainingData != null) {
            log.info("test on training data");
            testCrf(crf, this.trainingData);
        }
        if (this.testData != null) {
            log.info("test on test data");
            testCrf(crf, this.testData);
        }
    }

    public List<String> test(Reader reader) throws Exception {
        return testCrf(getData(reader));
    }

    public void train() throws Exception {
        log.info("Instance Data (train): " + this.trainingData.size());
        log.info("Instance Data (test): " + this.testData.size());
        this.p.getTargetAlphabet();
        CRF crf = new CRF(this.p, (Pipe) null);
        crf.addStatesForLabelsConnectedAsIn(this.trainingData);
        CRFTrainerByLabelLikelihood cRFTrainerByLabelLikelihood = new CRFTrainerByLabelLikelihood(crf);
        MultiSegmentationEvaluator multiSegmentationEvaluator = new MultiSegmentationEvaluator(new InstanceList[]{this.trainingData, this.testData}, new String[]{"Training", "Testing"}, this.p.getTargetAlphabet().toArray(), this.p.getTargetAlphabet().toArray()) { // from class: de.dfki.km.perspecting.obie.model.NounPhraseChunker.1
            public boolean precondition(TransducerTrainer transducerTrainer) {
                return transducerTrainer.getIteration() % 20 == 0;
            }
        };
        cRFTrainerByLabelLikelihood.addEvaluator(multiSegmentationEvaluator);
        cRFTrainerByLabelLikelihood.train(this.trainingData, Integer.MAX_VALUE);
        multiSegmentationEvaluator.evaluate(cRFTrainerByLabelLikelihood);
        ObjectOutputStream objectOutputStream = new ObjectOutputStream(new FileOutputStream(this.crfFile));
        objectOutputStream.writeObject(crf);
        objectOutputStream.close();
    }

    private Pipe buildPipe() {
        ArrayList arrayList = new ArrayList();
        arrayList.add(new NounPhraseChunkerPipe());
        arrayList.add(new TokenText("W="));
        arrayList.add(new FeaturesInWindow("WINDOW=", -this.windowSize, this.windowSize, Pattern.compile("T=.*"), true));
        arrayList.add(new FeaturesInWindow("WINDOW=", -this.windowSize, this.windowSize, Pattern.compile("W=.*"), true));
        arrayList.add(new TokenSequence2FeatureVectorSequence(true, true));
        return new SerialPipes(arrayList);
    }

    private InstanceList getData(String str) throws Exception {
        InstanceList instanceList = new InstanceList(this.p);
        instanceList.addThruPipe(new LineGroupIterator(new FileReader(new File(str)), Pattern.compile("^\\s*$"), true));
        return instanceList;
    }

    private InstanceList getData(Reader reader) throws Exception {
        InstanceList instanceList = new InstanceList(this.p);
        instanceList.addThruPipe(new LineGroupIterator(reader, Pattern.compile("^\\s*$"), true));
        return instanceList;
    }

    public void init() throws Exception {
        this.p = buildPipe();
        if (!this.trainFile.equals(this.testFile)) {
            this.trainingData = getData(this.trainFile);
            this.testData = getData(this.testFile);
        } else {
            this.trainingData = getData(this.trainFile);
            InstanceList[] split = this.trainingData.split(new Random(), new double[]{0.8d, 0.2d, 0.0d});
            this.trainingData = split[0];
            this.testData = split[1];
        }
    }

    private void testCrf(CRF crf, InstanceList instanceList) {
        int i = 0;
        int i2 = 0;
        Iterator it = instanceList.iterator();
        while (it.hasNext()) {
            Instance instance = (Instance) it.next();
            Sequence transduce = crf.transduce((FeatureVectorSequence) instance.getData());
            LabelSequence labelSequence = (LabelSequence) instance.getTarget();
            i2 += transduce.size();
            for (int i3 = 0; i3 < transduce.size(); i3++) {
                if (labelSequence.getLabelAtPosition(i3).getEntry().equals(transduce.get(i3))) {
                    i++;
                }
            }
        }
        log.info("accuracy (correct / total): " + i + " / " + i2 + " = " + (i / i2));
    }

    private List<String> testCrf(InstanceList instanceList) {
        ArrayList arrayList = new ArrayList();
        Iterator it = instanceList.iterator();
        while (it.hasNext()) {
            Sequence transduce = this.crf.transduce((FeatureVectorSequence) ((Instance) it.next()).getData());
            for (int i = 0; i < transduce.size(); i++) {
                arrayList.add(transduce.get(i));
            }
        }
        return arrayList;
    }
}
