package edu.stanford.nlp.optimization;

import edu.stanford.nlp.math.ArrayMath;
import edu.stanford.nlp.optimization.AbstractStochasticCachingDiffFunction;
import edu.stanford.nlp.optimization.DiffFunction;
import edu.stanford.nlp.util.Timing;
import java.text.DecimalFormat;
import java.text.NumberFormat;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import java.util.Set;
import org.springframework.core.task.AsyncTaskExecutor;

/* loaded from: input_file:WEB-INF/lib/stanford-corenlp-3.4.1.jar:edu/stanford/nlp/optimization/SGDWithAdaGradAndFOBOS.class */
public class SGDWithAdaGradAndFOBOS<T extends DiffFunction> implements Minimizer<T>, HasEvaluators {
    protected double[] x;
    protected double initRate;
    protected double lambda;
    protected double alpha;
    protected boolean quiet;
    private static final int DEFAULT_NUM_PASSES = 50;
    protected final int numPasses;
    protected int bSize;
    private static final int DEFAULT_TUNING_SAMPLES = Integer.MAX_VALUE;
    private static final int DEFAULT_BATCH_SIZE = 1000;
    private double eps;
    private double TOL;
    public List<double[]> yList;
    public List<double[]> sList;
    public double[] diag;
    private int hessSampleSize;
    private double[] s;
    private double[] y;
    protected Random gen;
    protected long maxTime;
    private int evaluateIters;
    private Evaluator[] evaluators;
    private Prior prior;
    private boolean useEvalImprovement;
    private boolean useAvgImprovement;
    private boolean suppressTestPrompt;
    private int terminateOnEvalImprovementNumOfEpoch;
    private double bestEvalSoFar;
    private double[] xBest;
    private int noImproveItrCount;
    private boolean useAdaDelta;
    private boolean useAdaDiff;
    private double rho;
    private double[] sumGradSquare;
    private double[] prevGrad;
    private double[] prevDeltaX;
    private double[] sumDeltaXSquare;
    private static final NumberFormat nf = new DecimalFormat("0.000E0");

    /* loaded from: input_file:WEB-INF/lib/stanford-corenlp-3.4.1.jar:edu/stanford/nlp/optimization/SGDWithAdaGradAndFOBOS$Prior.class */
    public enum Prior {
        LASSO,
        RIDGE,
        GAUSSIAN,
        aeLASSO,
        gLASSO,
        sgLASSO,
        NONE
    }

    public void setHessSampleSize(int i) {
        this.hessSampleSize = i;
    }

    public void terminateOnEvalImprovement(boolean z) {
        this.useEvalImprovement = z;
    }

    public void terminateOnAvgImprovement(boolean z, double d) {
        this.useAvgImprovement = z;
        this.TOL = d;
    }

    public void suppressTestPrompt(boolean z) {
        this.suppressTestPrompt = z;
    }

    public void setTerminateOnEvalImprovementNumOfEpoch(int i) {
        this.terminateOnEvalImprovementNumOfEpoch = i;
    }

    public boolean toContinue(double[] dArr, double d) {
        if (d < this.bestEvalSoFar) {
            this.noImproveItrCount++;
            return this.noImproveItrCount <= this.terminateOnEvalImprovementNumOfEpoch;
        }
        this.bestEvalSoFar = d;
        this.noImproveItrCount = 0;
        if (this.xBest == null) {
            this.xBest = Arrays.copyOf(dArr, dArr.length);
            return true;
        }
        System.arraycopy(dArr, 0, this.xBest, 0, dArr.length);
        return true;
    }

    private static Prior getPrior(String str) {
        if (str.equals("none")) {
            return Prior.NONE;
        }
        if (str.equals("lasso")) {
            return Prior.LASSO;
        }
        if (str.equals("ridge")) {
            return Prior.RIDGE;
        }
        if (str.equals("gaussian")) {
            return Prior.GAUSSIAN;
        }
        if (str.equals("ae-lasso")) {
            return Prior.aeLASSO;
        }
        if (str.equals("g-lasso")) {
            return Prior.gLASSO;
        }
        if (str.equals("sg-lasso")) {
            return Prior.sgLASSO;
        }
        throw new IllegalArgumentException("prior type " + str + " not recognized; supported priors are: lasso, ridge, gaussian, ae-lasso, g-lasso, and sg-lasso");
    }

    public SGDWithAdaGradAndFOBOS(double d, double d2, int i) {
        this(d, d2, i, -1);
    }

    public SGDWithAdaGradAndFOBOS(double d, double d2, int i, int i2) {
        this(d, d2, i, i2, "lasso", 1.0d, false, false, 0.001d, 0.95d);
    }

    public SGDWithAdaGradAndFOBOS(double d, double d2, int i, int i2, String str, double d3, boolean z, boolean z2, double d4, double d5) {
        this.alpha = 1.0d;
        this.quiet = false;
        this.bSize = 1;
        this.eps = 0.001d;
        this.TOL = 1.0E-4d;
        this.yList = null;
        this.sList = null;
        this.hessSampleSize = -1;
        this.y = null;
        this.gen = new Random(1L);
        this.maxTime = AsyncTaskExecutor.TIMEOUT_INDEFINITE;
        this.evaluateIters = 0;
        this.prior = Prior.LASSO;
        this.useEvalImprovement = false;
        this.useAvgImprovement = false;
        this.suppressTestPrompt = false;
        this.terminateOnEvalImprovementNumOfEpoch = 1;
        this.bestEvalSoFar = Double.NEGATIVE_INFINITY;
        this.noImproveItrCount = 0;
        this.useAdaDelta = false;
        this.useAdaDiff = false;
        this.rho = 0.95d;
        this.initRate = d;
        this.prior = getPrior(str);
        this.bSize = i2;
        this.lambda = d2;
        this.eps = d4;
        this.rho = d5;
        this.useAdaDelta = z;
        this.useAdaDiff = z2;
        this.alpha = d3;
        if (i >= 0) {
            this.numPasses = i;
        } else {
            this.numPasses = 50;
            sayln("  SGDWithAdaGradAndFOBOS: numPasses=" + i + ", defaulting to " + this.numPasses);
        }
    }

    public void shutUp() {
        this.quiet = true;
    }

    protected String getName() {
        return "SGDWithAdaGradAndFOBOS" + this.bSize + "_lambda" + nf.format(this.lambda) + "_alpha" + nf.format(this.alpha);
    }

    @Override // edu.stanford.nlp.optimization.HasEvaluators
    public void setEvaluators(int i, Evaluator[] evaluatorArr) {
        this.evaluateIters = i;
        this.evaluators = evaluatorArr;
    }

    private static double getNorm(double[] dArr) {
        double d = 0.0d;
        for (int i = 0; i < dArr.length; i++) {
            d += dArr[i] * dArr[i];
        }
        return Math.sqrt(d);
    }

    private double doEvaluation(double[] dArr) {
        if (this.evaluators == null) {
            return Double.NEGATIVE_INFINITY;
        }
        double d = Double.NEGATIVE_INFINITY;
        for (Evaluator evaluator : this.evaluators) {
            if (!this.suppressTestPrompt) {
                sayln("  Evaluating: " + evaluator.toString());
            }
            double evaluate = evaluator.evaluate(dArr);
            if (evaluate != Double.NEGATIVE_INFINITY) {
                d = evaluate;
            }
        }
        return d;
    }

    private static double pospart(double d) {
        if (d > 0.0d) {
            return d;
        }
        return 0.0d;
    }

    private double computeLearningRate(int i, double d) {
        double sqrt;
        double d2 = d - this.prevGrad[i];
        if (this.useAdaDelta) {
            double d3 = this.prevDeltaX[i];
            this.sumDeltaXSquare[i] = (this.sumDeltaXSquare[i] * this.rho) + ((1.0d - this.rho) * d3 * d3);
            if (this.useAdaDiff) {
                this.sumGradSquare[i] = (this.sumGradSquare[i] * this.rho) + ((1.0d - this.rho) * d2 * d2);
            } else {
                this.sumGradSquare[i] = (this.sumGradSquare[i] * this.rho) + ((1.0d - this.rho) * d * d);
            }
            sqrt = Math.sqrt(this.sumDeltaXSquare[i] + this.eps) / Math.sqrt(this.sumGradSquare[i] + this.eps);
        } else {
            if (this.useAdaDiff) {
                double[] dArr = this.sumGradSquare;
                dArr[i] = dArr[i] + (d2 * d2);
            } else {
                double[] dArr2 = this.sumGradSquare;
                dArr2[i] = dArr2[i] + (d * d);
            }
            sqrt = this.initRate / Math.sqrt(this.sumGradSquare[i] + this.eps);
        }
        return sqrt;
    }

    private void updateX(double[] dArr, int i, double d) {
        this.prevDeltaX[i] = d - dArr[i];
        dArr[i] = d;
    }

    @Override // edu.stanford.nlp.optimization.Minimizer
    public double[] minimize(DiffFunction diffFunction, double d, double[] dArr) {
        return minimize(diffFunction, d, dArr, -1);
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // edu.stanford.nlp.optimization.Minimizer
    public double[] minimize(DiffFunction diffFunction, double d, double[] dArr, int i) {
        Set hashSet;
        int i2 = 0;
        sayln("Using lambda=" + this.lambda);
        if (diffFunction instanceof AbstractStochasticCachingDiffUpdateFunction) {
            AbstractStochasticCachingDiffUpdateFunction abstractStochasticCachingDiffUpdateFunction = (AbstractStochasticCachingDiffUpdateFunction) diffFunction;
            abstractStochasticCachingDiffUpdateFunction.sampleMethod = AbstractStochasticCachingDiffFunction.SamplingMethod.Shuffled;
            i2 = abstractStochasticCachingDiffUpdateFunction.dataDimension();
            if (this.bSize > i2) {
                System.err.println("WARNING: Total number of samples=" + i2 + " is smaller than requested batch size=" + this.bSize + "!!!");
                this.bSize = i2;
                sayln("Using batch size=" + this.bSize);
            }
            if (this.bSize <= 0) {
                System.err.println("WARNING: Requested batch size=" + this.bSize + " <= 0 !!!");
                this.bSize = i2;
                sayln("Using batch size=" + this.bSize);
            }
        }
        this.x = new double[dArr.length];
        double[] dArr2 = null;
        double[] dArr3 = null;
        this.sumGradSquare = new double[dArr.length];
        this.prevGrad = new double[dArr.length];
        this.prevDeltaX = new double[dArr.length];
        if (this.useAdaDelta) {
            this.sumDeltaXSquare = new double[dArr.length];
            if (this.prior != Prior.NONE && this.prior != Prior.GAUSSIAN) {
                throw new UnsupportedOperationException("useAdaDelta is currently only supported for Prior.NONE or Prior.GAUSSIAN");
            }
        }
        int[][] iArr = (int[][]) null;
        if (this.prior != Prior.LASSO && this.prior != Prior.NONE) {
            dArr2 = new double[dArr.length];
            dArr3 = new double[dArr.length];
        }
        if (this.prior != Prior.LASSO && this.prior != Prior.RIDGE && this.prior != Prior.GAUSSIAN) {
            if (!(diffFunction instanceof HasFeatureGrouping)) {
                throw new UnsupportedOperationException("prior is specified to be ae-lasso or g-lasso, but function does not support feature grouping");
            }
            iArr = ((HasFeatureGrouping) diffFunction).getFeatureGrouping();
        }
        double[] dArr4 = this.prior == Prior.sgLASSO ? new double[dArr.length] : null;
        System.arraycopy(dArr, 0, this.x, 0, this.x.length);
        int i3 = 1;
        if ((diffFunction instanceof AbstractStochasticCachingDiffUpdateFunction) && i2 > 0) {
            i3 = i2 / this.bSize;
        }
        if (!(i > 0 || this.numPasses > 0)) {
            throw new UnsupportedOperationException("No maximum number of iterations has been specified.");
        }
        int max = Math.max(i, this.numPasses * i3);
        sayln("       Batch size of: " + this.bSize);
        sayln("       Data dimension of: " + i2);
        sayln("       Batches per pass through data:  " + i3);
        sayln("       Number of passes is = " + this.numPasses);
        sayln("       Max iterations is = " + max);
        Timing timing = new Timing();
        Timing timing2 = new Timing();
        timing.start();
        timing2.start();
        int i4 = 0;
        ArrayList arrayList = null;
        double d2 = 0.0d;
        int i5 = 0;
        while (true) {
            if (i5 >= this.numPasses) {
                break;
            }
            double d3 = Double.NEGATIVE_INFINITY;
            if (i5 > 0 && this.evaluateIters > 0 && i5 % this.evaluateIters == 0) {
                d3 = doEvaluation(this.x);
                if (this.useEvalImprovement && !toContinue(this.x, d3)) {
                    break;
                }
            }
            double d4 = Double.NEGATIVE_INFINITY;
            double d5 = Double.NEGATIVE_INFINITY;
            say("Iter: " + i4 + " pass " + i5 + " batch 1 ... ");
            int i6 = 0;
            int i7 = 0;
            for (int i8 = 0; i8 < i3; i8++) {
                i4++;
                double[] dArr5 = null;
                if (diffFunction instanceof AbstractStochasticCachingDiffUpdateFunction) {
                    AbstractStochasticCachingDiffUpdateFunction abstractStochasticCachingDiffUpdateFunction2 = (AbstractStochasticCachingDiffUpdateFunction) diffFunction;
                    if (this.bSize == i2) {
                        d4 = abstractStochasticCachingDiffUpdateFunction2.valueAt(this.x);
                        dArr5 = abstractStochasticCachingDiffUpdateFunction2.getDerivative();
                        d5 = d4 - d2;
                        d2 = d4;
                        if (arrayList == null) {
                            arrayList = new ArrayList();
                        }
                        arrayList.add(Double.valueOf(d4));
                    } else {
                        abstractStochasticCachingDiffUpdateFunction2.calculateStochasticGradient(this.x, this.bSize);
                        dArr5 = abstractStochasticCachingDiffUpdateFunction2.getDerivative();
                    }
                } else if (diffFunction instanceof AbstractCachingDiffFunction) {
                    dArr5 = ((AbstractCachingDiffFunction) diffFunction).derivativeAt(this.x);
                }
                if (this.prior == Prior.NONE || this.prior == Prior.GAUSSIAN) {
                    for (int i9 = 0; i9 < this.x.length; i9++) {
                        double d6 = dArr5[i9];
                        updateX(this.x, i9, this.x[i9] - (computeLearningRate(i9, d6) * d6));
                    }
                } else if (this.prior == Prior.LASSO || this.prior == Prior.RIDGE) {
                    double d7 = 0.0d;
                    if (diffFunction instanceof HasRegularizerParamRange) {
                        hashSet = ((HasRegularizerParamRange) diffFunction).getRegularizerParamRange(this.x);
                    } else {
                        hashSet = new HashSet();
                        for (int i10 = 0; i10 < this.x.length; i10++) {
                            hashSet.add(Integer.valueOf(i10));
                        }
                    }
                    Iterator it = hashSet.iterator();
                    while (it.hasNext()) {
                        int intValue = ((Integer) it.next()).intValue();
                        double d8 = dArr5[intValue];
                        double computeLearningRate = computeLearningRate(intValue, d8);
                        double d9 = this.x[intValue] - (computeLearningRate * d8);
                        double d10 = computeLearningRate * this.lambda;
                        if (this.prior == Prior.LASSO) {
                            double signum = Math.signum(d9) * pospart(Math.abs(d9) - d10);
                            updateX(this.x, intValue, signum);
                            if (signum != 0.0d) {
                                i6++;
                            }
                        } else if (this.prior == Prior.RIDGE) {
                            d7 += d9 * d9;
                            dArr2[intValue] = d9;
                            dArr3[intValue] = computeLearningRate;
                        }
                    }
                    if (this.prior == Prior.RIDGE) {
                        double sqrt = Math.sqrt(d7);
                        for (int i11 = 0; i11 < dArr2.length; i11++) {
                            double pospart = dArr2[i11] * pospart(1.0d - ((dArr3[i11] * this.lambda) / sqrt));
                            updateX(this.x, i11, pospart);
                            if (pospart != 0.0d) {
                                i6++;
                            }
                        }
                    }
                } else {
                    for (int[] iArr2 : iArr) {
                        double d11 = 0.0d;
                        double d12 = 0.0d;
                        double length = iArr2.length;
                        double log = Math.log(length);
                        for (int i12 : iArr2) {
                            double d13 = dArr5[i12];
                            double computeLearningRate2 = computeLearningRate(i12, d13);
                            double d14 = this.x[i12] - (computeLearningRate2 * d13);
                            d11 += d14 * d14;
                            d12 += Math.abs(d14);
                            dArr2[i12] = d14;
                            dArr3[i12] = computeLearningRate2;
                        }
                        if (this.prior == Prior.gLASSO) {
                            double sqrt2 = Math.sqrt(d11);
                            boolean z = false;
                            for (int i13 : iArr2) {
                                double pospart2 = dArr2[i13] * pospart(1.0d - (((dArr3[i13] * this.lambda) * log) / sqrt2));
                                updateX(this.x, i13, pospart2);
                                if (pospart2 != 0.0d) {
                                    i6++;
                                    z = true;
                                }
                            }
                            if (z) {
                                i7++;
                            }
                        } else if (this.prior == Prior.aeLASSO) {
                            int i14 = 0;
                            boolean z2 = false;
                            for (int i15 : iArr2) {
                                double signum2 = Math.signum(dArr2[i15]) * pospart(Math.abs(dArr2[i15]) - (((dArr3[i15] * this.lambda) / (1.0d + ((dArr3[i15] * this.lambda) * length))) * d12));
                                updateX(this.x, i15, signum2);
                                if (signum2 != 0.0d) {
                                    i6++;
                                    i14++;
                                    z2 = true;
                                }
                            }
                            if (z2) {
                                i7++;
                            }
                        } else if (this.prior == Prior.sgLASSO) {
                            double d15 = 0.0d;
                            for (int i16 : iArr2) {
                                double signum3 = Math.signum(dArr2[i16]) * pospart(Math.abs(dArr2[i16]) - ((dArr3[i16] * this.alpha) * this.lambda));
                                dArr4[i16] = signum3;
                                d15 += signum3 * signum3;
                            }
                            double sqrt3 = Math.sqrt(d15);
                            int i17 = 0;
                            boolean z3 = false;
                            for (int i18 : iArr2) {
                                double pospart3 = dArr4[i18] * pospart(1.0d - ((((dArr3[i18] * (1.0d - this.alpha)) * this.lambda) * log) / sqrt3));
                                updateX(this.x, i18, pospart3);
                                if (pospart3 != 0.0d) {
                                    i6++;
                                    i17++;
                                    z3 = true;
                                }
                            }
                            if (z3) {
                                i7++;
                            }
                        }
                    }
                }
                for (int i19 = 0; i19 < this.x.length; i19++) {
                    this.prevGrad[i19] = dArr5[i19];
                }
            }
            try {
                ArrayMath.assertFinite(this.x, "x");
                sayln(String.valueOf(i3) + ", n0-fCount:" + i6 + ((this.prior == Prior.LASSO || this.prior == Prior.RIDGE) ? "" : ", n0-gCount:" + i7) + (d3 != Double.NEGATIVE_INFINITY ? ", evalScore:" + d3 : "") + (d4 != Double.NEGATIVE_INFINITY ? ", obj_val:" + nf.format(d4) + ", obj_delta:" + d5 : ""));
                if (arrayList != null && this.useAvgImprovement && i4 > 5) {
                    int size = arrayList.size();
                    if (Math.abs((((size >= 10 ? (Double) arrayList.get(size - 10) : (Double) arrayList.get(0)).doubleValue() - d4) / (size >= 10 ? 10 : size)) / d4) < this.TOL) {
                        sayln("Online Optmization completed, due to average improvement: | newest_val - previous_val | / |newestVal| < TOL ");
                        break;
                    }
                }
                if (i4 >= max) {
                    sayln("Online Optimization complete.  Stopped after max iterations");
                    break;
                }
                if (timing.report() >= this.maxTime) {
                    sayln("Online Optimization complete.  Stopped after max time");
                    break;
                }
                i5++;
            } catch (ArrayMath.InvalidElementException e) {
                System.err.println(e.toString());
                for (int i20 = 0; i20 < this.x.length; i20++) {
                    this.x[i20] = Double.NaN;
                }
            }
        }
        if (this.evaluateIters > 0) {
            sayln("final evalScore is: " + (this.useEvalImprovement ? doEvaluation(this.xBest) : doEvaluation(this.x)));
        }
        sayln("Completed in: " + Timing.toSecondsString(timing.report()) + " s");
        return this.useEvalImprovement ? this.xBest : this.x;
    }

    protected void sayln(String str) {
        if (this.quiet) {
            return;
        }
        System.err.println(str);
    }

    protected void say(String str) {
        if (this.quiet) {
            return;
        }
        System.err.print(str);
    }
}
