package ir.Experiments.index;

import dm.algorithms.CVClassifier;
import dm.algorithms.PrecRecClassifier;
import dm.algorithms.WilcoxonClassification;
import dm.data.DataObject;
import dm.data.DistanceMeasure;
import dm.data.MIObjects.MIDistanceMeasure;
import dm.data.MIObjects.MultiInstanceObject;
import dm.data.MIObjects.PWilcoxMIDM;
import dm.data.MIObjects.WeightedSumMIDM;
import dm.data.MIObjects.WilcoxMIDM2;
import dm.data.database.Database;
import dm.data.database.SequDB;
import dm.data.featureVector.FeatureVector;
import dm.data.featureVector.SqEuclidianDistance;
import dm.util.math.WilcoxonTest;
import ir.utils.tools.Zeit;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Date;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Map;

/* loaded from: input_file:ir/Experiments/index/ClassificationTests.class */
public class ClassificationTests {
    public static double K;
    public static DistanceMeasure defaultDM;
    public static CVClassifier classifier;
    static final /* synthetic */ boolean $assertionsDisabled;

    static {
        $assertionsDisabled = !ClassificationTests.class.desiredAssertionStatus();
        K = -1.0d;
        defaultDM = null;
        classifier = null;
    }

    public static <T extends DataObject> Map<Integer, Double> runLOO(Database<T> database) {
        int ceil;
        int ceil2;
        HashMap hashMap = new HashMap();
        DistanceMeasure distanceMeasure = database.getDistanceMeasure();
        Iterator<T> objectIterator = database.objectIterator();
        int numClasses = database.getNumClasses();
        int[] iArr = new int[numClasses];
        HashSet hashSet = new HashSet();
        while (objectIterator.hasNext()) {
            hashSet.add(Integer.valueOf(objectIterator.next().getClassNr()));
        }
        Integer[] numArr = new Integer[hashSet.size()];
        hashSet.toArray(numArr);
        Arrays.sort(numArr);
        int i = 0;
        while (i < numArr.length) {
            iArr[i] = numArr[i].intValue();
            i++;
        }
        if (!$assertionsDisabled && i != numClasses) {
            throw new AssertionError();
        }
        HashSet hashSet2 = new HashSet();
        Iterator<T> objectIterator2 = database.objectIterator();
        while (objectIterator2.hasNext()) {
            T next = objectIterator2.next();
            hashSet2.clear();
            hashSet2.add(next.getPrimaryKey());
            if (distanceMeasure instanceof WilcoxMIDM2) {
                ((WilcoxMIDM2) distanceMeasure).calculateWeights(database, hashSet2);
                ((WilcoxMIDM2) distanceMeasure).assignObjectWeights(database);
            }
            if (next instanceof MultiInstanceObject) {
                Iterator<T> it = ((MultiInstanceObject) next).instances().iterator();
                while (it.hasNext()) {
                    it.next().setWeight(1.0d);
                }
            }
            next.setWeight(1.0d);
            int i2 = -1;
            double d = Double.MAX_VALUE;
            if (classifier == null) {
                for (int i3 = 0; i3 < iArr.length; i3++) {
                    if (distanceMeasure instanceof PWilcoxMIDM) {
                        ((PWilcoxMIDM) distanceMeasure).assignObjectWeight((MultiInstanceObject) next);
                    }
                    ArrayList arrayList = new ArrayList();
                    ArrayList arrayList2 = new ArrayList();
                    Iterator<T> objectIterator3 = database.objectIterator();
                    while (objectIterator3.hasNext()) {
                        T next2 = objectIterator3.next();
                        if (!next2.getPrimaryKey().equals(next.getPrimaryKey())) {
                            double distance = defaultDM != null ? defaultDM.distance(next, next2) : distanceMeasure.distance(next, next2);
                            if (next2.getClassNr() == iArr[i3]) {
                                arrayList.add(Double.valueOf(distance));
                            } else {
                                arrayList2.add(Double.valueOf(distance));
                            }
                        }
                    }
                    if (K > 0.0d) {
                        Collections.sort(arrayList);
                        Collections.sort(arrayList2);
                        if (K >= 1.0d) {
                            ceil = K > ((double) arrayList.size()) ? arrayList.size() : (int) K;
                            ceil2 = K > ((double) arrayList2.size()) ? arrayList2.size() : (int) K;
                        } else {
                            ceil = (int) Math.ceil(arrayList.size() * K);
                            ceil2 = (int) Math.ceil(arrayList2.size() * K);
                        }
                        ArrayList arrayList3 = new ArrayList(ceil);
                        for (int i4 = 0; i4 < ceil; i4++) {
                            arrayList3.add((Double) arrayList.get(i4));
                        }
                        arrayList = new ArrayList(arrayList3);
                        ArrayList arrayList4 = new ArrayList(ceil2);
                        for (int i5 = 0; i5 < ceil2; i5++) {
                            arrayList4.add((Double) arrayList2.get(i5));
                        }
                        arrayList2 = new ArrayList(arrayList4);
                    }
                    double wilcoxon = WilcoxonTest.wilcoxon(arrayList, arrayList2, -1);
                    if (wilcoxon < d) {
                        d = wilcoxon;
                        i2 = iArr[i3];
                    }
                    System.out.print("p <- c(");
                    Iterator it2 = arrayList.iterator();
                    while (it2.hasNext()) {
                        System.out.print(it2.next());
                        if (it2.hasNext()) {
                            System.out.print(",");
                        }
                    }
                    System.out.print("); n<-c(");
                    Iterator it3 = arrayList2.iterator();
                    while (it3.hasNext()) {
                        System.out.print(it3.next());
                        if (it3.hasNext()) {
                            System.out.print(",");
                        }
                    }
                    System.out.print(")\n");
                    System.out.println("class " + iArr[i3] + ", pval: " + wilcoxon);
                }
                if (!$assertionsDisabled && i2 == -1) {
                    throw new AssertionError();
                }
            } else {
                i2 = classifier.classify(next);
                d = classifier.getCertainty();
            }
            int classNr = next.getClassNr();
            System.out.println("pred: " + i2 + " for " + classNr + ", certainty: " + d);
            Double d2 = (Double) hashMap.get(Integer.valueOf(classNr));
            if (d2 == null) {
                d2 = new Double(0.0d);
            }
            hashMap.put(Integer.valueOf(classNr), Double.valueOf(d2.doubleValue() + (i2 == classNr ? 1 : 0)));
        }
        Iterator it4 = hashMap.keySet().iterator();
        while (it4.hasNext()) {
            int intValue = ((Integer) it4.next()).intValue();
            hashMap.put(Integer.valueOf(intValue), Double.valueOf(((Double) hashMap.get(Integer.valueOf(intValue))).doubleValue() / database.getMemberCount(intValue)));
        }
        return hashMap;
    }

    public static void instanceRun() throws IOException {
        SequDB<MultiInstanceObject<FeatureVector>> loadMISeqDB = MMDTests.loadMISeqDB("P:/nfs/infdbs/WissProj/Theseus/Data/Caltech_Benchmark/Arff/caltech_5_of_each_class.arff", 12300, 166, new WilcoxMIDM2(new SqEuclidianDistance()));
        K = 3.0d;
        SequDB resolveDB = MultiInstanceObject.resolveDB(loadMISeqDB);
        classifier = new PrecRecClassifier(resolveDB);
        Date date = new Date();
        Map<Integer, Double> runLOO = runLOO(resolveDB);
        double d = 0.0d;
        System.out.println("class:\taccuracy");
        for (Map.Entry<Integer, Double> entry : runLOO.entrySet()) {
            System.out.println(entry.getKey() + ":\t" + entry.getValue());
            d += entry.getValue().doubleValue();
        }
        System.out.println("mean accuracy: " + (d / runLOO.size()));
        System.out.println("took " + Zeit.wieLange(date));
    }

    public static void main(String[] strArr) throws IOException {
        WilcoxMIDM2 wilcoxMIDM2 = new WilcoxMIDM2(new SqEuclidianDistance());
        MMDTests.ARFF_FEATURE_OFFSET = 1;
        MMDTests.ARFF_FEATURE_SEP = ",";
        SequDB<MultiInstanceObject<FeatureVector>> loadMISeqDB = MMDTests.loadMISeqDB("P:/nfs/infdbs/WissProj/Theseus/Data/musk1_mi.arff", Integer.MAX_VALUE, 166, wilcoxMIDM2);
        wilcoxMIDM2.threshold = 0.1d;
        wilcoxMIDM2.upperThreshold = true;
        wilcoxMIDM2.useSMD = false;
        defaultDM = new WeightedSumMIDM(new SqEuclidianDistance());
        ((WeightedSumMIDM) defaultDM).setMAXIMIZE_WEIGHTS(false);
        ((WeightedSumMIDM) defaultDM).setOBJECT_SPECIFIC_THRESHOLD(true);
        ((WeightedSumMIDM) defaultDM).setThreshold(0.1d);
        WilcoxonClassification.defaultDM = (MIDistanceMeasure) defaultDM;
        WilcoxonClassification.K = 3.0d;
        WilcoxonClassification.classifier = new PrecRecClassifier(loadMISeqDB);
        ((PrecRecClassifier) WilcoxonClassification.classifier).VERBOSE = true;
        ((PrecRecClassifier) WilcoxonClassification.classifier).setDecisionMethod(1);
        ((PrecRecClassifier) WilcoxonClassification.classifier).setRecallBins(new double[]{0.1d, 0.2d, 0.3d, 0.4d, 0.5d, 1.0d});
        ((PrecRecClassifier) WilcoxonClassification.classifier).setRecallSelection(0.1d);
        Date date = new Date();
        Map<Integer, Double> runCVWithWilcoxon = WilcoxonClassification.runCVWithWilcoxon(loadMISeqDB, 5);
        double d = 0.0d;
        System.out.println("class:\taccuracy");
        for (Map.Entry<Integer, Double> entry : runCVWithWilcoxon.entrySet()) {
            System.out.println(entry.getKey() + ":\t" + entry.getValue());
            d += entry.getValue().doubleValue();
        }
        System.out.println("mean accuracy: " + (d / runCVWithWilcoxon.size()));
        System.out.println("took " + Zeit.wieLange(date));
    }
}
