package org.nd4j.linalg.api.activation;

import org.junit.Assert;
import org.junit.Test;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/nd4j/linalg/api/activation/SoftMaxTest.class */
public abstract class SoftMaxTest {
    private static Logger log = LoggerFactory.getLogger(SoftMaxTest.class);

    @Test
    public void testSoftMax() {
        Nd4j.factory().setOrder('f');
        INDArray reshape = Nd4j.linspace(1, 6, 6).reshape(2, 3);
        INDArray iNDArray = (INDArray) Activations.softmax().apply(reshape);
        INDArray iNDArray2 = (INDArray) Activations.softMaxRows().apply(reshape);
        INDArray sum = iNDArray.sum(0);
        INDArray sum2 = iNDArray2.sum(1);
        Assert.assertEquals(3.0d, sum.sum(Integer.MAX_VALUE).getFloat(0), 0.1d);
        Assert.assertEquals(2.0d, sum2.sum(Integer.MAX_VALUE).getFloat(0), 0.1d);
    }

    @Test
    public void testSoftMaxCOrder() {
        Nd4j.factory().setOrder('c');
        INDArray reshape = Nd4j.linspace(1, 6, 6).reshape(2, 3);
        INDArray iNDArray = (INDArray) Activations.softmax().apply(reshape);
        INDArray iNDArray2 = (INDArray) Activations.softMaxRows().apply(reshape);
        INDArray sum = iNDArray.sum(0);
        INDArray sum2 = iNDArray2.sum(1);
        Assert.assertEquals(3.0d, sum.sum(Integer.MAX_VALUE).getFloat(0), 0.1d);
        Assert.assertEquals(2.0d, sum2.sum(Integer.MAX_VALUE).getFloat(0), 0.1d);
    }
}
