package cc.mallet.fst.semi_supervised.constraints;

import cc.mallet.fst.SumLattice;
import cc.mallet.fst.semi_supervised.StateLabelMap;
import cc.mallet.types.FeatureVector;
import cc.mallet.types.FeatureVectorSequence;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import gnu.trove.TIntArrayList;
import gnu.trove.TIntObjectHashMap;
import java.util.ArrayList;
import java.util.BitSet;
import java.util.HashMap;
import java.util.Iterator;

/* loaded from: input_file:WEB-INF/lib/mallet-2.0.7.jar:cc/mallet/fst/semi_supervised/constraints/OneLabelL2RangeGEConstraints.class */
public class OneLabelL2RangeGEConstraints implements GEConstraint {
    protected TIntObjectHashMap<OneLabelL2IndGEConstraint> constraints;
    protected StateLabelMap map;
    protected TIntArrayList cache;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* loaded from: input_file:WEB-INF/lib/mallet-2.0.7.jar:cc/mallet/fst/semi_supervised/constraints/OneLabelL2RangeGEConstraints$OneLabelL2IndGEConstraint.class */
    protected class OneLabelL2IndGEConstraint {
        protected double[] expectation;
        static final /* synthetic */ boolean $assertionsDisabled;
        protected ArrayList<Double> lower = new ArrayList<>();
        protected ArrayList<Double> upper = new ArrayList<>();
        protected ArrayList<Double> weights = new ArrayList<>();
        protected HashMap<Integer, Integer> labelMap = new HashMap<>();
        protected int index = 0;
        protected double count = 0.0d;

        public OneLabelL2IndGEConstraint() {
        }

        public void add(int i, double d, double d2, double d3) {
            this.lower.add(Double.valueOf(d));
            this.upper.add(Double.valueOf(d2));
            this.weights.add(Double.valueOf(d3));
            this.labelMap.put(Integer.valueOf(i), Integer.valueOf(this.index));
            this.index++;
        }

        public void incrementExpectation(int i, double d) {
            if (this.labelMap.containsKey(Integer.valueOf(i))) {
                int intValue = this.labelMap.get(Integer.valueOf(i)).intValue();
                double[] dArr = this.expectation;
                dArr[intValue] = dArr[intValue] + d;
            }
        }

        public double getValueContribution(int i) {
            if (!this.labelMap.containsKey(Integer.valueOf(i))) {
                return 0.0d;
            }
            int intValue = this.labelMap.get(Integer.valueOf(i)).intValue();
            if (!$assertionsDisabled && this.count == 0.0d) {
                throw new AssertionError();
            }
            double d = this.expectation[intValue] / this.count;
            if (d < this.lower.get(intValue).doubleValue()) {
                return this.weights.get(intValue).doubleValue() * Math.pow(this.lower.get(intValue).doubleValue() - d, 2.0d);
            }
            if (d > this.upper.get(intValue).doubleValue()) {
                return this.weights.get(intValue).doubleValue() * Math.pow(this.upper.get(intValue).doubleValue() - d, 2.0d);
            }
            return 0.0d;
        }

        public int getNumConstrainedLabels() {
            return this.index;
        }

        public double getGradientContribution(int i) {
            if (!this.labelMap.containsKey(Integer.valueOf(i))) {
                return 0.0d;
            }
            int intValue = this.labelMap.get(Integer.valueOf(i)).intValue();
            if (!$assertionsDisabled && this.count == 0.0d) {
                throw new AssertionError();
            }
            double d = this.expectation[intValue] / this.count;
            if (d < this.lower.get(intValue).doubleValue()) {
                return 2.0d * this.weights.get(intValue).doubleValue() * ((this.lower.get(intValue).doubleValue() / this.count) - (this.expectation[intValue] / (this.count * this.count)));
            }
            if (d > this.upper.get(intValue).doubleValue()) {
                return 2.0d * this.weights.get(intValue).doubleValue() * ((this.upper.get(intValue).doubleValue() / this.count) - (this.expectation[intValue] / (this.count * this.count)));
            }
            return 0.0d;
        }

        static {
            $assertionsDisabled = !OneLabelL2RangeGEConstraints.class.desiredAssertionStatus();
        }
    }

    public OneLabelL2RangeGEConstraints() {
        this.constraints = new TIntObjectHashMap<>();
        this.cache = new TIntArrayList();
    }

    protected OneLabelL2RangeGEConstraints(TIntObjectHashMap<OneLabelL2IndGEConstraint> tIntObjectHashMap, StateLabelMap stateLabelMap) {
        this.constraints = tIntObjectHashMap;
        this.map = stateLabelMap;
        this.cache = new TIntArrayList();
    }

    public void addConstraint(int i, int i2, double d, double d2, double d3) {
        if (!this.constraints.containsKey(i)) {
            this.constraints.put(i, new OneLabelL2IndGEConstraint());
        }
        this.constraints.get(i).add(i2, d, d2, d3);
    }

    @Override // cc.mallet.fst.semi_supervised.constraints.GEConstraint
    public boolean isOneStateConstraint() {
        return true;
    }

    @Override // cc.mallet.fst.semi_supervised.constraints.GEConstraint
    public void setStateLabelMap(StateLabelMap stateLabelMap) {
        this.map = stateLabelMap;
    }

    @Override // cc.mallet.fst.semi_supervised.constraints.GEConstraint
    public void preProcess(FeatureVector featureVector) {
        this.cache.resetQuick();
        for (int i = 0; i < featureVector.numLocations(); i++) {
            int indexAtLocation = featureVector.indexAtLocation(i);
            if (this.constraints.containsKey(indexAtLocation)) {
                this.cache.add(indexAtLocation);
            }
        }
        if (this.constraints.containsKey(featureVector.getAlphabet().size())) {
            this.cache.add(featureVector.getAlphabet().size());
        }
    }

    @Override // cc.mallet.fst.semi_supervised.constraints.GEConstraint
    public BitSet preProcess(InstanceList instanceList) {
        int i = 0;
        BitSet bitSet = new BitSet(instanceList.size());
        Iterator<Instance> it = instanceList.iterator();
        while (it.hasNext()) {
            FeatureVectorSequence featureVectorSequence = (FeatureVectorSequence) it.next().getData();
            for (int i2 = 0; i2 < featureVectorSequence.size(); i2++) {
                FeatureVector featureVector = featureVectorSequence.get(i2);
                for (int i3 = 0; i3 < featureVector.numLocations(); i3++) {
                    int indexAtLocation = featureVector.indexAtLocation(i3);
                    if (this.constraints.containsKey(indexAtLocation)) {
                        this.constraints.get(indexAtLocation).count += 1.0d;
                        bitSet.set(i);
                    }
                }
                if (this.constraints.containsKey(featureVector.getAlphabet().size())) {
                    bitSet.set(i);
                    this.constraints.get(featureVector.getAlphabet().size()).count += 1.0d;
                }
            }
            i++;
        }
        return bitSet;
    }

    @Override // cc.mallet.fst.semi_supervised.constraints.GEConstraint
    public double getCompositeConstraintFeatureValue(FeatureVector featureVector, int i, int i2, int i3) {
        double d = 0.0d;
        int labelIndex = this.map.getLabelIndex(i3);
        for (int i4 = 0; i4 < this.cache.size(); i4++) {
            d += this.constraints.get(this.cache.getQuick(i4)).getGradientContribution(labelIndex);
        }
        return d;
    }

    @Override // cc.mallet.fst.semi_supervised.constraints.GEConstraint
    public double getValue() {
        double d = 0.0d;
        for (int i : this.constraints.keys()) {
            OneLabelL2IndGEConstraint oneLabelL2IndGEConstraint = this.constraints.get(i);
            if (oneLabelL2IndGEConstraint.count > 0.0d) {
                for (int i2 = 0; i2 < this.map.getNumLabels(); i2++) {
                    d -= oneLabelL2IndGEConstraint.getValueContribution(i2);
                }
            }
        }
        if ($assertionsDisabled || !(Double.isNaN(d) || Double.isInfinite(d))) {
            return d;
        }
        throw new AssertionError();
    }

    @Override // cc.mallet.fst.semi_supervised.constraints.GEConstraint
    public void zeroExpectations() {
        for (int i : this.constraints.keys()) {
            this.constraints.get(i).expectation = new double[this.constraints.get(i).getNumConstrainedLabels()];
        }
    }

    @Override // cc.mallet.fst.semi_supervised.constraints.GEConstraint
    public void computeExpectations(ArrayList<SumLattice> arrayList) {
        TIntArrayList tIntArrayList = new TIntArrayList();
        for (int i = 0; i < arrayList.size(); i++) {
            if (arrayList.get(i) != null) {
                SumLattice sumLattice = arrayList.get(i);
                FeatureVectorSequence featureVectorSequence = (FeatureVectorSequence) sumLattice.getInput();
                double[][] gammas = sumLattice.getGammas();
                for (int i2 = 0; i2 < featureVectorSequence.size(); i2++) {
                    tIntArrayList.resetQuick();
                    FeatureVector featureVector = featureVectorSequence.getFeatureVector(i2);
                    for (int i3 = 0; i3 < featureVector.numLocations(); i3++) {
                        int indexAtLocation = featureVector.indexAtLocation(i3);
                        if (this.constraints.containsKey(indexAtLocation)) {
                            tIntArrayList.add(indexAtLocation);
                        }
                    }
                    if (this.constraints.containsKey(featureVector.getAlphabet().size())) {
                        tIntArrayList.add(featureVector.getAlphabet().size());
                    }
                    for (int i4 = 0; i4 < this.map.getNumStates(); i4++) {
                        int labelIndex = this.map.getLabelIndex(i4);
                        if (labelIndex != -2) {
                            double exp = Math.exp(gammas[i2 + 1][i4]);
                            for (int i5 = 0; i5 < tIntArrayList.size(); i5++) {
                                this.constraints.get(tIntArrayList.getQuick(i5)).incrementExpectation(labelIndex, exp);
                            }
                        }
                    }
                }
            }
        }
    }

    @Override // cc.mallet.fst.semi_supervised.constraints.GEConstraint
    public GEConstraint copy() {
        return new OneLabelL2RangeGEConstraints(this.constraints, this.map);
    }

    static {
        $assertionsDisabled = !OneLabelL2RangeGEConstraints.class.desiredAssertionStatus();
    }
}
