/*
 * Decompiled with CFR 0.152.
 */
package weka.knowledgeflow.steps;

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;
import weka.classifiers.AggregateableEvaluation;
import weka.classifiers.Classifier;
import weka.classifiers.CostMatrix;
import weka.classifiers.Evaluation;
import weka.classifiers.evaluation.ThresholdCurve;
import weka.classifiers.misc.InputMappedClassifier;
import weka.core.BatchPredictor;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.OptionHandler;
import weka.core.OptionMetadata;
import weka.core.Utils;
import weka.core.WekaException;
import weka.gui.ProgrammaticProperty;
import weka.gui.explorer.ClassifierErrorsPlotInstances;
import weka.gui.explorer.ExplorerDefaults;
import weka.gui.visualize.PlotData2D;
import weka.knowledgeflow.Data;
import weka.knowledgeflow.ExecutionResult;
import weka.knowledgeflow.StepTask;
import weka.knowledgeflow.StepTaskCallback;
import weka.knowledgeflow.steps.BaseStep;
import weka.knowledgeflow.steps.KFStep;
import weka.knowledgeflow.steps.Step;

@KFStep(name="ClassifierPerformanceEvaluator", category="Evaluation", toolTipText="Evaluates batch classifiers", iconPath="weka/gui/knowledgeflow/icons/ClassifierPerformanceEvaluator.gif")
public class ClassifierPerformanceEvaluator
extends BaseStep {
    private static final long serialVersionUID = -2679292079974676672L;
    private transient AggregateableEvaluation m_eval;
    private transient Instances m_aggregatedPlotInstances;
    private transient ArrayList<Object> m_aggregatedPlotSizes;
    private transient ArrayList<Integer> m_aggregatedPlotShapes;
    protected boolean m_outputPerClassStats = true;
    protected boolean m_outputConfusionMatrix = true;
    protected boolean m_outputEntropyMetrics;
    protected boolean m_collectDataForVisAndAUC = true;
    protected boolean m_errorPlotPointSizeProportionalToMargin;
    protected boolean m_costSensitiveEval;
    protected String m_costString = "";
    protected CostMatrix m_matrix;
    protected String m_selectedEvalMetrics = "";
    protected List<String> m_metricsList = Evaluation.getAllEvaluationMetricNames();
    protected boolean m_isReset;
    protected AtomicInteger m_setsToGo;
    protected int m_maxSetNum;
    protected AtomicInteger m_taskCount;
    private transient ClassifierErrorsPlotInstances m_PlotInstances = null;

    protected void stringToList(String l) {
        if (l != null && l.length() > 0) {
            String[] parts = l.split(",");
            this.m_metricsList.clear();
            for (String s : parts) {
                this.m_metricsList.add(s.trim());
            }
        }
    }

    public void setOutputPerClassStats(boolean perClassStats) {
        this.m_outputPerClassStats = perClassStats;
    }

    @OptionMetadata(displayName="Output per-class stats", description="Output precision/recall and true/false positives for each class", displayOrder=1)
    public boolean getOutputPerClassStats() {
        return this.m_outputPerClassStats;
    }

    @OptionMetadata(displayName="Output confusion matrix", description="Output the matrix containing class confusions", displayOrder=2)
    public void setOutputConfusionMatrix(boolean outputConfusionMatrix) {
        this.m_outputConfusionMatrix = outputConfusionMatrix;
    }

    public boolean getOutputConfusionMatrix() {
        return this.m_outputConfusionMatrix;
    }

    @OptionMetadata(displayName="Output entropy evaluation measures", description="Output entropy-based evaluation measures", displayOrder=3)
    public void setOutputEntropyMetrics(boolean outputEntropyMetrics) {
        this.m_outputEntropyMetrics = outputEntropyMetrics;
    }

    public boolean getOutputEntropyMetrics() {
        return this.m_outputEntropyMetrics;
    }

    @OptionMetadata(displayName="Collect test data and predictions for visualization", description="Collect data and predictions in order to output visualizableError and thresholdData data", displayOrder=4)
    public void setCollectPredictionsForVisAndAUC(boolean collectPredictionsForVisAndAUC) {
        this.m_collectDataForVisAndAUC = collectPredictionsForVisAndAUC;
    }

    public boolean getCollectPredictionsForVisAndAUC() {
        return this.m_collectDataForVisAndAUC;
    }

    @OptionMetadata(displayName="Error plot point size proportional to margin", description="Set the point size proportional to the prediction margin for classification error plots")
    public boolean getErrorPlotPointSizeProportionalToMargin() {
        return this.m_errorPlotPointSizeProportionalToMargin;
    }

    public void setErrorPlotPointSizeProportionalToMargin(boolean e2) {
        this.m_errorPlotPointSizeProportionalToMargin = e2;
    }

    @ProgrammaticProperty
    public String getEvaluationMetricsToOutput() {
        return this.m_selectedEvalMetrics;
    }

    public void setEvaluationMetricsToOutput(String m) {
        this.m_selectedEvalMetrics = m;
        this.stringToList(m);
    }

    @ProgrammaticProperty
    public void setEvaluateWithRespectToCosts(boolean useCosts) {
        this.m_costSensitiveEval = useCosts;
    }

    public boolean getEvaluateWithRespectToCosts() {
        return this.m_costSensitiveEval;
    }

    @ProgrammaticProperty
    public void setCostMatrixString(String cms) {
        this.m_costString = cms;
    }

    public String getCostMatrixString() {
        return this.m_costString;
    }

    @Override
    public List<String> getIncomingConnectionTypes() {
        ArrayList<String> result = new ArrayList<String>();
        if (this.getStepManager().numIncomingConnectionsOfType("batchClassifier") == 0) {
            result.add("batchClassifier");
        }
        return result;
    }

    @Override
    public List<String> getOutgoingConnectionTypes() {
        ArrayList<String> result = new ArrayList<String>();
        if (this.getStepManager().numIncomingConnections() > 0) {
            result.add("text");
            if (this.m_collectDataForVisAndAUC) {
                result.add("thresholdData");
                result.add("visualizableError");
            }
        }
        return result;
    }

    public ClassifierPerformanceEvaluator() {
        this.m_metricsList.remove("Coverage");
        this.m_metricsList.remove("Region size");
        StringBuilder b = new StringBuilder();
        for (String s : this.m_metricsList) {
            b.append(s).append(",");
        }
        this.m_selectedEvalMetrics = b.substring(0, b.length() - 1);
    }

    @Override
    public void stepInit() throws WekaException {
        this.m_isReset = true;
        this.m_PlotInstances = null;
        this.m_aggregatedPlotInstances = null;
        this.m_taskCount = new AtomicInteger(0);
        if (this.m_costSensitiveEval && this.m_costString != null && this.m_costString.length() > 0) {
            try {
                this.m_matrix = CostMatrix.parseMatlab(this.getCostMatrixString());
            }
            catch (Exception e2) {
                throw new WekaException(e2);
            }
        }
    }

    @Override
    public void stop() {
        super.stop();
        if ((this.m_taskCount == null || this.m_taskCount.get() == 0) && this.isStopRequested()) {
            this.getStepManager().interrupted();
        }
    }

    @Override
    public synchronized void processIncoming(Data data) throws WekaException {
        try {
            int setNum = (Integer)data.getPayloadElement("aux_set_num");
            Instances trainingData = (Instances)data.getPayloadElement("aux_trainingSet");
            Instances testData = (Instances)data.getPayloadElement("aux_testsSet");
            if (testData == null || testData.numInstances() == 0) {
                this.getStepManager().logDetailed("No test set available - unable to evaluate");
                return;
            }
            Classifier classifier = (Classifier)data.getPayloadElement("batchClassifier");
            String evalLabel = data.getPayloadElement("aux_label").toString();
            if (classifier == null) {
                throw new WekaException("Classifier is null!!");
            }
            if (this.m_isReset) {
                Evaluation eval2;
                this.m_isReset = false;
                this.getStepManager().processing();
                this.m_maxSetNum = (Integer)data.getPayloadElement("aux_max_set_num");
                this.m_setsToGo = new AtomicInteger(0);
                if (trainingData == null) {
                    eval2 = new Evaluation(testData, this.m_costSensitiveEval ? this.m_matrix : null);
                    this.m_PlotInstances = ExplorerDefaults.getClassifierErrorsPlotInstances();
                    this.m_PlotInstances.setInstances(testData);
                    this.m_PlotInstances.setClassifier(classifier);
                    this.m_PlotInstances.setClassIndex(testData.classIndex());
                    this.m_PlotInstances.setEvaluation(eval2);
                    eval2 = ClassifierPerformanceEvaluator.adjustForInputMappedClassifier(eval2, classifier, testData, this.m_PlotInstances, this.m_costSensitiveEval ? this.m_matrix : null);
                    eval2.useNoPriors();
                    this.m_eval = new AggregateableEvaluation(eval2);
                    this.m_eval.setMetricsToDisplay(this.m_metricsList);
                } else {
                    eval2 = new Evaluation(trainingData, this.m_costSensitiveEval ? this.m_matrix : null);
                    this.m_PlotInstances = ExplorerDefaults.getClassifierErrorsPlotInstances();
                    this.m_PlotInstances.setInstances(trainingData);
                    this.m_PlotInstances.setClassifier(classifier);
                    this.m_PlotInstances.setClassIndex(trainingData.classIndex());
                    this.m_PlotInstances.setEvaluation(eval2);
                    eval2 = ClassifierPerformanceEvaluator.adjustForInputMappedClassifier(eval2, classifier, trainingData, this.m_PlotInstances, this.m_costSensitiveEval ? this.m_matrix : null);
                    this.m_eval = new AggregateableEvaluation(eval2);
                    this.m_eval.setMetricsToDisplay(this.m_metricsList);
                }
                this.m_PlotInstances.setUp();
                this.m_aggregatedPlotInstances = null;
            }
            if (!this.isStopRequested()) {
                this.getStepManager().logBasic("Scheduling evaluation of fold/set " + setNum + " for execution");
                EvaluationTask evalTask = new EvaluationTask(this, classifier, trainingData, testData, setNum, this.m_metricsList, this.getErrorPlotPointSizeProportionalToMargin(), evalLabel, new EvaluationCallback(), this.m_costSensitiveEval ? this.m_matrix : null, this.getCollectPredictionsForVisAndAUC());
                this.getStepManager().getExecutionEnvironment().submitTask(evalTask);
                this.m_taskCount.incrementAndGet();
            } else {
                this.getStepManager().interrupted();
            }
        }
        catch (Exception ex) {
            throw new WekaException(ex);
        }
    }

    protected synchronized void aggregateEvalTask(Evaluation eval2, Classifier classifier, Instances testData, ClassifierErrorsPlotInstances plotInstances, int setNum, String evalLabel) throws Exception {
        this.m_eval.aggregate(eval2);
        if (this.getCollectPredictionsForVisAndAUC()) {
            if (this.m_aggregatedPlotInstances == null) {
                this.m_aggregatedPlotShapes = (ArrayList)plotInstances.getPlotShapes().clone();
                this.m_aggregatedPlotSizes = (ArrayList)plotInstances.getPlotSizes().clone();
                this.m_aggregatedPlotInstances = new Instances(plotInstances.getPlotInstances());
            } else {
                ArrayList tmpSizes = (ArrayList)plotInstances.getPlotSizes().clone();
                ArrayList tmpShapes = (ArrayList)plotInstances.getPlotShapes().clone();
                Instances temp = plotInstances.getPlotInstances();
                for (int i = 0; i < temp.numInstances(); ++i) {
                    this.m_aggregatedPlotInstances.add(temp.get(i));
                    this.m_aggregatedPlotShapes.add((Integer)tmpShapes.get(i));
                    this.m_aggregatedPlotSizes.add(tmpSizes.get(i));
                }
            }
        }
        this.getStepManager().statusMessage("Completed folds/sets " + this.m_setsToGo.incrementAndGet());
        if (this.m_setsToGo.get() == this.m_maxSetNum) {
            if (this.getCollectPredictionsForVisAndAUC()) {
                AggregateableClassifierErrorsPlotInstances aggPlot = new AggregateableClassifierErrorsPlotInstances();
                aggPlot.setInstances(testData);
                aggPlot.setPlotInstances(this.m_aggregatedPlotInstances);
                aggPlot.setPlotShapes(this.m_aggregatedPlotShapes);
                aggPlot.setPlotSizes(this.m_aggregatedPlotSizes);
                aggPlot.setPointSizeProportionalToMargin(this.m_errorPlotPointSizeProportionalToMargin);
                aggPlot.getPlotInstances();
            }
            String textTitle = "";
            textTitle = textTitle + classifier.getClass().getName();
            String textOptions = "";
            if (classifier instanceof OptionHandler) {
                textOptions = Utils.joinOptions(((OptionHandler)((Object)classifier)).getOptions());
            }
            textTitle = textTitle.substring(textTitle.lastIndexOf(46) + 1, textTitle.length());
            if (evalLabel != null && evalLabel.length() > 0 && !textTitle.toLowerCase().startsWith(evalLabel.toLowerCase())) {
                textTitle = evalLabel + " : " + textTitle;
            }
            CostMatrix cm = this.m_costSensitiveEval ? CostMatrix.parseMatlab(this.getCostMatrixString()) : null;
            String resultT = "=== Evaluation result ===\n\nScheme: " + textTitle + "\n" + (textOptions.length() > 0 ? "Options: " + textOptions + "\n" : "") + "Relation: " + testData.relationName() + "\n\n" + (cm != null ? "Cost matrix:\n" + cm.toString() + "\n" : "") + this.m_eval.toSummaryString(this.getOutputEntropyMetrics());
            if (testData.classAttribute().isNominal()) {
                if (this.getOutputPerClassStats()) {
                    resultT = resultT + "\n" + this.m_eval.toClassDetailsString();
                }
                if (this.getOutputConfusionMatrix()) {
                    resultT = resultT + "\n" + this.m_eval.toMatrixString();
                }
            }
            Data text = new Data("text");
            text.setPayloadElement("text", resultT);
            text.setPayloadElement("aux_textTitle", textTitle);
            this.getStepManager().outputData(text);
            if (this.getCollectPredictionsForVisAndAUC() && this.getStepManager().numOutgoingConnectionsOfType("visualizableError") > 0) {
                PlotData2D errorD = new PlotData2D(this.m_aggregatedPlotInstances);
                errorD.setShapeSize(this.m_aggregatedPlotSizes);
                errorD.setShapeType(this.m_aggregatedPlotShapes);
                errorD.setPlotName(textTitle + " " + textOptions);
                Data visErr = new Data("visualizableError");
                visErr.setPayloadElement("visualizableError", errorD);
                this.getStepManager().outputData(visErr);
            }
            if (testData.classAttribute().isNominal() && this.getCollectPredictionsForVisAndAUC() && this.getStepManager().numOutgoingConnectionsOfType("thresholdData") > 0) {
                String[] options;
                ThresholdCurve tc = new ThresholdCurve();
                Instances result = tc.getCurve(this.m_eval.predictions(), 0);
                result.setRelationName(testData.relationName());
                PlotData2D pd = new PlotData2D(result);
                String htmlTitle = "<html><font size=-2>" + textTitle;
                String newOptions = "";
                if (classifier instanceof OptionHandler && (options = ((OptionHandler)((Object)classifier)).getOptions()).length > 0) {
                    for (int ii = 0; ii < options.length; ++ii) {
                        if (options[ii].length() == 0) continue;
                        if (options[ii].charAt(0) == '-' && (options[ii].charAt(1) < '0' || options[ii].charAt(1) > '9')) {
                            newOptions = newOptions + "<br>";
                        }
                        newOptions = newOptions + options[ii];
                    }
                }
                htmlTitle = htmlTitle + " " + newOptions + "<br> (class: " + testData.classAttribute().value(0) + ")</font></html>";
                pd.setPlotName(textTitle + " (class: " + testData.classAttribute().value(0) + ")");
                pd.setPlotNameHTML(htmlTitle);
                boolean[] connectPoints = new boolean[result.numInstances()];
                for (int jj = 1; jj < connectPoints.length; ++jj) {
                    connectPoints[jj] = true;
                }
                pd.setConnectPoints(connectPoints);
                Data threshData = new Data("thresholdData");
                threshData.setPayloadElement("thresholdData", pd);
                threshData.setPayloadElement("class_attribute", testData.classAttribute());
                this.getStepManager().outputData(threshData);
            }
            this.getStepManager().finished();
        }
        if (this.isStopRequested()) {
            this.getStepManager().interrupted();
        }
    }

    @Override
    public String getCustomEditorForStep() {
        return "weka.gui.knowledgeflow.steps.ClassifierPerformanceEvaluatorStepEditorDialog";
    }

    protected static Evaluation adjustForInputMappedClassifier(Evaluation eval2, Classifier classifier, Instances inst, ClassifierErrorsPlotInstances plotInstances, CostMatrix matrix) throws Exception {
        Instances mappedClassifierHeader;
        if (classifier instanceof InputMappedClassifier && !(eval2 = new Evaluation(new Instances(mappedClassifierHeader = ((InputMappedClassifier)classifier).getModelHeader(new Instances(inst, 0)), 0))).getHeader().equalHeaders(inst)) {
            Instances mappedClassifierDataset = ((InputMappedClassifier)classifier).getModelHeader(new Instances(mappedClassifierHeader, 0));
            for (int zz = 0; zz < inst.numInstances(); ++zz) {
                Instance mapped = ((InputMappedClassifier)classifier).constructMappedInstance(inst.instance(zz));
                mappedClassifierDataset.add(mapped);
            }
            eval2.setPriors(mappedClassifierDataset);
            if (plotInstances != null) {
                plotInstances.setInstances(mappedClassifierDataset);
                plotInstances.setClassifier(classifier);
                plotInstances.setClassIndex(mappedClassifierDataset.classIndex());
                plotInstances.setEvaluation(eval2);
            }
        }
        return eval2;
    }

    protected class EvaluationCallback
    implements StepTaskCallback<Object[]> {
        protected EvaluationCallback() {
        }

        @Override
        public void taskFinished(ExecutionResult<Object[]> result) throws Exception {
            if (!ClassifierPerformanceEvaluator.this.isStopRequested()) {
                Evaluation eval2 = (Evaluation)result.getResult()[0];
                Classifier classifier = (Classifier)result.getResult()[1];
                Instances testData = (Instances)result.getResult()[2];
                ClassifierErrorsPlotInstances plotInstances = (ClassifierErrorsPlotInstances)result.getResult()[3];
                int setNum = (Integer)result.getResult()[4];
                String evalLabel = result.getResult()[5].toString();
                ClassifierPerformanceEvaluator.this.aggregateEvalTask(eval2, classifier, testData, plotInstances, setNum, evalLabel);
            } else {
                ClassifierPerformanceEvaluator.this.getStepManager().interrupted();
            }
            ClassifierPerformanceEvaluator.this.m_taskCount.decrementAndGet();
        }

        @Override
        public void taskFailed(StepTask<Object[]> failedTask, ExecutionResult<Object[]> failedResult) throws Exception {
            Integer setNum = (Integer)failedResult.getResult()[4];
            ClassifierPerformanceEvaluator.this.getStepManager().logError("Evaluation for fold " + setNum + " failed", failedResult.getError());
            ClassifierPerformanceEvaluator.this.m_taskCount.decrementAndGet();
        }
    }

    protected static class EvaluationTask
    extends StepTask<Object[]> {
        private static final long serialVersionUID = -686972773536075889L;
        protected Classifier m_classifier;
        protected CostMatrix m_cMatrix;
        protected Instances m_trainData;
        protected Instances m_testData;
        protected int m_setNum;
        protected List<String> m_metricsList;
        protected boolean m_errPlotPtSizePropToMarg;
        protected String m_evalLabel;
        protected String m_classifierDesc = "";
        protected boolean m_collectPreds;

        public EvaluationTask(Step source, Classifier classifier, Instances trainData, Instances testData, int setNum, List<String> metricsList, boolean errPlotPtSizePropToMarg, String evalLabel, EvaluationCallback callback, CostMatrix matrix, boolean collectPreds) {
            super(source, callback);
            this.m_classifier = classifier;
            this.m_cMatrix = matrix;
            this.m_trainData = trainData;
            this.m_testData = testData;
            this.m_setNum = setNum;
            this.m_metricsList = metricsList;
            this.m_errPlotPtSizePropToMarg = errPlotPtSizePropToMarg;
            this.m_evalLabel = evalLabel;
            this.m_collectPreds = collectPreds;
            this.m_classifierDesc = this.m_classifier.getClass().getCanonicalName();
            this.m_classifierDesc = this.m_classifierDesc.substring(this.m_classifierDesc.lastIndexOf(".") + 1);
            if (this.m_classifier instanceof OptionHandler) {
                String optsString = Utils.joinOptions(((OptionHandler)((Object)this.m_classifier)).getOptions());
                this.m_classifierDesc = this.m_classifierDesc + " " + optsString;
            }
        }

        @Override
        public void process() throws Exception {
            Object[] r = new Object[6];
            r[4] = this.m_setNum;
            this.getExecutionResult().setResult(r);
            this.getLogHandler().statusMessage("Evaluating " + this.m_classifierDesc + " on fold/set " + this.m_setNum);
            this.getLogHandler().logDetailed("Evaluating " + this.m_classifierDesc + " on " + this.m_testData.relationName() + " fold/set " + this.m_setNum);
            ClassifierErrorsPlotInstances plotInstances = this.m_collectPreds ? ExplorerDefaults.getClassifierErrorsPlotInstances() : null;
            Evaluation eval2 = null;
            if (this.m_trainData == null) {
                eval2 = new Evaluation(this.m_testData, this.m_cMatrix);
                if (this.m_collectPreds) {
                    plotInstances.setInstances(this.m_testData);
                    plotInstances.setClassifier(this.m_classifier);
                    plotInstances.setClassIndex(this.m_testData.classIndex());
                    plotInstances.setEvaluation(eval2);
                    plotInstances.setPointSizeProportionalToMargin(this.m_errPlotPtSizePropToMarg);
                }
                eval2 = ClassifierPerformanceEvaluator.adjustForInputMappedClassifier(eval2, this.m_classifier, this.m_testData, plotInstances, this.m_cMatrix);
                eval2.useNoPriors();
                eval2.setMetricsToDisplay(this.m_metricsList);
                eval2.setDiscardPredictions(!this.m_collectPreds);
            } else {
                eval2 = new Evaluation(this.m_trainData, this.m_cMatrix);
                if (this.m_collectPreds) {
                    plotInstances.setInstances(this.m_trainData);
                    plotInstances.setClassifier(this.m_classifier);
                    plotInstances.setClassIndex(this.m_trainData.classIndex());
                    plotInstances.setEvaluation(eval2);
                    plotInstances.setPointSizeProportionalToMargin(this.m_errPlotPtSizePropToMarg);
                }
                eval2 = ClassifierPerformanceEvaluator.adjustForInputMappedClassifier(eval2, this.m_classifier, this.m_trainData, plotInstances, this.m_cMatrix);
                eval2.setMetricsToDisplay(this.m_metricsList);
                eval2.setDiscardPredictions(!this.m_collectPreds);
            }
            if (this.m_collectPreds) {
                plotInstances.setUp();
            }
            if (this.m_classifier instanceof BatchPredictor && ((BatchPredictor)((Object)this.m_classifier)).implementsMoreEfficientBatchPrediction()) {
                double[][] predictions = ((BatchPredictor)((Object)this.m_classifier)).distributionsForInstances(this.m_testData);
                if (this.m_collectPreds) {
                    plotInstances.process(this.m_testData, predictions, eval2);
                } else {
                    for (int i = 0; i < this.m_testData.numInstances(); ++i) {
                        eval2.evaluationForSingleInstance(predictions[i], this.m_testData.instance(i), false);
                    }
                }
            } else {
                for (int i = 0; i < this.m_testData.numInstances(); ++i) {
                    Instance temp = this.m_testData.instance(i);
                    if (this.m_collectPreds) {
                        plotInstances.process(temp, this.m_classifier, eval2);
                        continue;
                    }
                    eval2.evaluateModelOnce(this.m_classifier, temp);
                }
            }
            r[0] = eval2;
            r[1] = this.m_classifier;
            r[2] = this.m_testData;
            r[3] = this.m_collectPreds ? plotInstances : null;
            r[5] = this.m_evalLabel;
        }
    }

    protected static class AggregateableClassifierErrorsPlotInstances
    extends ClassifierErrorsPlotInstances {
        private static final long serialVersionUID = 2012744784036684168L;

        protected AggregateableClassifierErrorsPlotInstances() {
        }

        @Override
        public void setPlotShapes(ArrayList<Integer> plotShapes) {
            this.m_PlotShapes = plotShapes;
        }

        @Override
        public void setPlotSizes(ArrayList<Object> plotSizes) {
            this.m_PlotSizes = plotSizes;
        }

        public void setPlotInstances(Instances inst) {
            this.m_PlotInstances = inst;
        }

        @Override
        protected void finishUp() {
            this.m_FinishUpCalled = true;
            if (!this.m_SaveForVisualization) {
                return;
            }
            if (this.m_Instances.classAttribute().isNumeric() || this.m_pointSizeProportionalToMargin) {
                this.scaleNumericPredictions();
            }
        }
    }
}

