package org.nd4j.linalg.dataset;

import com.google.common.base.Function;
import com.google.common.collect.Lists;
import com.hp.hpl.jena.util.FileManager;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Queue;
import org.apache.commons.math3.random.MersenneTwister;
import org.apache.commons.math3.random.RandomGenerator;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.BlasWrapper;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.BooleanIndexing;
import org.nd4j.linalg.indexing.conditions.Condition;
import org.nd4j.linalg.util.FeatureUtil;
import org.nd4j.linalg.util.MathUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:WEB-INF/lib/nd4j-api-0.0.3.5.5.jar:org/nd4j/linalg/dataset/DataSet.class */
public class DataSet implements org.nd4j.linalg.dataset.api.DataSet {
    private static final long serialVersionUID = 1935520764586513365L;
    private static Logger log = LoggerFactory.getLogger((Class<?>) DataSet.class);
    private List<String> columnNames;
    private List<String> labelNames;
    private INDArray features;
    private INDArray labels;

    public DataSet() {
        this(Nd4j.zeros(new int[]{1}), Nd4j.zeros(new int[]{1}));
    }

    public DataSet(INDArray iNDArray, INDArray iNDArray2) {
        this.columnNames = new ArrayList();
        this.labelNames = new ArrayList();
        if (iNDArray.rows() != iNDArray2.rows()) {
            throw new IllegalStateException("Invalid data applyTransformToDestination; first and second do not have equal rows. First was " + iNDArray.rows() + " second was " + iNDArray2.rows());
        }
        this.features = iNDArray;
        this.labels = iNDArray2;
    }

    @Override // org.nd4j.linalg.dataset.api.DataSet
    public INDArray getFeatures() {
        return this.features;
    }

    @Override // org.nd4j.linalg.dataset.api.DataSet
    public Map<Integer, Double> labelCounts() {
        HashMap hashMap = new HashMap();
        if (this.labels == null) {
            return hashMap;
        }
        for (int i = 0; i < this.labels.rows(); i++) {
            int iamax = Nd4j.getBlasWrapper().iamax((BlasWrapper) this.labels.getRow(i));
            if (hashMap.get(Integer.valueOf(iamax)) == null) {
                hashMap.put(Integer.valueOf(iamax), Double.valueOf(1.0d));
            } else {
                hashMap.put(Integer.valueOf(iamax), Double.valueOf(((Double) hashMap.get(Integer.valueOf(iamax))).doubleValue() + 1.0d));
            }
        }
        return hashMap;
    }

    @Override // org.nd4j.linalg.dataset.api.DataSet
    public void apply(Condition condition, Function<Number, Number> function) {
        BooleanIndexing.applyWhere(getFeatureMatrix(), condition, function);
    }

    @Override // org.nd4j.linalg.dataset.api.DataSet
    public void setFeatures(INDArray iNDArray) {
        this.features = iNDArray;
    }

    @Override // org.nd4j.linalg.dataset.api.DataSet
    public void setLabels(INDArray iNDArray) {
        this.labels = iNDArray;
    }

    @Override // org.nd4j.linalg.dataset.api.DataSet
    public DataSet copy() {
        return new DataSet(getFeatures().dup(), getLabels().dup());
    }

    public static DataSet empty() {
        return new DataSet(Nd4j.zeros(new int[]{1}), Nd4j.zeros(new int[]{1}));
    }

    public static DataSet merge(List<DataSet> list) {
        if (list.isEmpty()) {
            throw new IllegalArgumentException("Unable to merge empty dataset");
        }
        DataSet dataSet = list.get(0);
        int i = totalExamples(list);
        INDArray create = Nd4j.create(i, dataSet.getFeatures().columns());
        INDArray create2 = Nd4j.create(i, dataSet.getLabels().columns());
        int i2 = 0;
        for (int i3 = 0; i3 < list.size(); i3++) {
            DataSet dataSet2 = list.get(i3);
            for (int i4 = 0; i4 < dataSet2.numExamples(); i4++) {
                DataSet dataSet3 = dataSet2.get(i4);
                create.putRow(i2, dataSet3.getFeatures());
                create2.putRow(i2, dataSet3.getLabels());
                i2++;
            }
        }
        return new DataSet(create, create2);
    }

    @Override // org.nd4j.linalg.dataset.api.DataSet
    public DataSet reshape(int i, int i2) {
        return new DataSet(getFeatures().reshape(i, i2), getLabels());
    }

    @Override // org.nd4j.linalg.dataset.api.DataSet
    public void multiplyBy(double d) {
        getFeatures().muli(Nd4j.scalar(d));
    }

    @Override // org.nd4j.linalg.dataset.api.DataSet
    public void divideBy(int i) {
        getFeatures().divi(Nd4j.scalar(i));
    }

    @Override // org.nd4j.linalg.dataset.api.DataSet
    public void shuffle() {
        List<DataSet> asList = asList();
        Collections.shuffle(asList);
        DataSet merge = merge(asList);
        setFeatures(merge.getFeatures());
        setLabels(merge.getLabels());
    }

    @Override // org.nd4j.linalg.dataset.api.DataSet
    public void squishToRange(double d, double d2) {
        for (int i = 0; i < getFeatures().length(); i++) {
            double doubleValue = ((Double) getFeatures().getScalar(i).element()).doubleValue();
            if (doubleValue < d) {
                getFeatures().put(i, Nd4j.scalar(d));
            } else if (doubleValue > d2) {
                getFeatures().put(i, Nd4j.scalar(d2));
            }
        }
    }

    @Override // org.nd4j.linalg.dataset.api.DataSet
    public void scaleMinAndMax(double d, double d2) {
        FeatureUtil.scaleMinMax(d, d2, getFeatureMatrix());
    }

    @Override // org.nd4j.linalg.dataset.api.DataSet
    public void scale() {
        FeatureUtil.scaleByMax(getFeatures());
    }

    @Override // org.nd4j.linalg.dataset.api.DataSet
    public void addFeatureVector(INDArray iNDArray) {
        setFeatures(Nd4j.hstack(new INDArray[0]));
    }

    @Override // org.nd4j.linalg.dataset.api.DataSet
    public void addFeatureVector(INDArray iNDArray, int i) {
        getFeatures().putRow(i, Nd4j.hstack(new INDArray[0]));
    }

    @Override // org.nd4j.linalg.dataset.api.DataSet
    public void normalize() {
        FeatureUtil.normalizeMatrix(getFeatures());
    }

    @Override // org.nd4j.linalg.dataset.api.DataSet
    public void binarize() {
        binarize(0.0d);
    }

    @Override // org.nd4j.linalg.dataset.api.DataSet
    public void binarize(double d) {
        for (int i = 0; i < getFeatures().length(); i++) {
            if (((Double) getFeatures().getScalar(i).element()).doubleValue() > d) {
                getFeatures().put(i, Nd4j.scalar(1.0f));
            } else {
                getFeatures().put(i, Nd4j.scalar(0.0f));
            }
        }
    }

    @Override // org.nd4j.linalg.dataset.api.DataSet
    public void normalizeZeroMeanZeroUnitVariance() {
        INDArray mean = getFeatures().mean(0);
        INDArray std = getFeatureMatrix().std(0);
        setFeatures(getFeatures().subiRowVector(mean));
        std.addi(Nd4j.scalar(1.0E-6d));
        setFeatures(getFeatures().diviRowVector(std));
    }

    private static int totalExamples(Collection<DataSet> collection) {
        int i = 0;
        Iterator<DataSet> it = collection.iterator();
        while (it.hasNext()) {
            i += it.next().numExamples();
        }
        return i;
    }

    @Override // org.nd4j.linalg.dataset.api.DataSet
    public int numInputs() {
        return getFeatures().columns();
    }

    @Override // org.nd4j.linalg.dataset.api.DataSet
    public void validate() {
        if (getFeatures().rows() != getLabels().rows()) {
            throw new IllegalStateException("Invalid dataset");
        }
    }

    @Override // org.nd4j.linalg.dataset.api.DataSet
    public int outcome() {
        if (numExamples() > 1) {
            throw new IllegalStateException("Unable to derive outcome for dataset greater than one row");
        }
        return Nd4j.getBlasWrapper().iamax((BlasWrapper) getLabels());
    }

    @Override // org.nd4j.linalg.dataset.api.DataSet
    public void setNewNumberOfLabels(int i) {
        setLabels(Nd4j.create(numExamples(), i));
    }

    @Override // org.nd4j.linalg.dataset.api.DataSet
    public void setOutcome(int i, int i2) {
        if (i > numExamples()) {
            throw new IllegalArgumentException("No example at " + i);
        }
        if (i2 > numOutcomes() || i2 < 0) {
            throw new IllegalArgumentException("Illegal label");
        }
        getLabels().putRow(i, FeatureUtil.toOutcomeVector(i2, numOutcomes()));
    }

    @Override // org.nd4j.linalg.dataset.api.DataSet
    public DataSet get(int i) {
        if (i > numExamples() || i < 0) {
            throw new IllegalArgumentException("invalid example number");
        }
        return new DataSet(getFeatures().getRow(i), getLabels().getRow(i));
    }

    @Override // org.nd4j.linalg.dataset.api.DataSet
    public DataSet get(int[] iArr) {
        return new DataSet(getFeatures().getRows(iArr), getLabels().getRows(iArr));
    }

    @Override // org.nd4j.linalg.dataset.api.DataSet
    public List<List<DataSet>> batchBy(int i) {
        return Lists.partition(asList(), i);
    }

    @Override // org.nd4j.linalg.dataset.api.DataSet
    public DataSet filterBy(int[] iArr) {
        List<DataSet> asList = asList();
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        for (int i : iArr) {
            arrayList2.add(Integer.valueOf(i));
        }
        for (DataSet dataSet : asList) {
            if (arrayList2.contains(Integer.valueOf(dataSet.outcome()))) {
                arrayList.add(dataSet);
            }
        }
        return merge(arrayList);
    }

    @Override // org.nd4j.linalg.dataset.api.DataSet
    public void filterAndStrip(int[] iArr) {
        DataSet filterBy = filterBy(iArr);
        ArrayList arrayList = new ArrayList();
        HashMap hashMap = new HashMap();
        for (int i = 0; i < iArr.length; i++) {
            hashMap.put(Integer.valueOf(iArr[i]), Integer.valueOf(i));
        }
        for (int i2 = 0; i2 < filterBy.numExamples(); i2++) {
            arrayList.add((Integer) hashMap.get(Integer.valueOf(filterBy.get(i2).outcome())));
        }
        INDArray create = Nd4j.create(filterBy.numExamples(), iArr.length);
        if (create.rows() != arrayList.size()) {
            throw new IllegalStateException("Inconsistent label sizes");
        }
        for (int i3 = 0; i3 < create.rows(); i3++) {
            Integer num = (Integer) arrayList.get(i3);
            if (num == null) {
                throw new IllegalStateException("Label not found on row " + i3);
            }
            create.putRow(i3, FeatureUtil.toOutcomeVector(num.intValue(), iArr.length));
        }
        setFeatures(filterBy.getFeatures());
        setLabels(create);
    }

    @Override // org.nd4j.linalg.dataset.api.DataSet
    public List<DataSet> dataSetBatches(int i) {
        List partition = Lists.partition(asList(), i);
        ArrayList arrayList = new ArrayList();
        Iterator it = partition.iterator();
        while (it.hasNext()) {
            arrayList.add(merge((List) it.next()));
        }
        return arrayList;
    }

    @Override // org.nd4j.linalg.dataset.api.DataSet
    public List<List<DataSet>> sortAndBatchByNumLabels() {
        sortByLabel();
        return Lists.partition(asList(), numOutcomes());
    }

    @Override // org.nd4j.linalg.dataset.api.DataSet
    public List<List<DataSet>> batchByNumLabels() {
        return Lists.partition(asList(), numOutcomes());
    }

    @Override // org.nd4j.linalg.dataset.api.DataSet
    public List<DataSet> asList() {
        ArrayList arrayList = new ArrayList(numExamples());
        for (int i = 0; i < numExamples(); i++) {
            arrayList.add(new DataSet(getFeatures().getRow(i), getLabels().getRow(i)));
        }
        return arrayList;
    }

    @Override // org.nd4j.linalg.dataset.api.DataSet
    public SplitTestAndTrain splitTestAndTrain(int i) {
        if (i >= numExamples()) {
            throw new IllegalArgumentException("Unable to split on size larger than the number of rows");
        }
        List<DataSet> asList = asList();
        Collections.rotate(asList, 3);
        Collections.shuffle(asList);
        ArrayList arrayList = new ArrayList();
        arrayList.add(asList.subList(0, i));
        arrayList.add(asList.subList(i, asList.size()));
        return new SplitTestAndTrain(merge((List) arrayList.get(0)), merge((List) arrayList.get(1)));
    }

    @Override // org.nd4j.linalg.dataset.api.DataSet
    public INDArray getLabels() {
        return this.labels;
    }

    @Override // org.nd4j.linalg.dataset.api.DataSet
    public INDArray getFeatureMatrix() {
        return getFeatures();
    }

    @Override // org.nd4j.linalg.dataset.api.DataSet
    public void sortByLabel() {
        HashMap hashMap = new HashMap();
        List<DataSet> asList = asList();
        int numOutcomes = numOutcomes();
        int numExamples = numExamples();
        for (DataSet dataSet : asList) {
            int outcome = dataSet.outcome();
            Queue queue = (Queue) hashMap.get(Integer.valueOf(outcome));
            if (queue == null) {
                queue = new ArrayDeque();
                hashMap.put(Integer.valueOf(outcome), queue);
            }
            queue.add(dataSet);
        }
        for (Integer num : hashMap.keySet()) {
            log.info("Label " + num + " has " + ((Queue) hashMap.get(num)).size() + " elements");
        }
        boolean z = true;
        int i = 0;
        while (i < numExamples) {
            if (z) {
                int i2 = 0;
                while (true) {
                    if (i2 >= numOutcomes) {
                        break;
                    }
                    Queue queue2 = (Queue) hashMap.get(Integer.valueOf(i2));
                    if (queue2 == null) {
                        z = false;
                        break;
                    }
                    DataSet dataSet2 = (DataSet) queue2.poll();
                    if (dataSet2 == null) {
                        z = false;
                        break;
                    } else {
                        addRow(dataSet2, i);
                        i++;
                        i2++;
                    }
                }
            } else {
                DataSet dataSet3 = null;
                Iterator it = hashMap.values().iterator();
                while (true) {
                    if (!it.hasNext()) {
                        break;
                    }
                    Queue queue3 = (Queue) it.next();
                    if (!queue3.isEmpty()) {
                        dataSet3 = (DataSet) queue3.poll();
                        break;
                    }
                }
                addRow(dataSet3, i);
            }
            i++;
        }
    }

    @Override // org.nd4j.linalg.dataset.api.DataSet
    public void addRow(DataSet dataSet, int i) {
        if (i > numExamples() || dataSet == null) {
            throw new IllegalArgumentException("Invalid index for adding a row");
        }
        getFeatures().putRow(i, dataSet.getFeatures());
        getLabels().putRow(i, dataSet.getLabels());
    }

    private int getLabel(DataSet dataSet) {
        return ((Float) dataSet.getLabels().max(Integer.MAX_VALUE).element()).intValue();
    }

    @Override // org.nd4j.linalg.dataset.api.DataSet
    public INDArray exampleSums() {
        return getFeatures().sum(1);
    }

    @Override // org.nd4j.linalg.dataset.api.DataSet
    public INDArray exampleMaxs() {
        return getFeatures().max(1);
    }

    @Override // org.nd4j.linalg.dataset.api.DataSet
    public INDArray exampleMeans() {
        return getFeatures().mean(1);
    }

    @Override // org.nd4j.linalg.dataset.api.DataSet
    public DataSet sample(int i) {
        return sample(i, new MersenneTwister(System.currentTimeMillis()));
    }

    @Override // org.nd4j.linalg.dataset.api.DataSet
    public DataSet sample(int i, RandomGenerator randomGenerator) {
        return sample(i, randomGenerator, false);
    }

    @Override // org.nd4j.linalg.dataset.api.DataSet
    public DataSet sample(int i, boolean z) {
        return sample(i, new MersenneTwister(System.currentTimeMillis()), z);
    }

    @Override // org.nd4j.linalg.dataset.api.DataSet
    public DataSet sample(int i, RandomGenerator randomGenerator, boolean z) {
        if (i >= numExamples()) {
            return this;
        }
        INDArray create = Nd4j.create(i, getFeatures().columns());
        INDArray create2 = Nd4j.create(i, numOutcomes());
        HashSet hashSet = new HashSet();
        for (int i2 = 0; i2 < i; i2++) {
            int nextInt = randomGenerator.nextInt(numExamples());
            if (!z) {
                while (hashSet.contains(Integer.valueOf(nextInt))) {
                    nextInt = randomGenerator.nextInt(numExamples());
                }
            }
            create.putRow(i2, get(nextInt).getFeatures());
            create2.putRow(i2, get(nextInt).getLabels());
        }
        return new DataSet(create, create2);
    }

    @Override // org.nd4j.linalg.dataset.api.DataSet
    public void roundToTheNearest(int i) {
        for (int i2 = 0; i2 < getFeatures().length(); i2++) {
            getFeatures().put(i2, Nd4j.scalar(MathUtils.roundDouble(((Double) getFeatures().getScalar(i2).element()).doubleValue(), i)));
        }
    }

    @Override // org.nd4j.linalg.dataset.api.DataSet
    public int numOutcomes() {
        return getLabels().columns();
    }

    @Override // org.nd4j.linalg.dataset.api.DataSet
    public int numExamples() {
        return getFeatures().rows();
    }

    public String toString() {
        StringBuilder sb = new StringBuilder();
        sb.append("===========INPUT===================\n").append(getFeatures().toString().replaceAll(FileManager.PATH_DELIMITER, "\n")).append("\n=================OUTPUT==================\n").append(getLabels().toString().replaceAll(FileManager.PATH_DELIMITER, "\n"));
        return sb.toString();
    }

    @Override // org.nd4j.linalg.dataset.api.DataSet
    public List<String> getLabelNames() {
        return this.labelNames;
    }

    @Override // org.nd4j.linalg.dataset.api.DataSet
    public void setLabelNames(List<String> list) {
        if (list == null || list.size() != numOutcomes()) {
            throw new IllegalArgumentException("Unable to applyTransformToDestination label names, does not match number of possible outcomes");
        }
        this.labelNames = list;
    }

    @Override // org.nd4j.linalg.dataset.api.DataSet
    public List<String> getColumnNames() {
        return this.columnNames;
    }

    @Override // org.nd4j.linalg.dataset.api.DataSet
    public void setColumnNames(List<String> list) {
        if (list.size() != numInputs()) {
            throw new IllegalArgumentException("Column names don't match input");
        }
        this.columnNames = list;
    }

    @Override // org.nd4j.linalg.dataset.api.DataSet, java.lang.Iterable
    public Iterator<DataSet> iterator() {
        return asList().iterator();
    }
}
