package org.nd4j.linalg.solvers;

import org.apache.commons.math3.util.FastMath;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.nd4j.linalg.solvers.api.OptimizableByGradientValueMatrix;
import org.nd4j.linalg.util.LinAlgExceptions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:WEB-INF/lib/nd4j-api-0.0.3.5.5.jar:org/nd4j/linalg/solvers/VectorizedBackTrackLineSearchMinimum.class */
public class VectorizedBackTrackLineSearchMinimum {
    private static Logger logger;
    OptimizableByGradientValueMatrix function;
    final int maxIterations = 100;
    double stpmax = 100.0d;
    final double EPS = 3.0E-12d;
    private double relTolx = 1.0E-10d;
    private double absTolx = 1.0E-4d;
    final double ALF = 1.0E-4d;
    static final /* synthetic */ boolean $assertionsDisabled;

    public VectorizedBackTrackLineSearchMinimum(OptimizableByGradientValueMatrix optimizableByGradientValueMatrix) {
        this.function = optimizableByGradientValueMatrix;
    }

    public void setStpmax(double d) {
        this.stpmax = d;
    }

    public double getStpmax() {
        return this.stpmax;
    }

    public void setRelTolx(double d) {
        this.relTolx = d;
    }

    public void setAbsTolx(double d) {
        this.absTolx = d;
    }

    public double optimize(INDArray iNDArray, int i, INDArray iNDArray2, INDArray iNDArray3) {
        double d;
        INDArray dup = iNDArray3.dup();
        double d2 = 0.0d;
        double value = this.function.getValue();
        double d3 = value;
        if (logger.isDebugEnabled()) {
            logger.trace("ENTERING BACKTRACK\n");
            logger.trace("Entering BackTrackLnSrch, value=" + value + ",\ndirection.oneNorm:" + iNDArray.norm1(Integer.MAX_VALUE) + "  direction.infNorm:" + FastMath.max(Double.NEGATIVE_INFINITY, ((Double) Transforms.abs(iNDArray).max(Integer.MAX_VALUE).element()).doubleValue()));
        }
        LinAlgExceptions.assertValidNum(iNDArray2);
        ((Double) iNDArray.norm2(Integer.MAX_VALUE).element()).doubleValue();
        double dot = Nd4j.getBlasWrapper().dot(iNDArray2, iNDArray);
        logger.debug("slope = " + dot);
        INDArray create = Nd4j.create(iNDArray.length());
        for (int i2 = 0; i2 < iNDArray.length(); i2++) {
            create.putScalar(i2, Math.max(Math.abs(((Double) dup.getScalar(i2).element()).doubleValue()), 1.0d));
        }
        double doubleValue = this.relTolx / ((Double) Transforms.abs(iNDArray).div(create).max(Integer.MAX_VALUE).element()).doubleValue();
        double d4 = 1.0d;
        double d5 = 0.0d;
        int i3 = 0;
        while (i3 < 100) {
            this.function.setCurrentIteration(i);
            logger.trace("BackTrack loop iteration " + i3 + " : alam=" + d4 + " oldAlam=" + d5);
            logger.trace("before step, x.1norm: " + iNDArray3.norm1(Integer.MAX_VALUE) + "\nalam: " + d4 + "\noldAlam: " + d5);
            if (!$assertionsDisabled && d4 == d5) {
                throw new AssertionError("alam == oldAlam");
            }
            iNDArray3.addi(iNDArray.mul(Double.valueOf(d4 - d5)));
            logger.debug("after step, x.1norm: " + iNDArray3.norm1(Integer.MAX_VALUE));
            if (d4 < doubleValue || smallAbsDiff(dup, iNDArray3)) {
                this.function.setParameters(dup);
                logger.trace("EXITING BACKTRACK: Jump too small (alamin=" + doubleValue + "). Exiting and using xold. Value=" + this.function.getValue());
                return 0.0d;
            }
            this.function.setParameters(iNDArray3);
            d5 = d4;
            double value2 = this.function.getValue();
            logger.debug("value = " + value2);
            if (value2 >= value + (1.0E-4d * d4 * dot)) {
                logger.debug("EXITING BACKTRACK: value=" + value2);
                if (value2 < value) {
                    throw new IllegalStateException("Function did not increase: f=" + value2 + " < " + value + "=fold");
                }
                return d4;
            }
            if (Double.isInfinite(value2) || Double.isInfinite(d3)) {
                logger.warn("Value is infinite after jump " + d5 + ". f=" + value2 + ", f2=" + d3 + ". Scaling back step size...");
                d = 0.2d * d4;
                if (d4 < doubleValue) {
                    this.function.setParameters(dup);
                    logger.warn("EXITING BACKTRACK: Jump too small. Exiting and using xold. Value=" + this.function.getValue());
                    return 0.0d;
                }
            } else if (d4 == 1.0d) {
                d = (-dot) / (2.0d * ((value2 - value) - dot));
            } else {
                double d6 = (value2 - value) - (d4 * dot);
                double d7 = (d3 - value) - (d2 * dot);
                if (!$assertionsDisabled && d4 - d2 == 0.0d) {
                    throw new AssertionError("FAILURE: dividing by alam-alam2. alam=" + d4);
                }
                double pow = ((d6 / FastMath.pow(d4, 2)) - (d7 / FastMath.pow(d2, 2))) / (d4 - d2);
                double d8 = ((((-d2) * d6) / (d4 * d4)) + ((d4 * d7) / (d2 * d2))) / (d4 - d2);
                if (pow == 0.0d) {
                    d = (-dot) / (2.0d * d8);
                } else {
                    double d9 = (d8 * d8) - ((3.0d * pow) * dot);
                    d = d9 < 0.0d ? 0.5d * d4 : d8 <= 0.0d ? ((-d8) + FastMath.sqrt(d9)) / (3.0d * pow) : (-dot) / (d8 + FastMath.sqrt(d9));
                }
                if (d > 0.5d * d4) {
                    d = 0.5d * d4;
                }
            }
            d2 = d4;
            d3 = value2;
            logger.debug("tmplam:" + d);
            d4 = Math.max(d, 0.1d * d4);
            i3++;
        }
        if (i3 >= 100) {
            throw new IllegalStateException("Too many iterations.");
        }
        return 0.0d;
    }

    private boolean smallAbsDiff(INDArray iNDArray, INDArray iNDArray2) {
        for (int i = 0; i < iNDArray.length(); i++) {
            if (Math.abs(((Double) iNDArray.getScalar(i).sub(iNDArray2.getScalar(i)).element()).doubleValue()) > this.absTolx) {
                return false;
            }
        }
        return true;
    }

    static {
        $assertionsDisabled = !VectorizedBackTrackLineSearchMinimum.class.desiredAssertionStatus();
        logger = LoggerFactory.getLogger(VectorizedBackTrackLineSearchMinimum.class.getName());
    }
}
