package org.nd4j.linalg.convolution.test;

import java.util.Arrays;
import org.junit.Assert;
import org.junit.Test;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.convolution.Convolution;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:WEB-INF/lib/nd4j-api-0.0.3.5.5.jar:org/nd4j/linalg/convolution/test/ConvolutionTests.class */
public abstract class ConvolutionTests {
    private static Logger log = LoggerFactory.getLogger((Class<?>) ConvolutionTests.class);

    @Test
    public void convNTest() {
        Assert.assertEquals(Nd4j.create(new double[]{10.0d, 16.0d, 22.0d, 28.0d, 34.0d, 40.0d}), Convolution.convn(Nd4j.linspace(1, 8, 8), Nd4j.linspace(1, 3, 3), Convolution.Type.VALID));
    }

    @Test
    public void testConv2d() {
        INDArray create = Nd4j.create(Nd4j.linspace(1, 8, 8).data(), new int[]{2, 2, 2});
        Assert.assertTrue(Arrays.equals(new int[]{3, 3}, Convolution.conv2d(create, create.dup(), Convolution.Type.FULL).shape()));
        INDArray create2 = Nd4j.create(Nd4j.linspace(1, 16, 16).data(), new int[]{2, 4, 2});
        Assert.assertTrue(Arrays.equals(new int[]{2, 4}, Convolution.conv2d(create2, create2.dup(), Convolution.Type.VALID).shape()));
    }

    /* JADX WARN: Type inference failed for: r0v1, types: [double[], double[][]] */
    /* JADX WARN: Type inference failed for: r0v4, types: [double[], double[][]] */
    @Test
    public void testConvolution() {
        INDArray create = Nd4j.create((double[][]) new double[]{new double[]{3.0d, 2.0d, 5.0d, 6.0d, 7.0d, 8.0d}, new double[]{5.0d, 4.0d, 2.0d, 10.0d, 8.0d, 1.0d}});
        INDArray create2 = Nd4j.create((double[][]) new double[]{new double[]{4.0d, 5.0d}, new double[]{1.0d, 2.0d}});
        log.info(Convolution.convn(create, create2, Convolution.Type.FULL).toString());
        log.info(Convolution.convn(create, create2, Convolution.Type.VALID).toString());
    }
}
