package com.github.stanfordfuturedata.momentsketch.optimizer;

import com.github.stanfordfuturedata.momentsketch.MathUtil;
import org.apache.commons.math3.linear.Array2DRowRealMatrix;
import org.apache.commons.math3.linear.ArrayRealVector;
import org.apache.commons.math3.linear.CholeskyDecomposition;
import org.apache.commons.math3.linear.RealVector;
import org.apache.commons.math3.linear.SingularValueDecomposition;

/* loaded from: input_file:com/github/stanfordfuturedata/momentsketch/optimizer/NewtonOptimizer.class */
public class NewtonOptimizer implements GenericOptimizer {
    protected FunctionWithHessian P;
    private double alpha = 0.3d;
    private double beta = 0.25d;
    private boolean verbose = false;
    protected int maxIter = 200;
    protected int stepCount = 0;
    protected int dampedStepCount = 0;
    protected boolean converged = false;

    public NewtonOptimizer(FunctionWithHessian functionWithHessian) {
        this.P = functionWithHessian;
    }

    @Override // com.github.stanfordfuturedata.momentsketch.optimizer.GenericOptimizer
    public void setVerbose(boolean z) {
        this.verbose = z;
    }

    @Override // com.github.stanfordfuturedata.momentsketch.optimizer.GenericOptimizer
    public void setMaxIter(int i) {
        this.maxIter = i;
    }

    @Override // com.github.stanfordfuturedata.momentsketch.optimizer.GenericOptimizer
    public int getStepCount() {
        return this.stepCount;
    }

    @Override // com.github.stanfordfuturedata.momentsketch.optimizer.GenericOptimizer
    public boolean isConverged() {
        return this.converged;
    }

    public int getDampedStepCount() {
        return this.dampedStepCount;
    }

    @Override // com.github.stanfordfuturedata.momentsketch.optimizer.GenericOptimizer
    public FunctionWithHessian getP() {
        return this.P;
    }

    @Override // com.github.stanfordfuturedata.momentsketch.optimizer.GenericOptimizer
    public double[] solve(double[] dArr, double d) {
        RealVector solve;
        int dim = this.P.dim();
        double[] dArr2 = (double[]) dArr.clone();
        double d2 = d / 10.0d;
        this.P.computeAll(dArr2, d2);
        double d3 = d * d;
        this.converged = false;
        int i = 0;
        while (true) {
            if (i >= this.maxIter) {
                break;
            }
            double value = this.P.getValue();
            double[] gradient = this.P.getGradient();
            double[][] hessian = this.P.getHessian();
            double mse = MathUtil.getMSE(gradient);
            if (this.verbose) {
                System.out.println(String.format("Step: %3d GradRMSE: %10.5g P: %10.5g", Integer.valueOf(i), Double.valueOf(Math.sqrt(mse)), Double.valueOf(value)));
            }
            if (mse < d3) {
                this.converged = true;
                break;
            }
            Array2DRowRealMatrix array2DRowRealMatrix = new Array2DRowRealMatrix(hessian, false);
            try {
                solve = new CholeskyDecomposition(array2DRowRealMatrix, 0.0d, 0.0d).getSolver().solve(new ArrayRealVector(gradient));
            } catch (Exception e) {
                solve = new SingularValueDecomposition(array2DRowRealMatrix).getSolver().solve(new ArrayRealVector(gradient));
            }
            solve.mapMultiplyToSelf(-1.0d);
            double d4 = 0.0d;
            for (int i2 = 0; i2 < dim; i2++) {
                d4 += solve.getEntry(i2) * gradient[i2];
            }
            double d5 = 1.0d;
            double[] dArr3 = new double[dim];
            for (int i3 = 0; i3 < dim; i3++) {
                dArr3[i3] = dArr2[i3] + (1.0d * solve.getEntry(i3));
            }
            this.P.computeAll(dArr3, d2);
            if (d4 * d4 > d3) {
                while ((value + ((this.alpha * d5) * d4)) - this.P.getValue() < (-d) && d5 >= 0.001d) {
                    d5 *= this.beta;
                    for (int i4 = 0; i4 < dim; i4++) {
                        dArr3[i4] = dArr2[i4] + (d5 * solve.getEntry(i4));
                    }
                    this.P.computeAll(dArr3, d2);
                }
            }
            if (d5 < 1.0d) {
                this.dampedStepCount++;
            }
            if (this.verbose && d5 < 1.0d) {
                System.out.println("Step Size: " + d5);
            }
            System.arraycopy(dArr3, 0, dArr2, 0, dim);
            i++;
        }
        this.stepCount = i;
        return dArr2;
    }
}
