/*
 * Decompiled with CFR 0.152.
 */
package org.jetbrains.kotlinx.dl.api.core.loss;

import kotlin.Metadata;
import kotlin.jvm.internal.DefaultConstructorMarker;
import kotlin.jvm.internal.Intrinsics;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
import org.jetbrains.kotlinx.dl.api.core.loss.LossFunction;
import org.jetbrains.kotlinx.dl.api.core.loss.LossesKt;
import org.jetbrains.kotlinx.dl.api.core.loss.ReductionType;
import org.jetbrains.kotlinx.dl.api.core.util.DtypeConversionUtilKt;
import org.tensorflow.Operand;
import org.tensorflow.op.Ops;
import org.tensorflow.op.dtypes.Cast;
import org.tensorflow.op.math.Abs;
import org.tensorflow.op.math.Add;
import org.tensorflow.op.math.Minimum;
import org.tensorflow.op.math.Mul;
import org.tensorflow.op.math.Sub;

@Metadata(mv={1, 8, 0}, k=1, xi=48, d1={"\u0000&\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0000\n\u0002\u0010\u0007\n\u0000\n\u0002\u0018\u0002\n\u0002\b\u0004\n\u0002\u0018\u0002\n\u0000\n\u0002\u0018\u0002\n\u0002\b\u0004\u0018\u00002\u00020\u0001B\u0019\u0012\b\b\u0002\u0010\u0002\u001a\u00020\u0003\u0012\b\b\u0002\u0010\u0004\u001a\u00020\u0005\u00a2\u0006\u0002\u0010\u0006JB\u0010\t\u001a\b\u0012\u0004\u0012\u00020\u00030\n2\u0006\u0010\u000b\u001a\u00020\f2\f\u0010\r\u001a\b\u0012\u0004\u0012\u00020\u00030\n2\f\u0010\u000e\u001a\b\u0012\u0004\u0012\u00020\u00030\n2\u000e\u0010\u000f\u001a\n\u0012\u0004\u0012\u00020\u0003\u0018\u00010\nH\u0016R\u0011\u0010\u0002\u001a\u00020\u0003\u00a2\u0006\b\n\u0000\u001a\u0004\b\u0007\u0010\b\u00a8\u0006\u0010"}, d2={"Lorg/jetbrains/kotlinx/dl/api/core/loss/Huber;", "Lorg/jetbrains/kotlinx/dl/api/core/loss/LossFunction;", "delta", "", "reductionType", "Lorg/jetbrains/kotlinx/dl/api/core/loss/ReductionType;", "(FLorg/jetbrains/kotlinx/dl/api/core/loss/ReductionType;)V", "getDelta", "()F", "apply", "Lorg/tensorflow/Operand;", "tf", "Lorg/tensorflow/op/Ops;", "yPred", "yTrue", "numberOfLosses", "tensorflow"})
public final class Huber
extends LossFunction {
    private final float delta;

    public Huber(float delta, @NotNull ReductionType reductionType) {
        Intrinsics.checkNotNullParameter((Object)((Object)reductionType), (String)"reductionType");
        super(reductionType);
        this.delta = delta;
    }

    public /* synthetic */ Huber(float f, ReductionType reductionType, int n, DefaultConstructorMarker defaultConstructorMarker) {
        if ((n & 1) != 0) {
            f = 1.0f;
        }
        if ((n & 2) != 0) {
            reductionType = ReductionType.SUM_OVER_BATCH_SIZE;
        }
        this(f, reductionType);
    }

    public final float getDelta() {
        return this.delta;
    }

    @Override
    @NotNull
    public Operand<Float> apply(@NotNull Ops tf, @NotNull Operand<Float> yPred, @NotNull Operand<Float> yTrue, @Nullable Operand<Float> numberOfLosses) {
        Intrinsics.checkNotNullParameter((Object)tf, (String)"tf");
        Intrinsics.checkNotNullParameter(yPred, (String)"yPred");
        Intrinsics.checkNotNullParameter(yTrue, (String)"yTrue");
        Sub error = tf.math.sub(yPred, yTrue);
        Cast cast = tf.dtypes.cast((Operand)tf.constant(this.delta), DtypeConversionUtilKt.getDType(), new Cast.Options[0]);
        Intrinsics.checkNotNullExpressionValue((Object)cast, (String)"tf.dtypes.cast(tf.constant(delta), getDType())");
        Operand deltaConst = (Operand)cast;
        Cast cast2 = tf.dtypes.cast((Operand)tf.constant(0.5), DtypeConversionUtilKt.getDType(), new Cast.Options[0]);
        Intrinsics.checkNotNullExpressionValue((Object)cast2, (String)"tf.dtypes.cast(tf.constant(0.5), getDType())");
        Operand point5 = (Operand)cast2;
        Abs abs = tf.math.abs((Operand)error);
        Intrinsics.checkNotNullExpressionValue((Object)abs, (String)"tf.math.abs(error)");
        Operand absError = (Operand)abs;
        Minimum minimum = tf.math.minimum(absError, deltaConst);
        Intrinsics.checkNotNullExpressionValue((Object)minimum, (String)"tf.math.minimum(absError, deltaConst)");
        Operand quadratic = (Operand)minimum;
        Sub sub = tf.math.sub(absError, quadratic);
        Intrinsics.checkNotNullExpressionValue((Object)sub, (String)"tf.math.sub(absError, quadratic)");
        Operand linear = (Operand)sub;
        Mul mul = tf.math.mul(point5, (Operand)tf.math.mul(quadratic, quadratic));
        Intrinsics.checkNotNullExpressionValue((Object)mul, (String)"tf.math.mul(point5, tf.m\u2026ul(quadratic, quadratic))");
        Operand q2Point5 = (Operand)mul;
        Mul mul2 = tf.math.mul(deltaConst, linear);
        Intrinsics.checkNotNullExpressionValue((Object)mul2, (String)"tf.math.mul(deltaConst, linear)");
        Operand deltaLinear = (Operand)mul2;
        Add add = tf.math.add(q2Point5, deltaLinear);
        Intrinsics.checkNotNullExpressionValue((Object)add, (String)"tf.math.add(q2Point5, deltaLinear)");
        Operand loss = (Operand)add;
        return LossesKt.meanOfLosses(tf, this.getReductionType(), (Operand<Float>)loss, numberOfLosses);
    }

    public Huber() {
        this(0.0f, null, 3, null);
    }
}

