/*
 * Decompiled with CFR 0.152.
 */
package org.jetbrains.kotlinx.dl.api.core.loss;

import kotlin.Metadata;
import kotlin.jvm.internal.DefaultConstructorMarker;
import kotlin.jvm.internal.Intrinsics;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
import org.jetbrains.kotlinx.dl.api.core.loss.LossFunction;
import org.jetbrains.kotlinx.dl.api.core.loss.LossesKt;
import org.jetbrains.kotlinx.dl.api.core.loss.ReductionType;
import org.tensorflow.Operand;
import org.tensorflow.op.Ops;
import org.tensorflow.op.core.Constant;
import org.tensorflow.op.math.Add;
import org.tensorflow.op.math.Minimum;
import org.tensorflow.op.math.Mul;
import org.tensorflow.op.math.Sub;

@Metadata(mv={1, 8, 0}, k=1, xi=48, d1={"\u0000$\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0000\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0018\u0002\n\u0002\u0010\u0007\n\u0000\n\u0002\u0018\u0002\n\u0002\b\u0004\u0018\u00002\u00020\u0001B\u000f\u0012\b\b\u0002\u0010\u0002\u001a\u00020\u0003\u00a2\u0006\u0002\u0010\u0004JB\u0010\u0005\u001a\b\u0012\u0004\u0012\u00020\u00070\u00062\u0006\u0010\b\u001a\u00020\t2\f\u0010\n\u001a\b\u0012\u0004\u0012\u00020\u00070\u00062\f\u0010\u000b\u001a\b\u0012\u0004\u0012\u00020\u00070\u00062\u000e\u0010\f\u001a\n\u0012\u0004\u0012\u00020\u0007\u0018\u00010\u0006H\u0016\u00a8\u0006\r"}, d2={"Lorg/jetbrains/kotlinx/dl/api/core/loss/BinaryCrossEntropy;", "Lorg/jetbrains/kotlinx/dl/api/core/loss/LossFunction;", "reductionType", "Lorg/jetbrains/kotlinx/dl/api/core/loss/ReductionType;", "(Lorg/jetbrains/kotlinx/dl/api/core/loss/ReductionType;)V", "apply", "Lorg/tensorflow/Operand;", "", "tf", "Lorg/tensorflow/op/Ops;", "yPred", "yTrue", "numberOfLosses", "tensorflow"})
public final class BinaryCrossEntropy
extends LossFunction {
    public BinaryCrossEntropy(@NotNull ReductionType reductionType) {
        Intrinsics.checkNotNullParameter((Object)((Object)reductionType), (String)"reductionType");
        super(reductionType);
    }

    public /* synthetic */ BinaryCrossEntropy(ReductionType reductionType, int n, DefaultConstructorMarker defaultConstructorMarker) {
        if ((n & 1) != 0) {
            reductionType = ReductionType.SUM_OVER_BATCH_SIZE;
        }
        this(reductionType);
    }

    @Override
    @NotNull
    public Operand<Float> apply(@NotNull Ops tf, @NotNull Operand<Float> yPred, @NotNull Operand<Float> yTrue, @Nullable Operand<Float> numberOfLosses) {
        Intrinsics.checkNotNullParameter((Object)tf, (String)"tf");
        Intrinsics.checkNotNullParameter(yPred, (String)"yPred");
        Intrinsics.checkNotNullParameter(yTrue, (String)"yTrue");
        float epsilon = 1.0E-7f;
        Constant constant = tf.constant(1.0f);
        Intrinsics.checkNotNull((Object)constant, (String)"null cannot be cast to non-null type org.tensorflow.Operand<kotlin.Float>");
        Operand oneOp = (Operand)constant;
        Constant constant2 = tf.constant(-1.0f);
        Intrinsics.checkNotNull((Object)constant2, (String)"null cannot be cast to non-null type org.tensorflow.Operand<kotlin.Float>");
        Operand minusOneOp = (Operand)constant2;
        Constant constant3 = tf.constant(epsilon);
        Intrinsics.checkNotNull((Object)constant3, (String)"null cannot be cast to non-null type org.tensorflow.Operand<kotlin.Float>");
        Operand epsilonOp = (Operand)constant3;
        Sub oneMinusEpsilonOp = tf.math.sub(oneOp, epsilonOp);
        Minimum clippedYPred = tf.math.minimum((Operand)oneMinusEpsilonOp, (Operand)tf.math.maximum(epsilonOp, yPred));
        Mul right = tf.math.mul(yTrue, (Operand)tf.math.log((Operand)tf.math.add((Operand)clippedYPred, epsilonOp)));
        Mul left = tf.math.mul((Operand)tf.math.log((Operand)tf.math.add((Operand)tf.math.sub(oneOp, (Operand)clippedYPred), epsilonOp)), (Operand)tf.math.sub(oneOp, yTrue));
        Add sum = tf.math.add((Operand)right, (Operand)left);
        Mul loss = tf.math.mul(minusOneOp, (Operand)sum);
        ReductionType reductionType = this.getReductionType();
        Intrinsics.checkNotNullExpressionValue((Object)loss, (String)"loss");
        return LossesKt.meanOfLosses(tf, reductionType, (Operand<Float>)((Operand)loss), numberOfLosses);
    }

    public BinaryCrossEntropy() {
        this(null, 1, null);
    }
}

