package org.nd4j.linalg.dataset.api;

import lombok.NonNull;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;

/* loaded from: input_file:org/nd4j/linalg/dataset/api/DataSetUtil.class */
public class DataSetUtil {
    public static INDArray tailor2d(@NonNull DataSet dataSet, boolean z) {
        if (dataSet == null) {
            throw new NullPointerException("dataSet");
        }
        return tailor2d(z ? dataSet.getFeatures() : dataSet.getLabels(), z ? dataSet.getFeaturesMaskArray() : dataSet.getLabelsMaskArray());
    }

    public static INDArray tailor2d(@NonNull INDArray iNDArray, INDArray iNDArray2) {
        if (iNDArray == null) {
            throw new NullPointerException("data");
        }
        switch (iNDArray.rank()) {
            case 1:
            case 2:
                return iNDArray;
            case 3:
                return tailor3d2d(iNDArray, iNDArray2);
            case 4:
                return tailor4d2d(iNDArray);
            default:
                throw new RuntimeException("Unsupported data rank");
        }
    }

    public static INDArray tailor3d2d(DataSet dataSet, boolean z) {
        return tailor3d2d(z ? dataSet.getFeatures() : dataSet.getLabels(), z ? dataSet.getFeaturesMaskArray() : dataSet.getLabelsMaskArray());
    }

    public static INDArray tailor3d2d(@NonNull INDArray iNDArray, INDArray iNDArray2) {
        if (iNDArray == null) {
            throw new NullPointerException("data");
        }
        int size = iNDArray.size(0);
        int size2 = iNDArray.size(1);
        int size3 = iNDArray.size(2);
        boolean z = iNDArray2 != null;
        INDArray create = Nd4j.create(size2, size3 * size);
        int tensorssAlongDimension = iNDArray.tensorssAlongDimension(2, 0);
        for (int i = 0; i < tensorssAlongDimension; i++) {
            INDArray tensorAlongDimension = iNDArray.tensorAlongDimension(i, 2, 0);
            if (z) {
                tensorAlongDimension.muli(iNDArray2);
            }
            create.putRow(i, Nd4j.toFlattened('c', tensorAlongDimension));
        }
        INDArray transpose = create.transpose();
        if (!z) {
            return transpose;
        }
        INDArray transpose2 = Nd4j.toFlattened('c', iNDArray2).transpose();
        INDArray create2 = Nd4j.create(transpose2.sumNumber().intValue(), size2);
        int i2 = 0;
        for (int i3 = 0; i3 < size; i3++) {
            for (int i4 = 0; i4 < size3; i4++) {
                if (transpose2.getInt((i3 * size3) + i4, 0) != 0) {
                    create2.putRow(i2, transpose.getRow((i3 * size3) + i4));
                    i2++;
                }
            }
        }
        return create2;
    }

    public static INDArray tailor4d2d(DataSet dataSet, boolean z) {
        return tailor4d2d(z ? dataSet.getFeatures() : dataSet.getLabels());
    }

    public static INDArray tailor4d2d(@NonNull INDArray iNDArray) {
        if (iNDArray == null) {
            throw new NullPointerException("data");
        }
        INDArray create = Nd4j.create(iNDArray.size(1), iNDArray.size(2) * iNDArray.size(3) * iNDArray.size(0));
        int tensorssAlongDimension = iNDArray.tensorssAlongDimension(3, 2, 0);
        for (int i = 0; i < tensorssAlongDimension; i++) {
            create.putRow(i, Nd4j.toFlattened(iNDArray.tensorAlongDimension(i, 3, 2, 0)));
        }
        return create.transposei();
    }
}
