/*
 * Decompiled with CFR 0.152.
 */
package weka.classifiers.trees.ht;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.TreeSet;
import weka.classifiers.trees.ht.ConditionalSufficientStats;
import weka.classifiers.trees.ht.SplitCandidate;
import weka.classifiers.trees.ht.SplitMetric;
import weka.classifiers.trees.ht.UnivariateNumericBinarySplit;
import weka.classifiers.trees.ht.WeightMass;
import weka.core.Statistics;
import weka.core.Utils;
import weka.estimators.UnivariateNormalEstimator;

public class GaussianConditionalSufficientStats
extends ConditionalSufficientStats
implements Serializable {
    private static final long serialVersionUID = -1527915607201784762L;
    protected Map<String, Double> m_minValObservedPerClass = new HashMap<String, Double>();
    protected Map<String, Double> m_maxValObservedPerClass = new HashMap<String, Double>();
    protected int m_numBins = 10;

    public void setNumBins(int b) {
        this.m_numBins = b;
    }

    public int getNumBins() {
        return this.m_numBins;
    }

    @Override
    public void update(double attVal, String classVal, double weight) {
        if (!Utils.isMissingValue(attVal)) {
            GaussianEstimator norm = (GaussianEstimator)this.m_classLookup.get(classVal);
            if (norm == null) {
                norm = new GaussianEstimator();
                this.m_classLookup.put(classVal, norm);
                this.m_minValObservedPerClass.put(classVal, attVal);
                this.m_maxValObservedPerClass.put(classVal, attVal);
            } else {
                if (attVal < this.m_minValObservedPerClass.get(classVal)) {
                    this.m_minValObservedPerClass.put(classVal, attVal);
                }
                if (attVal > this.m_maxValObservedPerClass.get(classVal)) {
                    this.m_maxValObservedPerClass.put(classVal, attVal);
                }
            }
            norm.addValue(attVal, weight);
        }
    }

    @Override
    public double probabilityOfAttValConditionedOnClass(double attVal, String classVal) {
        GaussianEstimator norm = (GaussianEstimator)this.m_classLookup.get(classVal);
        if (norm == null) {
            return 0.0;
        }
        return norm.probabilityDensity(attVal);
    }

    protected TreeSet<Double> getSplitPointCandidates() {
        TreeSet<Double> splits = new TreeSet<Double>();
        double min2 = Double.POSITIVE_INFINITY;
        double max2 = Double.NEGATIVE_INFINITY;
        for (String classVal : this.m_classLookup.keySet()) {
            if (!this.m_minValObservedPerClass.containsKey(classVal)) continue;
            if (this.m_minValObservedPerClass.get(classVal) < min2) {
                min2 = this.m_minValObservedPerClass.get(classVal);
            }
            if (!(this.m_maxValObservedPerClass.get(classVal) > max2)) continue;
            max2 = this.m_maxValObservedPerClass.get(classVal);
        }
        if (min2 < Double.POSITIVE_INFINITY) {
            double bin = max2 - min2;
            bin /= (double)(this.m_numBins + 1);
            for (int i = 0; i < this.m_numBins; ++i) {
                double split = min2 + bin * (double)(i + 1);
                if (!(split > min2) || !(split < max2)) continue;
                splits.add(split);
            }
        }
        return splits;
    }

    protected List<Map<String, WeightMass>> classDistsAfterSplit(double splitVal) {
        HashMap<String, WeightMass> lhsDist = new HashMap<String, WeightMass>();
        HashMap<String, WeightMass> rhsDist = new HashMap<String, WeightMass>();
        for (Map.Entry e2 : this.m_classLookup.entrySet()) {
            WeightMass mass;
            String classVal = (String)e2.getKey();
            GaussianEstimator attEst = (GaussianEstimator)e2.getValue();
            if (attEst == null) continue;
            if (splitVal < this.m_minValObservedPerClass.get(classVal)) {
                mass = (WeightMass)rhsDist.get(classVal);
                if (mass == null) {
                    mass = new WeightMass();
                    rhsDist.put(classVal, mass);
                }
                mass.m_weight += attEst.getSumOfWeights();
                continue;
            }
            if (splitVal > this.m_maxValObservedPerClass.get(classVal)) {
                mass = (WeightMass)lhsDist.get(classVal);
                if (mass == null) {
                    mass = new WeightMass();
                    lhsDist.put(classVal, mass);
                }
                mass.m_weight += attEst.getSumOfWeights();
                continue;
            }
            double[] weights = attEst.weightLessThanEqualAndGreaterThan(splitVal);
            WeightMass mass2 = (WeightMass)lhsDist.get(classVal);
            if (mass2 == null) {
                mass2 = new WeightMass();
                lhsDist.put(classVal, mass2);
            }
            mass2.m_weight += weights[0] + weights[1];
            mass2 = (WeightMass)rhsDist.get(classVal);
            if (mass2 == null) {
                mass2 = new WeightMass();
                rhsDist.put(classVal, mass2);
            }
            mass2.m_weight += weights[2];
        }
        ArrayList<Map<String, WeightMass>> dists = new ArrayList<Map<String, WeightMass>>();
        dists.add(lhsDist);
        dists.add(rhsDist);
        return dists;
    }

    @Override
    public SplitCandidate bestSplit(SplitMetric splitMetric, Map<String, WeightMass> preSplitDist, String attName) {
        SplitCandidate best = null;
        TreeSet<Double> candidates = this.getSplitPointCandidates();
        for (Double s : candidates) {
            List<Map<String, WeightMass>> postSplitDists = this.classDistsAfterSplit(s);
            double splitMerit = splitMetric.evaluateSplit(preSplitDist, postSplitDists);
            if (best != null && !(splitMerit > best.m_splitMerit)) continue;
            UnivariateNumericBinarySplit split = new UnivariateNumericBinarySplit(attName, s);
            best = new SplitCandidate(split, postSplitDists, splitMerit);
        }
        return best;
    }

    protected class GaussianEstimator
    extends UnivariateNormalEstimator
    implements Serializable {
        private static final long serialVersionUID = 4756032800685001315L;

        protected GaussianEstimator() {
        }

        public double getSumOfWeights() {
            return this.m_SumOfWeights;
        }

        public double probabilityDensity(double value) {
            this.updateMeanAndVariance();
            if (this.m_SumOfWeights > 0.0) {
                double stdDev = Math.sqrt(this.m_Variance);
                if (stdDev > 0.0) {
                    double diff = value - this.m_Mean;
                    return 1.0 / (CONST * stdDev) * Math.exp(-(diff * diff / (2.0 * this.m_Variance)));
                }
                return value == this.m_Mean ? 1.0 : 0.0;
            }
            return 0.0;
        }

        public double[] weightLessThanEqualAndGreaterThan(double value) {
            double stdDev = Math.sqrt(this.m_Variance);
            double equalW = this.probabilityDensity(value) * this.m_SumOfWeights;
            double lessW = stdDev > 0.0 ? Statistics.normalProbability((value - this.m_Mean) / stdDev) * this.m_SumOfWeights - equalW : (value < this.m_Mean ? this.m_SumOfWeights - equalW : 0.0);
            double greaterW = this.m_SumOfWeights - equalW - lessW;
            return new double[]{lessW, equalW, greaterW};
        }
    }
}

