package org.nd4j.linalg.convolution;

import org.nd4j.linalg.api.complex.IComplexNDArray;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.transforms.convolution.Col2Im;
import org.nd4j.linalg.api.ops.impl.transforms.convolution.Im2col;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/nd4j/linalg/convolution/Convolution.class */
public class Convolution {
    private static Logger log = LoggerFactory.getLogger(Convolution.class);

    /* loaded from: input_file:org/nd4j/linalg/convolution/Convolution$Type.class */
    public enum Type {
        FULL,
        VALID,
        SAME
    }

    private Convolution() {
    }

    public static INDArray col2im(INDArray iNDArray, int[] iArr, int[] iArr2, int i, int i2) {
        return col2im(iNDArray, iArr[0], iArr[1], iArr2[0], iArr2[1], i, i2);
    }

    public static INDArray col2im(INDArray iNDArray, int i, int i2, int i3, int i4, int i5, int i6) {
        if (iNDArray.rank() != 6) {
            throw new IllegalArgumentException("col2im input array must be rank 6");
        }
        return Nd4j.getExecutioner().exec(new Col2Im(iNDArray, i, i2, i3, i4, i5, i6)).z();
    }

    public static INDArray col2im(INDArray iNDArray, INDArray iNDArray2, int i, int i2, int i3, int i4, int i5, int i6) {
        if (iNDArray.rank() != 6) {
            throw new IllegalArgumentException("col2im input array must be rank 6");
        }
        if (iNDArray2.rank() != 4) {
            throw new IllegalArgumentException("col2im output array must be rank 4");
        }
        Nd4j.getExecutioner().exec(new Col2Im(iNDArray, i, i2, i3, i4, i5, i6, false, iNDArray2));
        return iNDArray2;
    }

    public static INDArray im2col(INDArray iNDArray, int[] iArr, int[] iArr2, int[] iArr3) {
        return im2col(iNDArray, iArr[0], iArr[1], iArr2[0], iArr2[1], iArr3[0], iArr3[1], 0, false);
    }

    public static INDArray im2col(INDArray iNDArray, int i, int i2, int i3, int i4, int i5, int i6, boolean z) {
        return Nd4j.getExecutioner().exec(new Im2col(iNDArray, i, i2, i3, i4, i5, i6, z)).z();
    }

    public static INDArray im2col(INDArray iNDArray, int i, int i2, int i3, int i4, int i5, int i6, boolean z, INDArray iNDArray2) {
        return Nd4j.getExecutioner().exec(new Im2col(iNDArray, i, i2, i3, i4, i5, i6, z, iNDArray2)).z();
    }

    public static INDArray im2col(INDArray iNDArray, int i, int i2, int i3, int i4, int i5, int i6, int i7, boolean z) {
        return Nd4j.getExecutioner().exec(new Im2col(iNDArray, i, i2, i3, i4, i5, i6, z)).z();
    }

    public static int outSize(int i, int i2, int i3, int i4, boolean z) {
        return z ? (((((i + (i4 * 2)) - i2) + i3) - 1) / i3) + 1 : (((i + (i4 * 2)) - i2) / i3) + 1;
    }

    public static INDArray conv2d(INDArray iNDArray, INDArray iNDArray2, Type type) {
        return Nd4j.getConvolution().conv2d(iNDArray, iNDArray2, type);
    }

    public static INDArray conv2d(IComplexNDArray iComplexNDArray, IComplexNDArray iComplexNDArray2, Type type) {
        return Nd4j.getConvolution().conv2d(iComplexNDArray, iComplexNDArray2, type);
    }

    public static INDArray convn(INDArray iNDArray, INDArray iNDArray2, Type type, int[] iArr) {
        return Nd4j.getConvolution().convn(iNDArray, iNDArray2, type, iArr);
    }

    public static IComplexNDArray convn(IComplexNDArray iComplexNDArray, IComplexNDArray iComplexNDArray2, Type type, int[] iArr) {
        return Nd4j.getConvolution().convn(iComplexNDArray, iComplexNDArray2, type, iArr);
    }

    public static INDArray convn(INDArray iNDArray, INDArray iNDArray2, Type type) {
        return Nd4j.getConvolution().convn(iNDArray, iNDArray2, type);
    }

    public static IComplexNDArray convn(IComplexNDArray iComplexNDArray, IComplexNDArray iComplexNDArray2, Type type) {
        return Nd4j.getConvolution().convn(iComplexNDArray, iComplexNDArray2, type);
    }
}
