package dm.algorithms.yale;

import dm.data.DataObject;
import dm.data.DistanceMeasure;
import dm.data.RegressionDataGenerator;
import dm.data.database.Database;
import dm.data.database.SequDB;
import dm.data.featureVector.FeatureVector;
import dm.util.MathUtil;
import edu.udo.cs.yale.Yale;
import edu.udo.cs.yale.example.Attribute;
import edu.udo.cs.yale.example.AttributeFactory;
import edu.udo.cs.yale.example.AttributeWeights;
import edu.udo.cs.yale.example.DoubleArrayDataRow;
import edu.udo.cs.yale.example.Example;
import edu.udo.cs.yale.example.ExampleSet;
import edu.udo.cs.yale.example.MemoryExampleTable;
import edu.udo.cs.yale.operator.OperatorCreationException;
import edu.udo.cs.yale.operator.OperatorException;
import edu.udo.cs.yale.operator.learner.kernel.AbstractMySVMModel;
import edu.udo.cs.yale.operator.learner.kernel.JMySVMLearner;
import edu.udo.cs.yale.tools.OperatorService;
import java.io.IOException;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.Map;

/* loaded from: input_file:dm/algorithms/yale/MySVMRegression.class */
public class MySVMRegression {
    JMySVMLearner regressor;
    ExampleSet trainSet;
    AbstractMySVMModel model;
    public static boolean init = false;

    public static ExampleSet convertDatabase(Database database, Map<String, Double> map) {
        Iterator objectIterator = database.objectIterator();
        int length = objectIterator.hasNext() ? ((FeatureVector) objectIterator.next()).values.length : 0;
        LinkedList linkedList = new LinkedList();
        for (int i = 0; i < length; i++) {
            linkedList.add(AttributeFactory.createAttribute("att" + i, 4));
        }
        Attribute createAttribute = AttributeFactory.createAttribute("label", 4);
        linkedList.add(createAttribute);
        MemoryExampleTable memoryExampleTable = new MemoryExampleTable(linkedList);
        Iterator objectIterator2 = database.objectIterator();
        while (objectIterator2.hasNext()) {
            FeatureVector featureVector = (FeatureVector) objectIterator2.next();
            double[] dArr = new double[featureVector.values.length + 1];
            for (int i2 = 0; i2 < featureVector.values.length; i2++) {
                dArr[i2] = featureVector.values[i2];
            }
            dArr[dArr.length - 1] = map.get(featureVector.getPrimaryKey()).doubleValue();
            memoryExampleTable.addDataRow(new DoubleArrayDataRow(dArr));
        }
        return memoryExampleTable.createCompleteExampleSet(createAttribute, (Attribute) null, (Attribute) null, (Attribute) null);
    }

    public static ExampleSet convertDatabase(Database database) {
        Iterator objectIterator = database.objectIterator();
        int length = objectIterator.hasNext() ? ((FeatureVector) objectIterator.next()).values.length : 0;
        LinkedList linkedList = new LinkedList();
        for (int i = 0; i < length; i++) {
            linkedList.add(AttributeFactory.createAttribute("att" + i, 4));
        }
        MemoryExampleTable memoryExampleTable = new MemoryExampleTable(linkedList);
        Iterator objectIterator2 = database.objectIterator();
        while (objectIterator2.hasNext()) {
            memoryExampleTable.addDataRow(new DoubleArrayDataRow(((FeatureVector) objectIterator2.next()).values));
        }
        return memoryExampleTable.createCompleteExampleSet((Attribute) null, (Attribute) null, (Attribute) null, (Attribute) null);
    }

    public static ExampleSet convertDatabase(Database database, int i) {
        Iterator objectIterator = database.objectIterator();
        int length = objectIterator.hasNext() ? ((FeatureVector) objectIterator.next()).values.length : 0;
        LinkedList linkedList = new LinkedList();
        for (int i2 = 0; i2 < length - 1; i2++) {
            linkedList.add(AttributeFactory.createAttribute("att" + i2, 4));
        }
        Attribute createAttribute = AttributeFactory.createAttribute("label", 4);
        linkedList.add(createAttribute);
        MemoryExampleTable memoryExampleTable = new MemoryExampleTable(linkedList);
        Iterator objectIterator2 = database.objectIterator();
        while (objectIterator2.hasNext()) {
            memoryExampleTable.addDataRow(new DoubleArrayDataRow(((FeatureVector) objectIterator2.next()).values));
        }
        return memoryExampleTable.createCompleteExampleSet(createAttribute, (Attribute) null, (Attribute) null, (Attribute) null);
    }

    public MySVMRegression(Database database, Map<String, Double> map) {
        if (!init) {
            try {
                Yale.init();
                init = true;
            } catch (IOException e) {
                e.printStackTrace();
            }
        }
        this.trainSet = convertDatabase(database, map);
        try {
            this.regressor = OperatorService.createOperator("JMySVMLearner");
        } catch (OperatorCreationException e2) {
            System.err.println("JMySVMLearner was not accepted as Operator.");
            e2.printStackTrace();
        }
    }

    public MySVMRegression(Database database, int i) {
        if (!init) {
            try {
                Yale.init();
                init = true;
            } catch (IOException e) {
                e.printStackTrace();
            }
        }
        this.trainSet = convertDatabase(database, i);
        try {
            this.regressor = OperatorService.createOperator("JMySVMLearner");
        } catch (OperatorCreationException e2) {
            System.err.println("JMySVMLearner was not accepted as Operator.");
            e2.printStackTrace();
        }
    }

    public void setParameter(String str, String str2) {
        this.regressor.setParameter(str, str2);
    }

    public void trainModel() {
        try {
            this.model = this.regressor.learn(this.trainSet);
        } catch (Exception e) {
            System.err.println(" SVM konnte nichts lernen !");
            e.printStackTrace();
        }
    }

    public double getPrediction(DataObject dataObject) {
        SequDB sequDB = new SequDB((DistanceMeasure) null);
        sequDB.insert(dataObject);
        return getPrediction(sequDB).get(dataObject.getPrimaryKey()).doubleValue();
    }

    public double[] getWeights() {
        try {
            AttributeWeights weights = this.regressor.getWeights(this.trainSet);
            double[] dArr = new double[weights.getSize()];
            for (int i = 0; i < weights.getSize(); i++) {
                dArr[i] = weights.getWeight("att" + i);
            }
            return dArr;
        } catch (Exception e) {
            e.printStackTrace();
            return null;
        }
    }

    public Map<String, Double> getPrediction(Database database) {
        ExampleSet convertDatabase = convertDatabase(database);
        try {
            this.model.apply(convertDatabase);
        } catch (OperatorException e) {
            e.printStackTrace();
        }
        HashMap hashMap = new HashMap();
        Iterator it = convertDatabase.iterator();
        Iterator objectIterator = database.objectIterator();
        Attribute attribute = convertDatabase.getAttribute("prediction");
        while (it.hasNext()) {
            hashMap.put(((DataObject) objectIterator.next()).getPrimaryKey(), new Double(((Example) it.next()).getDataRow().get(attribute)));
        }
        return hashMap;
    }

    public static void main(String[] strArr) {
        double[] dArr = {-1.0d, 3.0d, 5.0d};
        Database database = new RegressionDataGenerator(dArr, 0.5d, 100).db;
        System.setProperty("yale.home", "C:\\Programme\\yale-3.4");
        Iterator objectIterator = database.objectIterator();
        MySVMRegression mySVMRegression = new MySVMRegression(database, (objectIterator.hasNext() ? ((FeatureVector) objectIterator.next()).values.length : 0) - 1);
        System.out.println("Start Training !!!!!!");
        mySVMRegression.setParameter("kernel_type", "squaredDot");
        mySVMRegression.setParameter("scale", "false");
        mySVMRegression.setParameter("epsilon", "0.1");
        mySVMRegression.trainModel();
        double[] weights = mySVMRegression.getWeights();
        MathUtil.quadnormalize(weights);
        for (int i = 0; i < weights.length; i++) {
            System.out.println("Weight " + i + " = " + weights[i] + " ( " + dArr[i] + " )");
        }
    }
}
