/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.lossfunctions;

import java.util.HashMap;
import java.util.Map;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.common.base.Preconditions;
import org.nd4j.common.primitives.Pair;
import org.nd4j.linalg.activations.IActivation;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.lossfunctions.ILossFunction;
import org.nd4j.linalg.lossfunctions.LossUtil;

public abstract class SameDiffLoss
implements ILossFunction {
    protected transient SameDiff sd;
    protected transient SDVariable scorePerExampleVariable;

    protected SameDiffLoss() {
    }

    public abstract SDVariable defineLoss(SameDiff var1, SDVariable var2, SDVariable var3);

    protected void createSameDiffInstance(DataType dataType) {
        this.sd = SameDiff.create();
        SDVariable layerInput = this.sd.placeHolder("layerInput", dataType, -1L);
        SDVariable labels = this.sd.placeHolder("labels", dataType, -1L);
        this.scorePerExampleVariable = this.defineLoss(this.sd, layerInput, labels);
        this.scorePerExampleVariable.markAsLoss();
        this.sd.createGradFunction("layerInput");
    }

    @Override
    public double computeScore(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask, boolean average) {
        if (this.sd == null) {
            this.createSameDiffInstance(preOutput.dataType());
        }
        INDArray scoreArr = this.computeScoreArray(labels, preOutput, activationFn, mask);
        double score = scoreArr.sumNumber().doubleValue();
        if (average) {
            score /= (double)scoreArr.size(0);
        }
        return score;
    }

    @Override
    public INDArray computeScoreArray(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask) {
        if (this.sd == null) {
            this.createSameDiffInstance(preOutput.dataType());
        }
        Preconditions.checkArgument(labels.size(1) == preOutput.size(1), "Labels array numColumns (size(1) = %s) does not match output layer number of outputs (nOut = %s)", labels.size(1), preOutput.size(1));
        INDArray output = activationFn.getActivation(preOutput.dup(), true);
        HashMap<String, INDArray> m = new HashMap<String, INDArray>();
        m.put("labels", labels);
        m.put("layerInput", output);
        INDArray scoreArr = this.sd.outputSingle(m, this.scorePerExampleVariable.name());
        if (mask != null) {
            LossUtil.applyMask(scoreArr, mask);
        }
        return scoreArr;
    }

    @Override
    public INDArray computeGradient(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask) {
        if (this.sd == null) {
            this.createSameDiffInstance(preOutput.dataType());
        }
        HashMap<String, INDArray> m = new HashMap<String, INDArray>();
        INDArray output = activationFn.getActivation(preOutput.dup(), true);
        m.put("labels", labels);
        m.put("layerInput", output);
        Map<String, INDArray> grads = this.sd.calculateGradients(m, "layerInput");
        INDArray gradAtActivationOutput = grads.get("layerInput");
        INDArray gradAtInput = activationFn.backprop(preOutput.dup(), gradAtActivationOutput).getFirst();
        if (mask != null) {
            LossUtil.applyMask(gradAtInput, mask);
        }
        return gradAtInput;
    }

    @Override
    public Pair<Double, INDArray> computeGradientAndScore(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask, boolean average) {
        Pair<Double, INDArray> GradientAndScore = new Pair<Double, INDArray>();
        GradientAndScore.setFirst(this.computeScore(labels, preOutput, activationFn, mask, average));
        GradientAndScore.setSecond(this.computeGradient(labels, preOutput, activationFn, mask));
        return GradientAndScore;
    }

    @Override
    public String name() {
        return this.getClass().getSimpleName();
    }
}

