package cc.mallet.grmm.test;

import cc.mallet.grmm.types.Assignment;
import cc.mallet.grmm.types.UniNormalFactor;
import cc.mallet.grmm.types.Variable;
import cc.mallet.types.MatrixOps;
import cc.mallet.util.Randoms;
import gnu.trove.TDoubleArrayList;
import junit.framework.TestCase;
import junit.framework.TestSuite;
import junit.textui.TestRunner;

/* loaded from: input_file:WEB-INF/lib/mallet-2.0.7.jar:cc/mallet/grmm/test/TestUniNormalFactor.class */
public class TestUniNormalFactor extends TestCase {
    public TestUniNormalFactor(String str) {
        super(str);
    }

    public void testVarSet() {
        Variable variable = new Variable(-1);
        UniNormalFactor uniNormalFactor = new UniNormalFactor(variable, -1.0d, 1.5d);
        assertEquals(1, uniNormalFactor.varSet().size());
        assertTrue(uniNormalFactor.varSet().contains(variable));
    }

    public void testValue() {
        Variable variable = new Variable(-1);
        UniNormalFactor uniNormalFactor = new UniNormalFactor(variable, -1.0d, 2.0d);
        Assignment assignment = new Assignment(variable, -1.0d);
        assertEquals(0.2821d, uniNormalFactor.value(assignment), 1.0E-4d);
        assertEquals(Math.log(0.2821d), uniNormalFactor.logValue(assignment), 1.0E-4d);
        Assignment assignment2 = new Assignment(variable, 1.5d);
        assertEquals(0.05913d, uniNormalFactor.value(assignment2), 1.0E-4d);
        assertEquals(Math.log(0.05913d), uniNormalFactor.logValue(assignment2), 1.0E-4d);
    }

    public void testSample() {
        Variable variable = new Variable(-1);
        Randoms randoms = new Randoms(2343);
        UniNormalFactor uniNormalFactor = new UniNormalFactor(variable, -1.0d, 2.0d);
        TDoubleArrayList tDoubleArrayList = new TDoubleArrayList();
        for (int i = 0; i < 10000; i++) {
            tDoubleArrayList.add(uniNormalFactor.sample(randoms).getDouble(variable));
        }
        double[] nativeArray = tDoubleArrayList.toNativeArray();
        double mean = MatrixOps.mean(nativeArray);
        double stddev = MatrixOps.stddev(nativeArray);
        assertEquals(-1.0d, mean, 0.025d);
        assertEquals(Math.sqrt(2.0d), stddev, 0.01d);
    }

    public static TestSuite suite() {
        return new TestSuite((Class<?>) TestUniNormalFactor.class);
    }

    public static void main(String[] strArr) {
        TestSuite suite;
        if (strArr.length > 0) {
            suite = new TestSuite();
            for (String str : strArr) {
                suite.addTest(new TestUniNormalFactor(str));
            }
        } else {
            suite = suite();
        }
        TestRunner.run(suite);
    }
}
