package cc.mallet.extract;

import cc.mallet.fst.CRF;
import cc.mallet.fst.MaxLattice;
import cc.mallet.fst.MaxLatticeDefault;
import cc.mallet.fst.SumLatticeDefault;
import cc.mallet.fst.Transducer;
import cc.mallet.types.FeatureVector;
import cc.mallet.types.FeatureVectorSequence;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.LabelSequence;
import cc.mallet.types.Sequence;
import cc.mallet.types.Token;
import cc.mallet.types.TokenSequence;
import java.io.File;
import java.io.FileWriter;
import java.io.IOException;
import java.io.OutputStreamWriter;
import java.io.PrintStream;
import java.io.PrintWriter;
import java.io.Writer;
import java.text.DecimalFormat;
import java.util.List;

/* loaded from: input_file:cc/mallet/extract/LatticeViewer.class */
public class LatticeViewer {
    private static final int FEATURE_CUTOFF_PCT = 25;
    private static final int LENGTH = 10;
    public static int numMaxViterbi;
    private static int numFeaturesToDisplay;
    private static final int EXTRACTIONS_PER_FILE = 25;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:cc/mallet/extract/LatticeViewer$ExtorInfo.class */
    public static class ExtorInfo {
        TokenSequence input;
        Sequence predicted;
        LabelSequence target;
        FeatureVectorSequence fvs;
        MaxLattice lattice;
        Sequence bestStates;
        String link;
        String desc;
        String idx;

        public ExtorInfo(TokenSequence tokenSequence, Sequence sequence, LabelSequence labelSequence, String str, String str2) {
            this.input = tokenSequence;
            this.predicted = sequence;
            this.target = labelSequence;
            this.desc = str;
            this.idx = str2;
        }
    }

    static void lattice2html(PrintStream printStream, ExtorInfo extorInfo) {
        lattice2html(new PrintWriter((Writer) new OutputStreamWriter(printStream), true), extorInfo);
    }

    static void lattice2html(PrintWriter printWriter, ExtorInfo extorInfo) {
        if (!$assertionsDisabled && extorInfo.target.size() != extorInfo.predicted.size()) {
            throw new AssertionError();
        }
        if (!$assertionsDisabled && extorInfo.input.size() != extorInfo.predicted.size()) {
            throw new AssertionError();
        }
        int size = extorInfo.target.size();
        for (int i = 0; i < size; i += 9) {
            int min = Math.min(size, i + LENGTH);
            if (!allSeqMatches(extorInfo.predicted, extorInfo.target, i, min)) {
                error2html(printWriter, extorInfo, i, min);
            }
        }
    }

    private static void writeHeader(PrintWriter printWriter) {
        printWriter.println("<html><head><title>ERROR OUTPUT</title>\n<link rel=\"stylesheet\" href=\"errors.css\" type=\"text/css\" />\n</head><body>");
    }

    private static void writeFooter(PrintWriter printWriter) {
        printWriter.println("</body></html>");
    }

    private static void error2html(PrintWriter printWriter, ExtorInfo extorInfo, int i, int i2) {
        String str = extorInfo.idx + ":" + i + ":" + i2;
        printWriter.println("<p><A NAME=\"" + str + "\">");
        printWriter.println("<p>Instance " + extorInfo.desc + " Position " + i + "..." + i2);
        if (extorInfo.link != null) {
            printWriter.println("<a href=\"" + extorInfo.link + "#" + str + "\">[Lattice]</a>");
        }
        printWriter.println("</p>");
        printWriter.println("<table>");
        outputIndices(printWriter, i, i2);
        outputInputRow(printWriter, extorInfo.input, i, i2);
        outputTableRow(printWriter, InstanceList.TARGET_PROPERTY, extorInfo.target, extorInfo.predicted, i, i2);
        outputTableRow(printWriter, "predicted", extorInfo.predicted, extorInfo.target, i, i2);
        if (extorInfo.lattice != null) {
            outputLatticeRows(printWriter, extorInfo.lattice, i, i2);
            outputTransitionCosts(printWriter, extorInfo, i, i2);
            outputFeatures(printWriter, extorInfo.fvs, extorInfo.predicted, extorInfo.target, i, i2);
        }
        printWriter.println("</table>");
    }

    private static void outputLatticeRows(PrintWriter printWriter, MaxLattice maxLattice, int i, int i2) {
        DecimalFormat decimalFormat = new DecimalFormat("0.##");
        int min = Math.min(numMaxViterbi, maxLattice.getTransducer().numStates());
        List<Sequence<Transducer.State>> bestStateSequences = maxLattice.bestStateSequences(min);
        for (int i3 = 0; i3 < min; i3++) {
            printWriter.println("  <tr class=\"delta\">");
            printWriter.println("    <td class=\"label\">&delta; rank " + i3 + "</td>");
            for (int i4 = i; i4 < i2; i4++) {
                Transducer.State state = bestStateSequences.get(i3).get(i4 + 1);
                if (state.getName().equals(maxLattice.bestOutputSequence().get(i4))) {
                    printWriter.print("<td class=\"viterbi\">");
                } else {
                    printWriter.print("<td>");
                }
                printWriter.print(state.getName() + "<br />" + decimalFormat.format(-maxLattice.getDelta(i4 + 1, state.getIndex())) + "</td>");
            }
            printWriter.println("</tr>");
        }
    }

    public static int getNumFeaturesToDisplay() {
        return numFeaturesToDisplay;
    }

    public static void setNumFeaturesToDisplay(int i) {
        numFeaturesToDisplay = i;
    }

    private static void outputTransitionCosts(PrintWriter printWriter, ExtorInfo extorInfo, int i, int i2) {
        Transducer transducer = extorInfo.lattice.getTransducer();
        printWriter.println("<tr class=\"predtrans\">");
        printWriter.println("<td class=\"label\">Cost(pred. trans)</td>");
        for (int i3 = i; i3 < i2; i3++) {
            if (i3 == 0) {
                printWriter.println("<td></td>");
            } else {
                Transducer.TransitionIterator transitionIterator = ((CRF) transducer).getState(extorInfo.bestStates.get(i3 - 1).toString()).transitionIterator(extorInfo.fvs, i3, extorInfo.predicted, i3);
                if (transitionIterator.hasNext()) {
                    transitionIterator.next();
                    printWriter.print("<td>" + transitionIterator.describeTransition((int) (Math.abs(transitionIterator.getWeight()) / 25.0d)) + "</td>");
                } else {
                    printWriter.print("<td>No matching transition</td>");
                }
            }
        }
        printWriter.println("</tr>");
        printWriter.println("<tr class=\"targettrans\">");
        printWriter.println("<td class=\"label\">Cost(target trans)</td>");
        for (int i4 = i; i4 < i2; i4++) {
            if (i4 == 0) {
                printWriter.println("<td></td>");
            } else if (seqMatches(extorInfo.predicted, extorInfo.target, i4) && seqMatches(extorInfo.predicted, extorInfo.target, i4 - 1)) {
                printWriter.print("<td></td>");
            } else {
                CRF.State state = ((CRF) transducer).getState(extorInfo.target.get(i4 - 1).toString());
                if (state == null) {
                    printWriter.println("<td colspan='" + (i2 - i) + "'>Could not find state for " + extorInfo.target.get(i4 - 1) + "</td>");
                } else {
                    Transducer.TransitionIterator transitionIterator2 = state.transitionIterator(extorInfo.fvs, i4, extorInfo.target, i4);
                    if (transitionIterator2.hasNext()) {
                        transitionIterator2.next();
                        printWriter.print("<td>" + transitionIterator2.describeTransition((int) (Math.abs(transitionIterator2.getWeight()) / 25.0d)) + "</td>");
                    } else {
                        printWriter.print("<td>No matching transition</td>");
                    }
                }
            }
        }
        printWriter.println("</tr>");
        printWriter.println("<tr class=\"predtargettrans\">");
        printWriter.println("<td class=\"label\">Cost (pred->target trans)</td>");
        for (int i5 = i; i5 < i2; i5++) {
            if (i5 == 0) {
                printWriter.println("<td></td>");
            } else if (seqMatches(extorInfo.predicted, extorInfo.target, i5) && seqMatches(extorInfo.predicted, extorInfo.target, i5 - 1)) {
                printWriter.print("<td></td>");
            } else {
                Transducer.TransitionIterator transitionIterator3 = ((CRF) transducer).getState(extorInfo.bestStates.get(i5 - 1).toString()).transitionIterator(extorInfo.fvs, i5, extorInfo.target, i5);
                if (transitionIterator3.hasNext()) {
                    transitionIterator3.next();
                    printWriter.print("<td>" + transitionIterator3.describeTransition((int) (Math.abs(transitionIterator3.getWeight()) / 25.0d)) + "</td>");
                } else {
                    printWriter.print("<td>No matching transition</td>");
                }
            }
        }
        printWriter.println("</tr>");
    }

    private static void outputLatticeRows(PrintWriter printWriter, SumLatticeDefault sumLatticeDefault, int i, int i2) {
        DecimalFormat decimalFormat = new DecimalFormat("0.##");
        Transducer transducer = sumLatticeDefault.getTransducer();
        for (int i3 = 0; i3 < transducer.numStates(); i3++) {
            Transducer.State state = transducer.getState(i3);
            printWriter.println("  <tr class=\"alpha\">");
            printWriter.println("    <td class=\"label\">&alpha;(" + state.getName() + ")</td>");
            for (int i4 = i; i4 < i2; i4++) {
                printWriter.print("<td>" + decimalFormat.format(sumLatticeDefault.getAlpha(i4 + 1, state)) + "</td>");
            }
            printWriter.println("</tr>");
        }
        for (int i5 = 0; i5 < transducer.numStates(); i5++) {
            Transducer.State state2 = transducer.getState(i5);
            printWriter.println("  <tr class=\"beta\">");
            printWriter.println("    <td class=\"label\">&beta;(" + state2.getName() + ")</td>");
            for (int i6 = i; i6 < i2; i6++) {
                printWriter.print("<td>" + decimalFormat.format(sumLatticeDefault.getBeta(i6 + 1, state2)) + "</td>");
            }
            printWriter.println("</tr>");
        }
        for (int i7 = 0; i7 < transducer.numStates(); i7++) {
            Transducer.State state3 = transducer.getState(i7);
            printWriter.println("  <tr class=\"gamma\">");
            printWriter.println("    <td class=\"label\">&gamma;(" + state3.getName() + ")</td>");
            for (int i8 = i; i8 < i2; i8++) {
                printWriter.print("<td>" + decimalFormat.format(sumLatticeDefault.getGammaWeight(i8 + 1, state3)) + "</td>");
            }
            printWriter.println("</tr>");
        }
    }

    private static void outputInputRow(PrintWriter printWriter, TokenSequence tokenSequence, int i, int i2) {
        printWriter.println("  <tr class=\"input\">");
        printWriter.println("    <td class=\"label\"></td>");
        for (int i3 = i; i3 < i2; i3++) {
            printWriter.print("<td>" + ((Token) tokenSequence.get(i3)).getText() + "</td>");
        }
        printWriter.println("  </tr>");
    }

    private static void outputIndices(PrintWriter printWriter, int i, int i2) {
        printWriter.println("  <tr class=\"indices\">");
        printWriter.println("    <td class=\"label\"></td>");
        for (int i3 = i; i3 < i2; i3++) {
            printWriter.print("<td>" + i3 + "</td>");
        }
        printWriter.println("  </tr>");
    }

    private static void outputTableRow(PrintWriter printWriter, String str, Sequence sequence, Sequence sequence2, int i, int i2) {
        printWriter.println("  <tr class=\"" + str + "\">");
        printWriter.println("    <td class=\"label\">" + str + "</td>");
        for (int i3 = i; i3 < i2; i3++) {
            if (seqMatches(sequence, sequence2, i3)) {
                printWriter.print("<td>");
            } else {
                printWriter.print("<td class=\"error\">");
            }
            printWriter.print(sequence.get(i3));
            printWriter.print("</td>");
        }
        printWriter.println("  </tr>");
    }

    private static void outputFeatures(PrintWriter printWriter, FeatureVectorSequence featureVectorSequence, Sequence sequence, Sequence sequence2, int i, int i2) {
        printWriter.println("  <tr class=\"features\">\n<td class=\"label\">Features</td>");
        for (int i3 = i; i3 < i2; i3++) {
            if (seqMatches(sequence, sequence2, i3)) {
                printWriter.println("<td></td>");
            } else {
                printWriter.print("<td>");
                FeatureVector featureVector = featureVectorSequence.getFeatureVector(i3);
                for (int i4 = 0; i4 < featureVector.numLocations(); i4++) {
                    printWriter.print(featureVector.getAlphabet().lookupObject(featureVector.indexAtLocation(i4)));
                    if (featureVector.valueAtLocation(i4) != 1.0d) {
                        printWriter.print(" " + featureVector.valueAtLocation(i4));
                    }
                    printWriter.println("<br />");
                }
                printWriter.println("</td>");
            }
        }
        printWriter.println("  </tr>");
    }

    private static boolean seqMatches(Sequence sequence, Sequence sequence2, int i) {
        return sequence.get(i).toString().equals(sequence2.get(i).toString());
    }

    private static boolean allSeqMatches(Sequence sequence, Sequence sequence2, int i, int i2) {
        for (int i3 = i; i3 < i2; i3++) {
            if (!seqMatches(sequence, sequence2, i3)) {
                return false;
            }
        }
        return true;
    }

    public static void extraction2html(Extraction extraction, CRFExtractor cRFExtractor, PrintStream printStream) {
        new PrintWriter((Writer) new OutputStreamWriter(printStream), true);
        extraction2html(extraction, cRFExtractor, printStream, false);
    }

    public static void extraction2html(Extraction extraction, CRFExtractor cRFExtractor, PrintWriter printWriter) {
        extraction2html(extraction, cRFExtractor, printWriter, false);
    }

    public static void extraction2html(Extraction extraction, CRFExtractor cRFExtractor, PrintStream printStream, boolean z) {
        extraction2html(extraction, cRFExtractor, new PrintWriter((Writer) new OutputStreamWriter(printStream), true), z);
    }

    public static void extraction2html(Extraction extraction, CRFExtractor cRFExtractor, PrintWriter printWriter, boolean z) {
        writeHeader(printWriter);
        for (int i = 0; i < extraction.getNumDocuments(); i++) {
            DocumentExtraction documentExtraction = extraction.getDocumentExtraction(i);
            ExtorInfo infoForDoc = infoForDoc(((CharSequence) documentExtraction.getDocument()).toString(), documentExtraction.getName(), "N" + i, documentExtraction, cRFExtractor, z);
            if (!z) {
                infoForDoc.link = "lattice.html";
            }
            lattice2html(printWriter, infoForDoc);
        }
        writeFooter(printWriter);
    }

    private static ExtorInfo infoForDoc(String str, String str2, String str3, DocumentExtraction documentExtraction, CRFExtractor cRFExtractor, boolean z) {
        TokenSequence tokenSequence = (TokenSequence) documentExtraction.getInput();
        ExtorInfo extorInfo = new ExtorInfo(tokenSequence, documentExtraction.getPredictedLabels(), documentExtraction.getTarget(), str2, str3);
        if (z) {
            CRF crf = cRFExtractor.getCrf();
            Instance pipe = cRFExtractor.getFeaturePipe().pipe(new Instance(tokenSequence, null, null, null));
            extorInfo.fvs = (FeatureVectorSequence) pipe.getData();
            extorInfo.lattice = new MaxLatticeDefault(crf, (Sequence) pipe.getData(), null);
            extorInfo.bestStates = extorInfo.lattice.bestOutputSequence();
        }
        return extorInfo;
    }

    public static void viewDualResults(File file, Extraction extraction, CRFExtractor cRFExtractor, Extraction extraction2, CRFExtractor cRFExtractor2) throws IOException {
        if (extraction.getNumDocuments() != extraction2.getNumDocuments()) {
            throw new IllegalArgumentException("Extractions don't match: different number of docs.");
        }
        PrintWriter printWriter = new PrintWriter(new FileWriter(new File(file, "errors.html")));
        writeDualExtractions(printWriter, extraction, cRFExtractor, extraction2, cRFExtractor2, 0, extraction.getNumDocuments(), false);
        printWriter.close();
        int numDocuments = extraction.getNumDocuments();
        for (int i = 0; i < numDocuments; i += 25) {
            int min = Math.min(i + 25, numDocuments);
            PrintWriter printWriter2 = new PrintWriter(new FileWriter(new File(file, "lattice-" + i + ".html")));
            writeDualExtractions(printWriter2, extraction, cRFExtractor, extraction2, cRFExtractor2, i, min, true);
            printWriter2.close();
        }
    }

    private static String computeLatticeFname(int i) {
        return "lattice-" + (25 * (i / 25)) + ".html";
    }

    private static void writeDualExtractions(PrintWriter printWriter, Extraction extraction, CRFExtractor cRFExtractor, Extraction extraction2, CRFExtractor cRFExtractor2, int i, int i2, boolean z) {
        writeHeader(printWriter);
        for (int i3 = i; i3 < i2; i3++) {
            DocumentExtraction documentExtraction = extraction.getDocumentExtraction(i3);
            DocumentExtraction documentExtraction2 = extraction2.getDocumentExtraction(i3);
            String name = documentExtraction.getName();
            String obj = ((CharSequence) documentExtraction.getDocument()).toString();
            if (!obj.equals(((CharSequence) documentExtraction2.getDocument()).toString())) {
                System.err.println("Skipping document " + i3 + ": Extractions don't match");
            } else if (!predictionsMatch(documentExtraction.getPredictedLabels(), documentExtraction2.getPredictedLabels())) {
                ExtorInfo infoForDoc = infoForDoc(obj, "CRF1::" + name, "C1I" + i3, documentExtraction, cRFExtractor, z);
                ExtorInfo infoForDoc2 = infoForDoc(obj, "CRF2::" + name, "C2I" + i3, documentExtraction2, cRFExtractor2, z);
                if (!z) {
                    String computeLatticeFname = computeLatticeFname(i3);
                    infoForDoc2.link = computeLatticeFname;
                    infoForDoc.link = computeLatticeFname;
                }
                dualLattice2html(printWriter, name, infoForDoc, infoForDoc2);
            }
        }
        writeFooter(printWriter);
    }

    public static void dualLattice2html(PrintWriter printWriter, String str, ExtorInfo extorInfo, ExtorInfo extorInfo2) {
        if (!$assertionsDisabled && extorInfo.predicted.size() != extorInfo.target.size()) {
            throw new AssertionError();
        }
        if (!$assertionsDisabled && extorInfo.input.size() != extorInfo.predicted.size()) {
            throw new AssertionError();
        }
        if (!$assertionsDisabled && extorInfo2.input.size() != extorInfo2.predicted.size()) {
            throw new AssertionError();
        }
        if (!$assertionsDisabled && extorInfo2.predicted.size() != extorInfo2.target.size()) {
            throw new AssertionError();
        }
        int size = extorInfo.target.size();
        for (int i = 0; i < size; i += 9) {
            int min = Math.min(extorInfo.predicted.size(), i + LENGTH);
            if (!allSeqMatches(extorInfo.predicted, extorInfo2.predicted, i, min)) {
                error2html(printWriter, extorInfo, i, min);
                error2html(printWriter, extorInfo2, i, min);
            }
        }
    }

    private static boolean predictionsMatch(Sequence sequence, Sequence sequence2) {
        if (sequence.size() != sequence2.size()) {
            return false;
        }
        for (int i = 0; i < sequence.size(); i++) {
            if (!sequence.get(i).toString().equals(sequence2.get(i).toString())) {
                return false;
            }
        }
        return true;
    }

    static {
        $assertionsDisabled = !LatticeViewer.class.desiredAssertionStatus();
        numMaxViterbi = 5;
        numFeaturesToDisplay = 5;
    }
}
