/*
 * Decompiled with CFR 0.152.
 */
package org.jetbrains.kotlinx.dl.api.core.layer.activation;

import java.util.List;
import kotlin.Metadata;
import kotlin.collections.CollectionsKt;
import kotlin.jvm.internal.DefaultConstructorMarker;
import kotlin.jvm.internal.Intrinsics;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.kotlinx.dl.api.core.layer.activation.AbstractActivationLayer;
import org.tensorflow.Operand;
import org.tensorflow.op.Ops;
import org.tensorflow.op.core.ReduceMax;
import org.tensorflow.op.core.ReduceSum;
import org.tensorflow.op.core.Shape;
import org.tensorflow.op.core.Size;
import org.tensorflow.op.math.Div;
import org.tensorflow.op.math.Exp;

@Metadata(mv={1, 8, 0}, k=1, xi=48, d1={"\u0000.\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0000\n\u0002\u0010 \n\u0002\u0010\b\n\u0000\n\u0002\u0010\u000e\n\u0002\b\u0004\n\u0002\u0018\u0002\n\u0002\u0010\u0007\n\u0000\n\u0002\u0018\u0002\n\u0002\b\u0003\u0018\u00002\u00020\u0001B\u001f\u0012\u000e\b\u0002\u0010\u0002\u001a\b\u0012\u0004\u0012\u00020\u00040\u0003\u0012\b\b\u0002\u0010\u0005\u001a\u00020\u0006\u00a2\u0006\u0002\u0010\u0007J$\u0010\n\u001a\b\u0012\u0004\u0012\u00020\f0\u000b2\u0006\u0010\r\u001a\u00020\u000e2\f\u0010\u000f\u001a\b\u0012\u0004\u0012\u00020\f0\u000bH\u0016J\b\u0010\u0010\u001a\u00020\u0006H\u0016R\u0017\u0010\u0002\u001a\b\u0012\u0004\u0012\u00020\u00040\u0003\u00a2\u0006\b\n\u0000\u001a\u0004\b\b\u0010\t\u00a8\u0006\u0011"}, d2={"Lorg/jetbrains/kotlinx/dl/api/core/layer/activation/Softmax;", "Lorg/jetbrains/kotlinx/dl/api/core/layer/activation/AbstractActivationLayer;", "axis", "", "", "name", "", "(Ljava/util/List;Ljava/lang/String;)V", "getAxis", "()Ljava/util/List;", "forward", "Lorg/tensorflow/Operand;", "", "tf", "Lorg/tensorflow/op/Ops;", "input", "toString", "tensorflow"})
public final class Softmax
extends AbstractActivationLayer {
    @NotNull
    private final List<Integer> axis;

    public Softmax(@NotNull List<Integer> axis, @NotNull String name) {
        Intrinsics.checkNotNullParameter(axis, (String)"axis");
        Intrinsics.checkNotNullParameter((Object)name, (String)"name");
        super(name);
        this.axis = axis;
        if (this.axis.size() != 1) {
            throw new Exception("Multiple axes are not supported");
        }
    }

    public /* synthetic */ Softmax(List list, String string, int n, DefaultConstructorMarker defaultConstructorMarker) {
        if ((n & 1) != 0) {
            list = CollectionsKt.listOf((Object)-1);
        }
        if ((n & 2) != 0) {
            string = "";
        }
        this(list, string);
    }

    @NotNull
    public final List<Integer> getAxis() {
        return this.axis;
    }

    @Override
    @NotNull
    public Operand<Float> forward(@NotNull Ops tf, @NotNull Operand<Float> input) {
        Operand operand;
        Intrinsics.checkNotNullParameter((Object)tf, (String)"tf");
        Intrinsics.checkNotNullParameter(input, (String)"input");
        Shape shape = tf.shape(input);
        Size numDimensions = tf.size((Operand)shape);
        if (Intrinsics.areEqual((Object)numDimensions, (Object)tf.constant(2))) {
            org.tensorflow.op.nn.Softmax softmax = tf.nn.softmax(input);
            Intrinsics.checkNotNullExpressionValue((Object)softmax, (String)"{\n            tf.nn.softmax(input)\n        }");
            operand = (Operand)softmax;
        } else {
            ReduceMax.Options[] optionsArray = new ReduceMax.Options[]{ReduceMax.keepDims((Boolean)true)};
            Exp exp = tf.math.exp((Operand)tf.math.sub(input, (Operand)tf.reduceMax(input, (Operand)tf.constant(((Number)CollectionsKt.first(this.axis)).intValue()), optionsArray)));
            Intrinsics.checkNotNullExpressionValue((Object)exp, (String)"tf.math.exp(\n           \u2026ims(true)))\n            )");
            Operand e = (Operand)exp;
            ReduceSum.Options[] optionsArray2 = new ReduceSum.Options[]{ReduceSum.keepDims((Boolean)true)};
            ReduceSum reduceSum = tf.reduceSum(e, (Operand)tf.constant(((Number)CollectionsKt.first(this.axis)).intValue()), optionsArray2);
            Intrinsics.checkNotNullExpressionValue((Object)reduceSum, (String)"tf.reduceSum(e, tf.const\u2026ReduceSum.keepDims(true))");
            Operand s = (Operand)reduceSum;
            Div div = tf.math.div(e, s);
            Intrinsics.checkNotNullExpressionValue((Object)div, (String)"{\n            val e: Ope\u2026.math.div(e, s)\n        }");
            operand = (Operand)div;
        }
        return operand;
    }

    @NotNull
    public String toString() {
        return "Softmax(name = " + this.getName() + ", axis=" + this.axis + ')';
    }

    public Softmax() {
        this(null, null, 3, null);
    }
}

