package de.lmu.ifi.dbs.dm.distance.mi;

import de.lmu.ifi.dbs.dm.DistanceMeasure;
import de.lmu.ifi.dbs.dm.Kernel;
import de.lmu.ifi.dbs.dm.data.DataObject;
import de.lmu.ifi.dbs.dm.data.MultiInstanceObject;
import de.lmu.ifi.dbs.dm.database.Database;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;

/* loaded from: input_file:de/lmu/ifi/dbs/dm/distance/mi/ConvolutionDist.class */
public class ConvolutionDist<T extends DataObject> implements MIDistanceMeasure<MultiInstanceObject<T>> {
    private static final long serialVersionUID = 1084893141895086084L;
    private Kernel<T> kernel;
    private DistanceMeasure<T> dm;
    private Map<String, Double> normalizationTable;

    public ConvolutionDist(Kernel<T> kernel) {
        this.dm = null;
        this.normalizationTable = null;
        this.kernel = kernel;
    }

    public ConvolutionDist(Kernel<T> kernel, DistanceMeasure<T> distanceMeasure) {
        this.dm = null;
        this.normalizationTable = null;
        this.kernel = kernel;
        this.dm = distanceMeasure;
    }

    @Override // de.lmu.ifi.dbs.dm.DistanceMeasure
    public double distance(MultiInstanceObject<T> multiInstanceObject, MultiInstanceObject<T> multiInstanceObject2) {
        double doubleValue;
        double d = 0.0d;
        for (T t : multiInstanceObject.instances()) {
            for (T t2 : multiInstanceObject2.instances()) {
                d += this.kernel.kernel(t2, t) * t.getWeight() * t2.getWeight();
            }
        }
        if (this.normalizationTable == null) {
            doubleValue = multiInstanceObject.getWeight() * multiInstanceObject2.getWeight();
        } else {
            Double d2 = this.normalizationTable.get(multiInstanceObject.getPrimaryKey());
            Double d3 = this.normalizationTable.get(multiInstanceObject2.getPrimaryKey());
            if (d2 == null || d3 == null) {
                throw new IllegalArgumentException("No pre-calculated weights for object " + (d2 == null ? 1 : 2));
            }
            doubleValue = d2.doubleValue() * d3.doubleValue();
        }
        return 1.0d - ((d * d) / doubleValue);
    }

    public static <T extends DataObject> void calcNormalization(Database<MultiInstanceObject<T>> database, Kernel<T> kernel) {
        Iterator<MultiInstanceObject<T>> objectIterator = database.objectIterator();
        while (objectIterator.hasNext()) {
            MultiInstanceObject<T> next = objectIterator.next();
            double d = 0.0d;
            for (T t : next.instances()) {
                for (T t2 : next.instances()) {
                    d += kernel.kernel(t2, t) * t2.getWeight() * t.getWeight();
                }
            }
            next.setWeight(d);
        }
    }

    public static <T extends DataObject> void calcNormalizationFaster(Database<MultiInstanceObject<T>> database, Kernel<T> kernel) {
        Iterator<MultiInstanceObject<T>> objectIterator = database.objectIterator();
        while (objectIterator.hasNext()) {
            normalizeMiObject(objectIterator.next(), kernel);
        }
    }

    public static <T extends DataObject> void normalizeMiObject(MultiInstanceObject<T> multiInstanceObject, Kernel<T> kernel) {
        double d = 0.0d;
        List<T> instances = multiInstanceObject.instances();
        for (int i = 0; i < instances.size(); i++) {
            T t = instances.get(i);
            d += kernel.kernel(t, t) * t.getWeight() * t.getWeight();
            for (int i2 = i + 1; i2 < instances.size(); i2++) {
                T t2 = instances.get(i2);
                d += 2.0d * kernel.kernel(t, t2) * t.getWeight() * t2.getWeight();
            }
        }
        multiInstanceObject.setWeight(d);
    }

    public static <T extends DataObject> void calcNormalization(Database<MultiInstanceObject<T>> database, Set<String> set, Kernel<T> kernel) {
        if (set == null) {
            calcNormalizationFaster(database, kernel);
            return;
        }
        Iterator<MultiInstanceObject<T>> objectIterator = database.objectIterator();
        while (objectIterator.hasNext()) {
            MultiInstanceObject<T> next = objectIterator.next();
            if (!set.contains(next.getPrimaryKey())) {
                double d = 0.0d;
                List<T> instances = next.instances();
                for (int i = 0; i < instances.size(); i++) {
                    T t = instances.get(i);
                    d += kernel.kernel(t, t) * t.getWeight() * t.getWeight();
                    for (int i2 = i + 1; i2 < instances.size(); i2++) {
                        T t2 = instances.get(i2);
                        d += 2.0d * kernel.kernel(t, t2) * t.getWeight() * t2.getWeight();
                    }
                }
                next.setWeight(d);
            }
        }
    }

    public void calcInternalNormalization(Database<MultiInstanceObject<T>> database, Set<String> set, Kernel<T> kernel) {
        this.normalizationTable = new HashMap();
        Iterator<MultiInstanceObject<T>> objectIterator = database.objectIterator();
        while (objectIterator.hasNext()) {
            MultiInstanceObject<T> next = objectIterator.next();
            if (set == null || !set.contains(next.getPrimaryKey())) {
                double d = 0.0d;
                List<T> instances = next.instances();
                for (int i = 0; i < instances.size(); i++) {
                    T t = instances.get(i);
                    d += kernel.kernel(t, t) * t.getWeight() * t.getWeight();
                    for (int i2 = i + 1; i2 < instances.size(); i2++) {
                        T t2 = instances.get(i2);
                        d += 2.0d * kernel.kernel(t, t2) * t.getWeight() * t2.getWeight();
                    }
                }
                this.normalizationTable.put(next.getPrimaryKey(), Double.valueOf(d));
            }
        }
    }

    @Override // de.lmu.ifi.dbs.dm.DistanceMeasure
    public String getName() {
        return "ConvolutionDist";
    }

    @Override // de.lmu.ifi.dbs.dm.distance.mi.MIDistanceMeasure
    public DistanceMeasure<T> getInstanceDistance() {
        return this.dm;
    }

    public final Kernel<T> getKernel() {
        return this.kernel;
    }

    public final void setKernel(Kernel<T> kernel) {
        this.kernel = kernel;
    }
}
