package org.nd4j.linalg.util;

import java.util.ArrayList;
import java.util.Arrays;

/* loaded from: input_file:WEB-INF/lib/nd4j-api-0.0.3.5.5.jar:org/nd4j/linalg/util/Shape.class */
public class Shape {
    public static int[] squeeze(int[] iArr, int[] iArr2) {
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < iArr.length; i++) {
            if (iArr[i] != 1) {
                arrayList.add(Integer.valueOf(iArr[i]));
            }
        }
        return ArrayUtil.toArray(arrayList);
    }

    public static int[] sizeForAxes(int[] iArr, int[] iArr2) {
        int[] iArr3 = new int[iArr.length];
        for (int i = 0; i < iArr.length; i++) {
            iArr3[i] = iArr2[iArr[i]];
        }
        return iArr3;
    }

    public static boolean isVector(int[] iArr) {
        if (iArr.length > 2 || iArr.length < 1) {
            return false;
        }
        int prod = ArrayUtil.prod(iArr);
        return iArr[0] == prod || iArr[1] == prod;
    }

    public static boolean isMatrix(int[] iArr) {
        return iArr.length == 2 && !isVector(iArr);
    }

    public static int[] squeeze(int[] iArr) {
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < iArr.length; i++) {
            if (iArr[i] != 1) {
                arrayList.add(Integer.valueOf(iArr[i]));
            }
        }
        return ArrayUtil.toArray(arrayList);
    }

    public static int nonZeroDimension(int[] iArr) {
        return (iArr[0] != 1 || iArr.length <= 1) ? iArr[0] : iArr[1];
    }

    public static boolean shapeEquals(int[] iArr, int[] iArr2) {
        return (isColumnVectorShape(iArr) && isColumnVectorShape(iArr2)) ? Arrays.equals(iArr, iArr2) : (isRowVectorShape(iArr) && isRowVectorShape(iArr2)) ? Arrays.equals(squeeze(iArr), squeeze(iArr2)) : scalarEquals(iArr, iArr2) || Arrays.equals(iArr, iArr2);
    }

    public static boolean scalarEquals(int[] iArr, int[] iArr2) {
        return iArr.length == 0 ? iArr2.length == 1 && iArr2[0] == 1 : iArr2.length == 0 && iArr.length == 1 && iArr[0] == 1;
    }

    public static boolean isRowVectorShape(int[] iArr) {
        return (iArr.length == 2 && iArr[0] == 1) || iArr.length == 1;
    }

    public static boolean isColumnVectorShape(int[] iArr) {
        return iArr.length == 2 && iArr[1] == 1;
    }

    public static boolean squeezeEquals(int[] iArr, int[] iArr2) {
        int[] squeeze = squeeze(iArr);
        int[] squeeze2 = squeeze(iArr2);
        return scalarEquals(squeeze, squeeze2) || Arrays.equals(squeeze, squeeze2);
    }
}
