package dm.algorithms;

import dm.data.DataObject;
import dm.data.MIObjects.MultiInstanceObject;
import dm.data.database.Database;
import dm.data.featureVector.EuclidianDistance;
import dm.data.featureVector.dot;
import dm.util.LM;
import dm.util.PriorityQueue;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import java.util.Random;

/* loaded from: input_file:dm/algorithms/MiInstanceDistLearner.class */
public class MiInstanceDistLearner {
    public static final int COMPRESSION = 1;
    public static double SMOOTHER = 1.0E-5d;
    private static double SQRT_2PI = Math.sqrt(6.283185307179586d);
    public double[] param;
    double[] simParam;
    double[] dissimParam;
    Random rand = new Random();

    public MiInstanceDistLearner(Database[] databaseArr) {
        double d = 0.0d;
        for (Database database : databaseArr) {
            d += database.getCount();
        }
        System.out.println("Kernel-Tunining: Training on " + d + " instances!");
        initParam2(databaseArr);
    }

    public MiInstanceDistLearner(Database database, int i) {
        System.out.println("Kernel-Tunining: Training on " + database.getCount() + " instances!");
        HashMap[] hashMapArr = new HashMap[i];
        for (int i2 = 0; i2 < i; i2++) {
            hashMapArr[i2] = new HashMap();
        }
        fillDistances(hashMapArr, database);
        System.out.println("Distances selected " + i + " x " + hashMapArr[0].size());
        double[][] InitModel = InitModel(hashMapArr);
        System.out.println("Parallel Models initialized !");
        double d = Double.MAX_VALUE;
        int i3 = 0;
        do {
            int i4 = i3;
            i3++;
            System.out.println("Iteration " + i4);
            double d2 = d;
            HashMap hashMap = new HashMap();
            d = calcVariance(InitModel, hashMapArr, hashMap);
            System.out.println("OldVariance " + d2 + " new Var " + d);
            System.out.println("Object Similarities estimated");
            InitModel = updateModell(hashMap, hashMapArr);
            System.out.println("Modell updated");
            if (d2 - d <= 0.001d) {
                break;
            }
        } while (i3 < 100);
        this.param = InitModel[0];
    }

    private double[][] updateModell(Map map, Map[] mapArr) {
        double[][] dArr = new double[mapArr.length][2];
        for (int i = 0; i < mapArr.length; i++) {
            double d = 0.0d;
            double d2 = 0.0d;
            double d3 = 0.0d;
            double d4 = 0.0d;
            double d5 = 0.0d;
            double d6 = 0.0d;
            for (Map.Entry entry : mapArr[i].entrySet()) {
                String str = (String) entry.getKey();
                double doubleValue = ((Double) entry.getValue()).doubleValue();
                double doubleValue2 = ((Double) map.get(str)).doubleValue();
                d += doubleValue * doubleValue2;
                d2 += doubleValue2 * doubleValue * doubleValue;
                d3 += doubleValue2;
                double d7 = 1.0d - doubleValue2;
                d4 += d7 * doubleValue;
                d5 += d7 * doubleValue * doubleValue;
                d6 += d7;
            }
            double d8 = d / d3;
            double d9 = (d2 / d3) - (d8 * d8);
            double d10 = d4 / d6;
            dArr[i] = fitSigmoid(d8, d9, d10, (d5 / d6) - (d10 * d10));
            System.out.println("Model in Rep " + i + ": a " + dArr[i][0] + " b " + dArr[i][1]);
        }
        return dArr;
    }

    private double calcVariance(double[][] dArr, Map[] mapArr, Map map) {
        double d = 0.0d;
        double d2 = 0.0d;
        for (String str : mapArr[0].keySet()) {
            double d3 = 0.0d;
            double d4 = 0.0d;
            for (int i = 0; i < dArr.length; i++) {
                double sigmoid = sigmoid(((Double) mapArr[i].get(str)).doubleValue(), dArr[i]);
                d3 += sigmoid;
                d4 += sigmoid * sigmoid;
            }
            double length = dArr.length;
            double d5 = d3 / length;
            map.put(str, Double.valueOf(d5));
            d += (d4 / length) - (d5 * d5);
            d2 += 1.0d;
        }
        return d / d2;
    }

    private double sigmoid(double d, double[] dArr) {
        return 1.0d / (1.0d + Math.exp((dArr[0] * d) + dArr[1]));
    }

    private double[][] InitModel(Map[] mapArr) {
        double[][] dArr = new double[mapArr.length][2];
        for (int i = 0; i < mapArr.length; i++) {
            PriorityQueue priorityQueue = new PriorityQueue(mapArr[0].size());
            Iterator it = mapArr[i].values().iterator();
            while (it.hasNext()) {
                priorityQueue.add(((Double) it.next()).doubleValue(), null);
            }
            double d = 0.0d;
            double d2 = 0.0d;
            double d3 = 0.0d;
            double d4 = 0.0d;
            int size = (int) (priorityQueue.size() * 0.1d);
            int i2 = 0;
            while (!priorityQueue.isEmpty()) {
                double firstPriority = priorityQueue.firstPriority();
                priorityQueue.removeFirst();
                if (i2 < size) {
                    d += firstPriority;
                    d2 += firstPriority * firstPriority;
                } else {
                    d3 += firstPriority;
                    d4 += firstPriority * firstPriority;
                }
                i2++;
            }
            double d5 = d / size;
            double d6 = (d2 / size) - (d5 * d5);
            double d7 = d3 / (i2 - size);
            double d8 = (d4 / size) - (d7 * d7);
            System.out.println("Sim mu" + d5 + " sVar " + d6 + " DissimMu " + d7 + " dVar " + d8);
            dArr[i] = fitSigmoid(d5, d6, d7, d8);
            System.out.println("Model in " + i + " a " + dArr[i][0] + " b " + dArr[i][1]);
        }
        return dArr;
    }

    private double[] fitSigmoid(double d, double d2, double d3, double d4) {
        PriorityQueue priorityQueue = new PriorityQueue(false, 2000);
        double d5 = 0.0d;
        for (int i = 0; i < 2000; i++) {
            double generateGaussPunkt = generateGaussPunkt(d, d2);
            double NormVert = NormVert(generateGaussPunkt, d, d2) * NormVert(generateGaussPunkt, d3, d4);
            priorityQueue.add(generateGaussPunkt, new Double(NormVert));
            d5 += NormVert;
        }
        int size = priorityQueue.size() / 1;
        double[] dArr = new double[size];
        double[] dArr2 = new double[size];
        double d6 = 0.0d;
        int i2 = 0;
        int i3 = 0;
        while (!priorityQueue.isEmpty()) {
            double firstPriority = priorityQueue.firstPriority();
            d6 += ((Double) priorityQueue.removeFirst()).doubleValue();
            if (i2 % 1 == 0 && i3 < size) {
                dArr[i3] = firstPriority;
                int i4 = i3;
                i3++;
                dArr2[i4] = d6 / d5;
            }
            i2++;
        }
        this.param = LM.fitSigmoid(dArr, dArr2);
        return this.param;
    }

    private String generateKey(String str, String str2) {
        return String.valueOf(str) + ";" + str2;
    }

    private void fillDistances(Map<String, Double>[] mapArr, Database database) {
        EuclidianDistance euclidianDistance = new EuclidianDistance();
        int i = 0;
        Iterator objectIterator = database.objectIterator();
        while (objectIterator.hasNext()) {
            MultiInstanceObject multiInstanceObject = (MultiInstanceObject) objectIterator.next();
            if (multiInstanceObject.instances().size() >= mapArr.length) {
                Iterator objectIterator2 = database.objectIterator();
                while (objectIterator2.hasNext()) {
                    MultiInstanceObject multiInstanceObject2 = (MultiInstanceObject) objectIterator2.next();
                    if (multiInstanceObject2.instances().size() >= mapArr.length && multiInstanceObject.getPrimaryKey().compareTo(multiInstanceObject2.getPrimaryKey()) <= 0) {
                        PriorityQueue priorityQueue = new PriorityQueue(false, mapArr.length);
                        for (DataObject dataObject : multiInstanceObject.instances()) {
                            double d = Double.MAX_VALUE;
                            String str = "";
                            for (DataObject dataObject2 : multiInstanceObject2.instances()) {
                                double distance = euclidianDistance.distance(dataObject, dataObject2);
                                if (d > distance) {
                                    d = distance;
                                    str = dataObject2.getPrimaryKey();
                                }
                            }
                            String generateKey = generateKey(dataObject.getPrimaryKey(), str);
                            if (priorityQueue.size() < mapArr.length) {
                                priorityQueue.add(d, generateKey);
                            } else if (priorityQueue.firstPriority() > d) {
                                priorityQueue.removeFirst();
                                priorityQueue.add(d, generateKey);
                            }
                        }
                        String generateKey2 = generateKey(multiInstanceObject.getPrimaryKey(), multiInstanceObject2.getPrimaryKey());
                        int i2 = 0;
                        while (!priorityQueue.isEmpty()) {
                            mapArr[i2].put(generateKey2, new Double(priorityQueue.firstPriority()));
                            i2++;
                            priorityQueue.removeFirst();
                        }
                        if (i2 != mapArr.length) {
                            System.exit(0);
                        }
                        i++;
                        if (i % 100 == 0) {
                            System.out.println(String.valueOf((200.0d * i) / (database.getCount() * database.getCount())) + " % ");
                        }
                    }
                }
            }
        }
        System.out.println();
        for (Map<String, Double> map : mapArr) {
            System.out.println("Rep : " + map.size());
        }
    }

    public double NormVert(double d, double d2, double d3) {
        if (d3 != 0.0d) {
            return Math.max((1.0d / (SQRT_2PI * d3)) * Math.exp(((-(d - d2)) * (d - d2)) / (2.0d * d3)), SMOOTHER);
        }
        if (d == d2) {
            return 1.0d;
        }
        return SMOOTHER;
    }

    public void initParam(Database[] databaseArr) {
        double d = 0.0d;
        determineWeightedGaussians(databaseArr);
        PriorityQueue priorityQueue = new PriorityQueue(false, 2000);
        databaseArr[0].getDistanceMeasure();
        for (int i = 0; i < 2000; i++) {
            double generateGaussPunkt = generateGaussPunkt(this.simParam[0], this.simParam[1]);
            double NormVert = NormVert(generateGaussPunkt, this.simParam[0], this.simParam[1]) * NormVert(generateGaussPunkt, this.dissimParam[0], this.dissimParam[1]);
            priorityQueue.add(generateGaussPunkt, new Double(NormVert));
            d += NormVert;
        }
        int size = priorityQueue.size() / 1;
        double[] dArr = new double[size];
        double[] dArr2 = new double[size];
        double d2 = 0.0d;
        int i2 = 0;
        int i3 = 0;
        while (!priorityQueue.isEmpty()) {
            double firstPriority = priorityQueue.firstPriority();
            d2 += ((Double) priorityQueue.removeFirst()).doubleValue();
            if (i2 % 1 == 0 && i3 < size) {
                dArr[i3] = firstPriority;
                int i4 = i3;
                i3++;
                dArr2[i4] = d2 / d;
            }
            i2++;
        }
        this.param = LM.fitSigmoid(dArr, dArr2);
        System.out.println("Param alpha = " + this.param[0] + " beta " + this.param[1]);
    }

    private void initParam2(Database[] databaseArr) {
        double d = 0.0d;
        determineGaussians(databaseArr);
        PriorityQueue priorityQueue = new PriorityQueue(false, 2000);
        databaseArr[0].getDistanceMeasure();
        for (int i = 0; i < 2000; i++) {
            double generateGaussPunkt = generateGaussPunkt(this.simParam[0], this.simParam[1]);
            double NormVert = NormVert(generateGaussPunkt, this.simParam[0], this.simParam[1]) * NormVert(generateGaussPunkt, this.dissimParam[0], this.dissimParam[1]);
            priorityQueue.add(generateGaussPunkt, new Double(NormVert));
            d += NormVert;
        }
        int size = priorityQueue.size() / 1;
        double[] dArr = new double[size];
        double[] dArr2 = new double[size];
        double d2 = 0.0d;
        int i2 = 0;
        int i3 = 0;
        while (!priorityQueue.isEmpty()) {
            double firstPriority = priorityQueue.firstPriority();
            d2 += ((Double) priorityQueue.removeFirst()).doubleValue();
            if (i2 % 1 == 0 && i3 < size) {
                dArr[i3] = firstPriority;
                int i4 = i3;
                i3++;
                dArr2[i4] = d2 / d;
            }
            i2++;
        }
        this.param = LM.fitSigmoid(dArr, dArr2);
        System.out.println("Param alpha = " + this.param[0] + " beta " + this.param[1]);
    }

    private double generateGaussPunkt(double d, double d2) {
        return (this.rand.nextGaussian() * Math.sqrt(d2)) + d;
    }

    private void determineWeightedGaussians(Database[] databaseArr) {
        new dot();
        this.simParam = new double[2];
        this.dissimParam = new double[2];
        double d = 0.0d;
        double d2 = 0.0d;
        int i = 0;
        int i2 = 0;
        EuclidianDistance euclidianDistance = new EuclidianDistance();
        for (int i3 = 0; i3 < databaseArr.length; i3++) {
            i += databaseArr[i3].getCount();
            i2 += databaseArr[i3].getCount() * databaseArr[i3].getCount();
            Iterator objectIterator = databaseArr[i3].objectIterator();
            while (objectIterator.hasNext()) {
                MultiInstanceObject multiInstanceObject = (MultiInstanceObject) objectIterator.next();
                for (int i4 = 0; i4 < databaseArr.length; i4++) {
                    Iterator objectIterator2 = databaseArr[i4].objectIterator();
                    while (objectIterator2.hasNext()) {
                        MultiInstanceObject multiInstanceObject2 = (MultiInstanceObject) objectIterator2.next();
                        if (multiInstanceObject.getPrimaryKey().compareTo(multiInstanceObject2.getPrimaryKey()) <= 0) {
                            double d3 = Double.MAX_VALUE;
                            double d4 = 0.0d;
                            for (DataObject dataObject : multiInstanceObject.instances()) {
                                for (DataObject dataObject2 : multiInstanceObject2.instances()) {
                                    double distance = euclidianDistance.distance(dataObject, dataObject2);
                                    double weight = dataObject.getWeight() * dataObject2.getWeight();
                                    if (distance < d3) {
                                        d3 = distance;
                                        d4 = weight;
                                    }
                                }
                            }
                            if (i3 == i4) {
                                double[] dArr = this.simParam;
                                dArr[0] = dArr[0] + (d3 * d4);
                                double[] dArr2 = this.simParam;
                                dArr2[1] = dArr2[1] + (d3 * d3 * d4);
                                d += d4;
                            } else {
                                double[] dArr3 = this.dissimParam;
                                dArr3[0] = dArr3[0] + (d3 * d4);
                                double[] dArr4 = this.dissimParam;
                                dArr4[1] = dArr4[1] + (d3 * d3 * d4);
                                d2 += d4;
                            }
                        }
                    }
                }
            }
        }
        double[] dArr5 = this.simParam;
        dArr5[0] = dArr5[0] / d;
        this.simParam[1] = (this.simParam[1] / d) - (this.simParam[0] * this.simParam[0]);
        double[] dArr6 = this.dissimParam;
        dArr6[0] = dArr6[0] / d2;
        this.dissimParam[1] = (this.dissimParam[1] / d2) - (this.dissimParam[0] * this.dissimParam[0]);
        System.out.println("Sim Mean " + this.simParam[0] + " Sim StDev " + Math.sqrt(this.simParam[1]));
        System.out.println("Dissim Mean " + this.dissimParam[0] + " Dissim StDev " + Math.sqrt(this.dissimParam[1]));
        System.out.println("min 1.7976931348623157E308 max 0.0");
    }

    private void determineGaussians(Database[] databaseArr) {
        new dot();
        this.simParam = new double[2];
        this.dissimParam = new double[2];
        double d = 0.0d;
        double d2 = 0.0d;
        int i = 0;
        int i2 = 0;
        EuclidianDistance euclidianDistance = new EuclidianDistance();
        int i3 = 0;
        while (i3 < databaseArr.length) {
            i += databaseArr[i3].getCount();
            i2 += databaseArr[i3].getCount() * databaseArr[i3].getCount();
            Iterator objectIterator = databaseArr[i3].objectIterator();
            while (objectIterator.hasNext()) {
                MultiInstanceObject multiInstanceObject = (MultiInstanceObject) objectIterator.next();
                int i4 = 0;
                while (i4 < databaseArr.length) {
                    boolean z = i3 == i4;
                    Iterator objectIterator2 = databaseArr[i4].objectIterator();
                    while (objectIterator2.hasNext()) {
                        MultiInstanceObject multiInstanceObject2 = (MultiInstanceObject) objectIterator2.next();
                        if (multiInstanceObject.getPrimaryKey().compareTo(multiInstanceObject2.getPrimaryKey()) > 0) {
                            double d3 = Double.MAX_VALUE;
                            for (DataObject dataObject : multiInstanceObject.instances()) {
                                Iterator it = multiInstanceObject2.instances().iterator();
                                while (it.hasNext()) {
                                    double distance = euclidianDistance.distance(dataObject, (DataObject) it.next());
                                    if (d3 > distance) {
                                        d3 = distance;
                                    }
                                }
                            }
                            if (z) {
                                double[] dArr = this.simParam;
                                dArr[0] = dArr[0] + d3;
                                double[] dArr2 = this.simParam;
                                dArr2[1] = dArr2[1] + (d3 * d3);
                                d += 1.0d;
                            } else {
                                double[] dArr3 = this.dissimParam;
                                dArr3[0] = dArr3[0] + d3;
                                double[] dArr4 = this.dissimParam;
                                dArr4[1] = dArr4[1] + (d3 * d3);
                                d2 += 1.0d;
                            }
                        }
                    }
                    i4++;
                }
            }
            i3++;
        }
        double[] dArr5 = this.simParam;
        dArr5[0] = dArr5[0] / d;
        this.simParam[1] = (this.simParam[1] / d) - (this.simParam[0] * this.simParam[0]);
        double[] dArr6 = this.dissimParam;
        dArr6[0] = dArr6[0] / d2;
        this.dissimParam[1] = (this.dissimParam[1] / d2) - (this.dissimParam[0] * this.dissimParam[0]);
        System.out.println("Sim Mean " + this.simParam[0] + " Sim StDev " + Math.sqrt(this.simParam[1]));
        System.out.println("Dissim Mean " + this.dissimParam[0] + " Dissim StDev " + Math.sqrt(this.dissimParam[1]));
        System.out.println("min 1.7976931348623157E308 max 0.0");
    }
}
