package de.lmu.ifi.dbs.dm.algorithms;

import de.lmu.ifi.dbs.dm.DistanceMeasure;
import de.lmu.ifi.dbs.dm.data.DataObject;
import de.lmu.ifi.dbs.dm.data.MultiInstanceObject;
import de.lmu.ifi.dbs.dm.database.Database;
import de.lmu.ifi.dbs.utilities.math.Sampler;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Set;
import java.util.logging.Logger;

/* loaded from: input_file:de/lmu/ifi/dbs/dm/algorithms/KNNClassifier.class */
public class KNNClassifier<T extends DataObject> implements CVClassifier<T> {
    private static final Logger log;
    protected Database<T> db;
    protected int k;
    protected DistanceMeasure<T> dm;
    protected double certainty;
    protected boolean adjustClassDistribution;
    protected Map<Integer, Integer> class2Index;
    protected int[] numObjectsPerClass;
    protected int[] classes;
    protected int[] classAccess;
    protected double[] maxCertainties;
    static final /* synthetic */ boolean $assertionsDisabled;

    static {
        $assertionsDisabled = !KNNClassifier.class.desiredAssertionStatus();
        log = Logger.getLogger(KNNClassifier.class.getName());
    }

    protected KNNClassifier(int i) {
        this.dm = null;
        this.certainty = -1.0d;
        this.adjustClassDistribution = true;
        this.class2Index = new HashMap();
        this.k = i;
        this.db = null;
        if (i <= 0) {
            throw new IllegalArgumentException("k must be > 0");
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    public <D extends Database<T>> KNNClassifier(D d, int i) {
        this.dm = null;
        this.certainty = -1.0d;
        this.adjustClassDistribution = true;
        this.class2Index = new HashMap();
        this.db = d;
        this.k = i;
        if (i <= 0) {
            throw new IllegalArgumentException("k must be > 0");
        }
        Iterator objectIterator = d.objectIterator();
        this.numObjectsPerClass = new int[d.getNumClasses()];
        this.classes = new int[this.numObjectsPerClass.length];
        HashSet hashSet = new HashSet();
        while (objectIterator.hasNext()) {
            hashSet.add(Integer.valueOf(((DataObject) objectIterator.next()).getClassNr()));
        }
        Integer[] numArr = new Integer[hashSet.size()];
        hashSet.toArray(numArr);
        Arrays.sort(numArr);
        for (int i2 = 0; i2 < numArr.length; i2++) {
            this.classes[i2] = numArr[i2].intValue();
        }
        if (!$assertionsDisabled && numArr.length != d.getNumClasses()) {
            throw new AssertionError();
        }
        for (int i3 = 0; i3 < this.classes.length; i3++) {
            this.class2Index.put(Integer.valueOf(this.classes[i3]), Integer.valueOf(i3));
            this.numObjectsPerClass[i3] = d.getMemberCount(this.classes[i3]);
        }
        this.classAccess = new int[this.classes.length];
        if (!$assertionsDisabled && this.class2Index.size() != d.getNumClasses()) {
            throw new AssertionError();
        }
        this.maxCertainties = new double[this.classes.length];
        double d2 = -2.147483648E9d;
        double d3 = -2.147483648E9d;
        for (int i4 = 0; i4 < this.numObjectsPerClass.length; i4++) {
            if (d2 < this.numObjectsPerClass[i4]) {
                d3 = d2;
                d2 = this.numObjectsPerClass[i4];
            } else if (d3 <= this.numObjectsPerClass[i4]) {
                d3 = this.numObjectsPerClass[i4];
            }
        }
        for (int i5 = 0; i5 < this.numObjectsPerClass.length; i5++) {
            this.maxCertainties[i5] = (i / this.numObjectsPerClass[i5]) * (((double) this.numObjectsPerClass[i5]) == d2 ? d3 : d2);
        }
    }

    public void ignoreClassDistribution() {
        this.adjustClassDistribution = false;
    }

    @Override // de.lmu.ifi.dbs.dm.algorithms.Classifier
    public int classify(T t) {
        double[] distribution = getDistribution(t);
        if (this.adjustClassDistribution) {
            for (int i = 0; i < distribution.length; i++) {
                if (distribution[i] != 0.0d) {
                    int i2 = i;
                    distribution[i2] = distribution[i2] / this.numObjectsPerClass[i];
                }
            }
        }
        int i3 = -1;
        double d = 0.0d;
        double d2 = 0.0d;
        Sampler.sample(this.classAccess, this.classAccess);
        log.info(" ca: (" + this.classAccess[0] + ", " + this.classAccess[1] + ")");
        for (int i4 = 0; i4 < distribution.length; i4++) {
            if (distribution[this.classAccess[i4]] > d) {
                d2 = d;
                d = distribution[this.classAccess[i4]];
                i3 = this.classes[this.classAccess[i4]];
            } else if (distribution[this.classAccess[i4]] >= d2) {
                d2 = distribution[this.classAccess[i4]];
            }
        }
        if (this.k == 1) {
            this.certainty = 1.0d;
        } else {
            this.certainty = d2 == 0.0d ? this.maxCertainties[this.class2Index.get(Integer.valueOf(i3)).intValue()] : d / d2;
        }
        return i3;
    }

    @Override // de.lmu.ifi.dbs.dm.algorithms.CVClassifier
    public int classify(T t, Set<String> set) {
        double[] distribution = getDistribution(t, set);
        if (this.adjustClassDistribution) {
            for (int i = 0; i < distribution.length; i++) {
                if (distribution[i] != 0.0d) {
                    int i2 = i;
                    distribution[i2] = distribution[i2] / this.numObjectsPerClass[i];
                }
                log.info(String.format(Locale.ENGLISH, " %d: %.4f", Integer.valueOf(this.classes[i]), Double.valueOf(distribution[i])));
            }
        }
        int i3 = -1;
        double d = 0.0d;
        double d2 = 0.0d;
        Sampler.sample(this.classAccess, this.classAccess);
        for (int i4 = 0; i4 < distribution.length; i4++) {
            if (distribution[this.classAccess[i4]] > d) {
                d2 = d;
                d = distribution[this.classAccess[i4]];
                i3 = this.classes[this.classAccess[i4]];
            } else if (distribution[this.classAccess[i4]] >= d2) {
                d2 = distribution[this.classAccess[i4]];
            }
        }
        if (this.k == 1) {
            this.certainty = 1.0d;
        } else {
            this.certainty = d2 == 0.0d ? this.maxCertainties[this.class2Index.get(Integer.valueOf(i3)).intValue()] : d / d2;
        }
        log.info("b: " + d + ", sb: " + d2);
        return i3;
    }

    /* JADX WARN: Multi-variable type inference failed */
    public <S extends DataObject> int classify(MultiInstanceObject<S> multiInstanceObject, double[][] dArr, Set<String> set) {
        if (!$assertionsDisabled && dArr[0].length != this.classes.length) {
            throw new AssertionError();
        }
        double[] dArr2 = new double[this.classes.length];
        double[] dArr3 = new double[this.classes.length];
        int i = this.k;
        int i2 = -1;
        while (true) {
            log.info("k=" + this.k);
            this.certainty = 0.0d;
            double d = 0.0d;
            for (int i3 = 0; i3 < this.classes.length; i3++) {
                System.out.println("class " + i3);
                for (int i4 = 0; i4 < dArr.length; i4++) {
                    multiInstanceObject.instances().get(i4).setWeight(dArr[i4][i3]);
                }
                dArr2[i3] = getDistribution(multiInstanceObject, set);
                if (this.adjustClassDistribution) {
                    for (int i5 = 0; i5 < dArr2.length; i5++) {
                        if (dArr2[i3][i5] != 0.0d) {
                            double[] dArr4 = dArr2[i3];
                            int i6 = i5;
                            dArr4[i6] = dArr4[i6] / this.numObjectsPerClass[i5];
                        }
                        log.info(String.format(Locale.ENGLISH, " %d: %.4f", Integer.valueOf(this.classes[i5]), Double.valueOf(dArr2[i3][i5])));
                    }
                }
                int i7 = -1;
                double d2 = 0.0d;
                double d3 = 0.0d;
                Sampler.sample(this.classAccess, this.classAccess);
                for (int i8 = 0; i8 < dArr2[i3].length; i8++) {
                    if (dArr2[i3][this.classAccess[i8]] > d2) {
                        d3 = d2;
                        d2 = dArr2[i3][this.classAccess[i8]];
                        i7 = this.classes[this.classAccess[i8]];
                    } else if (dArr2[i3][this.classAccess[i8]] >= d3) {
                        d3 = dArr2[i3][this.classAccess[i8]];
                    }
                }
                if (this.k == 1) {
                    dArr3[i3] = 1.0d;
                } else {
                    dArr3[i3] = d3 == 0.0d ? Double.POSITIVE_INFINITY : d2 / d3;
                }
                if (dArr3[i3] > this.certainty) {
                    i2 = i7;
                    this.certainty = dArr3[i3];
                } else if (dArr3[i3] >= d) {
                    d = this.certainty;
                }
            }
            if (this.certainty == d && this.k != this.db.getCount() - set.size()) {
                this.k *= this.k;
                if (this.k > this.db.getCount() - set.size()) {
                    this.k = this.db.getCount() - set.size();
                }
            }
        }
        if (this.k != i) {
            log.info("had to reset k from " + this.k);
            this.k = i;
        }
        return i2;
    }

    @Override // de.lmu.ifi.dbs.dm.algorithms.Classifier
    public double[] getDistribution(T t) {
        double[] dArr = new double[this.db.getNumClasses()];
        DistanceMeasure<T> distanceMeasure = null;
        if (this.dm != null) {
            distanceMeasure = this.db.getDistanceMeasure();
            this.db.setDistanceMeasure(this.dm);
        }
        List<T> kNNQuery = this.db.kNNQuery((Database<T>) t, this.k);
        if (distanceMeasure != null) {
            this.db.setDistanceMeasure(distanceMeasure);
        }
        Iterator<T> it = kNNQuery.iterator();
        while (it.hasNext()) {
            int intValue = this.class2Index.get(Integer.valueOf(it.next().getClassNr())).intValue();
            dArr[intValue] = dArr[intValue] + 1.0d;
        }
        return dArr;
    }

    public double[] getDistributionAsLongAsPossible(T t, Set<String> set) {
        T next;
        double[] dArr = new double[this.db.getNumClasses()];
        this.db.reset();
        int i = 0;
        int i2 = -1;
        boolean z = true;
        DistanceMeasure<T> distanceMeasure = null;
        if (this.dm != null) {
            distanceMeasure = this.db.getDistanceMeasure();
            this.db.setDistanceMeasure(this.dm);
        }
        while (true) {
            if ((i < this.k || z) && (next = this.db.getNext(t, Double.MAX_VALUE, null)) != null) {
                if (!set.contains(next.getPrimaryKey())) {
                    int intValue = this.class2Index.get(Integer.valueOf(next.getClassNr())).intValue();
                    dArr[intValue] = dArr[intValue] + 1.0d;
                    i++;
                    if (i2 != -1 && next.getClassNr() != i2) {
                        z = false;
                    }
                    i2 = next.getClassNr();
                }
            }
        }
        if (distanceMeasure != null) {
            this.db.setDistanceMeasure(distanceMeasure);
        }
        if (i < this.k) {
            throw new IllegalArgumentException("Only " + i + " objects included in DB of size" + this.db.getCount() + " (k=" + this.k + ")");
        }
        return dArr;
    }

    @Override // de.lmu.ifi.dbs.dm.algorithms.CVClassifier
    public double[] getDistribution(T t, Set<String> set) {
        T next;
        double[] dArr = new double[this.db.getNumClasses()];
        this.db.reset();
        int i = 0;
        DistanceMeasure<T> distanceMeasure = null;
        if (this.dm != null) {
            distanceMeasure = this.db.getDistanceMeasure();
            this.db.setDistanceMeasure(this.dm);
        }
        while (i < this.k && (next = this.db.getNext(t, Double.MAX_VALUE, null)) != null) {
            if (!set.contains(next.getPrimaryKey())) {
                int intValue = this.class2Index.get(Integer.valueOf(next.getClassNr())).intValue();
                dArr[intValue] = dArr[intValue] + 1.0d;
                i++;
            }
        }
        if (distanceMeasure != null) {
            this.db.setDistanceMeasure(distanceMeasure);
        }
        if (i < this.k) {
            throw new IllegalArgumentException("Only " + i + " objects included in DB of size" + this.db.getCount() + " (k=" + this.k + ")");
        }
        return dArr;
    }

    @Override // de.lmu.ifi.dbs.dm.algorithms.Classifier
    public void saveClassifier(String str) {
        throw new UnsupportedOperationException();
    }

    @Override // de.lmu.ifi.dbs.dm.algorithms.CVClassifier
    public double getCertainty() {
        return this.certainty;
    }

    public final int getK() {
        return this.k;
    }

    public final void setK(int i) {
        this.k = i;
    }

    @Override // de.lmu.ifi.dbs.dm.algorithms.CVClassifier
    public void setDistanceMeasure(DistanceMeasure<T> distanceMeasure) {
        this.dm = distanceMeasure;
    }

    @Override // de.lmu.ifi.dbs.dm.algorithms.CVClassifier
    public final int[] getClasses() {
        return this.classes;
    }

    @Override // de.lmu.ifi.dbs.dm.algorithms.CVClassifier
    public Map<Integer, Integer> getClass2Index() {
        return this.class2Index;
    }
}
