package dm.algorithms;

import dm.data.DataObject;
import dm.data.DistanceMeasure;
import dm.data.MIObjects.ConvolutionDist;
import dm.data.MIObjects.MultiInstanceObject;
import dm.data.MIObjects.PWilcoxMIDM;
import dm.data.database.Database;
import dm.util.PriorityQueue;
import dm.util.math.Sampler;
import ir.utils.DataConverter;
import ir.utils.RankingObject;
import ir.utils.UpdatablePriorityQueue;
import ir.utils.statistics.StatisticalQueryResult;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Locale;
import java.util.Map;
import java.util.Set;

/* loaded from: input_file:dm/algorithms/PrecRecClassifier.class */
public class PrecRecClassifier<T extends DataObject> implements CVClassifier<T> {
    private Database<T> db;
    private int[] numObjectsPerClass;
    private int[] classes;
    private int[] classAccess;
    private double certainty;
    public static final int MAP = 0;
    public static final int FIRST_X_RECALL = 1;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* renamed from: dm, reason: collision with root package name */
    private DistanceMeasure<T> f1dm = null;
    private Map<Integer, Integer> class2Index = new HashMap();
    private double[][] weights = null;
    private double[] recallBins = {0.25d, 0.5d, 0.75d, 1.0d};
    private int decisionMethod = 0;
    private double recallSelection = 0.25d;
    private int firstXPosition = 0;
    public boolean VERBOSE = false;
    private boolean getSubstituteDVs = false;

    static {
        $assertionsDisabled = !PrecRecClassifier.class.desiredAssertionStatus();
    }

    public PrecRecClassifier(Database<T> database) {
        this.db = database;
        Iterator<T> objectIterator = database.objectIterator();
        this.numObjectsPerClass = new int[database.getNumClasses()];
        this.classes = new int[this.numObjectsPerClass.length];
        HashSet hashSet = new HashSet();
        while (objectIterator.hasNext()) {
            hashSet.add(Integer.valueOf(objectIterator.next().getClassNr()));
        }
        Integer[] numArr = new Integer[hashSet.size()];
        hashSet.toArray(numArr);
        Arrays.sort(numArr);
        for (int i = 0; i < numArr.length; i++) {
            this.classes[i] = numArr[i].intValue();
        }
        if (!$assertionsDisabled && numArr.length != database.getNumClasses()) {
            throw new AssertionError();
        }
        for (int i2 = 0; i2 < this.classes.length; i2++) {
            this.class2Index.put(Integer.valueOf(this.classes[i2]), Integer.valueOf(i2));
            this.numObjectsPerClass[i2] = database.getMemberCount(this.classes[i2]);
        }
        this.classAccess = new int[this.classes.length];
        if (!$assertionsDisabled && this.class2Index.size() != database.getNumClasses()) {
            throw new AssertionError();
        }
        System.out.println("from Constructor: classes: {" + this.classes[0] + ", " + this.classes[1] + "}");
    }

    private double getDecisionValue(StatisticalQueryResult statisticalQueryResult) {
        if (this.decisionMethod == 0) {
            return statisticalQueryResult.meanAveragePrecision;
        }
        if (this.decisionMethod != 1) {
            throw new IllegalArgumentException("Option \"" + this.decisionMethod + " not yet known.");
        }
        if (this.getSubstituteDVs) {
            return getSubstituteDecisionValue(statisticalQueryResult);
        }
        if (this.recallBins[this.firstXPosition] != this.recallSelection) {
            throw new IllegalArgumentException("cannot retrieve the precision for " + this.recallSelection + " recall if corresponding bin is missing");
        }
        return statisticalQueryResult.pRFixBins.get(this.firstXPosition).precision;
    }

    private double getSubstituteDecisionValue(StatisticalQueryResult statisticalQueryResult) {
        for (int i = this.firstXPosition; i < statisticalQueryResult.pRFixBins.size(); i++) {
            if (statisticalQueryResult.pRFixBins.get(i).precision > 0.0d) {
                return statisticalQueryResult.pRFixBins.get(i).precision;
            }
        }
        System.err.println("Warning: no bin with precision > 0 starting from " + this.firstXPosition);
        return 0.0d;
    }

    @Override // dm.algorithms.CVClassifier
    public int classify(T t, Set<String> set) {
        double[] dArr;
        double[] dArr2 = (double[]) null;
        if (this.weights == null) {
            dArr = getDistribution(t, set);
        } else {
            dArr = new double[this.classes.length];
            dArr2 = new double[this.classes.length];
            for (int i = 0; i < this.classes.length; i++) {
                System.out.println("doing class " + this.classes[i]);
                if (!(t instanceof MultiInstanceObject)) {
                    throw new IllegalArgumentException("Can currently only calculate weighted classification on MultiInstanceObjects");
                }
                System.out.print("weights: ");
                for (int i2 = 0; i2 < this.weights.length; i2++) {
                    ((MultiInstanceObject) t).instances().get(i2).setWeight(this.weights[i2][i]);
                    System.out.print(String.format(Locale.ENGLISH, " %.4f", Double.valueOf(((MultiInstanceObject) t).instances().get(i2).getWeight())));
                }
                if ((this.db.getDistanceMeasure() instanceof PWilcoxMIDM) && !(this.f1dm instanceof ConvolutionDist)) {
                    ((PWilcoxMIDM) this.db.getDistanceMeasure()).assignObjectWeight((MultiInstanceObject) t);
                }
                if (this.f1dm instanceof ConvolutionDist) {
                    ConvolutionDist.calcNormalization(this.db, set, ((ConvolutionDist) this.f1dm).getKernel());
                }
                double[] distribution = getDistribution(t, set);
                double d = 0.0d;
                int i3 = -1;
                Sampler.sample(this.classAccess, this.classAccess);
                for (int i4 = 0; i4 < distribution.length; i4++) {
                    if (distribution[this.classAccess[i4]] > d) {
                        d = distribution[this.classAccess[i4]];
                        i3 = this.classes[this.classAccess[i4]];
                    }
                    int i5 = i4;
                    dArr2[i5] = dArr2[i5] + distribution[i4];
                }
                if (i3 != this.classes[i]) {
                    System.out.println("ERR: no match for class " + this.classes[i] + " MAP=" + distribution[i] + "; class " + i3 + " was better with MAP=" + d);
                } else {
                    dArr[i] = d;
                }
            }
        }
        int i6 = -1;
        double d2 = 0.0d;
        double d3 = 0.0d;
        Sampler.sample(this.classAccess, this.classAccess);
        for (int i7 = 0; i7 < dArr.length; i7++) {
            System.out.print(" d" + dArr[i7]);
            if (dArr[this.classAccess[i7]] > d2) {
                d3 = d2;
                d2 = dArr[this.classAccess[i7]];
                i6 = this.classes[this.classAccess[i7]];
            } else if (dArr[this.classAccess[i7]] >= d3) {
                d3 = dArr[this.classAccess[i7]];
            }
        }
        if (i6 == -1) {
            d2 = 0.0d;
            d3 = 0.0d;
            for (int i8 = 0; i8 < dArr2.length; i8++) {
                System.out.print(" d" + dArr2[i8]);
                if (dArr2[this.classAccess[i8]] > d2) {
                    d3 = d2;
                    d2 = dArr2[this.classAccess[i8]];
                    i6 = this.classes[this.classAccess[i8]];
                } else if (dArr2[this.classAccess[i8]] >= d3) {
                    d3 = dArr2[this.classAccess[i8]];
                }
            }
        }
        if (i6 == -1 && this.decisionMethod == 1 && d2 == 0.0d) {
            this.getSubstituteDVs = true;
            int classify = classify(t, set);
            this.getSubstituteDVs = false;
            System.out.println("RESCUE CALL");
            return classify;
        }
        if (!$assertionsDisabled && i6 == -1) {
            throw new AssertionError();
        }
        this.certainty = d3 == 0.0d ? d2 == 0.0d ? 0.0d : Double.POSITIVE_INFINITY : d2 / d3;
        System.out.println("b: " + d2 + ", sb: " + d3 + "wc: " + i6 + "ct: " + this.certainty);
        return i6;
    }

    @Override // dm.algorithms.CVClassifier
    public double getCertainty() {
        return this.certainty;
    }

    public double[] getDistributionWITHOUTTIES(T t, Set<String> set) {
        double[] dArr = new double[this.classes.length];
        UpdatablePriorityQueue updatablePriorityQueue = new UpdatablePriorityQueue(true);
        DistanceMeasure<T> distanceMeasure = this.f1dm;
        if (this.f1dm == null) {
            distanceMeasure = this.db.getDistanceMeasure();
        }
        Iterator<T> objectIterator = this.db.objectIterator();
        while (objectIterator.hasNext()) {
            T next = objectIterator.next();
            if (set == null || !set.contains(next.getPrimaryKey())) {
                updatablePriorityQueue.insertIfBetter(new RankingObject(next.getPrimaryKey(), distanceMeasure.distance(t, next), next.getClassNr()));
            }
        }
        for (int i = 0; i < this.classes.length; i++) {
            int[] ranks = DataConverter.getRanks(updatablePriorityQueue.duplicate(), t.getPrimaryKey(), this.classes[i]);
            StatisticalQueryResult calcHitArray = new StatisticalQueryResult(t.getPrimaryKey(), this.recallBins).calcHitArray(ranks);
            dArr[i] = getDecisionValue(calcHitArray);
            if (this.VERBOSE) {
                System.out.print("\t for class " + this.classes[i] + ": " + calcHitArray.toString() + "\n\tha:");
                for (int i2 : ranks) {
                    System.out.print(" " + i2);
                }
                System.out.println("\n" + dArr[i]);
            }
        }
        return dArr;
    }

    @Override // dm.algorithms.CVClassifier
    public double[] getDistribution(T t, Set<String> set) {
        double[] dArr = new double[this.classes.length];
        PriorityQueue priorityQueue = new PriorityQueue(true, this.db.getCount() - (set == null ? 0 : set.size()));
        DistanceMeasure<T> distanceMeasure = this.f1dm;
        if (this.f1dm == null) {
            distanceMeasure = this.db.getDistanceMeasure();
        }
        Iterator<T> objectIterator = this.db.objectIterator();
        while (objectIterator.hasNext()) {
            T next = objectIterator.next();
            if (set == null || !set.contains(next.getPrimaryKey())) {
                priorityQueue.add(distanceMeasure.distance(t, next), next);
            }
        }
        for (int i = 0; i < this.classes.length; i++) {
            StatisticalQueryResult calcHitArray = new StatisticalQueryResult(t.getPrimaryKey(), this.recallBins).calcHitArray(DataConverter.getRanks(priorityQueue.copy(), this.classes[i]));
            dArr[i] = getDecisionValue(calcHitArray);
            if (this.VERBOSE) {
                System.out.println("\t for class " + this.classes[i] + ": " + calcHitArray.toString());
                System.out.println(dArr[i]);
            }
        }
        return dArr;
    }

    @Override // dm.algorithms.Classifier
    public int classify(T t) {
        return classify(t, null);
    }

    @Override // dm.algorithms.Classifier
    public double[] getDistribution(T t) {
        return getDistribution(t, null);
    }

    @Override // dm.algorithms.Classifier
    public void saveClassifier(String str) {
        throw new UnsupportedOperationException();
    }

    public void setWeights(double[][] dArr) {
        this.weights = dArr;
    }

    public void setRecallBins(double[] dArr) {
        this.recallBins = dArr;
        this.firstXPosition = 0;
        while (this.firstXPosition < dArr.length - 1 && dArr[this.firstXPosition] < this.recallSelection) {
            this.firstXPosition++;
        }
    }

    public void setRecallSelection(double d) {
        if (this.recallSelection != d) {
            this.firstXPosition = 0;
            while (this.firstXPosition < this.recallBins.length - 1 && this.recallBins[this.firstXPosition] < d) {
                this.firstXPosition++;
            }
        }
        this.recallSelection = d;
    }

    public final int getDecisionMethod() {
        return this.decisionMethod;
    }

    public final void setDecisionMethod(int i) {
        this.decisionMethod = i;
    }

    @Override // dm.algorithms.CVClassifier
    public void setDistanceMeasure(DistanceMeasure<T> distanceMeasure) {
        this.f1dm = distanceMeasure;
    }
}
