package dm.algorithms;

import dm.data.DataObject;
import dm.data.database.Database;
import dm.util.MathUtil;
import dm.util.PriorityQueue;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.util.Iterator;

/* loaded from: input_file:dm/algorithms/MBKNN.class */
public class MBKNN implements Classifier {
    public Database[] data;
    public int k;

    public MBKNN(int i) {
        this.k = i;
    }

    public MBKNN(String str) {
        try {
            String str2 = String.valueOf(str) + ".clf";
            System.out.println("Read " + str2);
            FileInputStream fileInputStream = new FileInputStream(str2);
            ObjectInputStream objectInputStream = new ObjectInputStream(fileInputStream);
            this.k = objectInputStream.readInt();
            int readInt = objectInputStream.readInt();
            this.data = new Database[readInt];
            for (int i = 0; i < readInt; i++) {
                this.data[i] = (Database) objectInputStream.readObject();
            }
            objectInputStream.close();
            fileInputStream.close();
        } catch (Exception e) {
            e.printStackTrace();
            System.exit(1);
        }
    }

    public void train(Database[] databaseArr, ModelMaker modelMaker, double d) {
        this.data = modelMaker.make(databaseArr, d);
        double d2 = Double.MAX_VALUE;
        double d3 = 0.0d;
        for (int i = 0; i < this.data.length; i++) {
            d2 = Math.min(d2, this.data[i].getCount());
            d3 += 1.0d;
        }
    }

    @Override // dm.algorithms.Classifier
    public int classify(DataObject dataObject) {
        return MathUtil.argmax(getDistribution(dataObject));
    }

    public double[] confidence(DataObject dataObject) {
        double[] dArr = new double[this.data.length];
        int i = 0;
        for (int i2 = 0; i2 < this.data.length; i2++) {
            i += this.data[i2].getCount();
        }
        PriorityQueue priorityQueue = new PriorityQueue(true, i);
        PriorityQueue priorityQueue2 = new PriorityQueue(false, this.k);
        for (int i3 = 0; i3 < this.data.length; i3++) {
            Iterator it = this.data[i3].savekNNQuery(dataObject, this.k).iterator();
            while (it.hasNext()) {
                double distance = this.data[i3].getDistanceMeasure().distance((DataObject) it.next(), dataObject);
                priorityQueue.add(distance, new Integer(i3));
                if (priorityQueue2.size() < this.k) {
                    priorityQueue2.add(distance, new Integer(i3));
                } else if (priorityQueue2.firstPriority() > distance) {
                    priorityQueue2.removeFirst();
                    priorityQueue2.add(distance, new Integer(i3));
                }
            }
        }
        Double d = new Double(priorityQueue2.firstPriority());
        while (!priorityQueue.isEmpty() && (d.equals(new Double(priorityQueue.firstPriority())) || priorityQueue.firstPriority() < priorityQueue2.firstPriority())) {
            double firstPriority = priorityQueue.firstPriority();
            if (firstPriority == 0.0d) {
                firstPriority = 1.0E-5d;
            }
            int intValue = ((Integer) priorityQueue.removeFirst()).intValue();
            dArr[intValue] = dArr[intValue] + (1.0d / (firstPriority * firstPriority));
        }
        MathUtil.normalize(dArr);
        return dArr;
    }

    @Override // dm.algorithms.Classifier
    public double[] getDistribution(DataObject dataObject) {
        return confidence(dataObject);
    }

    @Override // dm.algorithms.Classifier
    public void saveClassifier(String str) {
        try {
            System.out.println("Saving " + str + ".clf");
            FileOutputStream fileOutputStream = new FileOutputStream(String.valueOf(str) + ".clf");
            ObjectOutputStream objectOutputStream = new ObjectOutputStream(fileOutputStream);
            objectOutputStream.writeInt(this.k);
            objectOutputStream.writeInt(this.data.length);
            for (int i = 0; i < this.data.length; i++) {
                objectOutputStream.writeObject(this.data[i]);
            }
            objectOutputStream.flush();
            objectOutputStream.close();
            fileOutputStream.flush();
            fileOutputStream.close();
        } catch (IOException e) {
            e.printStackTrace();
        }
    }
}
