package dm.algorithms;

import dm.data.DataObject;
import dm.data.database.Database;
import dm.data.featureVector.FeatureVector;
import dm.data.text.FeatureSelector;
import dm.data.texttype.TextDoc;
import dm.util.MathUtil;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.util.Iterator;
import java.util.Map;
import weka.classifiers.functions.Logistic;
import weka.core.Attribute;
import weka.core.FastVector;
import weka.core.Instance;
import weka.core.Instances;

/* loaded from: input_file:dm/algorithms/WekaLogistic.class */
public class WekaLogistic implements Classifier {
    private Logistic function;
    private Instances instances;
    private FeatureSelector fs;
    private int classCount;
    private FastVector header;

    private FastVector getAttributeInfo(int i) {
        FastVector fastVector = new FastVector(this.fs.getWordNumber());
        int i2 = 0;
        for (String str : this.fs.getWords()) {
            fastVector.addElement(new Attribute(str));
            int i3 = i2;
            i2++;
            this.fs.update(str, i3);
        }
        FastVector fastVector2 = new FastVector();
        for (int i4 = 0; i4 < i; i4++) {
            fastVector2.addElement(new Integer(i4).toString());
        }
        fastVector.addElement(new Attribute("class", fastVector2));
        return fastVector;
    }

    private FastVector getAttributeInfo(int i, int i2) {
        FastVector fastVector = new FastVector(i2);
        for (int i3 = 0; i3 < i2; i3++) {
            fastVector.addElement(new Attribute("Attribute" + i3));
        }
        FastVector fastVector2 = new FastVector();
        for (int i4 = 0; i4 < i; i4++) {
            fastVector2.addElement(new Integer(i4).toString());
        }
        fastVector.addElement(new Attribute("class", fastVector2));
        return fastVector;
    }

    private Instance generateInstancefromTextDoc(TextDoc textDoc, Instances instances, int i) {
        double[] dArr = new double[instances.numAttributes()];
        for (Map.Entry entry : textDoc.getWords().entrySet()) {
            String str = (String) entry.getKey();
            Double d = (Double) entry.getValue();
            Integer index = this.fs.getIndex(str);
            if (index != null) {
                dArr[index.intValue()] = d.doubleValue();
            }
        }
        Instance instance = new Instance(1.0d, dArr);
        instance.setDataset(instances);
        instance.setValue(instances.numAttributes() - 1, new Integer(i).toString());
        return instance;
    }

    private Instance generateInstancefromEuc(FeatureVector featureVector, Instances instances, int i) {
        double[] dArr = new double[instances.numAttributes()];
        for (int i2 = 0; i2 < featureVector.values.length; i2++) {
            dArr[i2] = featureVector.values[i2];
        }
        Instance instance = new Instance(1.0d, dArr);
        instance.setDataset(instances);
        instance.setValue(instances.numAttributes() - 1, new Integer(i).toString());
        return instance;
    }

    public WekaLogistic(Database[] databaseArr, FeatureSelector featureSelector, int i) {
        this.fs = featureSelector;
        this.classCount = databaseArr.length;
        this.header = getAttributeInfo(databaseArr.length);
        int i2 = 0;
        for (Database database : databaseArr) {
            i2 += database.getCount();
        }
        this.instances = new Instances("Dataset", this.header, i2);
        for (int i3 = 0; i3 < databaseArr.length; i3++) {
            Iterator objectIterator = databaseArr[i3].objectIterator();
            while (objectIterator.hasNext()) {
                this.instances.add(generateInstancefromTextDoc((TextDoc) objectIterator.next(), this.instances, i3));
            }
        }
        this.instances.setClassIndex(this.instances.numAttributes() - 1);
        this.function = new Logistic();
        try {
            this.function.buildClassifier(this.instances);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public WekaLogistic(String str) {
        try {
            FileInputStream fileInputStream = new FileInputStream(String.valueOf(str) + ".clf");
            ObjectInputStream objectInputStream = new ObjectInputStream(fileInputStream);
            this.classCount = objectInputStream.readInt();
            this.fs = (FeatureSelector) objectInputStream.readObject();
            this.instances = (Instances) objectInputStream.readObject();
            this.function = (Logistic) objectInputStream.readObject();
            this.header = (FastVector) objectInputStream.readObject();
            objectInputStream.close();
            fileInputStream.close();
        } catch (Exception e) {
            e.printStackTrace();
            System.exit(1);
        }
    }

    public WekaLogistic(Database[] databaseArr, int i) {
        this.classCount = databaseArr.length;
        this.header = getAttributeInfo(databaseArr.length, i);
        int i2 = 0;
        for (Database database : databaseArr) {
            i2 += database.getCount();
        }
        this.instances = new Instances("Dataset", this.header, i2);
        for (int i3 = 0; i3 < databaseArr.length; i3++) {
            Iterator objectIterator = databaseArr[i3].objectIterator();
            while (objectIterator.hasNext()) {
                this.instances.add(generateInstancefromEuc((FeatureVector) objectIterator.next(), this.instances, i3));
            }
        }
        this.instances.setClassIndex(this.instances.numAttributes() - 1);
        this.function = new Logistic();
        try {
            this.function.buildClassifier(this.instances);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public double[] getDistribution(FeatureVector featureVector) {
        try {
            return this.function.distributionForInstance(generateInstancefromEuc(featureVector, this.instances, 0));
        } catch (Exception e) {
            e.printStackTrace();
            return new double[this.classCount];
        }
    }

    public double[] getDistribution(TextDoc textDoc) {
        try {
            return this.function.distributionForInstance(generateInstancefromTextDoc(textDoc, this.instances, 0));
        } catch (Exception e) {
            e.printStackTrace();
            return new double[this.classCount];
        }
    }

    public FeatureSelector getFeatureSelector() {
        return this.fs;
    }

    @Override // dm.algorithms.Classifier
    public int classify(DataObject dataObject) {
        return MathUtil.argmax(getDistribution(dataObject));
    }

    @Override // dm.algorithms.Classifier
    public double[] getDistribution(DataObject dataObject) {
        try {
            return getDistribution((TextDoc) dataObject);
        } catch (Exception e) {
            try {
                return getDistribution((FeatureVector) dataObject);
            } catch (Exception e2) {
                e2.printStackTrace();
                return new double[this.classCount];
            }
        }
    }

    public int getClassNumber() {
        return this.classCount;
    }

    @Override // dm.algorithms.Classifier
    public void saveClassifier(String str) {
        try {
            System.out.println("Saving " + str + ".clf");
            FileOutputStream fileOutputStream = new FileOutputStream(String.valueOf(str) + ".clf");
            ObjectOutputStream objectOutputStream = new ObjectOutputStream(fileOutputStream);
            objectOutputStream.writeInt(this.classCount);
            objectOutputStream.writeObject(this.fs);
            objectOutputStream.writeObject(this.instances);
            objectOutputStream.writeObject(this.function);
            objectOutputStream.writeObject(this.header);
            objectOutputStream.flush();
            objectOutputStream.close();
            fileOutputStream.flush();
            fileOutputStream.close();
        } catch (IOException e) {
            e.printStackTrace();
        }
    }
}
