package cc.mallet.fst.confidence;

import cc.mallet.fst.MaxLatticeDefault;
import cc.mallet.fst.Segment;
import cc.mallet.fst.Transducer;
import cc.mallet.types.ArraySequence;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.Sequence;
import cc.mallet.util.MalletLogger;
import java.util.ArrayList;
import java.util.Vector;
import java.util.logging.Logger;

/* loaded from: input_file:WEB-INF/lib/mallet-2.0.7.jar:cc/mallet/fst/confidence/ConstrainedViterbiTransducerCorrector.class */
public class ConstrainedViterbiTransducerCorrector implements TransducerCorrector {
    private static Logger logger = MalletLogger.getLogger(ConstrainedViterbiTransducerCorrector.class.getName());
    TransducerConfidenceEstimator confidenceEstimator;
    Transducer model;
    ArrayList leastConfidentSegments;

    public ConstrainedViterbiTransducerCorrector(TransducerConfidenceEstimator transducerConfidenceEstimator, Transducer transducer) {
        this.confidenceEstimator = transducerConfidenceEstimator;
        this.model = transducer;
    }

    public ConstrainedViterbiTransducerCorrector(Transducer transducer) {
        this(new ConstrainedForwardBackwardConfidenceEstimator(transducer), transducer);
    }

    public Vector getSegmentConfidences() {
        return this.confidenceEstimator.getSegmentConfidences();
    }

    public ArrayList getLeastConfidentSegments() {
        return this.leastConfidentSegments;
    }

    public ArrayList getLeastConfidentSegments(InstanceList instanceList, Object[] objArr, Object[] objArr2) {
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < instanceList.size(); i++) {
            arrayList.add(this.confidenceEstimator.rankSegmentsByConfidence(instanceList.get(i), objArr, objArr2)[0]);
        }
        return arrayList;
    }

    @Override // cc.mallet.fst.confidence.TransducerCorrector
    public ArrayList correctLeastConfidentSegments(InstanceList instanceList, Object[] objArr, Object[] objArr2) {
        return correctLeastConfidentSegments(instanceList, objArr, objArr2, false);
    }

    public ArrayList correctLeastConfidentSegments(InstanceList instanceList, Object[] objArr, Object[] objArr2, boolean z) {
        ArrayList arrayList = new ArrayList();
        this.leastConfidentSegments = new ArrayList();
        logger.info(getClass().getName() + " ranking confidence using " + this.confidenceEstimator.getClass().getName());
        for (int i = 0; i < instanceList.size(); i++) {
            logger.fine("correcting instance# " + i + " / " + instanceList.size());
            Instance instance = instanceList.get(i);
            Segment[] segmentArr = new Segment[1];
            Sequence sequence = (Sequence) instance.getData();
            Sequence sequence2 = (Sequence) instance.getTarget();
            Sequence<Object> bestOutputSequence = new MaxLatticeDefault(this.model, sequence).bestOutputSequence();
            int i2 = 0;
            for (int i3 = 0; i3 < bestOutputSequence.size(); i3++) {
                i2 += !bestOutputSequence.get(i3).equals(sequence2.get(i3)) ? 1 : 0;
            }
            if (i2 == 0) {
                this.leastConfidentSegments.add(null);
                arrayList.add(bestOutputSequence);
            } else {
                Segment[] rankSegmentsByConfidence = this.confidenceEstimator.rankSegmentsByConfidence(instance, objArr, objArr2);
                logger.fine("Ordered Segments:\n");
                for (Segment segment : rankSegmentsByConfidence) {
                    logger.fine(segment.toString());
                }
                logger.fine("Correcting Segment: True Sequence:");
                for (int i4 = 0; i4 < sequence2.size(); i4++) {
                    logger.fine(((String) sequence2.get(i4)) + "\t");
                }
                logger.fine("");
                logger.fine("Ordered Segments:\n");
                for (Segment segment2 : rankSegmentsByConfidence) {
                    logger.fine(segment2.toString());
                }
                Segment segment3 = rankSegmentsByConfidence[0];
                if (z) {
                    int i5 = 0;
                    while (true) {
                        if (i5 >= rankSegmentsByConfidence.length) {
                            break;
                        }
                        if (!rankSegmentsByConfidence[i5].correct()) {
                            segment3 = rankSegmentsByConfidence[i5];
                            break;
                        }
                        i5++;
                    }
                }
                if (z && segment3.correct()) {
                    logger.warning("cannot find incorrect segment, probably because error is in background state\n");
                    this.leastConfidentSegments.add(null);
                    arrayList.add(bestOutputSequence);
                } else {
                    this.leastConfidentSegments.add(segment3);
                    if (segment3 == null) {
                        arrayList.add(bestOutputSequence);
                    } else {
                        String[] strArr = new String[sequence2.size()];
                        int i6 = 0;
                        for (int i7 = 0; i7 < strArr.length; i7++) {
                            strArr[i7] = null;
                        }
                        for (int i8 = 0; i8 < sequence2.size(); i8++) {
                            if (segment3.indexInSegment(i8)) {
                                strArr[i8] = (String) sequence2.get(i8);
                                i6++;
                            }
                        }
                        if (segment3.endsPrematurely()) {
                            strArr[segment3.getEnd() + 1] = (String) sequence2.get(segment3.getEnd() + 1);
                            i6++;
                        }
                        logger.fine("Constrained Segment Sequence\n");
                        for (String str : strArr) {
                            logger.fine(str);
                        }
                        Sequence<Object> bestOutputSequence2 = new MaxLatticeDefault(this.model, rankSegmentsByConfidence[0].getInput(), new ArraySequence(strArr)).bestOutputSequence();
                        int i9 = 0;
                        for (int i10 = 0; i10 < sequence2.size(); i10++) {
                            i9 += !bestOutputSequence2.get(i10).equals(sequence2.get(i10)) ? 1 : 0;
                        }
                        logger.fine("Num incorrect tokens in original prediction: " + i2);
                        logger.fine("Num corrected tokens: " + i6);
                        logger.fine("Num incorrect tokens after correction-propagation: " + i9);
                        logger.fine("Correcting Segment: True Sequence:");
                        for (int i11 = 0; i11 < sequence2.size(); i11++) {
                            logger.fine(((String) sequence2.get(i11)) + "\t");
                        }
                        logger.fine("\nOriginal prediction: ");
                        for (int i12 = 0; i12 < bestOutputSequence.size(); i12++) {
                            logger.fine(((String) bestOutputSequence.get(i12)) + "\t");
                        }
                        logger.fine("\nCorrected prediction: ");
                        for (int i13 = 0; i13 < bestOutputSequence2.size(); i13++) {
                            logger.fine(((String) bestOutputSequence2.get(i13)) + "\t");
                        }
                        logger.fine("");
                        arrayList.add(bestOutputSequence2);
                    }
                }
            }
        }
        return arrayList;
    }
}
