/*
 * Decompiled with CFR 0.152.
 */
package org.jetbrains.kotlinx.dl.api.core.initializer;

import kotlin.Metadata;
import kotlin.NoWhenBranchMatchedException;
import kotlin.jvm.internal.DefaultConstructorMarker;
import kotlin.jvm.internal.Intrinsics;
import kotlin.jvm.internal.SourceDebugExtension;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.kotlinx.dl.api.core.initializer.Distribution;
import org.jetbrains.kotlinx.dl.api.core.initializer.Initializer;
import org.jetbrains.kotlinx.dl.api.core.initializer.Mode;
import org.jetbrains.kotlinx.dl.api.core.util.DtypeConversionUtilKt;
import org.tensorflow.Operand;
import org.tensorflow.op.Ops;
import org.tensorflow.op.dtypes.Cast;
import org.tensorflow.op.math.Mul;
import org.tensorflow.op.random.StatelessRandomNormal;
import org.tensorflow.op.random.StatelessRandomUniform;
import org.tensorflow.op.random.StatelessTruncatedNormal;

@Metadata(mv={1, 8, 0}, k=1, xi=48, d1={"\u0000F\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0000\n\u0002\u0010\u0006\n\u0000\n\u0002\u0018\u0002\n\u0000\n\u0002\u0018\u0002\n\u0000\n\u0002\u0010\t\n\u0002\b\n\n\u0002\u0018\u0002\n\u0002\u0010\u0007\n\u0000\n\u0002\u0010\b\n\u0002\b\u0002\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0010\u000e\n\u0002\b\u0002\b\u0016\u0018\u00002\u00020\u0001B-\u0012\b\b\u0002\u0010\u0002\u001a\u00020\u0003\u0012\b\b\u0002\u0010\u0004\u001a\u00020\u0005\u0012\b\b\u0002\u0010\u0006\u001a\u00020\u0007\u0012\b\b\u0002\u0010\b\u001a\u00020\t\u00a2\u0006\u0002\u0010\nJ<\u0010\u0013\u001a\b\u0012\u0004\u0012\u00020\u00150\u00142\u0006\u0010\u0016\u001a\u00020\u00172\u0006\u0010\u0018\u001a\u00020\u00172\u0006\u0010\u0019\u001a\u00020\u001a2\f\u0010\u001b\u001a\b\u0012\u0004\u0012\u00020\u00170\u00142\u0006\u0010\u001c\u001a\u00020\u001dH\u0016J\b\u0010\u001e\u001a\u00020\u001dH\u0016R\u0011\u0010\u0006\u001a\u00020\u0007\u00a2\u0006\b\n\u0000\u001a\u0004\b\u000b\u0010\fR\u0011\u0010\u0004\u001a\u00020\u0005\u00a2\u0006\b\n\u0000\u001a\u0004\b\r\u0010\u000eR\u0011\u0010\u0002\u001a\u00020\u0003\u00a2\u0006\b\n\u0000\u001a\u0004\b\u000f\u0010\u0010R\u0011\u0010\b\u001a\u00020\t\u00a2\u0006\b\n\u0000\u001a\u0004\b\u0011\u0010\u0012\u00a8\u0006\u001f"}, d2={"Lorg/jetbrains/kotlinx/dl/api/core/initializer/VarianceScaling;", "Lorg/jetbrains/kotlinx/dl/api/core/initializer/Initializer;", "scale", "", "mode", "Lorg/jetbrains/kotlinx/dl/api/core/initializer/Mode;", "distribution", "Lorg/jetbrains/kotlinx/dl/api/core/initializer/Distribution;", "seed", "", "(DLorg/jetbrains/kotlinx/dl/api/core/initializer/Mode;Lorg/jetbrains/kotlinx/dl/api/core/initializer/Distribution;J)V", "getDistribution", "()Lorg/jetbrains/kotlinx/dl/api/core/initializer/Distribution;", "getMode", "()Lorg/jetbrains/kotlinx/dl/api/core/initializer/Mode;", "getScale", "()D", "getSeed", "()J", "initialize", "Lorg/tensorflow/Operand;", "", "fanIn", "", "fanOut", "tf", "Lorg/tensorflow/op/Ops;", "shape", "name", "", "toString", "tensorflow"})
@SourceDebugExtension(value={"SMAP\nVarianceScaling.kt\nKotlin\n*S Kotlin\n*F\n+ 1 VarianceScaling.kt\norg/jetbrains/kotlinx/dl/api/core/initializer/VarianceScaling\n+ 2 fake.kt\nkotlin/jvm/internal/FakeKt\n*L\n1#1,239:1\n1#2:240\n*E\n"})
public class VarianceScaling
extends Initializer {
    private final double scale;
    @NotNull
    private final Mode mode;
    @NotNull
    private final Distribution distribution;
    private final long seed;

    public VarianceScaling(double scale, @NotNull Mode mode, @NotNull Distribution distribution, long seed) {
        Intrinsics.checkNotNullParameter((Object)((Object)mode), (String)"mode");
        Intrinsics.checkNotNullParameter((Object)((Object)distribution), (String)"distribution");
        this.scale = scale;
        this.mode = mode;
        this.distribution = distribution;
        this.seed = seed;
    }

    public /* synthetic */ VarianceScaling(double d, Mode mode, Distribution distribution, long l, int n, DefaultConstructorMarker defaultConstructorMarker) {
        if ((n & 1) != 0) {
            d = 1.0;
        }
        if ((n & 2) != 0) {
            mode = Mode.FAN_IN;
        }
        if ((n & 4) != 0) {
            distribution = Distribution.TRUNCATED_NORMAL;
        }
        if ((n & 8) != 0) {
            l = 12L;
        }
        this(d, mode, distribution, l);
    }

    public final double getScale() {
        return this.scale;
    }

    @NotNull
    public final Mode getMode() {
        return this.mode;
    }

    @NotNull
    public final Distribution getDistribution() {
        return this.distribution;
    }

    public final long getSeed() {
        return this.seed;
    }

    @Override
    @NotNull
    public Operand<Float> initialize(int fanIn, int fanOut, @NotNull Ops tf, @NotNull Operand<Integer> shape, @NotNull String name) {
        double d;
        Intrinsics.checkNotNullParameter((Object)tf, (String)"tf");
        Intrinsics.checkNotNullParameter(shape, (String)"shape");
        Intrinsics.checkNotNullParameter((Object)name, (String)"name");
        if (!(this.scale > 0.0)) {
            boolean bl = false;
            String string = "The 'scale' parameter value must be more than 0.0.";
            throw new IllegalArgumentException(string.toString());
        }
        double lscale = this.scale;
        switch (WhenMappings.$EnumSwitchMapping$0[this.mode.ordinal()]) {
            case 1: {
                d = Math.max(1.0, (double)fanIn);
                break;
            }
            case 2: {
                d = Math.max(1.0, (double)fanOut);
                break;
            }
            case 3: {
                d = Math.max(1.0, (double)(fanIn + fanOut) / 2.0);
                break;
            }
            default: {
                throw new NoWhenBranchMatchedException();
            }
        }
        lscale /= d;
        Operand distOp = null;
        Operand mulOp = null;
        double stddev = 0.0;
        long[] lArray = new long[]{this.seed, 0L};
        long[] seeds = lArray;
        switch (WhenMappings.$EnumSwitchMapping$1[this.distribution.ordinal()]) {
            case 1: {
                StatelessTruncatedNormal statelessTruncatedNormal = tf.random.statelessTruncatedNormal(shape, (Operand)tf.constant(seeds), DtypeConversionUtilKt.getDType());
                Intrinsics.checkNotNullExpressionValue((Object)statelessTruncatedNormal, (String)"tf.random.statelessTrunc\u2026stant(seeds), getDType())");
                distOp = (Operand)statelessTruncatedNormal;
                stddev = Math.sqrt(lscale) / 0.8796256610342398;
                Mul mul = tf.withName((String)name).math.mul(distOp, (Operand)tf.dtypes.cast((Operand)tf.constant(stddev), DtypeConversionUtilKt.getDType(), new Cast.Options[0]));
                Intrinsics.checkNotNullExpressionValue((Object)mul, (String)"tf.withName(name).math.m\u2026ant(stddev), getDType()))");
                mulOp = (Operand)mul;
                break;
            }
            case 2: {
                StatelessRandomNormal statelessRandomNormal = tf.random.statelessRandomNormal(shape, (Operand)tf.constant(seeds), DtypeConversionUtilKt.getDType());
                Intrinsics.checkNotNullExpressionValue((Object)statelessRandomNormal, (String)"tf.random.statelessRando\u2026stant(seeds), getDType())");
                distOp = (Operand)statelessRandomNormal;
                stddev = Math.sqrt(lscale);
                Mul mul = tf.withName((String)name).math.mul(distOp, (Operand)tf.dtypes.cast((Operand)tf.constant(stddev), DtypeConversionUtilKt.getDType(), new Cast.Options[0]));
                Intrinsics.checkNotNullExpressionValue((Object)mul, (String)"tf.withName(name).math.m\u2026ant(stddev), getDType()))");
                mulOp = (Operand)mul;
                break;
            }
            case 3: {
                StatelessRandomUniform statelessRandomUniform = tf.random.statelessRandomUniform(shape, (Operand)tf.constant(seeds), DtypeConversionUtilKt.getDType());
                Intrinsics.checkNotNullExpressionValue((Object)statelessRandomUniform, (String)"tf.random.statelessRando\u2026stant(seeds), getDType())");
                distOp = (Operand)statelessRandomUniform;
                stddev = Math.sqrt(3.0 * lscale);
                Mul mul = tf.withName((String)name).math.mul(distOp, (Operand)tf.dtypes.cast((Operand)tf.constant(stddev), DtypeConversionUtilKt.getDType(), new Cast.Options[0]));
                Intrinsics.checkNotNullExpressionValue((Object)mul, (String)"tf.withName(name).math.m\u2026ant(stddev), getDType()))");
                mulOp = (Operand)mul;
                break;
            }
            default: {
                throw new NoWhenBranchMatchedException();
            }
        }
        return mulOp;
    }

    @NotNull
    public String toString() {
        return "VarianceScaling(scale=" + this.scale + ", mode=" + (Object)((Object)this.mode) + ", distribution=" + (Object)((Object)this.distribution) + ", seed=" + this.seed + ')';
    }

    public VarianceScaling() {
        this(0.0, null, null, 0L, 15, null);
    }

    @Metadata(mv={1, 8, 0}, k=3, xi=48)
    public final class WhenMappings {
        public static final /* synthetic */ int[] $EnumSwitchMapping$0;
        public static final /* synthetic */ int[] $EnumSwitchMapping$1;

        static {
            int[] nArray = new int[Mode.values().length];
            try {
                nArray[Mode.FAN_IN.ordinal()] = 1;
            }
            catch (NoSuchFieldError noSuchFieldError) {
                // empty catch block
            }
            try {
                nArray[Mode.FAN_OUT.ordinal()] = 2;
            }
            catch (NoSuchFieldError noSuchFieldError) {
                // empty catch block
            }
            try {
                nArray[Mode.FAN_AVG.ordinal()] = 3;
            }
            catch (NoSuchFieldError noSuchFieldError) {
                // empty catch block
            }
            $EnumSwitchMapping$0 = nArray;
            nArray = new int[Distribution.values().length];
            try {
                nArray[Distribution.TRUNCATED_NORMAL.ordinal()] = 1;
            }
            catch (NoSuchFieldError noSuchFieldError) {
                // empty catch block
            }
            try {
                nArray[Distribution.UNTRUNCATED_NORMAL.ordinal()] = 2;
            }
            catch (NoSuchFieldError noSuchFieldError) {
                // empty catch block
            }
            try {
                nArray[Distribution.UNIFORM.ordinal()] = 3;
            }
            catch (NoSuchFieldError noSuchFieldError) {
                // empty catch block
            }
            $EnumSwitchMapping$1 = nArray;
        }
    }
}

