/*
 * Decompiled with CFR 0.152.
 */
package org.jetbrains.kotlinx.dl.api.core.loss;

import kotlin.Metadata;
import kotlin.jvm.internal.Intrinsics;
import kotlin.jvm.internal.SourceDebugExtension;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
import org.jetbrains.kotlinx.dl.api.core.loss.ReductionType;
import org.tensorflow.Operand;
import org.tensorflow.op.Ops;
import org.tensorflow.op.core.Constant;
import org.tensorflow.op.core.Range;
import org.tensorflow.op.core.ReduceSum;
import org.tensorflow.op.math.DivNoNan;
import org.tensorflow.op.math.Mean;

@Metadata(mv={1, 8, 0}, k=2, xi=48, d1={"\u0000\"\n\u0000\n\u0002\u0018\u0002\n\u0002\u0010\b\n\u0000\n\u0002\u0018\u0002\n\u0000\n\u0002\u0010\u0007\n\u0002\b\u0002\n\u0002\u0018\u0002\n\u0002\b\u0005\u001a$\u0010\u0000\u001a\b\u0012\u0004\u0012\u00020\u00020\u00012\u0006\u0010\u0003\u001a\u00020\u00042\f\u0010\u0005\u001a\b\u0012\u0004\u0012\u00020\u00060\u0001H\u0000\u001a<\u0010\u0007\u001a\b\u0012\u0004\u0012\u00020\u00060\u00012\u0006\u0010\u0003\u001a\u00020\u00042\u0006\u0010\b\u001a\u00020\t2\f\u0010\n\u001a\b\u0012\u0004\u0012\u00020\u00060\u00012\u000e\u0010\u000b\u001a\n\u0012\u0004\u0012\u00020\u0006\u0018\u00010\u0001H\u0000\u001a2\u0010\f\u001a\b\u0012\u0004\u0012\u00020\u00060\u00012\u0006\u0010\u0003\u001a\u00020\u00042\f\u0010\n\u001a\b\u0012\u0004\u0012\u00020\u00060\u00012\f\u0010\r\u001a\b\u0012\u0004\u0012\u00020\u00060\u0001H\u0000\u00a8\u0006\u000e"}, d2={"allAxes", "Lorg/tensorflow/Operand;", "", "tf", "Lorg/tensorflow/op/Ops;", "op", "", "meanOfLosses", "reductionType", "Lorg/jetbrains/kotlinx/dl/api/core/loss/ReductionType;", "loss", "numberOfLosses", "safeMean", "numElements", "tensorflow"})
@SourceDebugExtension(value={"SMAP\nLosses.kt\nKotlin\n*S Kotlin\n*F\n+ 1 Losses.kt\norg/jetbrains/kotlinx/dl/api/core/loss/LossesKt\n+ 2 fake.kt\nkotlin/jvm/internal/FakeKt\n*L\n1#1,437:1\n1#2:438\n*E\n"})
public final class LossesKt {
    @NotNull
    public static final Operand<Float> meanOfLosses(@NotNull Ops tf, @NotNull ReductionType reductionType, @NotNull Operand<Float> loss, @Nullable Operand<Float> numberOfLosses) {
        Intrinsics.checkNotNullParameter((Object)tf, (String)"tf");
        Intrinsics.checkNotNullParameter((Object)((Object)reductionType), (String)"reductionType");
        Intrinsics.checkNotNullParameter(loss, (String)"loss");
        Mean.Options[] optionsArray = new Mean.Options[]{Mean.keepDims((Boolean)false)};
        Mean meanLoss = tf.math.mean(loss, (Operand)tf.constant(-1), optionsArray);
        Operand operand = (Operand)meanLoss;
        Intrinsics.checkNotNullExpressionValue((Object)meanLoss, (String)"meanLoss");
        ReduceSum.Options[] optionsArray2 = new ReduceSum.Options[]{ReduceSum.keepDims((Boolean)false)};
        ReduceSum reduceSum = tf.reduceSum(operand, LossesKt.allAxes(tf, (Operand<Float>)((Operand)meanLoss)), optionsArray2);
        Intrinsics.checkNotNullExpressionValue((Object)reduceSum, (String)"tf.reduceSum(\n        me\u2026Sum.keepDims(false)\n    )");
        Operand<Float> totalLoss = (Operand<Float>)reduceSum;
        if (reductionType == ReductionType.SUM_OVER_BATCH_SIZE) {
            if (!(numberOfLosses != null)) {
                boolean bl = false;
                String string = "Operand numberOfLosses must be not null.";
                throw new IllegalStateException(string.toString());
            }
            totalLoss = LossesKt.safeMean(tf, loss, numberOfLosses);
        }
        return totalLoss;
    }

    @NotNull
    public static final Operand<Float> safeMean(@NotNull Ops tf, @NotNull Operand<Float> loss, @NotNull Operand<Float> numElements) {
        Intrinsics.checkNotNullParameter((Object)tf, (String)"tf");
        Intrinsics.checkNotNullParameter(loss, (String)"loss");
        Intrinsics.checkNotNullParameter(numElements, (String)"numElements");
        ReduceSum reduceSum = tf.reduceSum(loss, LossesKt.allAxes(tf, loss), new ReduceSum.Options[0]);
        Intrinsics.checkNotNullExpressionValue((Object)reduceSum, (String)"tf.reduceSum(loss, allAxes(tf, loss))");
        Operand totalLoss = (Operand)reduceSum;
        DivNoNan divNoNan = tf.math.divNoNan(totalLoss, numElements);
        Intrinsics.checkNotNullExpressionValue((Object)divNoNan, (String)"tf.math.divNoNan(totalLoss, numElements)");
        return (Operand)divNoNan;
    }

    @NotNull
    public static final Operand<Integer> allAxes(@NotNull Ops tf, @NotNull Operand<Float> op) {
        Operand operand;
        Intrinsics.checkNotNullParameter((Object)tf, (String)"tf");
        Intrinsics.checkNotNullParameter(op, (String)"op");
        int rank = op.asOutput().shape().numDimensions();
        if (rank != -1) {
            int[] axes = new int[rank];
            for (int i = 0; i < rank; ++i) {
                axes[i] = i;
            }
            Constant constant = tf.constant(axes);
            Intrinsics.checkNotNullExpressionValue((Object)constant, (String)"{\n        val axes = Int\u2026  tf.constant(axes)\n    }");
            operand = (Operand)constant;
        } else {
            Range range = tf.range((Operand)tf.constant(0), (Operand)tf.rank(op), (Operand)tf.constant(1));
            Intrinsics.checkNotNullExpressionValue((Object)range, (String)"{\n        tf.range(tf.co\u2026p), tf.constant(1))\n    }");
            operand = (Operand)range;
        }
        return operand;
    }
}

