/*
 * Decompiled with CFR 0.152.
 */
package org.jetbrains.kotlinx.dl.api.core.initializer;

import kotlin.Metadata;
import kotlin.jvm.internal.DefaultConstructorMarker;
import kotlin.jvm.internal.Intrinsics;
import kotlin.jvm.internal.SourceDebugExtension;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.kotlinx.dl.api.core.initializer.Initializer;
import org.jetbrains.kotlinx.dl.api.core.shape.ShapeFunctionsKt;
import org.jetbrains.kotlinx.dl.api.core.util.DtypeConversionUtilKt;
import org.tensorflow.Operand;
import org.tensorflow.Output;
import org.tensorflow.Shape;
import org.tensorflow.op.Ops;
import org.tensorflow.op.core.Reshape;
import org.tensorflow.op.dtypes.Cast;
import org.tensorflow.op.linalg.Qr;
import org.tensorflow.op.linalg.TensorDiagPart;
import org.tensorflow.op.linalg.Transpose;
import org.tensorflow.op.math.Mul;
import org.tensorflow.op.random.StatelessRandomNormal;

@Metadata(mv={1, 8, 0}, k=1, xi=48, d1={"\u00006\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0000\n\u0002\u0010\u0007\n\u0000\n\u0002\u0010\t\n\u0002\b\u0006\n\u0002\u0018\u0002\n\u0000\n\u0002\u0010\b\n\u0002\b\u0002\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0010\u000e\n\u0002\b\u0002\u0018\u00002\u00020\u0001B\u0019\u0012\b\b\u0002\u0010\u0002\u001a\u00020\u0003\u0012\b\b\u0002\u0010\u0004\u001a\u00020\u0005\u00a2\u0006\u0002\u0010\u0006J<\u0010\u000b\u001a\b\u0012\u0004\u0012\u00020\u00030\f2\u0006\u0010\r\u001a\u00020\u000e2\u0006\u0010\u000f\u001a\u00020\u000e2\u0006\u0010\u0010\u001a\u00020\u00112\f\u0010\u0012\u001a\b\u0012\u0004\u0012\u00020\u000e0\f2\u0006\u0010\u0013\u001a\u00020\u0014H\u0016J\b\u0010\u0015\u001a\u00020\u0014H\u0016R\u0011\u0010\u0002\u001a\u00020\u0003\u00a2\u0006\b\n\u0000\u001a\u0004\b\u0007\u0010\bR\u0011\u0010\u0004\u001a\u00020\u0005\u00a2\u0006\b\n\u0000\u001a\u0004\b\t\u0010\n\u00a8\u0006\u0016"}, d2={"Lorg/jetbrains/kotlinx/dl/api/core/initializer/Orthogonal;", "Lorg/jetbrains/kotlinx/dl/api/core/initializer/Initializer;", "gain", "", "seed", "", "(FJ)V", "getGain", "()F", "getSeed", "()J", "initialize", "Lorg/tensorflow/Operand;", "fanIn", "", "fanOut", "tf", "Lorg/tensorflow/op/Ops;", "shape", "name", "", "toString", "tensorflow"})
@SourceDebugExtension(value={"SMAP\nOrthogonal.kt\nKotlin\n*S Kotlin\n*F\n+ 1 Orthogonal.kt\norg/jetbrains/kotlinx/dl/api/core/initializer/Orthogonal\n+ 2 fake.kt\nkotlin/jvm/internal/FakeKt\n*L\n1#1,75:1\n1#2:76\n*E\n"})
public final class Orthogonal
extends Initializer {
    private final float gain;
    private final long seed;

    public Orthogonal(float gain, long seed) {
        this.gain = gain;
        this.seed = seed;
    }

    public /* synthetic */ Orthogonal(float f, long l, int n, DefaultConstructorMarker defaultConstructorMarker) {
        if ((n & 1) != 0) {
            f = 1.0f;
        }
        if ((n & 2) != 0) {
            l = 12L;
        }
        this(f, l);
    }

    public final float getGain() {
        return this.gain;
    }

    public final long getSeed() {
        return this.seed;
    }

    @Override
    @NotNull
    public Operand<Float> initialize(int fanIn, int fanOut, @NotNull Ops tf, @NotNull Operand<Integer> shape, @NotNull String name) {
        Intrinsics.checkNotNullParameter((Object)tf, (String)"tf");
        Intrinsics.checkNotNullParameter(shape, (String)"shape");
        Intrinsics.checkNotNullParameter((Object)name, (String)"name");
        long dimsShape = shape.asOutput().shape().size(0);
        if (!(dimsShape >= 2L)) {
            boolean $i$a$-require-Orthogonal$initialize$22 = false;
            String $i$a$-require-Orthogonal$initialize$22 = "The tensor to initialize must be at least two-dimensional";
            throw new IllegalArgumentException($i$a$-require-Orthogonal$initialize$22.toString());
        }
        long[] lArray = new long[]{this.seed, 0L};
        StatelessRandomNormal statelessRandomNormal = tf.random.statelessRandomNormal(shape, (Operand)tf.constant(lArray), DtypeConversionUtilKt.getDType());
        Intrinsics.checkNotNullExpressionValue((Object)statelessRandomNormal, (String)"tf.random.statelessRando\u2026L)), getDType()\n        )");
        Operand distOpND = (Operand)statelessRandomNormal;
        long numRows = 1L;
        int i = 0;
        while ((long)i < dimsShape - 1L) {
            numRows *= distOpND.asOutput().shape().size(i);
            ++i;
        }
        long numCols = distOpND.asOutput().shape().size(i - 1);
        long[] lArray2 = new long[]{Math.min(numRows, numCols)};
        Shape flatShape = Shape.make((long)Math.max(numRows, numCols), (long[])lArray2);
        Intrinsics.checkNotNullExpressionValue((Object)flatShape, (String)"flatShape");
        Reshape reshape = tf.reshape(distOpND, ShapeFunctionsKt.shapeOperand(tf, flatShape));
        Intrinsics.checkNotNullExpressionValue((Object)reshape, (String)"tf.reshape(distOpND, shapeOperand(tf, flatShape))");
        Operand distOp = (Operand)reshape;
        Qr.Options qrOptions = Qr.fullMatrices((Boolean)false);
        Qr.Options[] optionsArray = new Qr.Options[]{qrOptions};
        Qr qr = tf.linalg.qr(distOp, optionsArray);
        Intrinsics.checkNotNullExpressionValue((Object)qr, (String)"tf.linalg.qr(distOp, qrOptions)");
        Qr qrOp = qr;
        Output output = qrOp.q();
        Intrinsics.checkNotNullExpressionValue((Object)output, (String)"qrOp.q()");
        Operand qo = (Operand)output;
        Output output2 = qrOp.r();
        Intrinsics.checkNotNullExpressionValue((Object)output2, (String)"qrOp.r()");
        Operand ro = (Operand)output2;
        TensorDiagPart tensorDiagPart = tf.linalg.tensorDiagPart(ro);
        Intrinsics.checkNotNullExpressionValue((Object)tensorDiagPart, (String)"tf.linalg.tensorDiagPart(ro)");
        Operand d = (Operand)tensorDiagPart;
        Mul mul = tf.withName((String)name).math.mul(qo, (Operand)tf.math.sign(d));
        Intrinsics.checkNotNullExpressionValue((Object)mul, (String)"tf.withName(name).math.mul(qo, tf.math.sign(d))");
        Operand qop = (Operand)mul;
        if (numRows < numCols) {
            int[] nArray = new int[]{1, 0};
            Transpose transpose = tf.withName((String)name).linalg.transpose(qop, (Operand)tf.constant(nArray));
            Intrinsics.checkNotNullExpressionValue((Object)transpose, (String)"tf.withName(name).linalg\u2026nstant(intArrayOf(1, 0)))");
            qop = (Operand)transpose;
        }
        Mul mul2 = tf.math.mul((Operand)tf.reshape(qop, shape), (Operand)tf.dtypes.cast((Operand)tf.constant(this.gain), DtypeConversionUtilKt.getDType(), new Cast.Options[0]));
        Intrinsics.checkNotNullExpressionValue((Object)mul2, (String)"tf.math.mul(tf.reshape(q\u2026(this.gain), getDType()))");
        return (Operand)mul2;
    }

    @NotNull
    public String toString() {
        return "Orthogonal(gain=" + this.gain + ", seed=" + this.seed + ')';
    }

    public Orthogonal() {
        this(0.0f, 0L, 3, null);
    }
}

