/*
 * Decompiled with CFR 0.152.
 */
package org.jetbrains.kotlinx.dl.api.core.activation;

import kotlin.Metadata;
import kotlin.collections.ArraysKt;
import kotlin.collections.CollectionsKt;
import kotlin.collections.IntIterator;
import kotlin.jvm.internal.DefaultConstructorMarker;
import kotlin.jvm.internal.Intrinsics;
import kotlin.jvm.internal.SourceDebugExtension;
import kotlin.ranges.IntRange;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.kotlinx.dl.api.core.activation.Activation;
import org.tensorflow.Operand;
import org.tensorflow.Shape;
import org.tensorflow.op.Ops;
import org.tensorflow.op.core.Constant;
import org.tensorflow.op.core.EnsureShape;
import org.tensorflow.op.core.GatherNd;
import org.tensorflow.op.core.Range;
import org.tensorflow.op.core.ReduceSum;
import org.tensorflow.op.core.Reshape;
import org.tensorflow.op.core.Stack;
import org.tensorflow.op.core.Where3;
import org.tensorflow.op.dtypes.Cast;
import org.tensorflow.op.linalg.Transpose;
import org.tensorflow.op.math.Cumsum;
import org.tensorflow.op.math.Div;
import org.tensorflow.op.math.Equal;
import org.tensorflow.op.math.Greater;
import org.tensorflow.op.math.Maximum;
import org.tensorflow.op.nn.TopK;

@Metadata(mv={1, 8, 0}, k=1, xi=48, d1={"\u0000$\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0000\n\u0002\u0010\b\n\u0002\b\u0002\n\u0002\u0018\u0002\n\u0002\u0010\u0007\n\u0000\n\u0002\u0018\u0002\n\u0002\b\u0005\u0018\u00002\u00020\u0001B\u000f\u0012\b\b\u0002\u0010\u0002\u001a\u00020\u0003\u00a2\u0006\u0002\u0010\u0004J$\u0010\u0005\u001a\b\u0012\u0004\u0012\u00020\u00070\u00062\u0006\u0010\b\u001a\u00020\t2\f\u0010\n\u001a\b\u0012\u0004\u0012\u00020\u00070\u0006H\u0016J$\u0010\u000b\u001a\b\u0012\u0004\u0012\u00020\u00070\u00062\u0006\u0010\b\u001a\u00020\t2\f\u0010\n\u001a\b\u0012\u0004\u0012\u00020\u00070\u0006H\u0002J4\u0010\f\u001a\b\u0012\u0004\u0012\u00020\u00070\u00062\u0006\u0010\b\u001a\u00020\t2\f\u0010\n\u001a\b\u0012\u0004\u0012\u00020\u00070\u00062\u0006\u0010\u0002\u001a\u00020\u00032\u0006\u0010\r\u001a\u00020\u0003H\u0002R\u000e\u0010\u0002\u001a\u00020\u0003X\u0082\u0004\u00a2\u0006\u0002\n\u0000\u00a8\u0006\u000e"}, d2={"Lorg/jetbrains/kotlinx/dl/api/core/activation/SparsemaxActivation;", "Lorg/jetbrains/kotlinx/dl/api/core/activation/Activation;", "axis", "", "(I)V", "apply", "Lorg/tensorflow/Operand;", "", "tf", "Lorg/tensorflow/op/Ops;", "features", "compute2DSparsemax", "swapAxis", "lastIndex", "tensorflow"})
@SourceDebugExtension(value={"SMAP\nActivations.kt\nKotlin\n*S Kotlin\n*F\n+ 1 Activations.kt\norg/jetbrains/kotlinx/dl/api/core/activation/SparsemaxActivation\n+ 2 _Arrays.kt\nkotlin/collections/ArraysKt___ArraysKt\n*L\n1#1,711:1\n19361#2,7:712\n*S KotlinDebug\n*F\n+ 1 Activations.kt\norg/jetbrains/kotlinx/dl/api/core/activation/SparsemaxActivation\n*L\n653#1:712,7\n*E\n"})
public final class SparsemaxActivation
implements Activation {
    private final int axis;

    public SparsemaxActivation(int axis) {
        this.axis = axis;
    }

    public /* synthetic */ SparsemaxActivation(int n, int n2, DefaultConstructorMarker defaultConstructorMarker) {
        if ((n2 & 1) != 0) {
            n = -1;
        }
        this(n);
    }

    @Override
    @NotNull
    public Operand<Float> apply(@NotNull Ops tf, @NotNull Operand<Float> features) {
        boolean isLastAxis;
        Intrinsics.checkNotNullParameter((Object)tf, (String)"tf");
        Intrinsics.checkNotNullParameter(features, (String)"features");
        Shape shape = features.asOutput().shape();
        int rank = shape.numDimensions();
        boolean bl = isLastAxis = this.axis == -1 || this.axis == rank - 1;
        if (isLastAxis) {
            Operand<Float> output = this.compute2DSparsemax(tf, features);
            EnsureShape ensureShape = tf.ensureShape(output, shape);
            Intrinsics.checkNotNullExpressionValue((Object)ensureShape, (String)"tf.ensureShape(output, shape)");
            return (Operand)ensureShape;
        }
        int axisNorm = this.axis % rank;
        Operand<Float> logits = this.swapAxis(tf, features, axisNorm, rank - 1);
        Operand<Float> output = this.compute2DSparsemax(tf, logits);
        EnsureShape ensureShape = tf.ensureShape(this.swapAxis(tf, output, axisNorm, rank - 1), shape);
        Intrinsics.checkNotNullExpressionValue((Object)ensureShape, (String)"tf.ensureShape(swapAxis(\u2026isNorm, rank - 1), shape)");
        return (Operand)ensureShape;
    }

    private final Operand<Float> swapAxis(Ops tf, Operand<Float> features, int axis, int lastIndex) {
        Range range = tf.range((Operand)tf.constant(0), (Operand)tf.constant(lastIndex + 1), (Operand)tf.constant(1));
        Object object = new int[2][];
        int[] nArray = new int[]{axis};
        object[0] = nArray;
        nArray = new int[]{lastIndex};
        object[1] = nArray;
        Operand operand = (Operand)tf.constant((int[][])object);
        object = new int[2];
        object[0] = (int[])lastIndex;
        object[1] = (int[])axis;
        Transpose transpose = tf.linalg.transpose(features, (Operand)tf.tensorScatterUpdate((Operand)range, operand, (Operand)tf.constant((int[])object)));
        Intrinsics.checkNotNullExpressionValue((Object)transpose, (String)"tf.linalg.transpose(\n   \u2026)\n            )\n        )");
        return (Operand)transpose;
    }

    /*
     * WARNING - void declaration
     */
    private final Operand<Float> compute2DSparsemax(Ops tf, Operand<Float> features) {
        long[] shape = features.asOutput().tensor().shape();
        Intrinsics.checkNotNullExpressionValue((Object)shape, (String)"shape");
        long dims = shape[ArraysKt.getLastIndex((long[])shape)];
        Constant dimsOp = tf.constant((int)dims);
        long[] $this$reduce$iv = shape;
        boolean $i$f$reduce = false;
        if ($this$reduce$iv.length == 0) {
            throw new UnsupportedOperationException("Empty array can't be reduced.");
        }
        long accumulator$iv22 = $this$reduce$iv[0];
        IntIterator intIterator = new IntRange(1, ArraysKt.getLastIndex((long[])$this$reduce$iv)).iterator();
        while (intIterator.hasNext()) {
            void l;
            int index$iv = intIterator.nextInt();
            long l2 = $this$reduce$iv[index$iv];
            long acc = accumulator$iv22;
            boolean bl = false;
            accumulator$iv22 = acc * l;
        }
        long obs = accumulator$iv22 / dims;
        Constant one = tf.constant(1.0f);
        long[] accumulator$iv22 = new long[]{obs, dims};
        Reshape z = tf.reshape(features, (Operand)tf.constant(accumulator$iv22));
        TopK zSorted = tf.nn.topK((Operand)z, (Operand)dimsOp, new TopK.Options[0]);
        Cumsum zCumSum = tf.math.cumsum((Operand)zSorted.values(), (Operand)tf.constant(-1), new Cumsum.Options[0]);
        Range k = tf.range((Operand)one, (Operand)tf.math.add((Operand)tf.dtypes.cast((Operand)dimsOp, Float.class, new Cast.Options[0]), (Operand)one), (Operand)one);
        Greater zCheck = tf.math.greater((Operand)tf.math.add((Operand)one, (Operand)tf.math.mul((Operand)k, (Operand)zSorted.values())), (Operand)zCumSum);
        ReduceSum kz = tf.reduceSum((Operand)tf.dtypes.cast((Operand)zCheck, Integer.class, new Cast.Options[0]), (Operand)tf.constant(-1), new ReduceSum.Options[0]);
        Maximum kzSafe = tf.math.maximum((Operand)kz, (Operand)tf.constant(1));
        Object[] objectArray = new Object[2];
        objectArray[0] = tf.range((Operand)tf.constant(0), (Operand)tf.constant((int)obs), (Operand)tf.constant(1));
        int[] bl = new int[]{-1};
        objectArray[1] = tf.math.sub((Operand)tf.reshape((Operand)kzSafe, (Operand)tf.constant(bl)), (Operand)tf.constant(1));
        Iterable iterable = CollectionsKt.listOf((Object[])objectArray);
        objectArray = new Stack.Options[]{Stack.axis((Long)1L)};
        Stack indices = tf.stack(iterable, (Stack.Options[])objectArray);
        GatherNd tauSum = tf.gatherNd((Operand)zCumSum, (Operand)indices);
        Div tauZ = tf.math.div((Operand)tf.math.sub((Operand)tauSum, (Operand)one), (Operand)tf.dtypes.cast((Operand)kz, Float.class, new Cast.Options[0]));
        Maximum p = tf.math.maximum((Operand)tf.constant(0.0f), (Operand)tf.math.sub((Operand)z, (Operand)tf.expandDims((Operand)tauZ, (Operand)tf.constant(-1))));
        Object[] objectArray2 = new Object[2];
        objectArray2[0] = tf.range((Operand)tf.constant(0), (Operand)tf.constant((int)obs), (Operand)tf.constant(1));
        long[] lArray = new long[]{obs};
        objectArray2[1] = tf.fill((Operand)tf.constant(lArray), (Operand)tf.math.sub((Operand)dimsOp, (Operand)tf.constant(1)));
        Iterable iterable2 = CollectionsKt.listOf((Object[])objectArray2);
        objectArray2 = new Stack.Options[]{Stack.axis((Long)1L)};
        Stack zCumsumLastIndex = tf.stack(iterable2, (Stack.Options[])objectArray2);
        lArray = new long[]{obs, dims};
        Where3 pSafe = tf.where3((Operand)tf.math.logicalOr((Operand)tf.math.equal((Operand)kz, (Operand)tf.constant(0), new Equal.Options[0]), (Operand)tf.math.isNan((Operand)tf.gatherNd((Operand)zCumSum, (Operand)zCumsumLastIndex))), (Operand)tf.fill((Operand)tf.constant(lArray), (Operand)tf.constant(Float.NaN)), (Operand)p);
        Reshape reshape = tf.reshape((Operand)pSafe, (Operand)tf.constant(shape));
        Intrinsics.checkNotNullExpressionValue((Object)reshape, (String)"tf.reshape(pSafe, tf.constant(shape))");
        return (Operand)reshape;
    }

    @Override
    @NotNull
    public Operand<Float> apply(@NotNull Ops tf, @NotNull Operand<Float> features, @NotNull String name) {
        return Activation.DefaultImpls.apply(this, tf, features, name);
    }

    public SparsemaxActivation() {
        this(0, 1, null);
    }
}

