/*
 * Decompiled with CFR 0.152.
 */
package org.jetbrains.kotlinx.dl.api.core.optimizer;

import java.util.ArrayList;
import java.util.List;
import kotlin.Metadata;
import kotlin.jvm.internal.DefaultConstructorMarker;
import kotlin.jvm.internal.Intrinsics;
import kotlin.jvm.internal.SourceDebugExtension;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.kotlinx.dl.api.core.KGraph;
import org.jetbrains.kotlinx.dl.api.core.optimizer.AdamKt;
import org.jetbrains.kotlinx.dl.api.core.optimizer.ClipGradientAction;
import org.jetbrains.kotlinx.dl.api.core.optimizer.NoClipGradient;
import org.jetbrains.kotlinx.dl.api.core.optimizer.Optimizer;
import org.jetbrains.kotlinx.dl.api.core.util.DtypeConversionUtilKt;
import org.jetbrains.kotlinx.dl.api.core.util.NameConventionsKt;
import org.tensorflow.Operand;
import org.tensorflow.Output;
import org.tensorflow.Shape;
import org.tensorflow.op.MathOps;
import org.tensorflow.op.Ops;
import org.tensorflow.op.TrainOps;
import org.tensorflow.op.core.Assign;
import org.tensorflow.op.core.Constant;
import org.tensorflow.op.core.Fill;
import org.tensorflow.op.core.Gradients;
import org.tensorflow.op.core.Variable;
import org.tensorflow.op.train.ApplyAdam;

@Metadata(mv={1, 8, 0}, k=1, xi=48, d1={"\u0000b\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0000\n\u0002\u0010\u0007\n\u0002\b\u0004\n\u0002\u0010\u000b\n\u0000\n\u0002\u0018\u0002\n\u0002\b\u0005\n\u0002\u0018\u0002\n\u0000\n\u0002\u0018\u0002\n\u0002\b\n\n\u0002\u0010\u000e\n\u0002\b\u0004\n\u0002\u0010 \n\u0002\u0018\u0002\n\u0000\n\u0002\u0018\u0002\n\u0000\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0018\u0002\n\u0000\n\u0002\u0010\u0002\n\u0000\n\u0002\u0018\u0002\n\u0002\b\u0003\u0018\u00002\u00020\u0001BA\u0012\b\b\u0002\u0010\u0002\u001a\u00020\u0003\u0012\b\b\u0002\u0010\u0004\u001a\u00020\u0003\u0012\b\b\u0002\u0010\u0005\u001a\u00020\u0003\u0012\b\b\u0002\u0010\u0006\u001a\u00020\u0003\u0012\b\b\u0002\u0010\u0007\u001a\u00020\b\u0012\b\b\u0002\u0010\t\u001a\u00020\n\u00a2\u0006\u0002\u0010\u000bJ@\u0010!\u001a\u000e\u0012\n\u0012\b\u0012\u0004\u0012\u00020\u00030#0\"2\u0006\u0010$\u001a\u00020%2\u0006\u0010&\u001a\u00020'2\u0012\u0010(\u001a\u000e\u0012\n\u0012\b\u0012\u0004\u0012\u00020\u00030\u00120\"2\u0006\u0010)\u001a\u00020*H\u0014J&\u0010+\u001a\u00020,2\u0006\u0010$\u001a\u00020%2\u0006\u0010&\u001a\u00020'2\f\u0010-\u001a\b\u0012\u0004\u0012\u00020\u00030.H\u0002J,\u0010/\u001a\u00020,2\u0006\u0010$\u001a\u00020%2\u0006\u0010&\u001a\u00020'2\u0012\u00100\u001a\u000e\u0012\n\u0012\b\u0012\u0004\u0012\u00020\u00030.0\"H\u0014R\u0011\u0010\u0004\u001a\u00020\u0003\u00a2\u0006\b\n\u0000\u001a\u0004\b\f\u0010\rR\u0011\u0010\u0005\u001a\u00020\u0003\u00a2\u0006\b\n\u0000\u001a\u0004\b\u000e\u0010\rR\u0014\u0010\u000f\u001a\b\u0012\u0004\u0012\u00020\u00030\u0010X\u0082.\u00a2\u0006\u0002\n\u0000R\u0014\u0010\u0011\u001a\b\u0012\u0004\u0012\u00020\u00030\u0012X\u0082.\u00a2\u0006\u0002\n\u0000R\u0014\u0010\u0013\u001a\b\u0012\u0004\u0012\u00020\u00030\u0010X\u0082.\u00a2\u0006\u0002\n\u0000R\u0014\u0010\u0014\u001a\b\u0012\u0004\u0012\u00020\u00030\u0012X\u0082.\u00a2\u0006\u0002\n\u0000R\u0011\u0010\u0006\u001a\u00020\u0003\u00a2\u0006\b\n\u0000\u001a\u0004\b\u0015\u0010\rR\u0014\u0010\u0016\u001a\b\u0012\u0004\u0012\u00020\u00030\u0010X\u0082.\u00a2\u0006\u0002\n\u0000R\u0014\u0010\u0017\u001a\u00020\b8PX\u0090\u0004\u00a2\u0006\u0006\u001a\u0004\b\u0018\u0010\u0019R\u0011\u0010\u0002\u001a\u00020\u0003\u00a2\u0006\b\n\u0000\u001a\u0004\b\u001a\u0010\rR\u0014\u0010\u001b\u001a\b\u0012\u0004\u0012\u00020\u00030\u0010X\u0082.\u00a2\u0006\u0002\n\u0000R\u0014\u0010\u001c\u001a\u00020\u001d8VX\u0096\u0004\u00a2\u0006\u0006\u001a\u0004\b\u001e\u0010\u001fR\u0011\u0010\u0007\u001a\u00020\b\u00a2\u0006\b\n\u0000\u001a\u0004\b \u0010\u0019\u00a8\u00061"}, d2={"Lorg/jetbrains/kotlinx/dl/api/core/optimizer/Adam;", "Lorg/jetbrains/kotlinx/dl/api/core/optimizer/Optimizer;", "learningRate", "", "beta1", "beta2", "epsilon", "useNesterov", "", "clipGradient", "Lorg/jetbrains/kotlinx/dl/api/core/optimizer/ClipGradientAction;", "(FFFFZLorg/jetbrains/kotlinx/dl/api/core/optimizer/ClipGradientAction;)V", "getBeta1", "()F", "getBeta2", "betaOneConst", "Lorg/tensorflow/op/core/Constant;", "betaOnePower", "Lorg/tensorflow/op/core/Variable;", "betaTwoConst", "betaTwoPower", "getEpsilon", "epsilonConstant", "isRunningOnGPU", "isRunningOnGPU$tensorflow", "()Z", "getLearningRate", "learningRateConst", "optimizerName", "", "getOptimizerName", "()Ljava/lang/String;", "getUseNesterov", "applyGradients", "", "Lorg/tensorflow/Operand;", "graph", "Lorg/jetbrains/kotlinx/dl/api/core/KGraph;", "tf", "Lorg/tensorflow/op/Ops;", "weights", "gradients", "Lorg/tensorflow/op/core/Gradients;", "createAdamSlot", "", "v", "Lorg/tensorflow/Output;", "createSlots", "variables", "tensorflow"})
@SourceDebugExtension(value={"SMAP\nAdam.kt\nKotlin\n*S Kotlin\n*F\n+ 1 Adam.kt\norg/jetbrains/kotlinx/dl/api/core/optimizer/Adam\n+ 2 fake.kt\nkotlin/jvm/internal/FakeKt\n*L\n1#1,167:1\n1#2:168\n*E\n"})
public final class Adam
extends Optimizer {
    private final float learningRate;
    private final float beta1;
    private final float beta2;
    private final float epsilon;
    private final boolean useNesterov;
    private Constant<Float> epsilonConstant;
    private Constant<Float> learningRateConst;
    private Constant<Float> betaOneConst;
    private Constant<Float> betaTwoConst;
    private Variable<Float> betaOnePower;
    private Variable<Float> betaTwoPower;

    public Adam(float learningRate, float beta1, float beta2, float epsilon, boolean useNesterov, @NotNull ClipGradientAction clipGradient) {
        Intrinsics.checkNotNullParameter((Object)clipGradient, (String)"clipGradient");
        super(clipGradient);
        this.learningRate = learningRate;
        this.beta1 = beta1;
        this.beta2 = beta2;
        this.epsilon = epsilon;
        this.useNesterov = useNesterov;
        if (!(this.learningRate >= 0.0f)) {
            boolean $i$a$-require-Adam$52 = false;
            String $i$a$-require-Adam$52 = "Learning rate " + this.learningRate + " should be >= 0.0.";
            throw new IllegalArgumentException($i$a$-require-Adam$52.toString());
        }
        if (!(this.beta1 > 0.0f && this.beta1 < 1.0f)) {
            boolean $i$a$-require-Adam$62 = false;
            String $i$a$-require-Adam$62 = "Beta1 " + this.beta1 + " should be in range (0.0; 1.0).";
            throw new IllegalArgumentException($i$a$-require-Adam$62.toString());
        }
        if (!(this.beta2 > 0.0f && this.beta2 < 1.0f)) {
            boolean $i$a$-require-Adam$72 = false;
            String $i$a$-require-Adam$72 = "Beta2 " + this.beta2 + " should be in range (0.0; 1.0).";
            throw new IllegalArgumentException($i$a$-require-Adam$72.toString());
        }
        if (!(this.epsilon >= 0.0f)) {
            boolean bl = false;
            String string = "L2Strength " + this.epsilon + " should be >= 0.0.";
            throw new IllegalArgumentException(string.toString());
        }
    }

    public /* synthetic */ Adam(float f, float f2, float f3, float f4, boolean bl, ClipGradientAction clipGradientAction, int n, DefaultConstructorMarker defaultConstructorMarker) {
        if ((n & 1) != 0) {
            f = 0.001f;
        }
        if ((n & 2) != 0) {
            f2 = 0.9f;
        }
        if ((n & 4) != 0) {
            f3 = 0.999f;
        }
        if ((n & 8) != 0) {
            f4 = 1.0E-7f;
        }
        if ((n & 0x10) != 0) {
            bl = false;
        }
        if ((n & 0x20) != 0) {
            clipGradientAction = new NoClipGradient();
        }
        this(f, f2, f3, f4, bl, clipGradientAction);
    }

    public final float getLearningRate() {
        return this.learningRate;
    }

    public final float getBeta1() {
        return this.beta1;
    }

    public final float getBeta2() {
        return this.beta2;
    }

    public final float getEpsilon() {
        return this.epsilon;
    }

    public final boolean getUseNesterov() {
        return this.useNesterov;
    }

    @Override
    @NotNull
    protected List<Operand<Float>> applyGradients(@NotNull KGraph graph, @NotNull Ops tf, @NotNull List<Variable<Float>> weights, @NotNull Gradients gradients) {
        Intrinsics.checkNotNullParameter((Object)graph, (String)"graph");
        Intrinsics.checkNotNullParameter((Object)tf, (String)"tf");
        Intrinsics.checkNotNullParameter(weights, (String)"weights");
        Intrinsics.checkNotNullParameter((Object)gradients, (String)"gradients");
        List targets = new ArrayList();
        Constant constant = tf.constant((Object)Float.valueOf(this.beta1), DtypeConversionUtilKt.getDType());
        Intrinsics.checkNotNullExpressionValue((Object)constant, (String)"tf.constant(beta1, getDType())");
        this.betaOneConst = constant;
        Constant constant2 = tf.constant((Object)Float.valueOf(this.beta2), DtypeConversionUtilKt.getDType());
        Intrinsics.checkNotNullExpressionValue((Object)constant2, (String)"tf.constant(beta2, getDType())");
        this.betaTwoConst = constant2;
        Constant constant3 = tf.constant((Object)Float.valueOf(this.learningRate), DtypeConversionUtilKt.getDType());
        Intrinsics.checkNotNullExpressionValue((Object)constant3, (String)"tf.constant(learningRate, getDType())");
        this.learningRateConst = constant3;
        Constant constant4 = tf.constant((Object)Float.valueOf(this.epsilon), DtypeConversionUtilKt.getDType());
        Intrinsics.checkNotNullExpressionValue((Object)constant4, (String)"tf.constant(epsilon, getDType())");
        this.epsilonConstant = constant4;
        int n = weights.size();
        for (int i = 0; i < n; ++i) {
            Variable<Float> variable = weights.get(i);
            String varName = variable.ref().op().name();
            Intrinsics.checkNotNullExpressionValue((Object)varName, (String)"varName");
            Variable<Float> firstMomentSlot = this.getSlot(varName, "m");
            Variable<Float> secondMomentSlot = this.getSlot(varName, "v");
            TrainOps trainOps = tf.train;
            Operand operand = (Operand)variable;
            Operand operand2 = (Operand)firstMomentSlot;
            Operand operand3 = (Operand)secondMomentSlot;
            Variable<Float> variable2 = this.betaOnePower;
            if (variable2 == null) {
                Intrinsics.throwUninitializedPropertyAccessException((String)"betaOnePower");
                variable2 = null;
            }
            Operand operand4 = (Operand)variable2;
            Variable<Float> variable3 = this.betaTwoPower;
            if (variable3 == null) {
                Intrinsics.throwUninitializedPropertyAccessException((String)"betaTwoPower");
                variable3 = null;
            }
            Operand operand5 = (Operand)variable3;
            Constant<Float> constant5 = this.learningRateConst;
            if (constant5 == null) {
                Intrinsics.throwUninitializedPropertyAccessException((String)"learningRateConst");
                constant5 = null;
            }
            Operand operand6 = (Operand)constant5;
            Constant<Float> constant6 = this.betaOneConst;
            if (constant6 == null) {
                Intrinsics.throwUninitializedPropertyAccessException((String)"betaOneConst");
                constant6 = null;
            }
            Operand operand7 = (Operand)constant6;
            Constant<Float> constant7 = this.betaTwoConst;
            if (constant7 == null) {
                Intrinsics.throwUninitializedPropertyAccessException((String)"betaTwoConst");
                constant7 = null;
            }
            Operand operand8 = (Operand)constant7;
            Constant<Float> constant8 = this.epsilonConstant;
            if (constant8 == null) {
                Intrinsics.throwUninitializedPropertyAccessException((String)"epsilonConstant");
                constant8 = null;
            }
            Operand operand9 = (Operand)constant8;
            ClipGradientAction clipGradientAction = this.getClipGradient();
            Output output = gradients.dy(i);
            Intrinsics.checkNotNullExpressionValue((Object)output, (String)"gradients.dy(i)");
            ApplyAdam.Options[] optionsArray = new ApplyAdam.Options[]{ApplyAdam.useNesterov((Boolean)this.useNesterov), ApplyAdam.useLocking((Boolean)true)};
            ApplyAdam applyAdam = trainOps.applyAdam(operand, operand2, operand3, operand4, operand5, operand6, operand7, operand8, operand9, clipGradientAction.clipGradient(tf, (Operand<Float>)((Operand)output)), optionsArray);
            Intrinsics.checkNotNullExpressionValue((Object)applyAdam, (String)"tf.train.applyAdam(\n    \u2026g(true)\n                )");
            targets.add(applyAdam);
        }
        Variable<Float> variable = this.betaOnePower;
        if (variable == null) {
            Intrinsics.throwUninitializedPropertyAccessException((String)"betaOnePower");
            variable = null;
        }
        Operand operand = (Operand)variable;
        MathOps mathOps = tf.math;
        Variable<Float> variable4 = this.betaOnePower;
        if (variable4 == null) {
            Intrinsics.throwUninitializedPropertyAccessException((String)"betaOnePower");
            variable4 = null;
        }
        Operand operand10 = (Operand)variable4;
        Constant<Float> constant9 = this.betaOneConst;
        if (constant9 == null) {
            Intrinsics.throwUninitializedPropertyAccessException((String)"betaOneConst");
            constant9 = null;
        }
        Assign betaOnePowerInit1 = tf.assign(operand, (Operand)mathOps.mul(operand10, (Operand)constant9), new Assign.Options[0]);
        Variable<Float> variable5 = this.betaTwoPower;
        if (variable5 == null) {
            Intrinsics.throwUninitializedPropertyAccessException((String)"betaTwoPower");
            variable5 = null;
        }
        Operand operand11 = (Operand)variable5;
        MathOps mathOps2 = tf.math;
        Variable<Float> variable6 = this.betaTwoPower;
        if (variable6 == null) {
            Intrinsics.throwUninitializedPropertyAccessException((String)"betaTwoPower");
            variable6 = null;
        }
        Operand operand12 = (Operand)variable6;
        Constant<Float> constant10 = this.betaTwoConst;
        if (constant10 == null) {
            Intrinsics.throwUninitializedPropertyAccessException((String)"betaTwoConst");
            constant10 = null;
        }
        Assign betaTwoPowerInit2 = tf.assign(operand11, (Operand)mathOps2.mul(operand12, (Operand)constant10), new Assign.Options[0]);
        Intrinsics.checkNotNullExpressionValue((Object)betaOnePowerInit1, (String)"betaOnePowerInit1");
        graph.addOptimizerVariableInitializer(betaOnePowerInit1);
        Intrinsics.checkNotNullExpressionValue((Object)betaTwoPowerInit2, (String)"betaTwoPowerInit2");
        graph.addOptimizerVariableInitializer(betaTwoPowerInit2);
        Variable<Float> variable7 = this.betaOnePower;
        if (variable7 == null) {
            Intrinsics.throwUninitializedPropertyAccessException((String)"betaOnePower");
            variable7 = null;
        }
        graph.addOptimizerVariable(variable7);
        Variable<Float> variable8 = this.betaTwoPower;
        if (variable8 == null) {
            Intrinsics.throwUninitializedPropertyAccessException((String)"betaTwoPower");
            variable8 = null;
        }
        graph.addOptimizerVariable(variable8);
        return targets;
    }

    private final void createAdamSlot(KGraph graph, Ops tf, Output<Float> v) {
        String firstMomentInitializerName = NameConventionsKt.defaultInitializerOpName(this.createName$tensorflow(v, "m"));
        Fill firstMomentInitializer = tf.withName(firstMomentInitializerName).fill((Operand)tf.shape((Operand)v), (Operand)tf.constant((Object)Float.valueOf(0.0f), DtypeConversionUtilKt.getDType()));
        Output output = v.asOutput();
        Intrinsics.checkNotNullExpressionValue((Object)output, (String)"v.asOutput()");
        Intrinsics.checkNotNullExpressionValue((Object)firstMomentInitializer, (String)"firstMomentInitializer");
        this.createSlot(graph, tf, (Output<Float>)output, "m", (Operand<Float>)((Operand)firstMomentInitializer));
        String secondMomentInitializerName = NameConventionsKt.defaultInitializerOpName(this.createName$tensorflow(v, "v"));
        Fill secondMomentInitializer = tf.withName(secondMomentInitializerName).fill((Operand)tf.shape((Operand)v), (Operand)tf.constant((Object)Float.valueOf(0.0f), DtypeConversionUtilKt.getDType()));
        Output output2 = v.asOutput();
        Intrinsics.checkNotNullExpressionValue((Object)output2, (String)"v.asOutput()");
        Intrinsics.checkNotNullExpressionValue((Object)secondMomentInitializer, (String)"secondMomentInitializer");
        this.createSlot(graph, tf, (Output<Float>)output2, "v", (Operand<Float>)((Operand)secondMomentInitializer));
    }

    @Override
    protected void createSlots(@NotNull KGraph graph, @NotNull Ops tf, @NotNull List<Output<Float>> variables2) {
        Intrinsics.checkNotNullParameter((Object)graph, (String)"graph");
        Intrinsics.checkNotNullParameter((Object)tf, (String)"tf");
        Intrinsics.checkNotNullParameter(variables2, (String)"variables");
        for (Output<Float> v : variables2) {
            Output output = v.asOutput();
            Intrinsics.checkNotNullExpressionValue((Object)output, (String)"v.asOutput()");
            this.createAdamSlot(graph, tf, (Output<Float>)output);
        }
        Variable variable = tf.withName(AdamKt.access$getFIRST_BETA_POWER_NAME$p()).variable(Shape.scalar(), DtypeConversionUtilKt.getDType(), new Variable.Options[0]);
        Intrinsics.checkNotNullExpressionValue((Object)variable, (String)"tf.withName(FIRST_BETA_P\u2026ape.scalar(), getDType())");
        this.betaOnePower = variable;
        String betaOnePowerAssignName = NameConventionsKt.defaultAssignOpName(AdamKt.access$getFIRST_BETA_POWER_NAME$p());
        Ops ops = tf.withName(betaOnePowerAssignName);
        Variable<Float> variable2 = this.betaOnePower;
        if (variable2 == null) {
            Intrinsics.throwUninitializedPropertyAccessException((String)"betaOnePower");
            variable2 = null;
        }
        Assign assign = ops.assign((Operand)variable2, (Operand)tf.withName(NameConventionsKt.defaultInitializerOpName(AdamKt.access$getFIRST_BETA_POWER_NAME$p())).constant((Object)Float.valueOf(this.beta1), DtypeConversionUtilKt.getDType()), new Assign.Options[0]);
        Intrinsics.checkNotNullExpressionValue((Object)assign, (String)"tf.withName(betaOnePower\u2026getDType())\n            )");
        Assign betaOnePowerInit = assign;
        graph.addOptimizerVariableInitializer(betaOnePowerInit);
        Variable variable3 = tf.withName(AdamKt.access$getSECOND_BETA_POWER_NAME$p()).variable(Shape.scalar(), DtypeConversionUtilKt.getDType(), new Variable.Options[0]);
        Intrinsics.checkNotNullExpressionValue((Object)variable3, (String)"tf.withName(SECOND_BETA_\u2026ape.scalar(), getDType())");
        this.betaTwoPower = variable3;
        String betaTwoPowerAssignName = NameConventionsKt.defaultAssignOpName(AdamKt.access$getSECOND_BETA_POWER_NAME$p());
        Ops ops2 = tf.withName(betaTwoPowerAssignName);
        Variable<Float> variable4 = this.betaTwoPower;
        if (variable4 == null) {
            Intrinsics.throwUninitializedPropertyAccessException((String)"betaTwoPower");
            variable4 = null;
        }
        Assign assign2 = ops2.assign((Operand)variable4, (Operand)tf.withName(NameConventionsKt.defaultInitializerOpName(AdamKt.access$getSECOND_BETA_POWER_NAME$p())).constant((Object)Float.valueOf(this.beta2), DtypeConversionUtilKt.getDType()), new Assign.Options[0]);
        Intrinsics.checkNotNullExpressionValue((Object)assign2, (String)"tf.withName(betaTwoPower\u2026getDType())\n            )");
        Assign betaTwoPowerInit = assign2;
        graph.addOptimizerVariableInitializer(betaTwoPowerInit);
    }

    @Override
    @NotNull
    public String getOptimizerName() {
        return "Adam";
    }

    @Override
    public boolean isRunningOnGPU$tensorflow() {
        return true;
    }

    public Adam() {
        this(0.0f, 0.0f, 0.0f, 0.0f, false, null, 63, null);
    }
}

