package ai.djl.modality;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.types.DataType;
import ai.djl.util.JsonSerializable;
import ai.djl.util.JsonUtils;
import com.google.gson.Gson;
import com.google.gson.JsonElement;
import com.google.gson.JsonSerializationContext;
import com.google.gson.JsonSerializer;
import java.lang.reflect.Type;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.Iterator;
import java.util.List;
import java.util.stream.Collectors;

/* loaded from: input_file:WEB-INF/lib/api-0.19.0.jar:ai/djl/modality/Classifications.class */
public class Classifications implements JsonSerializable {
    private static final long serialVersionUID = 1;
    private static final Gson GSON = JsonUtils.builder().registerTypeAdapter(Classifications.class, new ClassificationsSerializer()).create();
    protected List<String> classNames;
    protected List<Double> probabilities;
    private int topK;

    /* loaded from: input_file:WEB-INF/lib/api-0.19.0.jar:ai/djl/modality/Classifications$Classification.class */
    public static class Classification {
        private String className;
        private double probability;

        public Classification(String str, double d) {
            this.className = str;
            this.probability = d;
        }

        public String getClassName() {
            return this.className;
        }

        public double getProbability() {
            return this.probability;
        }

        public String toString() {
            if (this.probability < 1.0E-5d) {
                return String.format("class: \"%s\", probability: %.1e", this.className, Double.valueOf(this.probability));
            }
            this.probability = ((int) (this.probability * 100000.0d)) / 100000.0f;
            return String.format("class: \"%s\", probability: %.5f", this.className, Double.valueOf(this.probability));
        }
    }

    /* loaded from: input_file:WEB-INF/lib/api-0.19.0.jar:ai/djl/modality/Classifications$ClassificationsSerializer.class */
    public static final class ClassificationsSerializer implements JsonSerializer<Classifications> {
        @Override // com.google.gson.JsonSerializer
        public JsonElement serialize(Classifications classifications, Type type, JsonSerializationContext jsonSerializationContext) {
            return jsonSerializationContext.serialize(classifications.topK());
        }
    }

    public Classifications(List<String> list, List<Double> list2) {
        this.classNames = list;
        this.probabilities = list2;
        this.topK = 5;
    }

    public Classifications(List<String> list, NDArray nDArray) {
        this(list, nDArray, 5);
    }

    public Classifications(List<String> list, NDArray nDArray, int i) {
        this.classNames = list;
        NDArray type = nDArray.toType(DataType.FLOAT64, false);
        this.probabilities = (List) Arrays.stream(type.toDoubleArray()).boxed().collect(Collectors.toList());
        type.close();
        this.topK = i;
    }

    public final void setTopK(int i) {
        this.topK = i;
    }

    public <T extends Classification> List<T> items() {
        ArrayList arrayList = new ArrayList(this.classNames.size());
        for (int i = 0; i < this.classNames.size(); i++) {
            arrayList.add(item(i));
        }
        return arrayList;
    }

    public <T extends Classification> T item(int i) {
        return (T) new Classification(this.classNames.get(i), this.probabilities.get(i).doubleValue());
    }

    public <T extends Classification> List<T> topK() {
        return topK(this.topK);
    }

    public <T extends Classification> List<T> topK(int i) {
        List<T> items = items();
        items.sort(Comparator.comparingDouble((v0) -> {
            return v0.getProbability();
        }).reversed());
        return items.subList(0, Math.min(items.size(), i));
    }

    public <T extends Classification> T best() {
        return (T) item(this.probabilities.indexOf(Collections.max(this.probabilities)));
    }

    public <T extends Classification> T get(String str) {
        int size = this.classNames.size();
        for (int i = 0; i < size; i++) {
            if (this.classNames.get(i).equals(str)) {
                return (T) item(i);
            }
        }
        return null;
    }

    @Override // ai.djl.util.JsonSerializable
    public String toJson() {
        return GSON.toJson(this) + '\n';
    }

    @Override // ai.djl.ndarray.BytesSupplier
    public String getAsString() {
        return toJson();
    }

    @Override // ai.djl.ndarray.BytesSupplier
    public ByteBuffer toByteBuffer() {
        return ByteBuffer.wrap(toJson().getBytes(StandardCharsets.UTF_8));
    }

    public String toString() {
        StringBuilder sb = new StringBuilder();
        sb.append('[').append(System.lineSeparator());
        Iterator it = topK(this.topK).iterator();
        while (it.hasNext()) {
            sb.append('\t').append((Classification) it.next()).append(System.lineSeparator());
        }
        sb.append(']');
        return sb.toString();
    }
}
