/*
 * Decompiled with CFR 0.152.
 */
package org.jetbrains.kotlinx.dl.api.core.optimizer;

import java.util.ArrayList;
import java.util.List;
import kotlin.Metadata;
import kotlin.jvm.internal.DefaultConstructorMarker;
import kotlin.jvm.internal.Intrinsics;
import kotlin.jvm.internal.SourceDebugExtension;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.kotlinx.dl.api.core.KGraph;
import org.jetbrains.kotlinx.dl.api.core.optimizer.ClipGradientAction;
import org.jetbrains.kotlinx.dl.api.core.optimizer.NoClipGradient;
import org.jetbrains.kotlinx.dl.api.core.optimizer.Optimizer;
import org.jetbrains.kotlinx.dl.api.core.util.DtypeConversionUtilKt;
import org.tensorflow.Operand;
import org.tensorflow.Output;
import org.tensorflow.op.Ops;
import org.tensorflow.op.TrainOps;
import org.tensorflow.op.core.Gradients;
import org.tensorflow.op.core.Variable;
import org.tensorflow.op.train.ApplyGradientDescent;

@Metadata(mv={1, 8, 0}, k=1, xi=48, d1={"\u0000J\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0000\n\u0002\u0010\u0007\n\u0000\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0010\u000b\n\u0002\b\u0007\n\u0002\u0010\u000e\n\u0002\b\u0003\n\u0002\u0010 \n\u0002\u0018\u0002\n\u0000\n\u0002\u0018\u0002\n\u0000\n\u0002\u0018\u0002\n\u0000\n\u0002\u0018\u0002\n\u0000\n\u0002\u0018\u0002\n\u0000\u0018\u00002\u00020\u0001B\u0019\u0012\b\b\u0002\u0010\u0002\u001a\u00020\u0003\u0012\b\b\u0002\u0010\u0004\u001a\u00020\u0005\u00a2\u0006\u0002\u0010\u0006J@\u0010\u0013\u001a\u000e\u0012\n\u0012\b\u0012\u0004\u0012\u00020\u00030\u00150\u00142\u0006\u0010\u0016\u001a\u00020\u00172\u0006\u0010\u0018\u001a\u00020\u00192\u0012\u0010\u001a\u001a\u000e\u0012\n\u0012\b\u0012\u0004\u0012\u00020\u00030\u001b0\u00142\u0006\u0010\u001c\u001a\u00020\u001dH\u0014R\u0014\u0010\u0007\u001a\u00020\b8PX\u0090\u0004\u00a2\u0006\u0006\u001a\u0004\b\t\u0010\nR\u001a\u0010\u0002\u001a\u00020\u0003X\u0086\u000e\u00a2\u0006\u000e\n\u0000\u001a\u0004\b\u000b\u0010\f\"\u0004\b\r\u0010\u000eR\u0014\u0010\u000f\u001a\u00020\u00108VX\u0096\u0004\u00a2\u0006\u0006\u001a\u0004\b\u0011\u0010\u0012\u00a8\u0006\u001e"}, d2={"Lorg/jetbrains/kotlinx/dl/api/core/optimizer/SGD;", "Lorg/jetbrains/kotlinx/dl/api/core/optimizer/Optimizer;", "learningRate", "", "clipGradient", "Lorg/jetbrains/kotlinx/dl/api/core/optimizer/ClipGradientAction;", "(FLorg/jetbrains/kotlinx/dl/api/core/optimizer/ClipGradientAction;)V", "isRunningOnGPU", "", "isRunningOnGPU$tensorflow", "()Z", "getLearningRate", "()F", "setLearningRate", "(F)V", "optimizerName", "", "getOptimizerName", "()Ljava/lang/String;", "applyGradients", "", "Lorg/tensorflow/Operand;", "graph", "Lorg/jetbrains/kotlinx/dl/api/core/KGraph;", "tf", "Lorg/tensorflow/op/Ops;", "weights", "Lorg/tensorflow/op/core/Variable;", "gradients", "Lorg/tensorflow/op/core/Gradients;", "tensorflow"})
@SourceDebugExtension(value={"SMAP\nSGD.kt\nKotlin\n*S Kotlin\n*F\n+ 1 SGD.kt\norg/jetbrains/kotlinx/dl/api/core/optimizer/SGD\n+ 2 fake.kt\nkotlin/jvm/internal/FakeKt\n*L\n1#1,57:1\n1#2:58\n*E\n"})
public final class SGD
extends Optimizer {
    private float learningRate;

    public SGD(float learningRate, @NotNull ClipGradientAction clipGradient) {
        Intrinsics.checkNotNullParameter((Object)clipGradient, (String)"clipGradient");
        super(clipGradient);
        this.learningRate = learningRate;
        if (!(this.learningRate >= 0.0f)) {
            boolean bl = false;
            String string = "Learning rate " + this.learningRate + " should be >= 0.0.";
            throw new IllegalArgumentException(string.toString());
        }
    }

    public /* synthetic */ SGD(float f, ClipGradientAction clipGradientAction, int n, DefaultConstructorMarker defaultConstructorMarker) {
        if ((n & 1) != 0) {
            f = 0.2f;
        }
        if ((n & 2) != 0) {
            clipGradientAction = new NoClipGradient();
        }
        this(f, clipGradientAction);
    }

    public final float getLearningRate() {
        return this.learningRate;
    }

    public final void setLearningRate(float f) {
        this.learningRate = f;
    }

    @Override
    @NotNull
    protected List<Operand<Float>> applyGradients(@NotNull KGraph graph, @NotNull Ops tf, @NotNull List<Variable<Float>> weights, @NotNull Gradients gradients) {
        Intrinsics.checkNotNullParameter((Object)graph, (String)"graph");
        Intrinsics.checkNotNullParameter((Object)tf, (String)"tf");
        Intrinsics.checkNotNullParameter(weights, (String)"weights");
        Intrinsics.checkNotNullParameter((Object)gradients, (String)"gradients");
        List targets = new ArrayList();
        int n = weights.size();
        for (int i = 0; i < n; ++i) {
            TrainOps trainOps = tf.train;
            Operand operand = (Operand)weights.get(i);
            Operand operand2 = (Operand)tf.constant((Object)Float.valueOf(this.learningRate), DtypeConversionUtilKt.getDType());
            ClipGradientAction clipGradientAction = this.getClipGradient();
            Output output = gradients.dy(i);
            Intrinsics.checkNotNullExpressionValue((Object)output, (String)"gradients.dy(i)");
            ApplyGradientDescent.Options[] optionsArray = new ApplyGradientDescent.Options[]{ApplyGradientDescent.useLocking((Boolean)true)};
            ApplyGradientDescent applyGradientDescent = trainOps.applyGradientDescent(operand, operand2, clipGradientAction.clipGradient(tf, (Operand<Float>)((Operand)output)), optionsArray);
            Intrinsics.checkNotNullExpressionValue((Object)applyGradientDescent, (String)"tf.train.applyGradientDe\u2026g(true)\n                )");
            targets.add(applyGradientDescent);
        }
        return targets;
    }

    @Override
    @NotNull
    public String getOptimizerName() {
        return "SGD";
    }

    @Override
    public boolean isRunningOnGPU$tensorflow() {
        return true;
    }

    public SGD() {
        this(0.0f, null, 3, null);
    }
}

