package de.dfki.km.perspecting.obie.crf.apps;

import cc.mallet.fst.CRF;
import cc.mallet.fst.CRFCacheStaleIndicator;
import cc.mallet.fst.CRFOptimizableByBatchLabelLikelihood;
import cc.mallet.fst.CRFTrainerByValueGradients;
import cc.mallet.fst.MaxLatticeDefault;
import cc.mallet.fst.MultiSegmentationEvaluator;
import cc.mallet.fst.ThreadedOptimizable;
import cc.mallet.fst.Transducer;
import cc.mallet.fst.TransducerTrainer;
import cc.mallet.optimize.Optimizable;
import cc.mallet.pipe.Pipe;
import cc.mallet.pipe.iterator.LineGroupIterator;
import cc.mallet.types.Alphabet;
import cc.mallet.types.FeatureVector;
import cc.mallet.types.FeatureVectorSequence;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.LabelAlphabet;
import cc.mallet.types.LabelSequence;
import cc.mallet.types.Sequence;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.FileReader;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.util.Arrays;
import java.util.Random;
import java.util.logging.Logger;
import java.util.regex.Pattern;

/* loaded from: input_file:de/dfki/km/perspecting/obie/crf/apps/SCOOBIE2CRF.class */
public class SCOOBIE2CRF {
    private final Logger log = Logger.getLogger(SCOOBIE2CRF.class.getName());
    private final String path;

    /* loaded from: input_file:de/dfki/km/perspecting/obie/crf/apps/SCOOBIE2CRF$SimpleTaggerSentence2FeatureVectorSequence.class */
    public static class SimpleTaggerSentence2FeatureVectorSequence extends Pipe {
        private static final long serialVersionUID = 1;

        public SimpleTaggerSentence2FeatureVectorSequence() {
            super(new Alphabet(), new LabelAlphabet());
        }

        /* JADX WARN: Type inference failed for: r0v4, types: [java.lang.String[], java.lang.String[][]] */
        private String[][] parseSentence(String str) {
            String[] split = str.split("\n");
            ?? r0 = new String[split.length];
            for (int i = 0; i < split.length; i++) {
                r0[i] = split[i].split(" ");
            }
            return r0;
        }

        public Instance pipe(Instance instance) {
            String[][] strArr;
            int length;
            Object data = instance.getData();
            Alphabet dataAlphabet = getDataAlphabet();
            if (data instanceof String) {
                strArr = parseSentence((String) data);
            } else {
                if (!(data instanceof String[][])) {
                    throw new IllegalArgumentException("Not a String or String[][]; got " + data);
                }
                strArr = (String[][]) data;
            }
            FeatureVector[] featureVectorArr = new FeatureVector[strArr.length];
            LabelSequence labelSequence = isTargetProcessing() ? new LabelSequence(getTargetAlphabet(), strArr.length) : null;
            for (int i = 0; i < strArr.length; i++) {
                if (!isTargetProcessing()) {
                    length = strArr[i].length;
                } else {
                    if (strArr[i].length < 1) {
                        throw new IllegalStateException("Missing label at line " + i + " instance " + instance.getName());
                    }
                    length = strArr[i].length - 1;
                    labelSequence.add(strArr[i][length]);
                }
                int[] iArr = new int[length];
                for (int i2 = 0; i2 < length; i2++) {
                    iArr[i2] = dataAlphabet.lookupIndex(strArr[i][i2]);
                }
                featureVectorArr[i] = new FeatureVector(dataAlphabet, iArr);
            }
            instance.setData(new FeatureVectorSequence(featureVectorArr));
            if (isTargetProcessing()) {
                instance.setTarget(labelSequence);
            } else {
                instance.setTarget(new LabelSequence(getTargetAlphabet()));
            }
            return instance;
        }
    }

    public SCOOBIE2CRF(String str) {
        this.path = str;
    }

    private void train(InstanceList instanceList, InstanceList instanceList2) throws Exception {
        CRF crf = new CRF(instanceList.getPipe(), (Pipe) null);
        crf.addFullyConnectedStatesForLabels();
        crf.setWeightsDimensionAsIn(instanceList, false);
        this.log.info("Training " + this.path + " on " + instanceList.size() + " instances");
        this.log.info("Testing " + this.path + " on " + instanceList2.size() + " instances");
        Optimizable.ByGradientValue threadedOptimizable = new ThreadedOptimizable(new CRFOptimizableByBatchLabelLikelihood(crf, instanceList, 32), instanceList, crf.getParameters().getNumFactors(), new CRFCacheStaleIndicator(crf));
        threadedOptimizable.shutdown();
        CRFTrainerByValueGradients cRFTrainerByValueGradients = new CRFTrainerByValueGradients(crf, new Optimizable.ByGradientValue[]{threadedOptimizable});
        Object[] array = instanceList.getTargetAlphabet().toArray();
        this.log.info("Labels to be predicted by " + this.path + ":  " + Arrays.toString(array));
        MultiSegmentationEvaluator multiSegmentationEvaluator = new MultiSegmentationEvaluator(new InstanceList[]{instanceList, instanceList2}, new String[]{"train", "test"}, array, array) { // from class: de.dfki.km.perspecting.obie.crf.apps.SCOOBIE2CRF.1
            public boolean precondition(TransducerTrainer transducerTrainer) {
                return transducerTrainer.getIteration() % 5 == 0 || transducerTrainer.isFinishedTraining();
            }
        };
        cRFTrainerByValueGradients.addEvaluator(multiSegmentationEvaluator);
        cRFTrainerByValueGradients.setMaxResets(0);
        cRFTrainerByValueGradients.train(instanceList, Integer.MAX_VALUE);
        multiSegmentationEvaluator.evaluate(cRFTrainerByValueGradients);
        new ObjectOutputStream(new FileOutputStream(this.path)).writeObject(crf);
    }

    public void test(File file) throws Exception {
        File file2 = new File(this.path);
        if (file2.exists()) {
            ObjectInputStream objectInputStream = new ObjectInputStream(new FileInputStream(file2));
            CRF crf = (CRF) objectInputStream.readObject();
            objectInputStream.close();
            Pipe inputPipe = crf.getInputPipe();
            inputPipe.setTargetProcessing(false);
            InstanceList instanceList = new InstanceList(inputPipe);
            BufferedReader bufferedReader = new BufferedReader(new FileReader(file));
            instanceList.addThruPipe(new LineGroupIterator(bufferedReader, Pattern.compile("^\\s*$"), true));
            bufferedReader.close();
            this.log.info("Testing with" + Arrays.toString(instanceList.getDataAlphabet().toArray()) + " tokens.");
            this.log.info("Testing " + instanceList.size() + " instances.");
            for (int i = 0; i < instanceList.size(); i++) {
                Sequence sequence = (Sequence) ((Instance) instanceList.get(i)).getData();
                Sequence sequence2 = apply(crf, sequence, 1)[0];
                for (int i2 = 0; i2 < sequence2.size(); i2++) {
                    System.out.println(String.valueOf(((FeatureVector) sequence.get(i2)).toString(true)) + " " + ((String) sequence2.get(i2)));
                }
            }
        }
    }

    public Sequence[] apply(Transducer transducer, Sequence sequence, int i) {
        return i == 1 ? new Sequence[]{transducer.transduce(sequence)} : (Sequence[]) new MaxLatticeDefault(transducer, sequence).bestOutputSequences(i).toArray(new Sequence[0]);
    }

    public void train(File file) throws Exception {
        SimpleTaggerSentence2FeatureVectorSequence simpleTaggerSentence2FeatureVectorSequence = new SimpleTaggerSentence2FeatureVectorSequence();
        simpleTaggerSentence2FeatureVectorSequence.setTargetProcessing(true);
        InstanceList instanceList = new InstanceList(simpleTaggerSentence2FeatureVectorSequence);
        this.log.info("Training with : " + file.getName());
        FileReader fileReader = new FileReader(file);
        instanceList.addThruPipe(new LineGroupIterator(fileReader, Pattern.compile("^\\s*$"), true));
        fileReader.close();
        InstanceList[] split = instanceList.split(new Random(), new double[]{0.9d, 0.1d});
        InstanceList instanceList2 = split[0];
        InstanceList instanceList3 = split[1];
        this.log.info("Number of features in training data: " + simpleTaggerSentence2FeatureVectorSequence.getDataAlphabet().size());
        try {
            train(instanceList2, instanceList3);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public static void main(String[] strArr) throws Exception {
        if (strArr.length != 3) {
            System.out.println("usage <modelPath> <train|test> <filepath>");
            return;
        }
        String str = strArr[0];
        String str2 = strArr[1];
        String str3 = strArr[2];
        SCOOBIE2CRF scoobie2crf = new SCOOBIE2CRF(str);
        if (str2.equals("train")) {
            scoobie2crf.train(new File(str3));
        } else {
            scoobie2crf.test(new File(str3));
        }
    }
}
