package org.cleartk.classifier.feature.transform.extractor;

import com.google.common.collect.LinkedHashMultiset;
import com.google.common.collect.Multiset;
import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.net.URI;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Locale;
import org.apache.uima.jcas.JCas;
import org.apache.uima.jcas.tcas.Annotation;
import org.cleartk.classifier.Feature;
import org.cleartk.classifier.Instance;
import org.cleartk.classifier.feature.extractor.CleartkExtractorException;
import org.cleartk.classifier.feature.extractor.simple.SimpleFeatureExtractor;
import org.cleartk.classifier.feature.transform.OneToOneTrainableExtractor_ImplBase;
import org.cleartk.classifier.feature.transform.TransformableFeature;

/* loaded from: input_file:org/cleartk/classifier/feature/transform/extractor/TfidfExtractor.class */
public class TfidfExtractor<OUTCOME_T> extends OneToOneTrainableExtractor_ImplBase<OUTCOME_T> implements SimpleFeatureExtractor {
    protected SimpleFeatureExtractor subExtractor;
    protected boolean isTrained;
    protected IDFMap idfMap;

    /* JADX INFO: Access modifiers changed from: protected */
    /* loaded from: input_file:org/cleartk/classifier/feature/transform/extractor/TfidfExtractor$IDFMap.class */
    public static class IDFMap {
        private Multiset<String> documentFreqMap = LinkedHashMultiset.create();
        private int totalDocumentCount = 0;

        public void add(String str) {
            this.documentFreqMap.add(str);
        }

        public void incTotalDocumentCount() {
            this.totalDocumentCount++;
        }

        public int getTotalDocumentCount() {
            return this.totalDocumentCount;
        }

        public int getDF(String str) {
            return this.documentFreqMap.count(str);
        }

        public double getIDF(String str) {
            return Math.log((this.totalDocumentCount + 1) / (getDF(str) + 1));
        }

        public void save(URI uri) throws IOException {
            BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(new File(uri)));
            bufferedWriter.append((CharSequence) String.format(Locale.ROOT, "NUM DOCUMENTS\t%d\n", Integer.valueOf(this.totalDocumentCount)));
            for (Multiset.Entry entry : this.documentFreqMap.entrySet()) {
                bufferedWriter.append((CharSequence) String.format(Locale.ROOT, "%s\t%d\n", entry.getElement(), Integer.valueOf(entry.getCount())));
            }
            bufferedWriter.close();
        }

        public void load(URI uri) throws IOException {
            BufferedReader bufferedReader = new BufferedReader(new FileReader(new File(uri)));
            this.totalDocumentCount = Integer.parseInt(bufferedReader.readLine().split("\\t")[1]);
            while (true) {
                String readLine = bufferedReader.readLine();
                if (readLine == null) {
                    bufferedReader.close();
                    return;
                } else {
                    String[] split = readLine.split("\\t");
                    this.documentFreqMap.add(split[0], Integer.parseInt(split[1]));
                }
            }
        }
    }

    public TfidfExtractor(String str) {
        this(str, null);
    }

    public TfidfExtractor(String str, SimpleFeatureExtractor simpleFeatureExtractor) {
        super(str);
        this.subExtractor = simpleFeatureExtractor;
        this.isTrained = false;
        this.idfMap = new IDFMap();
    }

    @Override // org.cleartk.classifier.feature.transform.OneToOneTrainableExtractor_ImplBase
    protected Feature transform(Feature feature) {
        return new Feature("TF-IDF_" + feature.getName(), Double.valueOf(((Integer) feature.getValue()).intValue() * this.idfMap.getIDF(feature.getName())));
    }

    public List<Feature> extract(JCas jCas, Annotation annotation) throws CleartkExtractorException {
        List<Feature> extract = this.subExtractor.extract(jCas, annotation);
        ArrayList arrayList = new ArrayList();
        if (this.isTrained) {
            Iterator<Feature> it = extract.iterator();
            while (it.hasNext()) {
                arrayList.add(transform(it.next()));
            }
        } else {
            arrayList.add(new TransformableFeature(this.name, extract));
        }
        return arrayList;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public IDFMap createIdfMap(Iterable<Instance<OUTCOME_T>> iterable) {
        IDFMap iDFMap = new IDFMap();
        for (Instance<OUTCOME_T> instance : iterable) {
            HashSet hashSet = new HashSet();
            for (Feature feature : instance.getFeatures()) {
                if (isTransformable(feature)) {
                    Iterator<Feature> it = ((TransformableFeature) feature).getFeatures().iterator();
                    while (it.hasNext()) {
                        hashSet.add(it.next().getName());
                    }
                }
            }
            Iterator it2 = hashSet.iterator();
            while (it2.hasNext()) {
                iDFMap.add((String) it2.next());
            }
            iDFMap.incTotalDocumentCount();
        }
        return iDFMap;
    }

    @Override // org.cleartk.classifier.feature.transform.TrainableExtractor
    public void train(Iterable<Instance<OUTCOME_T>> iterable) {
        this.idfMap = createIdfMap(iterable);
        this.isTrained = true;
    }

    @Override // org.cleartk.classifier.feature.transform.TrainableExtractor
    public void save(URI uri) throws IOException {
        this.idfMap.save(uri);
    }

    @Override // org.cleartk.classifier.feature.transform.TrainableExtractor
    public void load(URI uri) throws IOException {
        this.idfMap.load(uri);
        this.isTrained = true;
    }
}
