package de.lmu.ifi.dbs.dm.distance.mi;

import de.lmu.ifi.dbs.dm.DistanceMeasure;
import de.lmu.ifi.dbs.dm.Kernel;
import de.lmu.ifi.dbs.dm.data.DataObject;
import de.lmu.ifi.dbs.dm.data.MultiInstanceObject;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;

/* loaded from: input_file:de/lmu/ifi/dbs/dm/distance/mi/MMD.class */
public class MMD<S extends DataObject, T extends MultiInstanceObject<S>> implements MIDistanceMeasure<T> {
    private static final long serialVersionUID = 6632319465714157147L;
    private DistanceMeasure<S> dm;
    private Kernel<S> k;
    private boolean isNormalized;
    public static int FROM = -3;
    public static int TO = 3;

    public MMD(DistanceMeasure<S> distanceMeasure, Kernel<S> kernel) {
        this.dm = null;
        this.k = null;
        this.isNormalized = false;
        this.dm = distanceMeasure;
        this.k = kernel;
    }

    public MMD(DistanceMeasure<S> distanceMeasure) {
        this.dm = null;
        this.k = null;
        this.isNormalized = false;
        this.dm = distanceMeasure;
    }

    public MMD(Kernel<S> kernel) {
        this.dm = null;
        this.k = null;
        this.isNormalized = false;
        this.k = kernel;
    }

    public double distancePreNormalized(T t, T t2) {
        List instances = t.instances();
        List instances2 = t2.instances();
        int size = t.size();
        int size2 = t2.size();
        double weight = t.getWeight();
        double d = 0.0d;
        double weight2 = t2.getWeight();
        for (int i = 0; i < size; i++) {
            DataObject dataObject = (DataObject) instances.get(i);
            for (int i2 = 0; i2 < size2; i2++) {
                d += this.k.kernel(dataObject, (DataObject) instances2.get(i2));
            }
        }
        return Math.sqrt((((1.0d / (size * size)) * weight) - ((2.0d / (size * size2)) * weight2)) + ((1.0d / (size2 * size2)) * d));
    }

    @Override // de.lmu.ifi.dbs.dm.DistanceMeasure
    public double distance(T t, T t2) {
        if (this.k == null) {
            throw new IllegalArgumentException("Must first assign kernel function for calculating MMD.");
        }
        if (this.isNormalized) {
            return distancePreNormalized(t, t2);
        }
        int size = t.size();
        int size2 = t2.size();
        double d = 0.0d;
        double d2 = 0.0d;
        double d3 = 0.0d;
        Iterator it = t.iterator();
        while (it.hasNext()) {
            DataObject dataObject = (DataObject) it.next();
            Iterator it2 = t.iterator();
            while (it2.hasNext()) {
                d += this.k.kernel(dataObject, (DataObject) it2.next());
            }
            Iterator it3 = t2.iterator();
            while (it3.hasNext()) {
                d3 += this.k.kernel(dataObject, (DataObject) it3.next());
            }
        }
        Iterator it4 = t2.iterator();
        while (it4.hasNext()) {
            DataObject dataObject2 = (DataObject) it4.next();
            Iterator it5 = t2.iterator();
            while (it5.hasNext()) {
                d2 += this.k.kernel(dataObject2, (DataObject) it5.next());
            }
        }
        return Math.sqrt((((1 / (size * size)) * d) - ((2 / (size * size2)) * d3)) + ((1 / (size2 * size2)) * d2));
    }

    public double distanceAdaptive(T t, T t2) {
        int size = t.size();
        int size2 = t2.size();
        double[] dArr = new double[(size * size) + (size * size2) + (size2 * size2)];
        List instances = t.instances();
        List instances2 = t2.instances();
        for (int i = 0; i < size; i++) {
            DataObject dataObject = (DataObject) instances.get(i);
            dArr[size * i] = this.dm.distance(dataObject, dataObject);
            for (int i2 = i + 1; i2 < size; i2++) {
                dArr[(size * i) + i2] = this.dm.distance(dataObject, (DataObject) instances.get(i2));
                dArr[(size * i2) + i] = dArr[(size * i) + i2];
            }
            for (int i3 = 0; i3 < size2; i3++) {
                dArr[(size * size) + (size2 * i) + i3] = this.dm.distance(dataObject, (DataObject) instances2.get(i3));
            }
        }
        for (int i4 = 0; i4 < size2; i4++) {
            DataObject dataObject2 = (DataObject) instances2.get(i4);
            dArr[(size * (size + size2)) + (size2 * i4)] = this.dm.distance(dataObject2, dataObject2);
            for (int i5 = i4 + 1; i5 < size2; i5++) {
                dArr[(size * (size + size2)) + (size2 * i4) + i5] = this.dm.distance(dataObject2, (DataObject) instances2.get(i5));
                dArr[(size * (size + size2)) + (size2 * i5) + i4] = dArr[(size * (size + size2)) + (size2 * i4) + i5];
            }
        }
        double[] copyOf = Arrays.copyOf(dArr, dArr.length);
        Arrays.sort(copyOf);
        double d = copyOf.length % 2 == 0 ? (copyOf[copyOf.length / 2] + copyOf[(copyOf.length / 2) + 1]) / 2.0d : copyOf[(copyOf.length + 1) / 2];
        System.gc();
        System.out.println("Median distance for " + t.getPrimaryKey() + " and " + t2.getPrimaryKey() + ": " + d);
        if (d < 0.0d) {
            throw new RuntimeException("Distances must be positive (median is " + d + ")");
        }
        if (d == 0.0d) {
            return 0.0d;
        }
        double d2 = 0.0d;
        double d3 = 0.0d;
        double d4 = 0.0d;
        for (int i6 = 0; i6 < size; i6++) {
            d2 += Math.exp((-dArr[size * i6]) / d);
            for (int i7 = i6 + 1; i7 < size; i7++) {
                d2 += 2.0d * Math.exp((-dArr[(size * i6) + i7]) / d);
            }
            for (int i8 = 0; i8 < size2; i8++) {
                d4 += Math.exp((-dArr[((size * size) + (size2 * i6)) + i8]) / d);
            }
        }
        for (int i9 = 0; i9 < size2; i9++) {
            d3 += Math.exp(dArr[(size * (size + size2)) + (size2 * i9)] / d);
            for (int i10 = i9 + 1; i10 < size2; i10++) {
                d3 += 2.0d * Math.exp((-dArr[((size * (size + size2)) + (size2 * i9)) + i10]) / d);
            }
        }
        return Math.sqrt((((1 / (size * size)) * d2) - ((2 / (size * size2)) * d4)) + ((1 / (size2 * size2)) * d3));
    }

    @Override // de.lmu.ifi.dbs.dm.distance.mi.MIDistanceMeasure
    public DistanceMeasure<S> getInstanceDistance() {
        return this.dm;
    }

    @Override // de.lmu.ifi.dbs.dm.DistanceMeasure
    public String getName() {
        return "MMD";
    }

    public final boolean isNormalized() {
        return this.isNormalized;
    }

    public final void setNormalized(boolean z) {
        this.isNormalized = z;
    }

    public final Kernel<S> getKernel() {
        return this.k;
    }

    public final void setKernel(Kernel<S> kernel) {
        this.k = kernel;
    }
}
