/*
 * Decompiled with CFR 0.152.
 */
package org.jetbrains.kotlinx.dl.api.core.initializer;

import kotlin.Metadata;
import kotlin.jvm.internal.DefaultConstructorMarker;
import kotlin.jvm.internal.Intrinsics;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.kotlinx.dl.api.core.exception.IdentityDimensionalityException;
import org.jetbrains.kotlinx.dl.api.core.initializer.Initializer;
import org.jetbrains.kotlinx.dl.api.core.util.DtypeConversionUtilKt;
import org.tensorflow.Operand;
import org.tensorflow.op.Ops;
import org.tensorflow.op.core.MatrixSetDiagV2;
import org.tensorflow.op.core.ReduceMin;
import org.tensorflow.op.core.Reshape;
import org.tensorflow.op.core.Tile;
import org.tensorflow.op.core.Zeros;

@Metadata(mv={1, 8, 0}, k=1, xi=48, d1={"\u00000\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0000\n\u0002\u0010\u0007\n\u0002\b\u0004\n\u0002\u0018\u0002\n\u0000\n\u0002\u0010\b\n\u0002\b\u0002\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0010\u000e\n\u0002\b\u0002\u0018\u00002\u00020\u0001B\u000f\u0012\b\b\u0002\u0010\u0002\u001a\u00020\u0003\u00a2\u0006\u0002\u0010\u0004J<\u0010\u0007\u001a\b\u0012\u0004\u0012\u00020\u00030\b2\u0006\u0010\t\u001a\u00020\n2\u0006\u0010\u000b\u001a\u00020\n2\u0006\u0010\f\u001a\u00020\r2\f\u0010\u000e\u001a\b\u0012\u0004\u0012\u00020\n0\b2\u0006\u0010\u000f\u001a\u00020\u0010H\u0016J\b\u0010\u0011\u001a\u00020\u0010H\u0016R\u0011\u0010\u0002\u001a\u00020\u0003\u00a2\u0006\b\n\u0000\u001a\u0004\b\u0005\u0010\u0006\u00a8\u0006\u0012"}, d2={"Lorg/jetbrains/kotlinx/dl/api/core/initializer/Identity;", "Lorg/jetbrains/kotlinx/dl/api/core/initializer/Initializer;", "gain", "", "(F)V", "getGain", "()F", "initialize", "Lorg/tensorflow/Operand;", "fanIn", "", "fanOut", "tf", "Lorg/tensorflow/op/Ops;", "shape", "name", "", "toString", "tensorflow"})
public final class Identity
extends Initializer {
    private final float gain;

    public Identity(float gain) {
        this.gain = gain;
    }

    public /* synthetic */ Identity(float f, int n, DefaultConstructorMarker defaultConstructorMarker) {
        if ((n & 1) != 0) {
            f = 1.0f;
        }
        this(f);
    }

    public final float getGain() {
        return this.gain;
    }

    @Override
    @NotNull
    public Operand<Float> initialize(int fanIn, int fanOut, @NotNull Ops tf, @NotNull Operand<Integer> shape, @NotNull String name) {
        Intrinsics.checkNotNullParameter((Object)tf, (String)"tf");
        Intrinsics.checkNotNullParameter(shape, (String)"shape");
        Intrinsics.checkNotNullParameter((Object)name, (String)"name");
        long dimensions = shape.asOutput().shape().size(0);
        if (dimensions != 2L) {
            throw new IdentityDimensionalityException(dimensions);
        }
        ReduceMin minSize = tf.reduceMin(shape, (Operand)tf.constant(0), new ReduceMin.Options[0]);
        int[] nArray = new int[]{1};
        Reshape reshapedMinSize = tf.reshape((Operand)minSize, (Operand)tf.constant(nArray));
        float[] fArray = new float[]{this.gain};
        Tile diag = tf.tile((Operand)tf.constant(fArray), (Operand)reshapedMinSize);
        Zeros zeros = tf.withName(name).zeros(shape, DtypeConversionUtilKt.getDType());
        MatrixSetDiagV2 matrixSetDiagV2 = tf.matrixSetDiagV2((Operand)zeros, (Operand)diag, (Operand)tf.constant(0));
        Intrinsics.checkNotNullExpressionValue((Object)matrixSetDiagV2, (String)"tf.matrixSetDiagV2(zeros, diag, tf.constant(0))");
        return (Operand)matrixSetDiagV2;
    }

    @NotNull
    public String toString() {
        return "Identity(scale=" + this.gain + ')';
    }

    public Identity() {
        this(0.0f, 1, null);
    }
}

