package de.lmu.ifi.dbs.dm.algorithms;

import de.lmu.ifi.dbs.dm.DistanceMeasure;
import de.lmu.ifi.dbs.dm.data.DataObject;
import de.lmu.ifi.dbs.dm.data.featureVector.FeatureVector;
import de.lmu.ifi.dbs.dm.data.generators.DataGenerator;
import de.lmu.ifi.dbs.dm.database.Database;
import de.lmu.ifi.dbs.dm.database.SequDB;
import de.lmu.ifi.dbs.dm.distance.SqEuclideanDistance;
import de.lmu.ifi.dbs.utilities.math.Sampler;
import de.lmu.ifi.dbs.utilities.statistics.SummaryItem;
import de.lmu.ifi.dbs.utilities.tools.Zeit;
import java.util.ArrayList;
import java.util.Date;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import java.util.TreeMap;
import java.util.logging.Logger;

/* loaded from: input_file:de/lmu/ifi/dbs/dm/algorithms/CVClassificationRuns.class */
public class CVClassificationRuns<T extends DataObject> {
    public static boolean VERBOSE;
    protected CVClassifier<T> classifier;
    protected Database<T> db;
    protected int[] classes;
    protected Map<Integer, Integer> class2Index;
    protected long[][] confusionMatrix;
    static final /* synthetic */ boolean $assertionsDisabled;
    private transient Logger log = Logger.getLogger(CVClassificationRuns.class.getName());
    protected DistanceMeasure<T> dm = null;
    private long SAMPLING_SEED = 13;
    protected Map<Integer, List<T>> classMap = new TreeMap();

    static {
        $assertionsDisabled = !CVClassificationRuns.class.desiredAssertionStatus();
        VERBOSE = true;
    }

    /* JADX WARN: Multi-variable type inference failed */
    public <D extends Database<T>> CVClassificationRuns(D d, CVClassifier<T> cVClassifier) {
        this.classifier = null;
        this.db = d;
        this.classifier = cVClassifier;
        this.classes = cVClassifier.getClasses();
        this.class2Index = cVClassifier.getClass2Index();
        Iterator objectIterator = d.objectIterator();
        while (objectIterator.hasNext()) {
            DataObject dataObject = (DataObject) objectIterator.next();
            List<T> list = this.classMap.get(Integer.valueOf(dataObject.getClassNr()));
            if (list == null) {
                list = new ArrayList();
                this.classMap.put(Integer.valueOf(dataObject.getClassNr()), list);
            }
            list.add(dataObject);
        }
    }

    public long setSamplingSeed(long j) {
        long j2 = this.SAMPLING_SEED;
        this.SAMPLING_SEED = j;
        Sampler.RANDOM = new Random(this.SAMPLING_SEED);
        return j2;
    }

    public double runCV(int i) {
        if (i <= 1) {
            throw new IllegalArgumentException("cross validation must have fold > 1 - is " + i);
        }
        if (i > this.db.getCount()) {
            throw new IllegalArgumentException("no use in running more training rounds (" + i + ") than there are objects to be classified (" + this.db.getCount() + ")");
        }
        this.confusionMatrix = new long[this.classes.length][this.classes.length];
        if (this.classifier != null && this.dm != null) {
            this.classifier.setDistanceMeasure(this.dm);
        }
        int i2 = 0;
        int i3 = 0;
        for (Map.Entry<Integer, List<T>> entry : this.classMap.entrySet()) {
            Sampler.permute(entry.getValue());
            if (entry.getValue().size() < i) {
                this.log.warning("Warning: running " + i + "-fold cross validation with class " + entry.getKey() + " of size " + entry.getValue().size());
                i2++;
                if (i3 < entry.getValue().size()) {
                    i3 = entry.getValue().size();
                }
            }
        }
        if (i2 == this.classMap.size()) {
            throw new IllegalArgumentException("\n\t\"No use in running more training rounds (" + i + ") than there are\n\t\tobjects per class (at maximum " + i3 + ") to be classified\"");
        }
        HashSet hashSet = new HashSet();
        long j = 0;
        long j2 = 0;
        HashSet hashSet2 = new HashSet();
        for (int i4 = 0; i4 < i; i4++) {
            hashSet.clear();
            int i5 = 0;
            Iterator<Map.Entry<Integer, List<T>>> it = this.classMap.entrySet().iterator();
            while (it.hasNext()) {
                List<T> value = it.next().getValue();
                int size = value.size() / i;
                for (int i6 = 0; i6 < size; i6++) {
                    hashSet.add(value.get((i * i6) + i4).getPrimaryKey());
                }
                if ((size * i) + i4 < value.size()) {
                    hashSet.add(value.get((size * i) + i4).getPrimaryKey());
                }
                if (i4 >= value.size()) {
                    hashSet.add(value.get(i4 % value.size()).getPrimaryKey());
                }
                if (!$assertionsDisabled && hashSet.size() <= i5) {
                    throw new AssertionError();
                }
                i5 = hashSet.size();
            }
            hashSet2.addAll(hashSet);
            train(hashSet);
            Iterator<String> it2 = hashSet.iterator();
            while (it2.hasNext()) {
                T database = this.db.getInstance(it2.next());
                int classify = this.classifier.classify(database, hashSet);
                double certainty = this.classifier.getCertainty();
                if (classify == -1) {
                    this.log.warning("no hypotheses resulted in a matching prediction for " + database.getPrimaryKey());
                    classify = -1;
                    certainty = 0.0d;
                }
                if (!$assertionsDisabled && classify == -1) {
                    throw new AssertionError();
                }
                int classNr = database.getClassNr();
                if (VERBOSE) {
                    System.out.println(String.valueOf(classify == classNr ? "  " : "X ") + "pred: " + classify + " for " + classNr + ", certainty: " + certainty);
                }
                if (classify == classNr) {
                    j2++;
                }
                j++;
                long[] jArr = this.confusionMatrix[this.class2Index.get(Integer.valueOf(classNr)).intValue()];
                int intValue = this.class2Index.get(Integer.valueOf(classify)).intValue();
                jArr[intValue] = jArr[intValue] + 1;
            }
            if (VERBOSE) {
                System.out.println("==========" + i4 + "==========");
            }
        }
        if (!$assertionsDisabled && j < this.db.getCount()) {
            throw new AssertionError();
        }
        if ($assertionsDisabled || hashSet2.size() == this.db.getCount()) {
            return j2 / j;
        }
        throw new AssertionError();
    }

    public Map<Integer, Double> getAccuracies() {
        HashMap hashMap = new HashMap();
        for (int i = 0; i < this.confusionMatrix.length; i++) {
            double d = 0.0d;
            for (int i2 = 0; i2 < this.confusionMatrix[i].length; i2++) {
                d += this.confusionMatrix[i][i2];
            }
            hashMap.put(Integer.valueOf(this.classes[i]), Double.valueOf(this.confusionMatrix[i][i] / d));
        }
        return hashMap;
    }

    public SummaryItem repeatedRun(int i, int i2, Map<Integer, Double> map, long[][] jArr) {
        if (jArr == null) {
            jArr = new long[this.classes.length][this.classes.length];
        } else if (jArr.length != this.classes.length || jArr[0].length != this.classes.length) {
            throw new IllegalArgumentException("Require a confusion matrix of dimension " + this.classes.length + "x" + this.classes.length + "; got: " + jArr.length + "x" + jArr[0].length);
        }
        if (map != null) {
            map.clear();
        }
        Date date = new Date();
        SummaryItem summaryItem = new SummaryItem();
        for (int i3 = 0; i3 < i2; i3++) {
            summaryItem.add(runCV(i));
            for (int i4 = 0; i4 < jArr.length; i4++) {
                for (int i5 = 0; i5 < jArr[i4].length; i5++) {
                    long[] jArr2 = jArr[i4];
                    int i6 = i5;
                    jArr2[i6] = jArr2[i6] + this.confusionMatrix[i4][i5];
                }
            }
        }
        if (map != null) {
            long[][] jArr3 = this.confusionMatrix;
            this.confusionMatrix = jArr;
            map.putAll(getAccuracies());
            this.confusionMatrix = jArr3;
        }
        if (VERBOSE) {
            System.out.println("mean accuracy: " + summaryItem.getMean());
            System.out.println("took " + Zeit.wieLange(date));
        }
        return summaryItem;
    }

    protected void train(Set<String> set) {
    }

    public static void main(String[] strArr) {
        DataGenerator dataGenerator = new DataGenerator(0.0d, 0.75d, 1000, 2, 11L);
        dataGenerator.setType(0);
        SequDB sequDB = new SequDB(new SqEuclideanDistance());
        for (FeatureVector featureVector : dataGenerator.generate()) {
            featureVector.setClassNr(0);
            sequDB.insert(featureVector);
        }
        DataGenerator dataGenerator2 = new DataGenerator(0.25d, 1.0d, 1000, 2, 12L);
        dataGenerator2.setType(0);
        for (FeatureVector featureVector2 : dataGenerator2.generate()) {
            featureVector2.setClassNr(1);
            featureVector2.setPrimaryKey(new StringBuilder().append(sequDB.getCount() + 1).toString());
            sequDB.insert(featureVector2);
        }
        System.out.println(String.valueOf(sequDB.getMemberCount(0)) + ", " + sequDB.getMemberCount(1) + " => " + sequDB.getCount());
        CVClassificationRuns cVClassificationRuns = new CVClassificationRuns(sequDB, new KNNClassifier(sequDB, 2));
        cVClassificationRuns.setSamplingSeed(13L);
        HashMap hashMap = new HashMap();
        System.out.println("accuracy: " + cVClassificationRuns.repeatedRun(10, 10, hashMap, null).toString() + "\naccuracies:");
        System.out.println(hashMap);
    }
}
