/*
 * Decompiled with CFR 0.152.
 */
package org.jetbrains.kotlinx.dl.api.core.metric;

import kotlin.Metadata;
import kotlin.jvm.internal.Intrinsics;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
import org.jetbrains.kotlinx.dl.api.core.loss.ReductionType;
import org.jetbrains.kotlinx.dl.api.core.metric.Metric;
import org.jetbrains.kotlinx.dl.api.core.util.DtypeConversionUtilKt;
import org.tensorflow.Operand;
import org.tensorflow.op.Ops;
import org.tensorflow.op.dtypes.Cast;
import org.tensorflow.op.math.ArgMax;
import org.tensorflow.op.math.Equal;
import org.tensorflow.op.math.Mean;

@Metadata(mv={1, 8, 0}, k=1, xi=48, d1={"\u0000\u001e\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0018\u0002\n\u0002\u0010\u0007\n\u0000\n\u0002\u0018\u0002\n\u0002\b\u0004\u0018\u00002\u00020\u0001B\u0005\u00a2\u0006\u0002\u0010\u0002JB\u0010\u0003\u001a\b\u0012\u0004\u0012\u00020\u00050\u00042\u0006\u0010\u0006\u001a\u00020\u00072\f\u0010\b\u001a\b\u0012\u0004\u0012\u00020\u00050\u00042\f\u0010\t\u001a\b\u0012\u0004\u0012\u00020\u00050\u00042\u000e\u0010\n\u001a\n\u0012\u0004\u0012\u00020\u0005\u0018\u00010\u0004H\u0016\u00a8\u0006\u000b"}, d2={"Lorg/jetbrains/kotlinx/dl/api/core/metric/Accuracy;", "Lorg/jetbrains/kotlinx/dl/api/core/metric/Metric;", "()V", "apply", "Lorg/tensorflow/Operand;", "", "tf", "Lorg/tensorflow/op/Ops;", "yPred", "yTrue", "numberOfLabels", "tensorflow"})
public final class Accuracy
extends Metric {
    public Accuracy() {
        super(ReductionType.SUM_OVER_BATCH_SIZE);
    }

    @Override
    @NotNull
    public Operand<Float> apply(@NotNull Ops tf, @NotNull Operand<Float> yPred, @NotNull Operand<Float> yTrue, @Nullable Operand<Float> numberOfLabels) {
        Intrinsics.checkNotNullParameter((Object)tf, (String)"tf");
        Intrinsics.checkNotNullParameter(yPred, (String)"yPred");
        Intrinsics.checkNotNullParameter(yTrue, (String)"yTrue");
        ArgMax argMax = tf.math.argMax(yPred, (Operand)tf.constant(1));
        Intrinsics.checkNotNullExpressionValue((Object)argMax, (String)"tf.math.argMax(yPred, tf.constant(1))");
        Operand predicted = (Operand)argMax;
        ArgMax argMax2 = tf.math.argMax(yTrue, (Operand)tf.constant(1));
        Intrinsics.checkNotNullExpressionValue((Object)argMax2, (String)"tf.math.argMax(yTrue, tf.constant(1))");
        Operand expected = (Operand)argMax2;
        Mean mean = tf.math.mean((Operand)tf.dtypes.cast((Operand)tf.math.equal(predicted, expected, new Equal.Options[0]), DtypeConversionUtilKt.getDType(), new Cast.Options[0]), (Operand)tf.constant(0), new Mean.Options[0]);
        Intrinsics.checkNotNullExpressionValue((Object)mean, (String)"tf.math.mean(tf.dtypes.c\u2026DType()), tf.constant(0))");
        return (Operand)mean;
    }
}

