package com.github.stanfordfuturedata.momentsketch;

import com.github.stanfordfuturedata.momentsketch.optimizer.FunctionWithHessian;
import org.apache.commons.math3.util.FastMath;

/* loaded from: input_file:com/github/stanfordfuturedata/momentsketch/DMaxentLoss.class */
public class DMaxentLoss implements FunctionWithHessian {
    protected int dim;
    protected int nGrid;
    protected double[] d_mus;
    protected double[] xs;
    protected double[][] cpVals;
    protected double[] lambd;
    protected double[] weights;
    protected double[] mus;
    protected double[] grad;
    protected double[][] hess;

    public DMaxentLoss(double[] dArr, int i) {
        this.dim = dArr.length;
        this.nGrid = i;
        this.d_mus = dArr;
        this.xs = new double[i];
        for (int i2 = 0; i2 < i; i2++) {
            this.xs[i2] = ((i2 * 2.0d) / (i - 1)) - 1.0d;
        }
        this.cpVals = new double[2 * this.dim][i];
        for (int i3 = 0; i3 < i; i3++) {
            this.cpVals[0][i3] = 1.0d;
            this.cpVals[1][i3] = this.xs[i3];
        }
        for (int i4 = 2; i4 < 2 * this.dim; i4++) {
            for (int i5 = 0; i5 < i; i5++) {
                this.cpVals[i4][i5] = ((2.0d * this.xs[i5]) * this.cpVals[i4 - 1][i5]) - this.cpVals[i4 - 2][i5];
            }
        }
        int i6 = this.dim;
        this.weights = new double[i];
        this.mus = new double[2 * i6];
        this.grad = new double[i6];
        this.hess = new double[i6][i6];
    }

    public void setLambd(double[] dArr) {
        this.lambd = dArr;
    }

    @Override // com.github.stanfordfuturedata.momentsketch.optimizer.FunctionWithHessian
    public void computeOnlyValue(double[] dArr, double d) {
        computeAll(dArr, d);
    }

    @Override // com.github.stanfordfuturedata.momentsketch.optimizer.FunctionWithHessian
    public void computeAll(double[] dArr, double d) {
        setLambd(dArr);
        for (int i = 0; i < this.nGrid; i++) {
            double d2 = 0.0d;
            for (int i2 = 0; i2 < this.dim; i2++) {
                d2 += this.lambd[i2] * this.cpVals[i2][i];
            }
            this.weights[i] = FastMath.exp(d2);
        }
        for (int i3 = 0; i3 < 2 * this.dim; i3++) {
            double d3 = 0.0d;
            for (int i4 = 0; i4 < this.nGrid; i4++) {
                d3 += this.cpVals[i3][i4] * this.weights[i4];
            }
            this.mus[i3] = d3;
        }
        for (int i5 = 0; i5 < this.dim; i5++) {
            this.grad[i5] = this.mus[i5] - this.d_mus[i5];
        }
        for (int i6 = 0; i6 < this.dim; i6++) {
            for (int i7 = 0; i7 < this.dim; i7++) {
                this.hess[i6][i7] = 0.5d * (this.mus[i6 + i7] + this.mus[FastMath.abs(i6 - i7)]);
            }
        }
    }

    public double[] getWeights() {
        return this.weights;
    }

    @Override // com.github.stanfordfuturedata.momentsketch.optimizer.FunctionWithHessian
    public int dim() {
        return this.dim;
    }

    @Override // com.github.stanfordfuturedata.momentsketch.optimizer.FunctionWithHessian
    public double getValue() {
        double d = 0.0d;
        int length = this.d_mus.length;
        for (int i = 0; i < length; i++) {
            d += this.lambd[i] * this.d_mus[i];
        }
        return this.mus[0] - d;
    }

    @Override // com.github.stanfordfuturedata.momentsketch.optimizer.FunctionWithHessian
    public double[] getGradient() {
        return this.grad;
    }

    @Override // com.github.stanfordfuturedata.momentsketch.optimizer.FunctionWithHessian
    public double[][] getHessian() {
        return this.hess;
    }
}
