package cc.mallet.optimize.tests;

import cc.mallet.optimize.ConjugateGradient;
import cc.mallet.optimize.GradientAscent;
import cc.mallet.optimize.LimitedMemoryBFGS;
import cc.mallet.optimize.Optimizable;
import cc.mallet.optimize.OrthantWiseLimitedMemoryBFGS;
import junit.framework.TestCase;
import junit.framework.TestSuite;
import junit.textui.TestRunner;

/* loaded from: input_file:WEB-INF/lib/mallet-2.0.7.jar:cc/mallet/optimize/tests/TestOptimizer.class */
public class TestOptimizer extends TestCase {

    /* loaded from: input_file:WEB-INF/lib/mallet-2.0.7.jar:cc/mallet/optimize/tests/TestOptimizer$SimplePoly.class */
    static class SimplePoly implements Optimizable.ByGradientValue {
        double[] params = new double[1];

        SimplePoly() {
        }

        @Override // cc.mallet.optimize.Optimizable
        public void getParameters(double[] dArr) {
            dArr[0] = this.params[0];
        }

        @Override // cc.mallet.optimize.Optimizable
        public int getNumParameters() {
            return 1;
        }

        @Override // cc.mallet.optimize.Optimizable
        public double getParameter(int i) {
            return this.params[0];
        }

        @Override // cc.mallet.optimize.Optimizable
        public void setParameters(double[] dArr) {
            this.params[0] = dArr[0];
        }

        @Override // cc.mallet.optimize.Optimizable
        public void setParameter(int i, double d) {
            this.params[i] = d;
        }

        @Override // cc.mallet.optimize.Optimizable.ByGradientValue
        public double getValue() {
            System.out.println("param = " + this.params[0] + " value = " + (((((-3.0d) * this.params[0]) * this.params[0]) + (5.0d * this.params[0])) - 2.0d));
            return ((((-3.0d) * this.params[0]) * this.params[0]) + (5.0d * this.params[0])) - 2.0d;
        }

        @Override // cc.mallet.optimize.Optimizable.ByGradientValue
        public void getValueGradient(double[] dArr) {
            dArr[0] = ((-6.0d) * this.params[0]) + 5.0d;
        }
    }

    public TestOptimizer(String str) {
        super(str);
    }

    public void testGradientAscent() {
        SimplePoly simplePoly = new SimplePoly();
        new GradientAscent(simplePoly).optimize();
        assertEquals(0.8333333333333334d, simplePoly.params[0], 0.001d);
    }

    public void testLinearLBFGS() {
        SimplePoly simplePoly = new SimplePoly();
        new LimitedMemoryBFGS(simplePoly).optimize();
        assertEquals(0.8333333333333334d, simplePoly.params[0], 0.001d);
    }

    public void testOrthantWiseLBFGSWithoutL1() {
        SimplePoly simplePoly = new SimplePoly();
        new OrthantWiseLimitedMemoryBFGS(simplePoly).optimize();
        assertEquals(0.8333333333333334d, simplePoly.params[0], 0.001d);
    }

    public void testOrthantWiseLBFGSWithL1() {
        SimplePoly simplePoly = new SimplePoly();
        new OrthantWiseLimitedMemoryBFGS(simplePoly, 3.0d).optimize();
        assertEquals(0.3333333333333333d, simplePoly.params[0], 0.001d);
    }

    public void testConjugateGradient() {
        SimplePoly simplePoly = new SimplePoly();
        new ConjugateGradient(simplePoly).optimize();
        assertEquals(0.8333333333333334d, simplePoly.params[0], 0.001d);
    }

    public static TestSuite suite() {
        return new TestSuite((Class<?>) TestOptimizer.class);
    }

    public static void main(String[] strArr) {
        TestRunner.run(suite());
    }
}
