/*
 * Decompiled with CFR 0.152.
 */
package org.jetbrains.kotlinx.dl.api.core.layer.normalization;

import java.util.List;
import kotlin.Metadata;
import kotlin.collections.CollectionsKt;
import kotlin.jvm.internal.DefaultConstructorMarker;
import kotlin.jvm.internal.Intrinsics;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
import org.jetbrains.kotlinx.dl.api.core.initializer.Initializer;
import org.jetbrains.kotlinx.dl.api.core.initializer.Ones;
import org.jetbrains.kotlinx.dl.api.core.initializer.Zeros;
import org.jetbrains.kotlinx.dl.api.core.layer.KVariable;
import org.jetbrains.kotlinx.dl.api.core.layer.KVariableKt;
import org.jetbrains.kotlinx.dl.api.core.layer.Layer;
import org.jetbrains.kotlinx.dl.api.core.layer.NoGradients;
import org.jetbrains.kotlinx.dl.api.core.layer.ParametrizedLayer;
import org.jetbrains.kotlinx.dl.api.core.layer.TrainableLayerKt;
import org.jetbrains.kotlinx.dl.api.core.regularizer.Regularizer;
import org.jetbrains.kotlinx.dl.api.core.util.NameConventionsKt;
import org.tensorflow.Operand;
import org.tensorflow.Shape;
import org.tensorflow.op.Ops;
import org.tensorflow.op.core.Constant;
import org.tensorflow.op.core.Variable;
import org.tensorflow.op.math.Add;
import org.tensorflow.op.math.Mul;
import org.tensorflow.op.math.Rsqrt;

@Metadata(mv={1, 8, 0}, k=1, xi=48, d1={"\u0000b\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0000\n\u0002\u0010 \n\u0002\u0010\b\n\u0000\n\u0002\u0010\u0006\n\u0000\n\u0002\u0010\u000b\n\u0002\b\u0003\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0018\u0002\n\u0002\b\u0004\n\u0002\u0010\u000e\n\u0002\b\u0004\n\u0002\u0018\u0002\n\u0002\b \n\u0002\u0018\u0002\n\u0002\u0010\u0007\n\u0000\n\u0002\u0018\u0002\n\u0000\n\u0002\u0018\u0002\n\u0002\b\b\u0018\u00002\u00020\u00012\u00020\u00022\u00020\u0003B\u0087\u0001\u0012\u000e\b\u0002\u0010\u0004\u001a\b\u0012\u0004\u0012\u00020\u00060\u0005\u0012\b\b\u0002\u0010\u0007\u001a\u00020\b\u0012\b\b\u0002\u0010\t\u001a\u00020\n\u0012\b\b\u0002\u0010\u000b\u001a\u00020\b\u0012\b\b\u0002\u0010\f\u001a\u00020\n\u0012\b\b\u0002\u0010\r\u001a\u00020\u000e\u0012\b\b\u0002\u0010\u000f\u001a\u00020\u000e\u0012\n\b\u0002\u0010\u0010\u001a\u0004\u0018\u00010\u0011\u0012\n\b\u0002\u0010\u0012\u001a\u0004\u0018\u00010\u0011\u0012\b\b\u0002\u0010\u0013\u001a\u00020\u000e\u0012\b\b\u0002\u0010\u0014\u001a\u00020\u000e\u0012\b\b\u0002\u0010\u0015\u001a\u00020\u0016\u00a2\u0006\u0002\u0010\u0017Jn\u0010;\u001a\b\u0012\u0004\u0012\u00020=0<2\u0006\u0010>\u001a\u00020?2\f\u0010@\u001a\b\u0012\u0004\u0012\u00020=0<2\u000e\u0010(\u001a\n\u0012\u0004\u0012\u00020=\u0018\u00010A2\u000e\u0010\u001a\u001a\n\u0012\u0004\u0012\u00020=\u0018\u00010<2\f\u00100\u001a\b\u0012\u0004\u0012\u00020=0<2\f\u0010B\u001a\b\u0012\u0004\u0012\u00020=0<2\f\u0010C\u001a\b\u0012\u0004\u0012\u00020=0<H\u0002JB\u0010D\u001a\b\u0012\u0004\u0012\u00020=0<2\u0006\u0010>\u001a\u00020?2\f\u0010E\u001a\b\u0012\u0004\u0012\u00020=0<2\f\u0010F\u001a\b\u0012\u0004\u0012\u00020\n0<2\u000e\u0010G\u001a\n\u0012\u0004\u0012\u00020=\u0018\u00010<H\u0016J\b\u0010H\u001a\u00020\u0016H\u0016R\u0017\u0010\u0004\u001a\b\u0012\u0004\u0012\u00020\u00060\u0005\u00a2\u0006\b\n\u0000\u001a\u0004\b\u0018\u0010\u0019R\u001c\u0010\u001a\u001a\u0004\u0018\u00010\u001bX\u0080\u000e\u00a2\u0006\u000e\n\u0000\u001a\u0004\b\u001c\u0010\u001d\"\u0004\b\u001e\u0010\u001fR\u0011\u0010\u000f\u001a\u00020\u000e\u00a2\u0006\b\n\u0000\u001a\u0004\b \u0010!R\u0013\u0010\u0012\u001a\u0004\u0018\u00010\u0011\u00a2\u0006\b\n\u0000\u001a\u0004\b\"\u0010#R\u0011\u0010\t\u001a\u00020\n\u00a2\u0006\b\n\u0000\u001a\u0004\b$\u0010%R\u0011\u0010\u000b\u001a\u00020\b\u00a2\u0006\b\n\u0000\u001a\u0004\b&\u0010'R\u001c\u0010(\u001a\u0004\u0018\u00010\u001bX\u0080\u000e\u00a2\u0006\u000e\n\u0000\u001a\u0004\b)\u0010\u001d\"\u0004\b*\u0010\u001fR\u0011\u0010\r\u001a\u00020\u000e\u00a2\u0006\b\n\u0000\u001a\u0004\b+\u0010!R\u0013\u0010\u0010\u001a\u0004\u0018\u00010\u0011\u00a2\u0006\b\n\u0000\u001a\u0004\b,\u0010#R\u0014\u0010-\u001a\u00020\n8VX\u0096\u0004\u00a2\u0006\u0006\u001a\u0004\b.\u0010%R\u0011\u0010\u0007\u001a\u00020\b\u00a2\u0006\b\n\u0000\u001a\u0004\b/\u0010'R\u001a\u00100\u001a\u00020\u001bX\u0080.\u00a2\u0006\u000e\n\u0000\u001a\u0004\b1\u0010\u001d\"\u0004\b2\u0010\u001fR\u0011\u0010\u0013\u001a\u00020\u000e\u00a2\u0006\b\n\u0000\u001a\u0004\b3\u0010!R\u001a\u00104\u001a\u00020\u001bX\u0080.\u00a2\u0006\u000e\n\u0000\u001a\u0004\b5\u0010\u001d\"\u0004\b6\u0010\u001fR\u0011\u0010\u0014\u001a\u00020\u000e\u00a2\u0006\b\n\u0000\u001a\u0004\b7\u0010!R\u0011\u0010\f\u001a\u00020\n\u00a2\u0006\b\n\u0000\u001a\u0004\b8\u0010%R\u001a\u00109\u001a\b\u0012\u0004\u0012\u00020\u001b0\u00058VX\u0096\u0004\u00a2\u0006\u0006\u001a\u0004\b:\u0010\u0019\u00a8\u0006I"}, d2={"Lorg/jetbrains/kotlinx/dl/api/core/layer/normalization/BatchNorm;", "Lorg/jetbrains/kotlinx/dl/api/core/layer/Layer;", "Lorg/jetbrains/kotlinx/dl/api/core/layer/NoGradients;", "Lorg/jetbrains/kotlinx/dl/api/core/layer/ParametrizedLayer;", "axis", "", "", "momentum", "", "center", "", "epsilon", "scale", "gammaInitializer", "Lorg/jetbrains/kotlinx/dl/api/core/initializer/Initializer;", "betaInitializer", "gammaRegularizer", "Lorg/jetbrains/kotlinx/dl/api/core/regularizer/Regularizer;", "betaRegularizer", "movingMeanInitializer", "movingVarianceInitializer", "name", "", "(Ljava/util/List;DZDZLorg/jetbrains/kotlinx/dl/api/core/initializer/Initializer;Lorg/jetbrains/kotlinx/dl/api/core/initializer/Initializer;Lorg/jetbrains/kotlinx/dl/api/core/regularizer/Regularizer;Lorg/jetbrains/kotlinx/dl/api/core/regularizer/Regularizer;Lorg/jetbrains/kotlinx/dl/api/core/initializer/Initializer;Lorg/jetbrains/kotlinx/dl/api/core/initializer/Initializer;Ljava/lang/String;)V", "getAxis", "()Ljava/util/List;", "beta", "Lorg/jetbrains/kotlinx/dl/api/core/layer/KVariable;", "getBeta$tensorflow", "()Lorg/jetbrains/kotlinx/dl/api/core/layer/KVariable;", "setBeta$tensorflow", "(Lorg/jetbrains/kotlinx/dl/api/core/layer/KVariable;)V", "getBetaInitializer", "()Lorg/jetbrains/kotlinx/dl/api/core/initializer/Initializer;", "getBetaRegularizer", "()Lorg/jetbrains/kotlinx/dl/api/core/regularizer/Regularizer;", "getCenter", "()Z", "getEpsilon", "()D", "gamma", "getGamma$tensorflow", "setGamma$tensorflow", "getGammaInitializer", "getGammaRegularizer", "hasActivation", "getHasActivation", "getMomentum", "movingMean", "getMovingMean$tensorflow", "setMovingMean$tensorflow", "getMovingMeanInitializer", "movingVariance", "getMovingVariance$tensorflow", "setMovingVariance$tensorflow", "getMovingVarianceInitializer", "getScale", "variables", "getVariables", "batchNorm", "Lorg/tensorflow/Operand;", "", "tf", "Lorg/tensorflow/op/Ops;", "x", "Lorg/tensorflow/op/core/Variable;", "movingVar", "eps", "build", "input", "isTraining", "numberOfLosses", "toString", "tensorflow"})
public final class BatchNorm
extends Layer
implements NoGradients,
ParametrizedLayer {
    @NotNull
    private final List<Integer> axis;
    private final double momentum;
    private final boolean center;
    private final double epsilon;
    private final boolean scale;
    @NotNull
    private final Initializer gammaInitializer;
    @NotNull
    private final Initializer betaInitializer;
    @Nullable
    private final Regularizer gammaRegularizer;
    @Nullable
    private final Regularizer betaRegularizer;
    @NotNull
    private final Initializer movingMeanInitializer;
    @NotNull
    private final Initializer movingVarianceInitializer;
    @Nullable
    private KVariable gamma;
    @Nullable
    private KVariable beta;
    public KVariable movingMean;
    public KVariable movingVariance;

    public BatchNorm(@NotNull List<Integer> axis, double momentum, boolean center, double epsilon, boolean scale, @NotNull Initializer gammaInitializer, @NotNull Initializer betaInitializer, @Nullable Regularizer gammaRegularizer, @Nullable Regularizer betaRegularizer, @NotNull Initializer movingMeanInitializer, @NotNull Initializer movingVarianceInitializer, @NotNull String name) {
        Intrinsics.checkNotNullParameter(axis, (String)"axis");
        Intrinsics.checkNotNullParameter((Object)gammaInitializer, (String)"gammaInitializer");
        Intrinsics.checkNotNullParameter((Object)betaInitializer, (String)"betaInitializer");
        Intrinsics.checkNotNullParameter((Object)movingMeanInitializer, (String)"movingMeanInitializer");
        Intrinsics.checkNotNullParameter((Object)movingVarianceInitializer, (String)"movingVarianceInitializer");
        Intrinsics.checkNotNullParameter((Object)name, (String)"name");
        super(name);
        this.axis = axis;
        this.momentum = momentum;
        this.center = center;
        this.epsilon = epsilon;
        this.scale = scale;
        this.gammaInitializer = gammaInitializer;
        this.betaInitializer = betaInitializer;
        this.gammaRegularizer = gammaRegularizer;
        this.betaRegularizer = betaRegularizer;
        this.movingMeanInitializer = movingMeanInitializer;
        this.movingVarianceInitializer = movingVarianceInitializer;
    }

    public /* synthetic */ BatchNorm(List list, double d, boolean bl, double d2, boolean bl2, Initializer initializer, Initializer initializer2, Regularizer regularizer, Regularizer regularizer2, Initializer initializer3, Initializer initializer4, String string, int n, DefaultConstructorMarker defaultConstructorMarker) {
        if ((n & 1) != 0) {
            Object[] objectArray = new Integer[]{3};
            list = CollectionsKt.arrayListOf((Object[])objectArray);
        }
        if ((n & 2) != 0) {
            d = 0.99;
        }
        if ((n & 4) != 0) {
            bl = true;
        }
        if ((n & 8) != 0) {
            d2 = 0.001;
        }
        if ((n & 0x10) != 0) {
            bl2 = true;
        }
        if ((n & 0x20) != 0) {
            initializer = new Ones();
        }
        if ((n & 0x40) != 0) {
            initializer2 = new Zeros();
        }
        if ((n & 0x80) != 0) {
            regularizer = null;
        }
        if ((n & 0x100) != 0) {
            regularizer2 = null;
        }
        if ((n & 0x200) != 0) {
            initializer3 = new Zeros();
        }
        if ((n & 0x400) != 0) {
            initializer4 = new Ones();
        }
        if ((n & 0x800) != 0) {
            string = "";
        }
        this(list, d, bl, d2, bl2, initializer, initializer2, regularizer, regularizer2, initializer3, initializer4, string);
    }

    @NotNull
    public final List<Integer> getAxis() {
        return this.axis;
    }

    public final double getMomentum() {
        return this.momentum;
    }

    public final boolean getCenter() {
        return this.center;
    }

    public final double getEpsilon() {
        return this.epsilon;
    }

    public final boolean getScale() {
        return this.scale;
    }

    @NotNull
    public final Initializer getGammaInitializer() {
        return this.gammaInitializer;
    }

    @NotNull
    public final Initializer getBetaInitializer() {
        return this.betaInitializer;
    }

    @Nullable
    public final Regularizer getGammaRegularizer() {
        return this.gammaRegularizer;
    }

    @Nullable
    public final Regularizer getBetaRegularizer() {
        return this.betaRegularizer;
    }

    @NotNull
    public final Initializer getMovingMeanInitializer() {
        return this.movingMeanInitializer;
    }

    @NotNull
    public final Initializer getMovingVarianceInitializer() {
        return this.movingVarianceInitializer;
    }

    @Nullable
    public final KVariable getGamma$tensorflow() {
        return this.gamma;
    }

    public final void setGamma$tensorflow(@Nullable KVariable kVariable) {
        this.gamma = kVariable;
    }

    @Nullable
    public final KVariable getBeta$tensorflow() {
        return this.beta;
    }

    public final void setBeta$tensorflow(@Nullable KVariable kVariable) {
        this.beta = kVariable;
    }

    @NotNull
    public final KVariable getMovingMean$tensorflow() {
        KVariable kVariable = this.movingMean;
        if (kVariable != null) {
            return kVariable;
        }
        Intrinsics.throwUninitializedPropertyAccessException((String)"movingMean");
        return null;
    }

    public final void setMovingMean$tensorflow(@NotNull KVariable kVariable) {
        Intrinsics.checkNotNullParameter((Object)kVariable, (String)"<set-?>");
        this.movingMean = kVariable;
    }

    @NotNull
    public final KVariable getMovingVariance$tensorflow() {
        KVariable kVariable = this.movingVariance;
        if (kVariable != null) {
            return kVariable;
        }
        Intrinsics.throwUninitializedPropertyAccessException((String)"movingVariance");
        return null;
    }

    public final void setMovingVariance$tensorflow(@NotNull KVariable kVariable) {
        Intrinsics.checkNotNullParameter((Object)kVariable, (String)"<set-?>");
        this.movingVariance = kVariable;
    }

    @Override
    @NotNull
    public List<KVariable> getVariables() {
        Object[] objectArray = new KVariable[]{this.gamma, this.beta, this.getMovingMean$tensorflow(), this.getMovingVariance$tensorflow()};
        return CollectionsKt.listOfNotNull((Object[])objectArray);
    }

    @Override
    @NotNull
    public Operand<Float> build(@NotNull Ops tf, @NotNull Operand<Float> input, @NotNull Operand<Boolean> isTraining, @Nullable Operand<Float> numberOfLosses) {
        Intrinsics.checkNotNullParameter((Object)tf, (String)"tf");
        Intrinsics.checkNotNullParameter(input, (String)"input");
        Intrinsics.checkNotNullParameter(isTraining, (String)"isTraining");
        Shape inputShape = input.asOutput().shape();
        Shape weightShape = Shape.make((long)inputShape.size(((Number)this.axis.get(0)).intValue()), (long[])new long[0]);
        if (((CharSequence)this.getName()).length() == 0) {
            throw new RuntimeException("Cannot build BatchNorm layer, because of empty name");
        }
        int fanIn = Integer.MIN_VALUE;
        int fanOut = Integer.MIN_VALUE;
        String string = NameConventionsKt.batchNormMovingMeanVarName(this.getName());
        Intrinsics.checkNotNullExpressionValue((Object)weightShape, (String)"weightShape");
        this.setMovingMean$tensorflow(KVariableKt.createVariable(tf, string, weightShape, fanIn, fanOut, this.movingMeanInitializer, null));
        this.setMovingVariance$tensorflow(KVariableKt.createVariable(tf, NameConventionsKt.batchNormMovingVarianceVarName(this.getName()), weightShape, fanIn, fanOut, this.movingVarianceInitializer, null));
        if (this.scale) {
            this.gamma = KVariableKt.createVariable(tf, NameConventionsKt.batchNormGammaVarName(this.getName()), weightShape, fanIn, fanOut, this.gammaInitializer, this.gammaRegularizer);
        }
        if (this.center) {
            this.beta = KVariableKt.createVariable(tf, NameConventionsKt.batchNormBetaVarName(this.getName()), weightShape, fanIn, fanOut, this.betaInitializer, this.betaRegularizer);
        }
        Ops tf2 = tf.withName("BatchNorm");
        Intrinsics.checkNotNullExpressionValue((Object)tf2, (String)"tf");
        KVariable kVariable = this.gamma;
        Object object = kVariable != null ? kVariable.getVariable() : null;
        KVariable kVariable2 = this.beta;
        Operand operand = (Operand)(kVariable2 != null ? kVariable2.getVariable() : null);
        Operand operand2 = (Operand)this.getMovingMean$tensorflow().getVariable();
        Operand operand3 = (Operand)this.getMovingVariance$tensorflow().getVariable();
        Constant constant = tf2.constant((float)this.epsilon);
        Intrinsics.checkNotNullExpressionValue((Object)constant, (String)"tf.constant(epsilon.toFloat())");
        return this.batchNorm(tf2, input, (Variable<Float>)object, (Operand<Float>)operand, (Operand<Float>)operand2, (Operand<Float>)operand3, (Operand<Float>)((Operand)constant));
    }

    private final Operand<Float> batchNorm(Ops tf, Operand<Float> x, Variable<Float> gamma, Operand<Float> beta, Operand<Float> movingMean, Operand<Float> movingVar, Operand<Float> eps) {
        Operand operand;
        Rsqrt rsqrt = tf.math.rsqrt((Operand)tf.math.add(movingVar, eps));
        Intrinsics.checkNotNullExpressionValue((Object)rsqrt, (String)"tf.math.rsqrt(tf.math.add(movingVar, eps))");
        Operand inv = (Operand)rsqrt;
        if (this.scale) {
            Mul mul = tf.math.mul(inv, (Operand)gamma);
            Intrinsics.checkNotNullExpressionValue((Object)mul, (String)"tf.math.mul(inv, gamma)");
            inv = (Operand)mul;
        }
        Mul xNorm = tf.math.mul((Operand)tf.math.sub(x, movingMean), inv);
        if (this.center) {
            Add add = tf.math.add((Operand)xNorm, beta);
            Intrinsics.checkNotNullExpressionValue((Object)add, (String)"tf.math.add(xNorm, beta)");
            operand = (Operand)add;
        } else {
            Intrinsics.checkNotNullExpressionValue((Object)xNorm, (String)"xNorm");
            operand = (Operand)xNorm;
        }
        return operand;
    }

    @NotNull
    public String toString() {
        StringBuilder stringBuilder = new StringBuilder();
        stringBuilder.append("BatchNorm(name = ").append(this.getName()).append(", isTrainable=").append(TrainableLayerKt.isTrainable(this)).append(", axis=").append(this.axis).append(", momentum=").append(this.momentum).append(", center=").append(this.center).append(", epsilon=").append(this.epsilon).append(", scale=").append(this.scale).append(", gammaInitializer=").append(this.gammaInitializer).append(", betaInitializer=").append(this.betaInitializer).append(", gammaRegularizer=").append(this.gammaRegularizer).append(", betaRegularizer=").append(this.betaRegularizer).append(", movingMeanInitializer=");
        KVariable kVariable = this.gamma;
        KVariable kVariable2 = this.beta;
        stringBuilder.append(this.movingMeanInitializer).append(", movingVarianceInitializer=").append(this.movingVarianceInitializer).append(", hasActivation=").append(this.getHasActivation()).append(", gammaShapeArray=").append(kVariable != null ? kVariable.getShape() : null).append(", betaShapeArray=").append(kVariable2 != null ? kVariable2.getShape() : null).append(", movingMeanShapeArray=").append(this.getMovingMean$tensorflow().getShape()).append(", movingVarianceShapeArray=").append(this.getMovingVariance$tensorflow().getShape()).append(')');
        return stringBuilder.toString();
    }

    @Override
    public boolean getHasActivation() {
        return false;
    }

    @Override
    public int getParamCount() {
        return ParametrizedLayer.DefaultImpls.getParamCount(this);
    }

    public BatchNorm() {
        this(null, 0.0, false, 0.0, false, null, null, null, null, null, null, null, 4095, null);
    }
}

