/*
 * Decompiled with CFR 0.152.
 */
package org.jetbrains.kotlinx.dl.api.core.layer.activation;

import java.util.Arrays;
import java.util.Collection;
import java.util.List;
import kotlin.Metadata;
import kotlin.collections.ArraysKt;
import kotlin.collections.CollectionsKt;
import kotlin.jvm.internal.DefaultConstructorMarker;
import kotlin.jvm.internal.Intrinsics;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
import org.jetbrains.kotlinx.dl.api.core.initializer.Initializer;
import org.jetbrains.kotlinx.dl.api.core.initializer.Zeros;
import org.jetbrains.kotlinx.dl.api.core.layer.KVariable;
import org.jetbrains.kotlinx.dl.api.core.layer.KVariableKt;
import org.jetbrains.kotlinx.dl.api.core.layer.TrainableLayer;
import org.jetbrains.kotlinx.dl.api.core.layer.activation.AbstractActivationLayer;
import org.jetbrains.kotlinx.dl.api.core.regularizer.Regularizer;
import org.jetbrains.kotlinx.dl.api.core.shape.ShapeFunctionsKt;
import org.tensorflow.Operand;
import org.tensorflow.Shape;
import org.tensorflow.op.Ops;
import org.tensorflow.op.math.Add;
import org.tensorflow.op.math.Mul;
import org.tensorflow.op.nn.Relu;

@Metadata(mv={1, 8, 0}, k=1, xi=48, d1={"\u0000R\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0000\n\u0002\u0018\u0002\n\u0000\n\u0002\u0018\u0002\n\u0000\n\u0002\u0010\u0015\n\u0000\n\u0002\u0010\u000e\n\u0002\b\u0002\n\u0002\u0018\u0002\n\u0002\b\t\n\u0002\u0010\u000b\n\u0002\b\u0006\n\u0002\u0010 \n\u0002\b\u0004\n\u0002\u0018\u0002\n\u0002\u0010\u0007\n\u0000\n\u0002\u0018\u0002\n\u0002\b\u0003\u0018\u00002\u00020\u00012\u00020\u0002B1\u0012\b\b\u0002\u0010\u0003\u001a\u00020\u0004\u0012\n\b\u0002\u0010\u0005\u001a\u0004\u0018\u00010\u0006\u0012\n\b\u0002\u0010\u0007\u001a\u0004\u0018\u00010\b\u0012\b\b\u0002\u0010\t\u001a\u00020\n\u00a2\u0006\u0002\u0010\u000bJ\b\u0010!\u001a\u00020\nH\u0002J$\u0010\"\u001a\b\u0012\u0004\u0012\u00020$0#2\u0006\u0010%\u001a\u00020&2\f\u0010'\u001a\b\u0012\u0004\u0012\u00020$0#H\u0016J\b\u0010(\u001a\u00020\nH\u0016R\u001a\u0010\f\u001a\u00020\rX\u0080.\u00a2\u0006\u000e\n\u0000\u001a\u0004\b\u000e\u0010\u000f\"\u0004\b\u0010\u0010\u0011R\u0011\u0010\u0003\u001a\u00020\u0004\u00a2\u0006\b\n\u0000\u001a\u0004\b\u0012\u0010\u0013R\u0013\u0010\u0005\u001a\u0004\u0018\u00010\u0006\u00a2\u0006\b\n\u0000\u001a\u0004\b\u0014\u0010\u0015R\u001a\u0010\u0016\u001a\u00020\u0017X\u0096\u000e\u00a2\u0006\u000e\n\u0000\u001a\u0004\b\u0016\u0010\u0018\"\u0004\b\u0019\u0010\u001aR\u0013\u0010\u0007\u001a\u0004\u0018\u00010\b\u00a2\u0006\b\n\u0000\u001a\u0004\b\u001b\u0010\u001cR\u001a\u0010\u001d\u001a\b\u0012\u0004\u0012\u00020\r0\u001e8VX\u0096\u0004\u00a2\u0006\u0006\u001a\u0004\b\u001f\u0010 \u00a8\u0006)"}, d2={"Lorg/jetbrains/kotlinx/dl/api/core/layer/activation/PReLU;", "Lorg/jetbrains/kotlinx/dl/api/core/layer/activation/AbstractActivationLayer;", "Lorg/jetbrains/kotlinx/dl/api/core/layer/TrainableLayer;", "alphaInitializer", "Lorg/jetbrains/kotlinx/dl/api/core/initializer/Initializer;", "alphaRegularizer", "Lorg/jetbrains/kotlinx/dl/api/core/regularizer/Regularizer;", "sharedAxes", "", "name", "", "(Lorg/jetbrains/kotlinx/dl/api/core/initializer/Initializer;Lorg/jetbrains/kotlinx/dl/api/core/regularizer/Regularizer;[ILjava/lang/String;)V", "alpha", "Lorg/jetbrains/kotlinx/dl/api/core/layer/KVariable;", "getAlpha$tensorflow", "()Lorg/jetbrains/kotlinx/dl/api/core/layer/KVariable;", "setAlpha$tensorflow", "(Lorg/jetbrains/kotlinx/dl/api/core/layer/KVariable;)V", "getAlphaInitializer", "()Lorg/jetbrains/kotlinx/dl/api/core/initializer/Initializer;", "getAlphaRegularizer", "()Lorg/jetbrains/kotlinx/dl/api/core/regularizer/Regularizer;", "isTrainable", "", "()Z", "setTrainable", "(Z)V", "getSharedAxes", "()[I", "variables", "", "getVariables", "()Ljava/util/List;", "alphaVariableName", "forward", "Lorg/tensorflow/Operand;", "", "tf", "Lorg/tensorflow/op/Ops;", "input", "toString", "tensorflow"})
public final class PReLU
extends AbstractActivationLayer
implements TrainableLayer {
    @NotNull
    private final Initializer alphaInitializer;
    @Nullable
    private final Regularizer alphaRegularizer;
    @Nullable
    private final int[] sharedAxes;
    public KVariable alpha;
    private boolean isTrainable;

    public PReLU(@NotNull Initializer alphaInitializer, @Nullable Regularizer alphaRegularizer, @Nullable int[] sharedAxes, @NotNull String name) {
        Intrinsics.checkNotNullParameter((Object)alphaInitializer, (String)"alphaInitializer");
        Intrinsics.checkNotNullParameter((Object)name, (String)"name");
        super(name);
        this.alphaInitializer = alphaInitializer;
        this.alphaRegularizer = alphaRegularizer;
        this.sharedAxes = sharedAxes;
        this.isTrainable = true;
    }

    public /* synthetic */ PReLU(Initializer initializer, Regularizer regularizer, int[] nArray, String string, int n, DefaultConstructorMarker defaultConstructorMarker) {
        if ((n & 1) != 0) {
            initializer = new Zeros();
        }
        if ((n & 2) != 0) {
            regularizer = null;
        }
        if ((n & 4) != 0) {
            nArray = null;
        }
        if ((n & 8) != 0) {
            string = "";
        }
        this(initializer, regularizer, nArray, string);
    }

    @NotNull
    public final Initializer getAlphaInitializer() {
        return this.alphaInitializer;
    }

    @Nullable
    public final Regularizer getAlphaRegularizer() {
        return this.alphaRegularizer;
    }

    @Nullable
    public final int[] getSharedAxes() {
        return this.sharedAxes;
    }

    @NotNull
    public final KVariable getAlpha$tensorflow() {
        KVariable kVariable = this.alpha;
        if (kVariable != null) {
            return kVariable;
        }
        Intrinsics.throwUninitializedPropertyAccessException((String)"alpha");
        return null;
    }

    public final void setAlpha$tensorflow(@NotNull KVariable kVariable) {
        Intrinsics.checkNotNullParameter((Object)kVariable, (String)"<set-?>");
        this.alpha = kVariable;
    }

    private final String alphaVariableName() {
        return ((CharSequence)this.getName()).length() > 0 ? this.getName() + "_alpha" : "alpha";
    }

    @Override
    @NotNull
    public List<KVariable> getVariables() {
        return CollectionsKt.listOf((Object)this.getAlpha$tensorflow());
    }

    @Override
    public boolean isTrainable() {
        return this.isTrainable;
    }

    @Override
    public void setTrainable(boolean bl) {
        this.isTrainable = bl;
    }

    @Override
    @NotNull
    public Operand<Float> forward(@NotNull Ops tf, @NotNull Operand<Float> input) {
        int fanIn;
        Intrinsics.checkNotNullParameter((Object)tf, (String)"tf");
        Intrinsics.checkNotNullParameter(input, (String)"input");
        Shape inputShape = input.asOutput().shape();
        Intrinsics.checkNotNullExpressionValue((Object)inputShape, (String)"inputShape");
        long[] alphaShapeArray = CollectionsKt.toLongArray((Collection)ArraysKt.drop((long[])ShapeFunctionsKt.toLongArray(inputShape), (int)1));
        if (this.sharedAxes != null) {
            for (int axis : this.sharedAxes) {
                alphaShapeArray[axis - 1] = 1L;
            }
        }
        int fanOut = fanIn = (int)inputShape.size(inputShape.numDimensions() - 1);
        long[] axis = CollectionsKt.toLongArray((Collection)ArraysKt.drop((long[])alphaShapeArray, (int)1));
        Shape alphaShape = Shape.make((long)alphaShapeArray[0], (long[])Arrays.copyOf(axis, axis.length));
        String string = this.alphaVariableName();
        Intrinsics.checkNotNullExpressionValue((Object)alphaShape, (String)"alphaShape");
        this.setAlpha$tensorflow(KVariableKt.createVariable(tf, string, alphaShape, fanIn, fanOut, this.alphaInitializer, this.alphaRegularizer));
        Relu positive = tf.nn.relu(input);
        Mul negative = tf.math.mul((Operand)tf.math.neg((Operand)this.getAlpha$tensorflow().getVariable()), (Operand)tf.nn.relu((Operand)tf.math.neg(input)));
        Add add = tf.math.add((Operand)positive, (Operand)negative);
        Intrinsics.checkNotNullExpressionValue((Object)add, (String)"tf.math.add(positive, negative)");
        return (Operand)add;
    }

    @NotNull
    public String toString() {
        String string;
        StringBuilder stringBuilder = new StringBuilder().append("PReLU(name = ").append(this.getName()).append(", isTrainable=").append(this.isTrainable()).append(", alphaInitializer=").append(this.alphaInitializer).append(", alphaRegularizer=").append(this.alphaRegularizer).append(", sharedAxes=");
        if (this.sharedAxes != null) {
            String string2 = Arrays.toString(this.sharedAxes);
            string = string2;
            Intrinsics.checkNotNullExpressionValue((Object)string2, (String)"toString(this)");
        } else {
            string = null;
        }
        return stringBuilder.append(string).append(')').toString();
    }

    @Override
    public int getParamCount() {
        return TrainableLayer.DefaultImpls.getParamCount(this);
    }

    public PReLU() {
        this(null, null, null, null, 15, null);
    }
}

