package org.cleartk.classifier.feature.transform.extractor;

import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.net.MalformedURLException;
import java.net.URI;
import java.net.URISyntaxException;
import java.net.URL;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import org.apache.uima.jcas.JCas;
import org.apache.uima.jcas.tcas.Annotation;
import org.cleartk.classifier.Feature;
import org.cleartk.classifier.Instance;
import org.cleartk.classifier.feature.extractor.CleartkExtractorException;
import org.cleartk.classifier.feature.extractor.simple.SimpleFeatureExtractor;
import org.cleartk.classifier.feature.transform.TransformableFeature;
import org.cleartk.classifier.feature.transform.extractor.TfidfExtractor;

/* loaded from: input_file:org/cleartk/classifier/feature/transform/extractor/CentroidTfidfSimilarityExtractor.class */
public class CentroidTfidfSimilarityExtractor<OUTCOME_T> extends TfidfExtractor<OUTCOME_T> {
    private Map<String, Double> centroidMap;
    private SimilarityFunction simFunction;
    private static String docFreqFileSuffix = "_tfidf-centroid-extractor_idfmap.dat";
    private static String centroidMapFileSuffix = "_tfidf-centroid-extractor_centroidmap.dat";

    public static URI getDocumentFrequencyDataURI(String str, URI uri) throws MalformedURLException, URISyntaxException {
        return new URL(uri.toURL(), str + docFreqFileSuffix).toURI();
    }

    public static URI getCentroidDataURI(String str, URI uri) throws MalformedURLException, URISyntaxException {
        return new URL(uri.toURL(), str + centroidMapFileSuffix).toURI();
    }

    public CentroidTfidfSimilarityExtractor(String str) {
        super(str);
    }

    public CentroidTfidfSimilarityExtractor(String str, SimpleFeatureExtractor simpleFeatureExtractor) {
        super(str);
        this.subExtractor = simpleFeatureExtractor;
        this.isTrained = false;
        this.idfMap = new TfidfExtractor.IDFMap();
    }

    @Override // org.cleartk.classifier.feature.transform.OneToOneTrainableExtractor_ImplBase, org.cleartk.classifier.feature.transform.TrainableExtractor
    public Instance<OUTCOME_T> transform(Instance<OUTCOME_T> instance) {
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        for (Feature feature : instance.getFeatures()) {
            if (isTransformable(feature)) {
                arrayList2.addAll(((TransformableFeature) feature).getFeatures());
            } else {
                arrayList.add(feature);
            }
        }
        arrayList.add(new Feature(this.name, new Double(this.simFunction.distance(featuresToFeatureMap(arrayList2), this.centroidMap))));
        return new Instance<>(instance.getOutcome(), arrayList);
    }

    public Map<String, Double> featuresToFeatureMap(List<Feature> list) {
        HashMap hashMap = new HashMap();
        Iterator<Feature> it = list.iterator();
        while (it.hasNext()) {
            String name = it.next().getName();
            hashMap.put(name, Double.valueOf(((Integer) r0.getValue()).intValue() * this.idfMap.getIDF(name)));
        }
        return hashMap;
    }

    @Override // org.cleartk.classifier.feature.transform.extractor.TfidfExtractor, org.cleartk.classifier.feature.extractor.simple.SimpleFeatureExtractor
    public List<Feature> extract(JCas jCas, Annotation annotation) throws CleartkExtractorException {
        List<Feature> extract = this.subExtractor.extract(jCas, annotation);
        ArrayList arrayList = new ArrayList();
        if (this.isTrained) {
            arrayList.add(new Feature(this.name, Double.valueOf(this.simFunction.distance(featuresToFeatureMap(extract), this.centroidMap))));
        } else {
            arrayList.add(new TransformableFeature(this.name, extract));
        }
        return arrayList;
    }

    protected Map<String, Double> computeCentroid(Iterable<Instance<OUTCOME_T>> iterable, TfidfExtractor.IDFMap iDFMap) {
        int totalDocumentCount = iDFMap.getTotalDocumentCount();
        HashMap hashMap = new HashMap();
        Iterator<Instance<OUTCOME_T>> it = iterable.iterator();
        while (it.hasNext()) {
            for (Feature feature : it.next().getFeatures()) {
                if (isTransformable(feature)) {
                    Iterator<Feature> it2 = ((TransformableFeature) feature).getFeatures().iterator();
                    while (it2.hasNext()) {
                        String name = it2.next().getName();
                        hashMap.put(name, Double.valueOf((hashMap.containsKey(name) ? ((Double) hashMap.get(name)).doubleValue() : 0.0d) + (((Integer) r0.getValue()).intValue() * iDFMap.getIDF(name))));
                    }
                }
            }
            for (Map.Entry entry : hashMap.entrySet()) {
                hashMap.put(entry.getKey(), Double.valueOf(((Double) entry.getValue()).doubleValue() / totalDocumentCount));
            }
        }
        return hashMap;
    }

    @Override // org.cleartk.classifier.feature.transform.extractor.TfidfExtractor, org.cleartk.classifier.feature.transform.TrainableExtractor
    public void train(Iterable<Instance<OUTCOME_T>> iterable) {
        this.idfMap = createIdfMap(iterable);
        this.centroidMap = computeCentroid(iterable, this.idfMap);
        this.isTrained = true;
        this.simFunction = new FixedCosineSimilarity(this.centroidMap);
    }

    @Override // org.cleartk.classifier.feature.transform.extractor.TfidfExtractor, org.cleartk.classifier.feature.transform.TrainableExtractor
    public void save(URI uri) throws IOException {
        try {
            URI documentFrequencyDataURI = getDocumentFrequencyDataURI(this.name, uri);
            URI centroidDataURI = getCentroidDataURI(this.name, uri);
            this.idfMap.save(documentFrequencyDataURI);
            BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(new File(centroidDataURI)));
            for (Map.Entry<String, Double> entry : this.centroidMap.entrySet()) {
                bufferedWriter.append((CharSequence) String.format(Locale.ROOT, "%s\t%f\n", entry.getKey(), entry.getValue()));
            }
            bufferedWriter.close();
        } catch (URISyntaxException e) {
            throw new IOException(e);
        }
    }

    @Override // org.cleartk.classifier.feature.transform.extractor.TfidfExtractor, org.cleartk.classifier.feature.transform.TrainableExtractor
    public void load(URI uri) throws IOException {
        try {
            URI documentFrequencyDataURI = getDocumentFrequencyDataURI(this.name, uri);
            URI centroidDataURI = getCentroidDataURI(this.name, uri);
            this.idfMap.load(documentFrequencyDataURI);
            File file = new File(centroidDataURI);
            this.centroidMap = new HashMap();
            BufferedReader bufferedReader = new BufferedReader(new FileReader(file));
            while (true) {
                String readLine = bufferedReader.readLine();
                if (readLine == null) {
                    bufferedReader.close();
                    this.simFunction = new FixedCosineSimilarity(this.centroidMap);
                    this.isTrained = true;
                    return;
                } else {
                    String[] split = readLine.split("\\t");
                    this.centroidMap.put(split[0], Double.valueOf(Double.parseDouble(split[1])));
                }
            }
        } catch (URISyntaxException e) {
            throw new IOException(e);
        }
    }
}
