/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.layers;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.deeplearning4j.nn.api.MaskState;
import org.deeplearning4j.nn.api.layers.IOutputLayer;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.BaseLayer;
import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.deeplearning4j.optimize.Solver;
import org.nd4j.common.base.Preconditions;
import org.nd4j.common.primitives.Pair;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.lossfunctions.ILossFunction;

public abstract class BaseOutputLayer<LayerConfT extends org.deeplearning4j.nn.conf.layers.BaseOutputLayer>
extends BaseLayer<LayerConfT>
implements Serializable,
IOutputLayer {
    protected INDArray labels;
    private transient Solver solver;
    private double fullNetRegTerm;
    protected INDArray inputMaskArray;
    protected MaskState inputMaskArrayState;

    public BaseOutputLayer(NeuralNetConfiguration conf, DataType dataType) {
        super(conf, dataType);
    }

    @Override
    public double computeScore(double fullNetRegTerm, boolean training, LayerWorkspaceMgr workspaceMgr) {
        if (this.input == null || this.labels == null) {
            throw new IllegalStateException("Cannot calculate score without input and labels " + this.layerId());
        }
        this.fullNetRegTerm = fullNetRegTerm;
        INDArray preOut = this.preOutput2d(training, workspaceMgr);
        ILossFunction lossFunction = ((org.deeplearning4j.nn.conf.layers.BaseOutputLayer)this.layerConf()).getLossFn();
        INDArray labels2d = this.getLabels2d(workspaceMgr, ArrayType.FF_WORKING_MEM);
        double score = lossFunction.computeScore(labels2d, preOut, ((org.deeplearning4j.nn.conf.layers.BaseOutputLayer)this.layerConf()).getActivationFn(), this.maskArray, false);
        if (this.conf().isMiniBatch()) {
            score /= (double)this.getInputMiniBatchSize();
        }
        this.score = score += fullNetRegTerm;
        return score;
    }

    @Override
    public boolean needsLabels() {
        return true;
    }

    @Override
    public INDArray computeScoreForExamples(double fullNetRegTerm, LayerWorkspaceMgr workspaceMgr) {
        if (this.input == null || this.labels == null) {
            throw new IllegalStateException("Cannot calculate score without input and labels " + this.layerId());
        }
        INDArray preOut = this.preOutput2d(false, workspaceMgr);
        ILossFunction lossFunction = ((org.deeplearning4j.nn.conf.layers.BaseOutputLayer)this.layerConf()).getLossFn();
        INDArray scoreArray = lossFunction.computeScoreArray(this.getLabels2d(workspaceMgr, ArrayType.FF_WORKING_MEM), preOut, ((org.deeplearning4j.nn.conf.layers.BaseOutputLayer)this.layerConf()).getActivationFn(), this.maskArray);
        if (fullNetRegTerm != 0.0) {
            scoreArray.addi((Number)fullNetRegTerm);
        }
        return workspaceMgr.leverageTo(ArrayType.ACTIVATIONS, scoreArray);
    }

    @Override
    public void computeGradientAndScore(LayerWorkspaceMgr workspaceMgr) {
        if (this.input == null || this.labels == null) {
            return;
        }
        INDArray preOut = this.preOutput2d(true, workspaceMgr);
        Pair<Gradient, INDArray> pair = this.getGradientsAndDelta(preOut, workspaceMgr);
        this.gradient = (Gradient)pair.getFirst();
        this.score = this.computeScore(this.fullNetRegTerm, true, workspaceMgr);
    }

    @Override
    protected void setScoreWithZ(INDArray z) {
        throw new RuntimeException("Not supported - " + this.layerId());
    }

    @Override
    public Pair<Gradient, Double> gradientAndScore() {
        return new Pair((Object)this.gradient(), (Object)this.score());
    }

    @Override
    public Pair<Gradient, INDArray> backpropGradient(INDArray epsilon, LayerWorkspaceMgr workspaceMgr) {
        this.assertInputSet(true);
        Pair<Gradient, INDArray> pair = this.getGradientsAndDelta(this.preOutput2d(true, workspaceMgr), workspaceMgr);
        INDArray delta = (INDArray)pair.getSecond();
        INDArray w = this.getParamWithNoise("W", true, workspaceMgr);
        INDArray epsilonNext = workspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, delta.dataType(), new long[]{w.size(0), delta.size(0)}, 'f');
        epsilonNext = w.mmuli(delta.transpose(), epsilonNext).transpose();
        epsilonNext = this.backpropDropOutIfPresent(epsilonNext);
        return new Pair(pair.getFirst(), (Object)epsilonNext);
    }

    @Override
    public Gradient gradient() {
        return this.gradient;
    }

    private Pair<Gradient, INDArray> getGradientsAndDelta(INDArray preOut, LayerWorkspaceMgr workspaceMgr) {
        ILossFunction lossFunction = ((org.deeplearning4j.nn.conf.layers.BaseOutputLayer)this.layerConf()).getLossFn();
        INDArray labels2d = this.getLabels2d(workspaceMgr, ArrayType.BP_WORKING_MEM);
        INDArray delta = lossFunction.computeGradient(labels2d, preOut, ((org.deeplearning4j.nn.conf.layers.BaseOutputLayer)this.layerConf()).getActivationFn(), this.maskArray);
        DefaultGradient gradient = new DefaultGradient();
        INDArray weightGradView = (INDArray)this.gradientViews.get("W");
        Nd4j.gemm((INDArray)this.input.castTo(weightGradView.dataType()), (INDArray)delta, (INDArray)weightGradView, (boolean)true, (boolean)false, (double)1.0, (double)0.0);
        gradient.gradientForVariable().put("W", weightGradView);
        if (this.hasBias()) {
            INDArray biasGradView = (INDArray)this.gradientViews.get("b");
            delta.sum(biasGradView, new int[]{0});
            gradient.gradientForVariable().put("b", biasGradView);
        }
        delta = workspaceMgr.leverageTo(ArrayType.ACTIVATION_GRAD, delta);
        return new Pair((Object)gradient, (Object)delta);
    }

    @Override
    public INDArray activate(INDArray input, boolean training, LayerWorkspaceMgr workspaceMgr) {
        this.setInput(input, workspaceMgr);
        return this.activate(training, workspaceMgr);
    }

    @Override
    public double f1Score(DataSet data) {
        return this.f1Score(data.getFeatures(), data.getLabels());
    }

    @Override
    public double f1Score(INDArray examples, INDArray labels) {
        Evaluation eval = new Evaluation();
        eval.eval(labels, this.activate(examples, false, LayerWorkspaceMgr.noWorkspacesImmutable()));
        return eval.f1();
    }

    @Override
    public int numLabels() {
        return (int)this.labels.size(1);
    }

    @Override
    public void fit(DataSetIterator iter) {
        while (iter.hasNext()) {
            this.fit((DataSet)iter.next());
        }
    }

    @Override
    public int[] predict(INDArray input) {
        INDArray output = this.activate(input, false, LayerWorkspaceMgr.noWorkspacesImmutable());
        Preconditions.checkState((output.rank() == 2 ? 1 : 0) != 0, (String)"predict(INDArray) method can only be used on rank 2 output - got array with rank %s", (int)output.rank());
        return output.argMax(new int[]{1}).toIntVector();
    }

    @Override
    public List<String> predict(DataSet dataSet) {
        int[] intRet = this.predict(dataSet.getFeatures());
        ArrayList<String> ret = new ArrayList<String>();
        for (int i : intRet) {
            ret.add(i, dataSet.getLabelName(i));
        }
        return ret;
    }

    @Override
    public void fit(INDArray input, INDArray labels) {
        throw new UnsupportedOperationException("Not supported");
    }

    @Override
    public void fit(DataSet data) {
        throw new UnsupportedOperationException("Not supported");
    }

    @Override
    public void fit(INDArray examples, int[] labels) {
        throw new UnsupportedOperationException("Not supported");
    }

    @Override
    public void clear() {
        super.clear();
        this.labels = null;
        this.solver = null;
        this.inputMaskArrayState = null;
        this.inputMaskArray = null;
        this.fullNetRegTerm = 0.0;
    }

    @Override
    public void fit(INDArray data, LayerWorkspaceMgr workspaceMgr) {
        throw new UnsupportedOperationException("Not supported");
    }

    @Override
    public INDArray getLabels() {
        return this.labels;
    }

    @Override
    public void setLabels(INDArray labels) {
        this.labels = labels;
    }

    protected INDArray preOutput2d(boolean training, LayerWorkspaceMgr workspaceMgr) {
        return this.preOutput(training, workspaceMgr);
    }

    @Override
    protected void applyMask(INDArray to) {
        if (this.maskArray.isColumnVectorOrScalar()) {
            to.muliColumnVector(this.maskArray.castTo(to.dataType()));
        } else if (Arrays.equals(to.shape(), this.maskArray.shape())) {
            to.muli(this.maskArray.castTo(to.dataType()));
        } else {
            throw new IllegalStateException("Invalid mask array: per-example masking should be a column vector, per output masking arrays should be the same shape as the output/labels arrays. Mask shape: " + Arrays.toString(this.maskArray.shape()) + ", output shape: " + Arrays.toString(to.shape()) + this.layerId());
        }
    }

    protected abstract INDArray getLabels2d(LayerWorkspaceMgr var1, ArrayType var2);

    @Override
    public boolean isPretrainLayer() {
        return false;
    }

    @Override
    public boolean hasBias() {
        return ((org.deeplearning4j.nn.conf.layers.BaseOutputLayer)this.layerConf()).hasBias();
    }
}

