/*
 * Decompiled with CFR 0.152.
 */
package de.dfki.sds.kecs.ml;

import de.dfki.sds.hephaistos.storage.assertion.Assertion;
import de.dfki.sds.hephaistos.storage.assertion.AssertionPool;
import de.dfki.sds.hephaistos.storage.file.FileInfoStorage;
import de.dfki.sds.kecs.util.ExceptionUtility;
import de.dfki.sds.kecs.util.Prediction;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.function.Consumer;
import java.util.function.Function;
import weka.classifiers.AbstractClassifier;
import weka.classifiers.Evaluation;
import weka.classifiers.bayes.NaiveBayes;
import weka.classifiers.functions.Logistic;
import weka.classifiers.functions.SMO;
import weka.classifiers.lazy.IBk;
import weka.classifiers.trees.J48;
import weka.classifiers.trees.RandomForest;
import weka.core.Attribute;
import weka.core.DenseInstance;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.UnsupportedAttributeTypeException;

public class PredictionManager {
    private AbstractClassifier classifier;
    private List<String> classLabels;
    private Instances dataset;
    private double confidenceThreshold = 0.1;
    private Consumer<List<Attribute>> schemaDefinition;
    private Consumer<Context> featureDefinition;
    private Function<Assertion, String> classProvider;
    @Deprecated
    private Consumer<Context> predictionConsumer;
    private Consumer<Context> predictionsConsumer;
    private Consumer<Context> trainSetProvider;
    private Consumer<Context> testSetProvider;
    public static final String AVOID_UNARY_CLASS = "urn:ml:avoidUnaryClass";
    private boolean printDataset;
    private boolean printEvaluation;
    private boolean printClassifier;

    public void svm() {
        this.classifier = new SMO();
        this.dataset = null;
    }

    public void knn(int k) {
        this.classifier = new IBk();
        try {
            this.classifier.setOptions(new String[]{"-K", String.valueOf(k)});
        }
        catch (Exception ex) {
            throw new RuntimeException(ex);
        }
        this.dataset = null;
    }

    public void j48() {
        this.classifier = new J48();
        this.dataset = null;
    }

    public void randomForest() {
        this.classifier = new RandomForest();
        this.dataset = null;
    }

    public void naiveBayes() {
        this.classifier = new NaiveBayes();
        this.dataset = null;
    }

    public void logistic() {
        this.classifier = new Logistic();
        this.dataset = null;
    }

    public void train(FileInfoStorage fileInfoStorage, AssertionPool pool) {
        this.dataset = null;
        this.classLabels = null;
        Context ctx = new Context();
        ctx.fileInfoStorage = fileInfoStorage;
        ctx.assertionPool = pool;
        ctx.assertionSet = new HashSet();
        this.trainSetProvider.accept(ctx);
        HashSet<String> classes = new HashSet<String>();
        for (Object assertion : ctx.assertionSet) {
            classes.add(this.classProvider.apply((Assertion)assertion));
        }
        if (classes.size() == 1) {
            classes.add(AVOID_UNARY_CLASS);
        }
        this.classLabels = new ArrayList<String>(classes);
        this.classLabels.sort((a, b) -> a.compareTo((String)b));
        ArrayList<Attribute> attrList = new ArrayList<Attribute>();
        this.schemaDefinition.accept(attrList);
        attrList.add(new Attribute("class", this.classLabels));
        this.dataset = new Instances("tmp", attrList, 0);
        this.dataset.setClassIndex(this.dataset.numAttributes() - 1);
        for (Assertion assertion : ctx.assertionSet) {
            Instance inst = this.toInstance(assertion, this.classProvider.apply(assertion), fileInfoStorage, pool);
            this.dataset.add(inst);
        }
        if (this.dataset.isEmpty()) {
            this.dataset = null;
            return;
        }
        if (classes.size() == 1) {
            DenseInstance inst = new DenseInstance(this.dataset.numAttributes());
            inst.setDataset(this.dataset);
            inst.setValue(this.dataset.numAttributes() - 1, AVOID_UNARY_CLASS);
            this.dataset.add(inst);
        }
        try {
            this.classifier.buildClassifier(this.dataset);
        }
        catch (UnsupportedAttributeTypeException ex) {
            ExceptionUtility.save(ex);
            System.out.println("[Prediction Manager Warning] " + ex.getMessage());
            System.out.println(this.dataset);
            this.dataset = null;
            return;
        }
        catch (Exception ex) {
            ExceptionUtility.save(ex);
            throw new RuntimeException(ex);
        }
        if (this.printDataset) {
            System.out.println(this.dataset);
        }
        if (this.printClassifier) {
            System.out.println(this.classifier);
        }
        if (this.printEvaluation) {
            try {
                Evaluation eval2 = new Evaluation(this.dataset);
                eval2.evaluateModel(this.classifier, this.dataset, new Object[0]);
                System.out.println(eval2.toSummaryString());
            }
            catch (Exception ex) {
                ExceptionUtility.save(ex);
                throw new RuntimeException(ex);
            }
        }
    }

    public void predict(FileInfoStorage fileInfoStorage, AssertionPool pool) {
        if (!this.classifierIsTrained()) {
            return;
        }
        Context ctx = new Context();
        ctx.fileInfoStorage = fileInfoStorage;
        ctx.assertionPool = pool;
        ctx.assertionSet = new HashSet();
        this.testSetProvider.accept(ctx);
        ArrayList<Prediction> predictions = new ArrayList<Prediction>();
        for (Assertion assertion : ctx.assertionSet) {
            Instance inst = this.toInstance(assertion, null, fileInfoStorage, pool);
            try {
                this.classifier.classifyInstance(inst);
                double[] dArray = this.classifier.distributionForInstance(inst);
                for (int i = 0; i < dArray.length; ++i) {
                    double conf = dArray[i];
                    String classLbl = this.classLabels.get(i);
                    if (classLbl.equals(AVOID_UNARY_CLASS)) {
                        conf = 0.0;
                    }
                    Prediction prediction = new Prediction(classLbl, conf);
                    prediction.setAssertion(assertion);
                    prediction.setInstance(inst);
                    predictions.add(prediction);
                }
            }
            catch (Exception ex) {
                ExceptionUtility.save(ex);
                throw new RuntimeException(ex);
            }
        }
        predictions.removeIf(p -> p.getConfidence() < this.confidenceThreshold);
        predictions.sort((a, b) -> Double.compare(b.getConfidence(), a.getConfidence()));
        Context predCtx = new Context();
        predCtx.fileInfoStorage = fileInfoStorage;
        predCtx.assertionPool = pool;
        predCtx.predictions = predictions;
        this.predictionsConsumer.accept(predCtx);
        pool.commit();
    }

    public void trainAndPredict(FileInfoStorage fileInfoStorage, AssertionPool pool) {
        this.train(fileInfoStorage, pool);
        this.predict(fileInfoStorage, pool);
    }

    private Instance toInstance(Assertion assertion, String classLabel, FileInfoStorage fileInfoStorage, AssertionPool pool) {
        DenseInstance inst = new DenseInstance(this.dataset.numAttributes());
        inst.setDataset(this.dataset);
        Context ctx = new Context();
        ctx.assertion = assertion;
        ctx.instance = inst;
        ctx.fileInfoStorage = fileInfoStorage;
        ctx.assertionPool = pool;
        this.featureDefinition.accept(ctx);
        if (classLabel != null) {
            inst.setValue(this.dataset.numAttributes() - 1, classLabel);
        }
        return inst;
    }

    private boolean classifierIsTrained() {
        return this.dataset != null;
    }

    public double getConfidenceThreshold() {
        return this.confidenceThreshold;
    }

    public void setConfidenceThreshold(double confidenceThreshold) {
        this.confidenceThreshold = confidenceThreshold;
    }

    public Consumer<List<Attribute>> getSchemaDefinition() {
        return this.schemaDefinition;
    }

    public void setSchemaDefinition(Consumer<List<Attribute>> schemaDefinition) {
        this.schemaDefinition = schemaDefinition;
    }

    public Consumer<Context> getFeatureDefinition() {
        return this.featureDefinition;
    }

    public void setFeatureDefinition(Consumer<Context> featureDefinition) {
        this.featureDefinition = featureDefinition;
    }

    public Consumer<Context> getTrainSetProvider() {
        return this.trainSetProvider;
    }

    public void setTrainSetProvider(Consumer<Context> trainSetProvider) {
        this.trainSetProvider = trainSetProvider;
    }

    public Consumer<Context> getTestSetProvider() {
        return this.testSetProvider;
    }

    public void setTestSetProvider(Consumer<Context> testSetProvider) {
        this.testSetProvider = testSetProvider;
    }

    public Function<Assertion, String> getClassProvider() {
        return this.classProvider;
    }

    public void setClassProvider(Function<Assertion, String> classProvider) {
        this.classProvider = classProvider;
    }

    @Deprecated
    public Consumer<Context> getPredictionConsumer() {
        return this.predictionConsumer;
    }

    @Deprecated
    public void setPredictionConsumer(Consumer<Context> predictionConsumer) {
        this.predictionConsumer = predictionConsumer;
    }

    public Consumer<Context> getPredictionsConsumer() {
        return this.predictionsConsumer;
    }

    public void setPredictionsConsumer(Consumer<Context> predictionsConsumer) {
        this.predictionsConsumer = predictionsConsumer;
    }

    public boolean isPrintDataset() {
        return this.printDataset;
    }

    public void setPrintDataset(boolean printDataset) {
        this.printDataset = printDataset;
    }

    public boolean isPrintEvaluation() {
        return this.printEvaluation;
    }

    public void setPrintEvaluation(boolean printEvaluation) {
        this.printEvaluation = printEvaluation;
    }

    public boolean isPrintClassifier() {
        return this.printClassifier;
    }

    public void setPrintClassifier(boolean printClassifier) {
        this.printClassifier = printClassifier;
    }

    public class Context {
        private FileInfoStorage fileInfoStorage;
        private AssertionPool assertionPool;
        private Assertion assertion;
        private Instance instance;
        private Set<Assertion> assertionSet;
        @Deprecated
        private Prediction prediction;
        private List<Prediction> predictions;

        public Context() {
        }

        public Context(FileInfoStorage fileInfoStorage, AssertionPool assertionPool) {
            this.fileInfoStorage = fileInfoStorage;
            this.assertionPool = assertionPool;
        }

        public Context(FileInfoStorage fileInfoStorage, AssertionPool assertionPool, Assertion assertion, Instance instance) {
            this.fileInfoStorage = fileInfoStorage;
            this.assertionPool = assertionPool;
            this.assertion = assertion;
            this.instance = instance;
        }

        public FileInfoStorage getFileInfoStorage() {
            return this.fileInfoStorage;
        }

        public AssertionPool getAssertionPool() {
            return this.assertionPool;
        }

        public Assertion getAssertion() {
            return this.assertion;
        }

        public Instance getInstance() {
            return this.instance;
        }

        public Set<Assertion> getAssertionSet() {
            return this.assertionSet;
        }

        @Deprecated
        public Prediction getPrediction() {
            return this.prediction;
        }

        public List<Prediction> getPredictions() {
            return this.predictions;
        }
    }
}

