package cc.mallet.fst.semi_supervised.pr.constraints;

import cc.mallet.fst.semi_supervised.StateLabelMap;
import cc.mallet.types.FeatureVector;
import cc.mallet.types.FeatureVectorSequence;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import gnu.trove.TIntArrayList;
import gnu.trove.TIntIntHashMap;
import gnu.trove.TIntObjectHashMap;
import java.util.BitSet;
import java.util.Iterator;

/* loaded from: input_file:WEB-INF/lib/mallet-2.0.7.jar:cc/mallet/fst/semi_supervised/pr/constraints/OneLabelL2PRConstraints.class */
public class OneLabelL2PRConstraints implements PRConstraint {
    protected TIntObjectHashMap<OneLabelPRConstraint> constraints;
    protected TIntIntHashMap constraintIndices;
    protected StateLabelMap map;
    protected boolean normalized;
    protected TIntArrayList cache;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* JADX INFO: Access modifiers changed from: protected */
    /* loaded from: input_file:WEB-INF/lib/mallet-2.0.7.jar:cc/mallet/fst/semi_supervised/pr/constraints/OneLabelL2PRConstraints$OneLabelPRConstraint.class */
    public class OneLabelPRConstraint {
        protected double[] target;
        protected double[] expectation = null;
        protected double count = 0.0d;
        protected double weight;

        public OneLabelPRConstraint(double[] dArr, double d) {
            this.target = dArr;
            this.weight = d;
        }

        public OneLabelPRConstraint copy() {
            OneLabelPRConstraint oneLabelPRConstraint = new OneLabelPRConstraint(this.target, this.weight);
            oneLabelPRConstraint.count = this.count;
            oneLabelPRConstraint.expectation = new double[this.target.length];
            return oneLabelPRConstraint;
        }
    }

    public OneLabelL2PRConstraints(boolean z) {
        this.constraints = new TIntObjectHashMap<>();
        this.constraintIndices = new TIntIntHashMap();
        this.cache = new TIntArrayList();
        this.normalized = z;
    }

    protected OneLabelL2PRConstraints(TIntObjectHashMap<OneLabelPRConstraint> tIntObjectHashMap, TIntIntHashMap tIntIntHashMap, StateLabelMap stateLabelMap, boolean z) {
        this.constraints = new TIntObjectHashMap<>();
        for (int i : tIntObjectHashMap.keys()) {
            this.constraints.put(i, tIntObjectHashMap.get(i).copy());
        }
        this.constraintIndices = tIntIntHashMap;
        this.map = stateLabelMap;
        this.cache = new TIntArrayList();
        this.normalized = z;
    }

    @Override // cc.mallet.fst.semi_supervised.pr.constraints.PRConstraint
    public PRConstraint copy() {
        return new OneLabelL2PRConstraints(this.constraints, this.constraintIndices, this.map, this.normalized);
    }

    public void addConstraint(int i, double[] dArr, double d) {
        this.constraints.put(i, new OneLabelPRConstraint(dArr, d));
        this.constraintIndices.put(i, this.constraintIndices.size());
    }

    @Override // cc.mallet.fst.semi_supervised.pr.constraints.PRConstraint
    public int numDimensions() {
        if ($assertionsDisabled || this.map != null) {
            return this.map.getNumLabels() * this.constraints.size();
        }
        throw new AssertionError();
    }

    public boolean isOneStateConstraint() {
        return true;
    }

    @Override // cc.mallet.fst.semi_supervised.pr.constraints.PRConstraint
    public void setStateLabelMap(StateLabelMap stateLabelMap) {
        this.map = stateLabelMap;
    }

    @Override // cc.mallet.fst.semi_supervised.pr.constraints.PRConstraint
    public void preProcess(FeatureVector featureVector) {
        this.cache.resetQuick();
        for (int i = 0; i < featureVector.numLocations(); i++) {
            int indexAtLocation = featureVector.indexAtLocation(i);
            if (this.constraints.containsKey(indexAtLocation)) {
                this.cache.add(indexAtLocation);
            }
        }
    }

    @Override // cc.mallet.fst.semi_supervised.pr.constraints.PRConstraint
    public BitSet preProcess(InstanceList instanceList) {
        int i = 0;
        BitSet bitSet = new BitSet(instanceList.size());
        Iterator<Instance> it = instanceList.iterator();
        while (it.hasNext()) {
            FeatureVectorSequence featureVectorSequence = (FeatureVectorSequence) it.next().getData();
            for (int i2 = 0; i2 < featureVectorSequence.size(); i2++) {
                FeatureVector featureVector = featureVectorSequence.get(i2);
                for (int i3 = 0; i3 < featureVector.numLocations(); i3++) {
                    int indexAtLocation = featureVector.indexAtLocation(i3);
                    if (this.constraints.containsKey(indexAtLocation)) {
                        this.constraints.get(indexAtLocation).count += 1.0d;
                        bitSet.set(i);
                    }
                }
            }
            i++;
        }
        return bitSet;
    }

    @Override // cc.mallet.fst.semi_supervised.pr.constraints.PRConstraint
    public double getScore(FeatureVector featureVector, int i, int i2, int i3, double[] dArr) {
        double d;
        double d2;
        double d3 = 0.0d;
        int labelIndex = this.map.getLabelIndex(i3);
        for (int i4 = 0; i4 < this.cache.size(); i4++) {
            int i5 = this.constraintIndices.get(this.cache.getQuick(i4));
            if (this.normalized) {
                d = d3;
                d2 = dArr[i5 + (this.constraints.size() * labelIndex)] / this.constraints.get(this.cache.getQuick(i4)).count;
            } else {
                d = d3;
                d2 = dArr[i5 + (this.constraints.size() * labelIndex)];
            }
            d3 = d + d2;
        }
        return d3;
    }

    @Override // cc.mallet.fst.semi_supervised.pr.constraints.PRConstraint
    public void incrementExpectations(FeatureVector featureVector, int i, int i2, int i3, double d) {
        int labelIndex = this.map.getLabelIndex(i3);
        for (int i4 = 0; i4 < this.cache.size(); i4++) {
            double[] dArr = this.constraints.get(this.cache.getQuick(i4)).expectation;
            dArr[labelIndex] = dArr[labelIndex] + d;
        }
    }

    @Override // cc.mallet.fst.semi_supervised.pr.constraints.PRConstraint
    public void getExpectations(double[] dArr) {
        if (!$assertionsDisabled && dArr.length != numDimensions()) {
            throw new AssertionError();
        }
        for (int i : this.constraintIndices.keys()) {
            int i2 = this.constraintIndices.get(i);
            OneLabelPRConstraint oneLabelPRConstraint = this.constraints.get(i);
            for (int i3 = 0; i3 < oneLabelPRConstraint.expectation.length; i3++) {
                dArr[i2 + (i3 * this.constraints.size())] = oneLabelPRConstraint.expectation[i3];
            }
        }
    }

    @Override // cc.mallet.fst.semi_supervised.pr.constraints.PRConstraint
    public void addExpectations(double[] dArr) {
        if (!$assertionsDisabled && dArr.length != numDimensions()) {
            throw new AssertionError();
        }
        for (int i : this.constraintIndices.keys()) {
            int i2 = this.constraintIndices.get(i);
            OneLabelPRConstraint oneLabelPRConstraint = this.constraints.get(i);
            for (int i3 = 0; i3 < oneLabelPRConstraint.expectation.length; i3++) {
                double[] dArr2 = oneLabelPRConstraint.expectation;
                int i4 = i3;
                dArr2[i4] = dArr2[i4] + dArr[i2 + (i3 * this.constraints.size())];
            }
        }
    }

    @Override // cc.mallet.fst.semi_supervised.pr.constraints.PRConstraint
    public void zeroExpectations() {
        for (int i : this.constraints.keys()) {
            this.constraints.get(i).expectation = new double[this.map.getNumLabels()];
        }
    }

    @Override // cc.mallet.fst.semi_supervised.pr.constraints.PRConstraint
    public double getAuxiliaryValueContribution(double[] dArr) {
        double d = 0.0d;
        for (int i : this.constraints.keys()) {
            int i2 = this.constraintIndices.get(i);
            for (int i3 = 0; i3 < this.map.getNumLabels(); i3++) {
                double d2 = dArr[i2 + (i3 * this.constraints.size())];
                d += (this.constraints.get(i).target[i3] * d2) - ((d2 * d2) / (2.0d * this.constraints.get(i).weight));
            }
        }
        return d;
    }

    @Override // cc.mallet.fst.semi_supervised.pr.constraints.PRConstraint
    public double getCompleteValueContribution(double[] dArr) {
        double d;
        double d2;
        double pow;
        double d3 = 0.0d;
        for (int i : this.constraints.keys()) {
            OneLabelPRConstraint oneLabelPRConstraint = this.constraints.get(i);
            for (int i2 = 0; i2 < this.map.getNumLabels(); i2++) {
                if (this.normalized) {
                    d = d3;
                    d2 = oneLabelPRConstraint.weight;
                    pow = Math.pow(oneLabelPRConstraint.target[i2] - (oneLabelPRConstraint.expectation[i2] / oneLabelPRConstraint.count), 2.0d);
                } else {
                    d = d3;
                    d2 = oneLabelPRConstraint.weight;
                    pow = Math.pow(oneLabelPRConstraint.target[i2] - oneLabelPRConstraint.expectation[i2], 2.0d);
                }
                d3 = d + ((d2 * pow) / 2.0d);
            }
        }
        return d3;
    }

    @Override // cc.mallet.fst.semi_supervised.pr.constraints.PRConstraint
    public void getGradient(double[] dArr, double[] dArr2) {
        for (int i : this.constraints.keys()) {
            int i2 = this.constraintIndices.get(i);
            OneLabelPRConstraint oneLabelPRConstraint = this.constraints.get(i);
            for (int i3 = 0; i3 < this.map.getNumLabels(); i3++) {
                if (this.normalized) {
                    dArr2[i2 + (i3 * this.constraints.size())] = (oneLabelPRConstraint.target[i3] - (oneLabelPRConstraint.expectation[i3] / oneLabelPRConstraint.count)) - (dArr[i2 + (i3 * this.constraints.size())] / oneLabelPRConstraint.weight);
                } else {
                    dArr2[i2 + (i3 * this.constraints.size())] = (oneLabelPRConstraint.target[i3] - oneLabelPRConstraint.expectation[i3]) - (dArr[i2 + (i3 * this.constraints.size())] / oneLabelPRConstraint.weight);
                }
            }
        }
    }

    static {
        $assertionsDisabled = !OneLabelL2PRConstraints.class.desiredAssertionStatus();
    }
}
