package org.nd4j.linalg.dataset.api.iterator;

import java.util.ArrayList;
import java.util.List;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.DataSetPreProcessor;

/* loaded from: input_file:org/nd4j/linalg/dataset/api/iterator/KFoldIterator.class */
public class KFoldIterator implements DataSetIterator {
    private DataSet singleFold;
    private int k;
    private int batch;
    private int lastBatch;
    private int kCursor;
    private DataSet test;
    private DataSet train;
    protected DataSetPreProcessor preProcessor;

    public KFoldIterator(DataSet dataSet) {
        this(10, dataSet);
    }

    public KFoldIterator(int i, DataSet dataSet) {
        this.kCursor = 0;
        this.k = i;
        this.singleFold = dataSet.copy();
        if (i <= 1) {
            throw new IllegalArgumentException();
        }
        if (dataSet.numExamples() % i == 0) {
            this.batch = dataSet.numExamples() / i;
            this.lastBatch = dataSet.numExamples() / i;
        } else if (i != 2) {
            this.batch = dataSet.numExamples() / (i - 1);
            this.lastBatch = dataSet.numExamples() % (i - 1);
        } else {
            this.lastBatch = dataSet.numExamples() / 2;
            this.batch = this.lastBatch + 1;
        }
    }

    @Override // org.nd4j.linalg.dataset.api.iterator.DataSetIterator
    public DataSet next(int i) throws UnsupportedOperationException {
        return null;
    }

    @Override // org.nd4j.linalg.dataset.api.iterator.DataSetIterator
    public int totalExamples() {
        return this.singleFold.getLabels().size(0);
    }

    @Override // org.nd4j.linalg.dataset.api.iterator.DataSetIterator
    public int inputColumns() {
        return this.singleFold.getFeatures().size(1);
    }

    @Override // org.nd4j.linalg.dataset.api.iterator.DataSetIterator
    public int totalOutcomes() {
        return this.singleFold.getLabels().size(1);
    }

    @Override // org.nd4j.linalg.dataset.api.iterator.DataSetIterator
    public boolean resetSupported() {
        return true;
    }

    @Override // org.nd4j.linalg.dataset.api.iterator.DataSetIterator
    public boolean asyncSupported() {
        return false;
    }

    @Override // org.nd4j.linalg.dataset.api.iterator.DataSetIterator
    public void reset() {
        this.singleFold.shuffle();
        this.kCursor = 0;
    }

    @Override // org.nd4j.linalg.dataset.api.iterator.DataSetIterator
    public int batch() {
        return this.batch;
    }

    public int lastBatch() {
        return this.lastBatch;
    }

    @Override // org.nd4j.linalg.dataset.api.iterator.DataSetIterator
    public int cursor() {
        return this.kCursor;
    }

    @Override // org.nd4j.linalg.dataset.api.iterator.DataSetIterator
    public int numExamples() {
        return totalExamples();
    }

    @Override // org.nd4j.linalg.dataset.api.iterator.DataSetIterator
    public void setPreProcessor(DataSetPreProcessor dataSetPreProcessor) {
        this.preProcessor = dataSetPreProcessor;
    }

    @Override // org.nd4j.linalg.dataset.api.iterator.DataSetIterator
    public DataSetPreProcessor getPreProcessor() {
        return this.preProcessor;
    }

    @Override // org.nd4j.linalg.dataset.api.iterator.DataSetIterator
    public List<String> getLabels() {
        return this.singleFold.getLabelNamesList();
    }

    @Override // java.util.Iterator
    public boolean hasNext() {
        return this.kCursor < this.k;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // java.util.Iterator
    public DataSet next() {
        nextFold();
        return this.train;
    }

    @Override // java.util.Iterator
    public void remove() {
    }

    private void nextFold() {
        int i;
        int i2;
        if (this.kCursor == this.k - 1) {
            i = totalExamples() - this.lastBatch;
            i2 = totalExamples();
        } else {
            i = this.kCursor * this.batch;
            i2 = i + this.batch;
        }
        ArrayList arrayList = new ArrayList();
        if (i2 < totalExamples()) {
            arrayList.add((DataSet) this.singleFold.getRange(0, i));
            arrayList.add((DataSet) this.singleFold.getRange(i2, totalExamples()));
            this.train = DataSet.merge(arrayList);
        } else {
            this.train = (DataSet) this.singleFold.getRange(0, i);
        }
        this.test = (DataSet) this.singleFold.getRange(i, i2);
        this.kCursor++;
    }

    public DataSet testFold() {
        return this.test;
    }
}
