package de.lmu.ifi.dbs.dm.algorithms.libsvm.tools;

import de.lmu.ifi.dbs.dm.Kernel;
import de.lmu.ifi.dbs.dm.algorithms.libsvm.lib.ModelStorage;
import de.lmu.ifi.dbs.dm.algorithms.libsvm.lib.svm;
import de.lmu.ifi.dbs.dm.algorithms.libsvm.lib.svm_model;
import de.lmu.ifi.dbs.dm.algorithms.libsvm.lib.svm_parameter;
import de.lmu.ifi.dbs.dm.algorithms.libsvm.lib.svm_problem;
import de.lmu.ifi.dbs.dm.data.featureVector.FeatureVector;
import de.lmu.ifi.dbs.dm.database.Database;
import de.lmu.ifi.dbs.dm.database.SequDB;
import de.lmu.ifi.dbs.dm.kernels.LinearKernel;
import de.lmu.ifi.dbs.utilities.statistics.SummaryItem;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;

/* loaded from: input_file:de/lmu/ifi/dbs/dm/algorithms/libsvm/tools/SVR.class */
public class SVR<T extends FeatureVector> {
    private Database<T> db;
    private svm_model model;
    private svm_problem problem;
    private svm_problem km;
    private int problem_size;
    private svm_parameter param;
    private Map<String, Integer> key2index;
    private String[] primaryKeys;
    private Kernel<T> kernel;
    static final /* synthetic */ boolean $assertionsDisabled;

    static {
        $assertionsDisabled = !SVR.class.desiredAssertionStatus();
    }

    public SVR(Database<T> database, Kernel<T> kernel) {
        this.db = null;
        this.model = null;
        this.problem = null;
        this.km = null;
        this.problem_size = 0;
        this.param = null;
        this.key2index = null;
        this.primaryKeys = null;
        this.kernel = new LinearKernel();
        init_setup(database, kernel);
        SVM_Manip sVM_Manip = new SVM_Manip(this.kernel);
        this.km = sVM_Manip.createKernelMatrix(database, null, null, false, false);
        String[] primaryKeys = sVM_Manip.getPrimaryKeys();
        this.primaryKeys = primaryKeys;
        init_primaryKeyMapping(primaryKeys);
        this.problem_size = this.km.l;
    }

    public SVR(Database<T> database, Kernel<T> kernel, List<Integer> list, Collection<Integer> collection) {
        this.db = null;
        this.model = null;
        this.problem = null;
        this.km = null;
        this.problem_size = 0;
        this.param = null;
        this.key2index = null;
        this.primaryKeys = null;
        this.kernel = new LinearKernel();
        init_setup(database, kernel);
        SVM_Manip sVM_Manip = new SVM_Manip(this.kernel);
        this.km = sVM_Manip.createKernelMatrix(database, list, collection, false, false);
        String[] primaryKeys = sVM_Manip.getPrimaryKeys();
        this.primaryKeys = primaryKeys;
        init_primaryKeyMapping(primaryKeys);
        this.problem_size = this.km.l;
    }

    /* JADX WARN: Multi-variable type inference failed */
    public SVR(final Database<T> database, Kernel<T> kernel, List<Integer> list, Collection<Integer> collection, int i) {
        this.db = null;
        this.model = null;
        this.problem = null;
        this.km = null;
        this.problem_size = 0;
        this.param = null;
        this.key2index = null;
        this.primaryKeys = null;
        this.kernel = new LinearKernel();
        init_setup(database, kernel);
        if (i > 0) {
            double[] dArr = new double[database.getCount()];
            SequDB<FeatureVector> removeLabels = SVM_Manip.removeLabels(database, i, dArr, null);
            SVM_Manip sVM_Manip = new SVM_Manip(this.kernel);
            this.km = sVM_Manip.createKernelMatrix(removeLabels, list, collection, false, false);
            assignClassLabels(dArr);
            String[] primaryKeys = sVM_Manip.getPrimaryKeys();
            this.primaryKeys = primaryKeys;
            init_primaryKeyMapping(primaryKeys);
        } else {
            SVM_Manip sVM_Manip2 = new SVM_Manip(null);
            this.km = sVM_Manip2.createKernelMatrix(new Iterable<FeatureVector>() { // from class: de.lmu.ifi.dbs.dm.algorithms.libsvm.tools.SVR.1
                @Override // java.lang.Iterable
                public Iterator<FeatureVector> iterator() {
                    return database.objectIterator();
                }
            }, database.getCount(), null, false);
            String[] primaryKeys2 = sVM_Manip2.getPrimaryKeys();
            this.primaryKeys = primaryKeys2;
            init_primaryKeyMapping(primaryKeys2);
        }
        this.problem_size = this.km.l;
    }

    /* JADX WARN: Multi-variable type inference failed */
    public SVR(final Database<T> database, int i, Collection<Integer> collection, int i2, double[] dArr) {
        this.db = null;
        this.model = null;
        this.problem = null;
        this.km = null;
        this.problem_size = 0;
        this.param = null;
        this.key2index = null;
        this.primaryKeys = null;
        this.kernel = new LinearKernel();
        init_setup(database, null);
        if (i < 0 || i == 4 || i == 7 || i > 8) {
            throw new IllegalArgumentException("kernel_type must be in {0,1,2,3,5,6,8}");
        }
        this.param = SVM_Manip.create_svm_parameter(1, i);
        if (i == 1 || i == 2 || i == 3 || i == 6 || i == 8) {
            try {
                this.param.gamma = dArr[0];
            } catch (ArrayIndexOutOfBoundsException e) {
                throw new IllegalArgumentException("Kernel type " + i + " requires a gamma parameter");
            }
        }
        if (i == 1 || i == 3 || i == 6 || i == 8) {
            try {
                this.param.coef0 = dArr[1];
            } catch (ArrayIndexOutOfBoundsException e2) {
                throw new IllegalArgumentException("Kernel type " + i + " requires a coef0 parameter");
            }
        }
        if (i == 1 || i == 6) {
            try {
                this.param.degree = (int) dArr[2];
            } catch (ArrayIndexOutOfBoundsException e3) {
                throw new IllegalArgumentException("Kernel type " + i + " requires a degree parameter");
            }
        }
        this.param.svm_type = 3;
        if (i2 > 0) {
            double[] dArr2 = new double[database.getCount()];
            SequDB<FeatureVector> removeLabels = SVM_Manip.removeLabels(database, i2, dArr2, collection);
            SVM_Manip sVM_Manip = new SVM_Manip(null);
            this.problem = sVM_Manip.createSVMProblem(removeLabels, removeLabels.getCount(), null, false);
            assignClassLabels(dArr2);
            String[] primaryKeys = sVM_Manip.getPrimaryKeys();
            this.primaryKeys = primaryKeys;
            init_primaryKeyMapping(primaryKeys);
        } else {
            SVM_Manip sVM_Manip2 = new SVM_Manip(null);
            this.problem = sVM_Manip2.createSVMProblem(new Iterable<FeatureVector>() { // from class: de.lmu.ifi.dbs.dm.algorithms.libsvm.tools.SVR.2
                @Override // java.lang.Iterable
                public Iterator<FeatureVector> iterator() {
                    return database.objectIterator();
                }
            }, database.getCount(), null, false);
            String[] primaryKeys2 = sVM_Manip2.getPrimaryKeys();
            this.primaryKeys = primaryKeys2;
            init_primaryKeyMapping(primaryKeys2);
        }
        this.problem_size = this.problem.l;
    }

    public boolean assignClassLabels(double[] dArr) {
        try {
            if (this.km == null) {
                SVM_Manip.assignClassLabels(this.problem, dArr);
                return true;
            }
            SVM_Manip.assignClassLabels(this.km, dArr);
            return true;
        } catch (IllegalArgumentException e) {
            return false;
        }
    }

    public boolean assignClassLabels(Map<String, Double> map) {
        svm_problem svm_problemVar = this.km == null ? this.problem : this.km;
        for (Map.Entry<String, Integer> entry : this.key2index.entrySet()) {
            Double d = map.get(entry.getKey());
            if (d == null) {
                throw new IllegalArgumentException("Entry map contains key '" + entry.getKey() + "' not mapped to a label position in the given label map");
            }
            svm_problemVar.y[entry.getValue().intValue()] = d.doubleValue();
        }
        return true;
    }

    private void init_setup(Database<T> database, Kernel<T> kernel) {
        this.db = database;
        if (kernel != null) {
            this.kernel = kernel;
        }
        this.key2index = new HashMap();
    }

    private void init_primaryKeyMapping(String[] strArr) {
        for (int i = 0; i < strArr.length; i++) {
            this.key2index.put(strArr[i], Integer.valueOf(i));
        }
    }

    public void train(List<String> list) {
        if (this.km != null) {
            SVM_Manip.normalize_kernel_matrix(this.km);
            this.param = SVM_Manip.create_svm_parameter(1, 4);
            this.param.svm_type = 4;
        } else if (!$assertionsDisabled && this.param.kernel_type == 4) {
            throw new AssertionError();
        }
        double[] dArr = {-1.0d, -1.0d};
        HashSet hashSet = null;
        if (list != null) {
            HashSet hashSet2 = new HashSet();
            Iterator<String> it = list.iterator();
            while (it.hasNext()) {
                hashSet2.add(this.key2index.get(it.next()));
            }
            hashSet = new HashSet(this.problem_size - list.size());
            for (int i = 0; i < this.problem_size; i++) {
                if (!hashSet2.contains(Integer.valueOf(i))) {
                    hashSet.add(Integer.valueOf(i));
                }
            }
        }
        if (this.km == null) {
            SVM_Manip.optimize_4Regression(this.problem, this.param, dArr, hashSet, 10, true);
        } else {
            SVM_Manip.optimize_4Regression(this.km, this.param, dArr, hashSet, 10, true);
        }
        this.param.C = dArr[0];
        this.param.nu = dArr[1];
        svm_problem svm_problemVar = this.km;
        if (this.km == null) {
            svm_problemVar = this.problem;
        }
        if (hashSet != null) {
            svm_problemVar = SVM_Manip.filter_svm_problem_instances(svm_problemVar, hashSet);
        }
        this.model = svm.svm_train(svm_problemVar, this.param);
    }

    public double predict(String str) {
        if (this.model == null) {
            throw new IllegalStateException("You must first train a model in order to ask for a prediction");
        }
        Integer num = this.key2index.get(str);
        if (num == null) {
            throw new IllegalArgumentException("primary key '" + str + "' unknown");
        }
        return svm.svm_predict(this.model, this.km == null ? this.problem.x[num.intValue()] : this.km.x[num.intValue()]);
    }

    public double[] predict(List<String> list) {
        if (this.model == null) {
            throw new IllegalStateException("You must first train a model in order to ask for a prediction");
        }
        double[] dArr = new double[list.size()];
        int i = 0;
        for (String str : list) {
            Integer num = this.key2index.get(str);
            if (num == null) {
                throw new IllegalArgumentException("primary key '" + str + "' unknown");
            }
            dArr[i] = svm.svm_predict(this.model, this.km == null ? this.problem.x[num.intValue()] : this.km.x[num.intValue()]);
            i++;
        }
        return dArr;
    }

    public svm_model getModel() {
        return this.model;
    }

    public svm_problem getKm() {
        return this.km;
    }

    public svm_problem getProblem() {
        return this.problem;
    }

    public svm_parameter getParam() {
        return this.param;
    }

    public Set<Integer> getInstancesOfClass(int i) {
        HashSet hashSet = new HashSet();
        Iterator<T> objectIterator = this.db.objectIterator();
        while (objectIterator.hasNext()) {
            T next = objectIterator.next();
            if (next.getClassNr() == i) {
                hashSet.add(this.key2index.get(next.getPrimaryKey()));
            }
        }
        return hashSet;
    }

    public Database<T> getDB() {
        return this.db;
    }

    public double cvByClassLabel(boolean z, ModelStorage modelStorage) {
        SummaryItem summaryItem = new SummaryItem();
        Iterator<Integer> it = this.db.getClassIDs().iterator();
        while (it.hasNext()) {
            int intValue = it.next().intValue();
            if (modelStorage != null) {
                modelStorage.incrementFile();
            }
            Set<Integer> instancesOfClass = getInstancesOfClass(intValue);
            if (z) {
                System.out.print("Class " + intValue + " ");
            }
            summaryItem.add(SVM_Manip.optimize_and_validate(this.problem, this.param, instancesOfClass, z, modelStorage));
        }
        System.out.println(String.valueOf(this.db.getNumClasses()) + "-fold (class-wise) CV ERR: " + summaryItem.getMean() + ", SD: " + summaryItem.getStdD());
        return summaryItem.getMean();
    }
}
