package ir.Experiments.index;

import dm.algorithms.KNNClassifier;
import dm.data.DataObject;
import dm.data.DistanceMeasure;
import dm.data.MIObjects.MIDistanceMeasure;
import dm.data.MIObjects.MultiInstanceObject;
import dm.data.MIObjects.ParallelizedMIDM;
import dm.data.database.Database;
import dm.data.database.MultiDistanceSequDB;
import dm.data.featureVector.SqEuclidianDistance;
import dm.data.kernels.RBFKernel;
import dm.util.math.Sampler;
import ir.utils.statistics.SummaryItem;
import ir.utils.tools.Zeit;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Date;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.TreeMap;

/* loaded from: input_file:ir/Experiments/index/ParallelizedMIDMTests.class */
public class ParallelizedMIDMTests<T extends DataObject> {
    private static long SAMPLING_SEED;
    public boolean VERBOSE;
    private Database<MultiInstanceObject<T>> db;
    private KNNClassifier classifier;
    private int[] ks;
    private Map<Integer, Double>[][] accuracySumsPerClassPerDistPerK;
    private Map<Integer, SummaryItem>[][] accumulatedDistanceSummaries;
    static final /* synthetic */ boolean $assertionsDisabled;

    static {
        $assertionsDisabled = !ParallelizedMIDMTests.class.desiredAssertionStatus();
        SAMPLING_SEED = 13L;
    }

    public ParallelizedMIDMTests(Database<MultiInstanceObject<T>> database, KNNClassifier kNNClassifier) {
        this.VERBOSE = true;
        this.classifier = null;
        this.accuracySumsPerClassPerDistPerK = null;
        this.accumulatedDistanceSummaries = null;
        this.db = database;
        this.classifier = kNNClassifier;
        this.ks = new int[]{1, 2, 3, 5};
    }

    public ParallelizedMIDMTests(Database<MultiInstanceObject<T>> database, KNNClassifier kNNClassifier, int[] iArr) {
        this.VERBOSE = true;
        this.classifier = null;
        this.accuracySumsPerClassPerDistPerK = null;
        this.accumulatedDistanceSummaries = null;
        this.db = database;
        this.classifier = kNNClassifier;
        this.ks = iArr;
        Arrays.sort(this.ks);
    }

    public Map<Integer, Double> runCV(int i) {
        if (i <= 1) {
            throw new IllegalArgumentException("cross validation must have fold > 1 - is " + i);
        }
        if (i > this.db.getCount()) {
            throw new IllegalArgumentException("no use in running more training rounds (" + i + ") than there are objects to be classified (" + this.db.getCount() + ")");
        }
        HashMap hashMap = new HashMap();
        if (this.db.getDistanceMeasure() instanceof ParallelizedMIDM) {
            this.accuracySumsPerClassPerDistPerK = new HashMap[this.ks.length][((ParallelizedMIDM) this.db.getDistanceMeasure()).getNumDistances()];
            for (int i2 = 0; i2 < this.accuracySumsPerClassPerDistPerK.length; i2++) {
                for (int i3 = 0; i3 < this.accuracySumsPerClassPerDistPerK[i2].length; i3++) {
                    this.accuracySumsPerClassPerDistPerK[i2][i3] = new HashMap();
                }
            }
        }
        HashMap hashMap2 = new HashMap();
        DistanceMeasure distanceMeasure = this.db.getDistanceMeasure();
        Iterator<MultiInstanceObject<T>> objectIterator = this.db.objectIterator();
        int numClasses = this.db.getNumClasses();
        int[] iArr = new int[numClasses];
        TreeMap treeMap = new TreeMap();
        while (objectIterator.hasNext()) {
            MultiInstanceObject<T> next = objectIterator.next();
            List list = (List) treeMap.get(Integer.valueOf(next.getClassNr()));
            if (list == null) {
                list = new ArrayList();
                treeMap.put(Integer.valueOf(next.getClassNr()), list);
            }
            list.add(next);
        }
        int i4 = 0;
        int i5 = 0;
        int i6 = 0;
        HashMap hashMap3 = new HashMap();
        HashMap hashMap4 = new HashMap();
        for (Map.Entry entry : treeMap.entrySet()) {
            iArr[i4] = ((Integer) entry.getKey()).intValue();
            hashMap4.put((Integer) entry.getKey(), Integer.valueOf(i4));
            hashMap3.put(Integer.valueOf(i4), Sampler.sample(((List) entry.getValue()).size(), (List) entry.getValue()));
            i4++;
            if (((List) entry.getValue()).size() < i) {
                System.err.println("Warning: running " + i + "-fold cross validation with class " + entry.getKey() + " of size " + ((List) entry.getValue()).size());
                i5++;
                if (i6 < ((List) entry.getValue()).size()) {
                    i6 = ((List) entry.getValue()).size();
                }
            }
        }
        if (!$assertionsDisabled && i4 != numClasses) {
            throw new AssertionError();
        }
        if (i5 == hashMap3.size()) {
            throw new IllegalArgumentException("\n\t\"No use in running more training rounds (" + i + ") than there are\n\t\tobjects per class (at maximum " + i6 + ") to be classified\"");
        }
        ((MIDistanceMeasure) distanceMeasure).getInstanceDistance();
        HashSet hashSet = new HashSet();
        int i7 = 0;
        HashSet hashSet2 = new HashSet();
        for (int i8 = 0; i8 < i; i8++) {
            hashSet.clear();
            int i9 = 0;
            Iterator it = hashMap3.entrySet().iterator();
            while (it.hasNext()) {
                List list2 = (List) ((Map.Entry) it.next()).getValue();
                int size = list2.size() / i;
                for (int i10 = 0; i10 < size; i10++) {
                    hashSet.add(((MultiInstanceObject) list2.get((i * i10) + i8)).getPrimaryKey());
                }
                if ((size * i) + i8 < list2.size()) {
                    hashSet.add(((MultiInstanceObject) list2.get((size * i) + i8)).getPrimaryKey());
                }
                if (i8 >= list2.size()) {
                    hashSet.add(((MultiInstanceObject) list2.get(i8 % list2.size())).getPrimaryKey());
                }
                if (!$assertionsDisabled && hashSet.size() <= i9) {
                    throw new AssertionError();
                }
                i9 = hashSet.size();
            }
            hashSet2.addAll(hashSet);
            Iterator it2 = hashSet.iterator();
            while (it2.hasNext()) {
                MultiInstanceObject<T> database = this.db.getInstance((String) it2.next());
                int classify = this.classifier.classify(database, hashSet);
                double certainty = this.classifier.getCertainty();
                if (!$assertionsDisabled && classify == -1) {
                    throw new AssertionError();
                }
                int classNr = database.getClassNr();
                if (this.VERBOSE) {
                    System.out.println(String.valueOf(classify == classNr ? "  " : "X ") + "pred: " + classify + " for " + classNr + ", certainty: " + certainty);
                }
                if (classify == classNr) {
                    Double d = (Double) hashMap.get(Integer.valueOf(classNr));
                    if (d == null) {
                        d = new Double(0.0d);
                    }
                    hashMap.put(Integer.valueOf(classNr), Double.valueOf(d.doubleValue() + 1.0d));
                }
                Integer num = (Integer) hashMap2.get(Integer.valueOf(classNr));
                if (num == null) {
                    num = 0;
                }
                hashMap2.put(Integer.valueOf(classNr), Integer.valueOf(num.intValue() + 1));
                i7++;
                if (this.accuracySumsPerClassPerDistPerK != null) {
                    if (classify == classNr) {
                        Double d2 = this.accuracySumsPerClassPerDistPerK[0][0].get(Integer.valueOf(classNr));
                        if (d2 == null) {
                            d2 = new Double(0.0d);
                        }
                        this.accuracySumsPerClassPerDistPerK[0][0].put(Integer.valueOf(classNr), Double.valueOf(d2.doubleValue() + 1.0d));
                    }
                    ParallelizedMIDM parallelizedMIDM = (ParallelizedMIDM) this.db.getDistanceMeasure();
                    MultiDistanceSequDB multiDistanceSequDB = (MultiDistanceSequDB) this.db;
                    while (multiDistanceSequDB.getDOption() < parallelizedMIDM.getNumDistances()) {
                        for (int i11 = 0; i11 < this.ks.length; i11++) {
                            if (i11 != 0 || multiDistanceSequDB.getDOption() != 0) {
                                this.classifier.setK(this.ks[i11]);
                                multiDistanceSequDB.resetKsForDist();
                                int classify2 = this.classifier.classify(database, hashSet);
                                double certainty2 = this.classifier.getCertainty();
                                if (!$assertionsDisabled && classify2 == -1) {
                                    throw new AssertionError();
                                }
                                if (this.VERBOSE) {
                                    System.out.println(String.valueOf(classify2 == classNr ? "  " : "X ") + "pred: " + classify2 + " for " + classNr + ", certainty: " + certainty2);
                                }
                                if (classify2 == classNr) {
                                    Double d3 = (Double) hashMap.get(Integer.valueOf(classNr));
                                    if (d3 == null) {
                                        d3 = new Double(0.0d);
                                    }
                                    hashMap.put(Integer.valueOf(classNr), Double.valueOf(d3.doubleValue() + 1.0d));
                                    Double d4 = this.accuracySumsPerClassPerDistPerK[i11][multiDistanceSequDB.getDOption()].get(Integer.valueOf(classNr));
                                    if (d4 == null) {
                                        d4 = new Double(0.0d);
                                    }
                                    this.accuracySumsPerClassPerDistPerK[i11][multiDistanceSequDB.getDOption()].put(Integer.valueOf(classNr), Double.valueOf(d4.doubleValue() + 1.0d));
                                }
                            }
                        }
                        multiDistanceSequDB.setDOption(multiDistanceSequDB.getDOption() + 1);
                        this.classifier.setK(this.ks[0]);
                    }
                    multiDistanceSequDB.setDOption(0);
                }
                if (this.VERBOSE) {
                    System.out.println("for " + database.getPrimaryKey());
                }
            }
            System.out.println("==========" + i8 + "==========");
        }
        if (!$assertionsDisabled && i7 < this.db.getCount()) {
            throw new AssertionError();
        }
        if (!$assertionsDisabled && hashSet2.size() != this.db.getCount()) {
            throw new AssertionError();
        }
        Iterator it3 = hashMap.keySet().iterator();
        while (it3.hasNext()) {
            int intValue = ((Integer) it3.next()).intValue();
            hashMap.put(Integer.valueOf(intValue), Double.valueOf(((Double) hashMap.get(Integer.valueOf(intValue))).doubleValue() / (((Integer) hashMap2.get(Integer.valueOf(intValue))).intValue() * (this.accuracySumsPerClassPerDistPerK == null ? 1 : this.accuracySumsPerClassPerDistPerK.length * this.accuracySumsPerClassPerDistPerK[0].length))));
        }
        for (int i12 = 0; i12 < this.accumulatedDistanceSummaries.length; i12++) {
            for (int i13 = 0; i13 < this.accumulatedDistanceSummaries[i12].length; i13++) {
                for (Integer num2 : this.accuracySumsPerClassPerDistPerK[i12][i13].keySet()) {
                    this.accuracySumsPerClassPerDistPerK[i12][i13].put(num2, Double.valueOf(this.accuracySumsPerClassPerDistPerK[i12][i13].get(num2).doubleValue() / ((Integer) hashMap2.get(num2)).intValue()));
                }
            }
        }
        return hashMap;
    }

    public Map<Integer, Double> repeatedRun(int i, int i2) {
        Date date = new Date();
        double d = 0.0d;
        HashMap hashMap = new HashMap();
        if (this.db.getDistanceMeasure() instanceof ParallelizedMIDM) {
            this.accumulatedDistanceSummaries = new Map[this.ks.length][((ParallelizedMIDM) this.db.getDistanceMeasure()).getNumDistances()];
            for (int i3 = 0; i3 < this.accumulatedDistanceSummaries.length; i3++) {
                for (int i4 = 0; i4 < this.accumulatedDistanceSummaries[i3].length; i4++) {
                    this.accumulatedDistanceSummaries[i3][i4] = new HashMap();
                }
            }
        }
        int[] classes = this.classifier.getClasses();
        for (int i5 = 0; i5 < i2; i5++) {
            if (i <= 1) {
                throw new IllegalArgumentException("LOO not yet implemented");
            }
            if (!(this.db.getDistanceMeasure() instanceof MIDistanceMeasure)) {
                throw new IllegalArgumentException("can only treat multi-instance DBs");
            }
            Map<Integer, Double> runCV = runCV(i);
            double d2 = 0.0d;
            if (this.VERBOSE) {
                System.out.println("class:\taccuracy\tstdev");
            }
            for (Map.Entry<Integer, Double> entry : runCV.entrySet()) {
                SummaryItem summaryItem = (SummaryItem) hashMap.get(entry.getKey());
                if (summaryItem == null) {
                    if (i5 > 0) {
                        throw new IllegalArgumentException("empty accuracy bin in iteration " + i5 + ": class " + entry.getKey() + " not yet entered");
                    }
                    summaryItem = new SummaryItem();
                    hashMap.put(entry.getKey(), summaryItem);
                }
                summaryItem.add(entry.getValue().doubleValue());
                if (this.VERBOSE) {
                    System.out.println(entry.getKey() + ":\t" + entry.getValue());
                }
                d2 += entry.getValue().doubleValue();
            }
            double size = d2 / runCV.size();
            if (this.VERBOSE) {
                System.out.println("mean accuracy r" + i5 + ": " + size);
            }
            for (int i6 = 0; this.accumulatedDistanceSummaries != null && i6 < this.accumulatedDistanceSummaries[0].length; i6++) {
                for (int i7 = 0; i7 < this.ks.length; i7++) {
                    for (int i8 = 0; i8 < classes.length; i8++) {
                        Double d3 = this.accuracySumsPerClassPerDistPerK[i7][i6].get(Integer.valueOf(classes[i8]));
                        SummaryItem summaryItem2 = this.accumulatedDistanceSummaries[i7][i6].get(Integer.valueOf(classes[i8]));
                        if (summaryItem2 == null) {
                            if (i5 > 0) {
                                throw new IllegalArgumentException("empty accuracy bin in iteration " + i5 + ": class " + classes[i8] + " not yet entered");
                            }
                            summaryItem2 = new SummaryItem();
                            this.accumulatedDistanceSummaries[i7][i6].put(Integer.valueOf(classes[i8]), summaryItem2);
                        }
                        summaryItem2.add(d3 == null ? 0.0d : d3.doubleValue());
                    }
                }
            }
        }
        System.out.println();
        HashMap hashMap2 = new HashMap();
        for (Map.Entry entry2 : hashMap.entrySet()) {
            if (!$assertionsDisabled && ((SummaryItem) entry2.getValue()).getCount() != i2) {
                throw new AssertionError();
            }
            double mean = ((SummaryItem) entry2.getValue()).getMean();
            hashMap2.put((Integer) entry2.getKey(), Double.valueOf(mean));
            if (this.VERBOSE) {
                System.out.println(entry2.getKey() + ":\t" + mean + "\t" + ((SummaryItem) entry2.getValue()).getStdD());
            }
            d += mean;
        }
        if (this.VERBOSE) {
            System.out.println("mean accuracy: " + (d / hashMap2.size()));
            System.out.println("took " + Zeit.wieLange(date));
        }
        if (this.VERBOSE && this.accumulatedDistanceSummaries != null) {
            SummaryItem[][] summaryItemArr = new SummaryItem[this.ks.length][this.accumulatedDistanceSummaries[0].length];
            System.out.println("Format: class\t{accuracy\tstdev}+");
            System.out.print("class");
            for (int i9 = 0; i9 < summaryItemArr[0].length; i9++) {
                for (int i10 = 0; i10 < summaryItemArr.length; i10++) {
                    summaryItemArr[i10][i9] = new SummaryItem();
                    System.out.print("\td" + i9 + "k" + this.ks[i10] + "acc\td" + i9 + "k" + this.ks[i10] + "sd");
                }
            }
            System.out.println();
            for (int i11 = 0; i11 < classes.length; i11++) {
                System.out.print(classes[i11]);
                for (int i12 = 0; i12 < this.accumulatedDistanceSummaries[0].length; i12++) {
                    for (int i13 = 0; i13 < this.accumulatedDistanceSummaries.length; i13++) {
                        SummaryItem summaryItem3 = this.accumulatedDistanceSummaries[i13][i12].get(Integer.valueOf(classes[i11]));
                        if (summaryItem3 == null) {
                            System.out.print("\t0\t0");
                        } else {
                            if (!$assertionsDisabled && summaryItem3.getCount() != i2) {
                                throw new AssertionError();
                            }
                            double mean2 = summaryItem3.getMean();
                            System.out.print("\t" + mean2 + "\t" + summaryItem3.getStdD());
                            summaryItemArr[i13][i12].add(mean2);
                        }
                    }
                }
                System.out.println();
            }
            System.out.print("total");
            for (int i14 = 0; i14 < summaryItemArr[0].length; i14++) {
                for (int i15 = 0; i15 < summaryItemArr.length; i15++) {
                    System.out.print("\t" + summaryItemArr[i15][i14].getMean() + "\t" + summaryItemArr[i15][i14].getStdD());
                }
            }
            System.out.println();
        }
        return hashMap2;
    }

    public final int[] getKs() {
        return this.ks;
    }

    public final void setKs(int[] iArr) {
        this.ks = iArr;
        Arrays.sort(iArr);
    }

    public final Map<Integer, SummaryItem>[][] getAccumulatedDistanceSummaries() {
        return this.accumulatedDistanceSummaries;
    }

    public static long setSamplingSeed(long j) {
        long j2 = SAMPLING_SEED;
        SAMPLING_SEED = j;
        Sampler.RANDOM = new Random(SAMPLING_SEED);
        return j2;
    }

    public static void main(String[] strArr) throws IOException {
        setSamplingSeed(11111982L);
        ParallelizedMIDM parallelizedMIDM = new ParallelizedMIDM();
        MMDTests.ARFF_FEATURE_OFFSET = 1;
        MMDTests.FILTER_4 = false;
        MultiDistanceSequDB multiDistanceSequDB = new MultiDistanceSequDB(MMDTests.loadMISeqDB("/nfs/infdbs/WissProj/Theseus/Data/Stock4B/stock4B_0.25_sift.arff", Integer.MAX_VALUE, 128, parallelizedMIDM));
        parallelizedMIDM.kernelRunsOnly();
        parallelizedMIDM.addKernel(new RBFKernel(16.798d, new SqEuclidianDistance()), multiDistanceSequDB);
        multiDistanceSequDB.addDistanceMeasures(parallelizedMIDM);
        new ParallelizedMIDMTests(multiDistanceSequDB, new KNNClassifier(multiDistanceSequDB, 1)).repeatedRun(5, 10);
    }
}
