/*
 * Decompiled with CFR 0.152.
 */
package org.jetbrains.kotlinx.dl.api.core.optimizer;

import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import kotlin.Metadata;
import kotlin.jvm.functions.Function1;
import kotlin.jvm.internal.Intrinsics;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.kotlinx.dl.api.core.KGraph;
import org.jetbrains.kotlinx.dl.api.core.optimizer.ClipGradientAction;
import org.jetbrains.kotlinx.dl.api.core.optimizer.Optimizer;
import org.jetbrains.kotlinx.dl.api.core.util.DtypeConversionUtilKt;
import org.jetbrains.kotlinx.dl.api.core.util.NameConventionsKt;
import org.tensorflow.Operand;
import org.tensorflow.Output;
import org.tensorflow.op.Ops;
import org.tensorflow.op.core.Assign;
import org.tensorflow.op.core.Gradients;
import org.tensorflow.op.core.Variable;

@Metadata(mv={1, 8, 0}, k=1, xi=48, d1={"\u0000`\n\u0002\u0018\u0002\n\u0002\u0010\u0000\n\u0000\n\u0002\u0018\u0002\n\u0002\b\u0004\n\u0002\u0010\u000b\n\u0002\b\u0003\n\u0002\u0010\u000e\n\u0002\b\u0003\n\u0002\u0010%\n\u0002\u0018\u0002\n\u0002\u0010\u0007\n\u0000\n\u0002\u0010 \n\u0002\u0018\u0002\n\u0000\n\u0002\u0018\u0002\n\u0000\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0018\u0002\n\u0002\b\u0004\n\u0002\u0018\u0002\n\u0002\b\u0003\n\u0002\u0010\u0002\n\u0002\b\t\b&\u0018\u00002\u00020\u0001B\r\u0012\u0006\u0010\u0002\u001a\u00020\u0003\u00a2\u0006\u0002\u0010\u0004J@\u0010\u0013\u001a\u000e\u0012\n\u0012\b\u0012\u0004\u0012\u00020\u00120\u00150\u00142\u0006\u0010\u0016\u001a\u00020\u00172\u0006\u0010\u0018\u001a\u00020\u00192\u0012\u0010\u001a\u001a\u000e\u0012\n\u0012\b\u0012\u0004\u0012\u00020\u00120\u00110\u00142\u0006\u0010\u001b\u001a\u00020\u001cH$J2\u0010\u001d\u001a\u00020\u001c2\u0006\u0010\u0018\u001a\u00020\u00192\f\u0010\u001e\u001a\b\u0012\u0004\u0012\u00020\u00120\u00152\u0012\u0010\u001a\u001a\u000e\u0012\n\u0012\b\u0012\u0004\u0012\u00020\u00120\u00110\u0014H\u0002J#\u0010\u001f\u001a\u00020\f2\f\u0010 \u001a\b\u0012\u0004\u0012\u00020\u00120!2\u0006\u0010\"\u001a\u00020\fH\u0010\u00a2\u0006\u0002\b#J<\u0010$\u001a\u00020%2\u0006\u0010\u0016\u001a\u00020\u00172\u0006\u0010\u0018\u001a\u00020\u00192\f\u0010 \u001a\b\u0012\u0004\u0012\u00020\u00120!2\u0006\u0010\"\u001a\u00020\f2\f\u0010&\u001a\b\u0012\u0004\u0012\u00020\u00120\u0015H\u0014J,\u0010'\u001a\u00020%2\u0006\u0010\u0016\u001a\u00020\u00172\u0006\u0010\u0018\u001a\u00020\u00192\u0012\u0010(\u001a\u000e\u0012\n\u0012\b\u0012\u0004\u0012\u00020\u00120!0\u0014H\u0014J\u001e\u0010)\u001a\b\u0012\u0004\u0012\u00020\u00120\u00112\u0006\u0010*\u001a\u00020\f2\u0006\u0010\"\u001a\u00020\fH\u0004JK\u0010+\u001a\u000e\u0012\n\u0012\b\u0012\u0004\u0012\u00020\u00120\u00150\u00142\u0006\u0010\u0016\u001a\u00020\u00172\u0012\u0010\u001a\u001a\u000e\u0012\n\u0012\b\u0012\u0004\u0012\u00020\u00120\u00110\u00142\u0006\u0010\u0018\u001a\u00020\u00192\f\u0010\u001e\u001a\b\u0012\u0004\u0012\u00020\u00120\u0015H\u0000\u00a2\u0006\u0002\b,J(\u0010-\u001a\u000e\u0012\n\u0012\b\u0012\u0004\u0012\u00020\u00120!0\u00142\u0012\u0010(\u001a\u000e\u0012\n\u0012\b\u0012\u0004\u0012\u00020\u00120\u00110\u0014H\u0002R\u0011\u0010\u0002\u001a\u00020\u0003\u00a2\u0006\b\n\u0000\u001a\u0004\b\u0005\u0010\u0006R\u0012\u0010\u0007\u001a\u00020\bX\u00a0\u0004\u00a2\u0006\u0006\u001a\u0004\b\t\u0010\nR\u0012\u0010\u000b\u001a\u00020\fX\u00a6\u0004\u00a2\u0006\u0006\u001a\u0004\b\r\u0010\u000eR,\u0010\u000f\u001a \u0012\u0004\u0012\u00020\f\u0012\u0016\u0012\u0014\u0012\u0004\u0012\u00020\f\u0012\n\u0012\b\u0012\u0004\u0012\u00020\u00120\u00110\u00100\u0010X\u0082.\u00a2\u0006\u0002\n\u0000\u00a8\u0006."}, d2={"Lorg/jetbrains/kotlinx/dl/api/core/optimizer/Optimizer;", "", "clipGradient", "Lorg/jetbrains/kotlinx/dl/api/core/optimizer/ClipGradientAction;", "(Lorg/jetbrains/kotlinx/dl/api/core/optimizer/ClipGradientAction;)V", "getClipGradient", "()Lorg/jetbrains/kotlinx/dl/api/core/optimizer/ClipGradientAction;", "isRunningOnGPU", "", "isRunningOnGPU$tensorflow", "()Z", "optimizerName", "", "getOptimizerName", "()Ljava/lang/String;", "slots", "", "Lorg/tensorflow/op/core/Variable;", "", "applyGradients", "", "Lorg/tensorflow/Operand;", "graph", "Lorg/jetbrains/kotlinx/dl/api/core/KGraph;", "tf", "Lorg/tensorflow/op/Ops;", "weights", "gradients", "Lorg/tensorflow/op/core/Gradients;", "computeGradients", "loss", "createName", "variable", "Lorg/tensorflow/Output;", "slotName", "createName$tensorflow", "createSlot", "", "initializer", "createSlots", "variables", "getSlot", "varName", "prepareTargets", "prepareTargets$tensorflow", "variablesToOutputs", "tensorflow"})
public abstract class Optimizer {
    @NotNull
    private final ClipGradientAction clipGradient;
    private Map<String, Map<String, Variable<Float>>> slots;

    public Optimizer(@NotNull ClipGradientAction clipGradient) {
        Intrinsics.checkNotNullParameter((Object)clipGradient, (String)"clipGradient");
        this.clipGradient = clipGradient;
    }

    @NotNull
    public final ClipGradientAction getClipGradient() {
        return this.clipGradient;
    }

    @NotNull
    public final List<Operand<Float>> prepareTargets$tensorflow(@NotNull KGraph graph, @NotNull List<Variable<Float>> weights, @NotNull Ops tf, @NotNull Operand<Float> loss) {
        Intrinsics.checkNotNullParameter((Object)graph, (String)"graph");
        Intrinsics.checkNotNullParameter(weights, (String)"weights");
        Intrinsics.checkNotNullParameter((Object)tf, (String)"tf");
        Intrinsics.checkNotNullParameter(loss, (String)"loss");
        this.slots = new LinkedHashMap();
        Gradients gradients = this.computeGradients(tf, loss, weights);
        List<Output<Float>> variableOutputs = this.variablesToOutputs(weights);
        this.createSlots(graph, tf, variableOutputs);
        return this.applyGradients(graph, tf, weights, gradients);
    }

    private final List<Output<Float>> variablesToOutputs(List<Variable<Float>> variables2) {
        List variableOutputs = new ArrayList();
        int n = variables2.size();
        for (int i = 0; i < n; ++i) {
            Output output = variables2.get(i).asOutput();
            Intrinsics.checkNotNullExpressionValue((Object)output, (String)"variables[i].asOutput()");
            variableOutputs.add(i, output);
        }
        return variableOutputs;
    }

    @NotNull
    protected abstract List<Operand<Float>> applyGradients(@NotNull KGraph var1, @NotNull Ops var2, @NotNull List<Variable<Float>> var3, @NotNull Gradients var4);

    private final Gradients computeGradients(Ops tf, Operand<Float> loss, List<Variable<Float>> weights) {
        Gradients gradients = tf.gradients(loss, (Iterable)weights, new Gradients.Options[0]);
        Intrinsics.checkNotNullExpressionValue((Object)gradients, (String)"tf.gradients(loss, weights)");
        return gradients;
    }

    protected void createSlots(@NotNull KGraph graph, @NotNull Ops tf, @NotNull List<Output<Float>> variables2) {
        Intrinsics.checkNotNullParameter((Object)graph, (String)"graph");
        Intrinsics.checkNotNullParameter((Object)tf, (String)"tf");
        Intrinsics.checkNotNullParameter(variables2, (String)"variables");
    }

    @NotNull
    public abstract String getOptimizerName();

    protected void createSlot(@NotNull KGraph graph, @NotNull Ops tf, @NotNull Output<Float> variable, @NotNull String slotName, @NotNull Operand<Float> initializer) {
        Intrinsics.checkNotNullParameter((Object)graph, (String)"graph");
        Intrinsics.checkNotNullParameter((Object)tf, (String)"tf");
        Intrinsics.checkNotNullParameter(variable, (String)"variable");
        Intrinsics.checkNotNullParameter((Object)slotName, (String)"slotName");
        Intrinsics.checkNotNullParameter(initializer, (String)"initializer");
        String createName = this.createName$tensorflow(variable, slotName);
        Variable variable2 = tf.withName(createName).variable(variable.shape(), DtypeConversionUtilKt.getDType(), new Variable.Options[0]);
        Intrinsics.checkNotNullExpressionValue((Object)variable2, (String)"tf.withName(createName).\u2026able.shape(), getDType())");
        Variable slot = variable2;
        String assignName = NameConventionsKt.defaultAssignOpName(this.createName$tensorflow(variable, slotName));
        Assign assign = tf.withName(assignName).assign((Operand)slot, initializer, new Assign.Options[0]);
        Intrinsics.checkNotNullExpressionValue((Object)assign, (String)"tf.withName(assignName).assign(slot, initializer)");
        Assign slotInit = assign;
        graph.addOptimizerVariableInitializer(slotInit);
        graph.addOptimizerVariable((Variable<Float>)slot);
        String varName = variable.op().name();
        Map<String, Map<String, Variable<Float>>> map = this.slots;
        if (map == null) {
            Intrinsics.throwUninitializedPropertyAccessException((String)"slots");
            map = null;
        }
        Map map2 = map.computeIfAbsent(slotName, arg_0 -> Optimizer.createSlot$lambda$0(createSlot.variables.1.INSTANCE, arg_0));
        Intrinsics.checkNotNullExpressionValue((Object)map2, (String)"slots.computeIfAbsent(slotName) { mutableMapOf() }");
        Map variables2 = map2;
        Intrinsics.checkNotNullExpressionValue((Object)varName, (String)"varName");
        variables2.put(varName, slot);
    }

    @NotNull
    protected final Variable<Float> getSlot(@NotNull String varName, @NotNull String slotName) {
        Intrinsics.checkNotNullParameter((Object)varName, (String)"varName");
        Intrinsics.checkNotNullParameter((Object)slotName, (String)"slotName");
        Map<String, Map<String, Variable<Float>>> map = this.slots;
        if (map == null) {
            Intrinsics.throwUninitializedPropertyAccessException((String)"slots");
            map = null;
        }
        Map<String, Variable<Float>> map2 = map.get(slotName);
        Intrinsics.checkNotNull(map2);
        Map<String, Variable<Float>> variables2 = map2;
        Variable<Float> variable = variables2.get(varName);
        Intrinsics.checkNotNull(variable);
        return variable;
    }

    @NotNull
    public String createName$tensorflow(@NotNull Output<Float> variable, @NotNull String slotName) {
        Intrinsics.checkNotNullParameter(variable, (String)"variable");
        Intrinsics.checkNotNullParameter((Object)slotName, (String)"slotName");
        return NameConventionsKt.defaultOptimizerVariableName(variable.op().name() + '-' + slotName);
    }

    public abstract boolean isRunningOnGPU$tensorflow();

    private static final Map createSlot$lambda$0(Function1 $tmp0, Object p0) {
        Intrinsics.checkNotNullParameter((Object)$tmp0, (String)"$tmp0");
        return (Map)$tmp0.invoke(p0);
    }
}

