package ai.djl.nn.transformer;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.training.loss.Loss;

/* loaded from: input_file:ai/djl/nn/transformer/BertNextSentenceLoss.class */
public class BertNextSentenceLoss extends Loss {
    private int labelIdx;
    private int nextSentencePredictionIdx;

    public BertNextSentenceLoss(int i, int i2) {
        super("BertNSLoss");
        this.labelIdx = i;
        this.nextSentencePredictionIdx = i2;
    }

    @Override // ai.djl.training.evaluator.Evaluator
    public NDArray evaluate(NDList nDList, NDList nDList2) {
        NDManager subManagerOf = NDManager.subManagerOf(nDList);
        try {
            subManagerOf.tempAttachAll(nDList, nDList2);
            NDArray nDArray = (NDArray) subManagerOf.ret(nDList.get(this.labelIdx).toType(DataType.FLOAT32, false).oneHot(2).mul(nDList2.get(this.nextSentencePredictionIdx)).sum(new int[]{1}).mul(Float.valueOf(-1.0f)).mean());
            if (subManagerOf != null) {
                subManagerOf.close();
            }
            return nDArray;
        } catch (Throwable th) {
            if (subManagerOf != null) {
                try {
                    subManagerOf.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    public NDArray accuracy(NDList nDList, NDList nDList2) {
        NDManager subManagerOf = NDManager.subManagerOf(nDList);
        try {
            subManagerOf.tempAttachAll(nDList, nDList2);
            NDArray nDArray = nDList.get(this.labelIdx);
            NDArray nDArray2 = (NDArray) subManagerOf.ret(nDArray.eq(nDList2.get(this.nextSentencePredictionIdx).argMax(1).toType(DataType.INT32, false)).sum().toType(DataType.FLOAT32, false).div(Long.valueOf(nDArray.getShape().size())));
            if (subManagerOf != null) {
                subManagerOf.close();
            }
            return nDArray2;
        } catch (Throwable th) {
            if (subManagerOf != null) {
                try {
                    subManagerOf.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }
}
