/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.params;

import java.util.Arrays;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import org.deeplearning4j.nn.api.ParamInitializer;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.GravesLSTM;
import org.deeplearning4j.nn.conf.layers.Layer;
import org.deeplearning4j.nn.weights.IWeightInit;
import org.deeplearning4j.nn.weights.WeightInitUtil;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;

public class GravesLSTMParamInitializer
implements ParamInitializer {
    private static final GravesLSTMParamInitializer INSTANCE = new GravesLSTMParamInitializer();
    public static final String RECURRENT_WEIGHT_KEY = "RW";
    public static final String BIAS_KEY = "b";
    public static final String INPUT_WEIGHT_KEY = "W";

    public static GravesLSTMParamInitializer getInstance() {
        return INSTANCE;
    }

    @Override
    public long numParams(NeuralNetConfiguration conf) {
        return this.numParams(conf.getLayer());
    }

    @Override
    public long numParams(Layer l) {
        GravesLSTM layerConf = (GravesLSTM)l;
        long nL = layerConf.getNOut();
        long nLast = layerConf.getNIn();
        long nParams = nLast * (4L * nL) + nL * (4L * nL + 3L) + 4L * nL;
        return nParams;
    }

    @Override
    public List<String> paramKeys(Layer layer) {
        return Arrays.asList(INPUT_WEIGHT_KEY, RECURRENT_WEIGHT_KEY, BIAS_KEY);
    }

    @Override
    public List<String> weightKeys(Layer layer) {
        return Arrays.asList(INPUT_WEIGHT_KEY, RECURRENT_WEIGHT_KEY);
    }

    @Override
    public List<String> biasKeys(Layer layer) {
        return Collections.singletonList(BIAS_KEY);
    }

    @Override
    public boolean isWeightParam(Layer layer, String key) {
        return RECURRENT_WEIGHT_KEY.equals(key) || INPUT_WEIGHT_KEY.equals(key);
    }

    @Override
    public boolean isBiasParam(Layer layer, String key) {
        return BIAS_KEY.equals(key);
    }

    @Override
    public Map<String, INDArray> init(NeuralNetConfiguration conf, INDArray paramsView, boolean initializeParams) {
        Map<String, INDArray> params = Collections.synchronizedMap(new LinkedHashMap());
        GravesLSTM layerConf = (GravesLSTM)conf.getLayer();
        double forgetGateInit = layerConf.getForgetGateBiasInit();
        long nL = layerConf.getNOut();
        long nLast = layerConf.getNIn();
        conf.addVariable(INPUT_WEIGHT_KEY);
        conf.addVariable(RECURRENT_WEIGHT_KEY);
        conf.addVariable(BIAS_KEY);
        long length = this.numParams(conf);
        if (paramsView.length() != length) {
            throw new IllegalStateException("Expected params view of length " + length + ", got length " + paramsView.length());
        }
        long nParamsIn = nLast * (4L * nL);
        long nParamsRecurrent = nL * (4L * nL + 3L);
        long nBias = 4L * nL;
        INDArray paramsViewReshape = paramsView.reshape(new long[]{paramsView.length()});
        INDArray inputWeightView = paramsViewReshape.get(new INDArrayIndex[]{NDArrayIndex.interval((long)0L, (long)nParamsIn)});
        INDArray recurrentWeightView = paramsViewReshape.get(new INDArrayIndex[]{NDArrayIndex.interval((long)nParamsIn, (long)(nParamsIn + nParamsRecurrent))});
        INDArray biasView = paramsViewReshape.get(new INDArrayIndex[]{NDArrayIndex.interval((long)(nParamsIn + nParamsRecurrent), (long)(nParamsIn + nParamsRecurrent + nBias))});
        if (initializeParams) {
            long fanIn = nL;
            long fanOut = nLast + nL;
            long[] inputWShape = new long[]{nLast, 4L * nL};
            long[] recurrentWShape = new long[]{nL, 4L * nL + 3L};
            IWeightInit rwInit = layerConf.getWeightInitFnRecurrent() != null ? layerConf.getWeightInitFnRecurrent() : layerConf.getWeightInitFn();
            params.put(INPUT_WEIGHT_KEY, layerConf.getWeightInitFn().init(fanIn, fanOut, inputWShape, 'f', inputWeightView));
            params.put(RECURRENT_WEIGHT_KEY, rwInit.init(fanIn, fanOut, recurrentWShape, 'f', recurrentWeightView));
            biasView.put(new INDArrayIndex[]{NDArrayIndex.interval((long)nL, (long)(2L * nL))}, Nd4j.valueArrayOf((long[])new long[]{1L, nL}, (double)forgetGateInit));
            params.put(BIAS_KEY, biasView);
        } else {
            params.put(INPUT_WEIGHT_KEY, WeightInitUtil.reshapeWeights(new long[]{nLast, 4L * nL}, inputWeightView));
            params.put(RECURRENT_WEIGHT_KEY, WeightInitUtil.reshapeWeights(new long[]{nL, 4L * nL + 3L}, recurrentWeightView));
            params.put(BIAS_KEY, biasView);
        }
        return params;
    }

    @Override
    public Map<String, INDArray> getGradientsFromFlattened(NeuralNetConfiguration conf, INDArray gradientView) {
        GravesLSTM layerConf = (GravesLSTM)conf.getLayer();
        long nL = layerConf.getNOut();
        long nLast = layerConf.getNIn();
        long length = this.numParams(conf);
        if (gradientView.length() != length) {
            throw new IllegalStateException("Expected gradient view of length " + length + ", got length " + gradientView.length());
        }
        INDArray gradientViewReshape = gradientView.reshape(new long[]{gradientView.length()});
        long nParamsIn = nLast * (4L * nL);
        long nParamsRecurrent = nL * (4L * nL + 3L);
        long nBias = 4L * nL;
        INDArray inputWeightGradView = gradientViewReshape.get(new INDArrayIndex[]{NDArrayIndex.interval((long)0L, (long)nParamsIn)}).reshape('f', new long[]{nLast, 4L * nL});
        INDArray recurrentWeightGradView = gradientViewReshape.get(new INDArrayIndex[]{NDArrayIndex.interval((long)nParamsIn, (long)(nParamsIn + nParamsRecurrent))}).reshape('f', new long[]{nL, 4L * nL + 3L});
        INDArray biasGradView = gradientViewReshape.get(new INDArrayIndex[]{NDArrayIndex.interval((long)(nParamsIn + nParamsRecurrent), (long)(nParamsIn + nParamsRecurrent + nBias))});
        LinkedHashMap<String, INDArray> out = new LinkedHashMap<String, INDArray>();
        out.put(INPUT_WEIGHT_KEY, inputWeightGradView);
        out.put(RECURRENT_WEIGHT_KEY, recurrentWeightGradView);
        out.put(BIAS_KEY, biasGradView);
        return out;
    }
}

