/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.learning.legacy;

import java.io.Serializable;
import java.util.Arrays;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.ops.transforms.Transforms;

public class AdaGrad
implements Serializable {
    public static final double DEFAULT_ADAGRAD_EPSILON = 1.0E-6;
    public INDArray historicalGradient;
    public int[] shape;
    protected double learningRate = 0.1;
    protected int numIterations = 0;
    private double epsilon = 1.0E-6;
    private char gradientReshapeOrder;

    public int stateSizeForInputSize(int inputSize) {
        return inputSize;
    }

    public void setStateViewArray(INDArray viewArray, int[] gradientShape, char gradientOrder, boolean initialize) {
        if (!viewArray.isRowVector()) {
            throw new IllegalArgumentException("Invalid input: expect row vector input");
        }
        if (initialize) {
            viewArray.assign(this.epsilon);
        }
        this.historicalGradient = viewArray;
        this.historicalGradient = Shape.newShapeNoCopy(this.historicalGradient, gradientShape, gradientOrder == 'f');
        if (this.historicalGradient == null) {
            throw new IllegalStateException("Could not correctly reshape gradient view array");
        }
        this.gradientReshapeOrder = gradientOrder;
    }

    public AdaGrad(int rows, int cols, double learningRate) {
        this.shape = new int[]{rows, cols};
        this.learningRate = learningRate;
    }

    public AdaGrad(int rows, int cols) {
        this(rows, cols, 0.1);
    }

    public AdaGrad(int[] shape, double learningRate) {
        this.shape = shape;
        this.learningRate = learningRate;
    }

    public AdaGrad(double learningRate) {
        this.learningRate = learningRate;
    }

    public AdaGrad(double learningRate, double epsilon) {
        this.learningRate = learningRate;
        this.epsilon = epsilon;
    }

    public void update(Object ... args) {
        if (args.length > 0) {
            this.learningRate = (Double)args[0];
        }
    }

    public INDArray getGradient(INDArray gradient, int iteration) {
        if (this.historicalGradient == null) {
            throw new IllegalStateException("Updater has not been initialized with view state");
        }
        this.historicalGradient.addi(gradient.mul(gradient));
        INDArray sqrtHistory = Transforms.sqrt(this.historicalGradient.dup(this.gradientReshapeOrder), false).addi(this.epsilon);
        INDArray ret = gradient.muli(sqrtHistory.rdivi(this.learningRate));
        ++this.numIterations;
        return ret;
    }

    public double getGradient(double gradient, int column, int[] shape) {
        boolean historicalInitialized = false;
        if (this.historicalGradient == null) {
            this.historicalGradient = Nd4j.ones(shape);
            historicalInitialized = true;
        }
        double sqrtHistory = !historicalInitialized ? Math.sqrt(this.historicalGradient.getDouble(column)) : this.historicalGradient.getDouble(column);
        double learningRates = this.learningRate / (sqrtHistory + this.epsilon);
        double adjustedGradient = gradient * learningRates;
        this.historicalGradient.putScalar(column, this.historicalGradient.getDouble(column) + gradient * gradient);
        ++this.numIterations;
        return adjustedGradient;
    }

    public INDArray getGradient(INDArray gradient, int slice, int[] shape) {
        INDArray learningRates;
        boolean historicalInitialized = false;
        if (this.historicalGradient == null) {
            this.historicalGradient = Nd4j.zeros(shape).add(this.epsilon);
            historicalInitialized = true;
        } else if (!this.historicalGradient.isVector() && this.historicalGradient.slice(slice).length() != gradient.length()) {
            throw new IllegalArgumentException("Illegal gradient");
        }
        INDArray sqrtHistory = this.historicalGradient.isVector() ? Transforms.sqrt(this.historicalGradient) : (!historicalInitialized ? Transforms.sqrt(this.historicalGradient.slice(slice)) : this.historicalGradient);
        try {
            learningRates = sqrtHistory.rdivi(this.learningRate);
        }
        catch (ArithmeticException ae) {
            learningRates = sqrtHistory.rdivi(this.learningRate + this.epsilon);
        }
        if (gradient.length() != learningRates.length()) {
            gradient.muli(learningRates.slice(slice));
        } else {
            gradient.muli(learningRates);
        }
        this.historicalGradient.slice(slice).addi(gradient.mul(gradient));
        ++this.numIterations;
        return gradient;
    }

    public AdaGrad createSubset(int index) {
        INDArray slice;
        if (this.historicalGradient == null) {
            this.historicalGradient = Nd4j.ones(this.shape);
        }
        if (Shape.isMatrix(this.shape)) {
            INDArray slice2;
            AdaGrad a = new AdaGrad(1, this.historicalGradient.columns());
            a.historicalGradient = slice2 = this.historicalGradient.slice(index).dup();
            a.setLearningRate(this.learningRate);
            return a;
        }
        AdaGrad a = new AdaGrad(1, 1);
        a.historicalGradient = slice = Nd4j.scalar(this.historicalGradient.getDouble(index));
        a.setLearningRate(this.learningRate);
        return a;
    }

    public INDArray getHistoricalGradient() {
        return this.historicalGradient;
    }

    public int[] getShape() {
        return this.shape;
    }

    public double getLearningRate() {
        return this.learningRate;
    }

    public int getNumIterations() {
        return this.numIterations;
    }

    public double getEpsilon() {
        return this.epsilon;
    }

    public char getGradientReshapeOrder() {
        return this.gradientReshapeOrder;
    }

    public void setHistoricalGradient(INDArray historicalGradient) {
        this.historicalGradient = historicalGradient;
    }

    public void setShape(int[] shape) {
        this.shape = shape;
    }

    public void setLearningRate(double learningRate) {
        this.learningRate = learningRate;
    }

    public void setNumIterations(int numIterations) {
        this.numIterations = numIterations;
    }

    public void setEpsilon(double epsilon) {
        this.epsilon = epsilon;
    }

    public void setGradientReshapeOrder(char gradientReshapeOrder) {
        this.gradientReshapeOrder = gradientReshapeOrder;
    }

    public boolean equals(Object o) {
        if (o == this) {
            return true;
        }
        if (!(o instanceof AdaGrad)) {
            return false;
        }
        AdaGrad other = (AdaGrad)o;
        if (!other.canEqual(this)) {
            return false;
        }
        INDArray this$historicalGradient = this.getHistoricalGradient();
        INDArray other$historicalGradient = other.getHistoricalGradient();
        if (this$historicalGradient == null ? other$historicalGradient != null : !this$historicalGradient.equals(other$historicalGradient)) {
            return false;
        }
        if (!Arrays.equals(this.getShape(), other.getShape())) {
            return false;
        }
        if (Double.compare(this.getLearningRate(), other.getLearningRate()) != 0) {
            return false;
        }
        if (this.getNumIterations() != other.getNumIterations()) {
            return false;
        }
        if (Double.compare(this.getEpsilon(), other.getEpsilon()) != 0) {
            return false;
        }
        return this.getGradientReshapeOrder() == other.getGradientReshapeOrder();
    }

    protected boolean canEqual(Object other) {
        return other instanceof AdaGrad;
    }

    public int hashCode() {
        int PRIME = 59;
        int result = 1;
        INDArray $historicalGradient = this.getHistoricalGradient();
        result = result * 59 + ($historicalGradient == null ? 43 : $historicalGradient.hashCode());
        result = result * 59 + Arrays.hashCode(this.getShape());
        long $learningRate = Double.doubleToLongBits(this.getLearningRate());
        result = result * 59 + (int)($learningRate >>> 32 ^ $learningRate);
        result = result * 59 + this.getNumIterations();
        long $epsilon = Double.doubleToLongBits(this.getEpsilon());
        result = result * 59 + (int)($epsilon >>> 32 ^ $epsilon);
        result = result * 59 + this.getGradientReshapeOrder();
        return result;
    }

    public String toString() {
        return "AdaGrad(historicalGradient=" + this.getHistoricalGradient() + ", shape=" + Arrays.toString(this.getShape()) + ", learningRate=" + this.getLearningRate() + ", numIterations=" + this.getNumIterations() + ", epsilon=" + this.getEpsilon() + ", gradientReshapeOrder=" + this.getGradientReshapeOrder() + ")";
    }

    public AdaGrad() {
    }
}

