/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.params;

import java.util.Arrays;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import org.deeplearning4j.nn.api.ParamInitializer;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.Convolution1DLayer;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
import org.deeplearning4j.nn.conf.layers.Layer;
import org.deeplearning4j.nn.weights.WeightInitUtil;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;

public class ConvolutionParamInitializer
implements ParamInitializer {
    private static final ConvolutionParamInitializer INSTANCE = new ConvolutionParamInitializer();
    public static final String WEIGHT_KEY = "W";
    public static final String BIAS_KEY = "b";

    public static ConvolutionParamInitializer getInstance() {
        return INSTANCE;
    }

    @Override
    public long numParams(NeuralNetConfiguration conf) {
        return this.numParams(conf.getLayer());
    }

    @Override
    public long numParams(Layer l) {
        ConvolutionLayer layerConf = (ConvolutionLayer)l;
        int[] kernel = layerConf.getKernelSize();
        long nIn = layerConf.getNIn();
        long nOut = layerConf.getNOut();
        if (layerConf instanceof Convolution1DLayer) {
            return nIn * nOut * (long)kernel[0] + (layerConf.hasBias() ? nOut : 0L);
        }
        return nIn * nOut * (long)kernel[0] * (long)kernel[1] + (layerConf.hasBias() ? nOut : 0L);
    }

    @Override
    public List<String> paramKeys(Layer layer) {
        ConvolutionLayer layerConf = (ConvolutionLayer)layer;
        if (layerConf.hasBias()) {
            return Arrays.asList(WEIGHT_KEY, BIAS_KEY);
        }
        return this.weightKeys(layer);
    }

    @Override
    public List<String> weightKeys(Layer layer) {
        return Collections.singletonList(WEIGHT_KEY);
    }

    @Override
    public List<String> biasKeys(Layer layer) {
        ConvolutionLayer layerConf = (ConvolutionLayer)layer;
        if (layerConf.hasBias()) {
            return Collections.singletonList(BIAS_KEY);
        }
        return Collections.emptyList();
    }

    @Override
    public boolean isWeightParam(Layer layer, String key) {
        return WEIGHT_KEY.equals(key);
    }

    @Override
    public boolean isBiasParam(Layer layer, String key) {
        return BIAS_KEY.equals(key);
    }

    @Override
    public Map<String, INDArray> init(NeuralNetConfiguration conf, INDArray paramsView, boolean initializeParams) {
        ConvolutionLayer layer = (ConvolutionLayer)conf.getLayer();
        if (layer.getKernelSize().length != 2) {
            throw new IllegalArgumentException("Filter size must be == 2");
        }
        Map<String, INDArray> params = Collections.synchronizedMap(new LinkedHashMap());
        ConvolutionLayer layerConf = (ConvolutionLayer)conf.getLayer();
        long nOut = layerConf.getNOut();
        INDArray paramsViewReshape = paramsView.reshape(new long[]{paramsView.length()});
        if (layer.hasBias()) {
            INDArray biasView = paramsViewReshape.get(new INDArrayIndex[]{NDArrayIndex.interval((long)0L, (long)nOut)});
            INDArray weightView = paramsViewReshape.get(new INDArrayIndex[]{NDArrayIndex.interval((long)nOut, (long)this.numParams(conf))});
            params.put(BIAS_KEY, this.createBias(conf, biasView, initializeParams));
            params.put(WEIGHT_KEY, this.createWeightMatrix(conf, weightView, initializeParams));
            conf.addVariable(WEIGHT_KEY);
            conf.addVariable(BIAS_KEY);
            conf.addVariable(BIAS_KEY);
        } else {
            INDArray weightView = paramsView;
            params.put(WEIGHT_KEY, this.createWeightMatrix(conf, weightView, initializeParams));
            conf.addVariable(WEIGHT_KEY);
        }
        return params;
    }

    @Override
    public Map<String, INDArray> getGradientsFromFlattened(NeuralNetConfiguration conf, INDArray gradientView) {
        ConvolutionLayer layerConf = (ConvolutionLayer)conf.getLayer();
        int[] kernel = layerConf.getKernelSize();
        long nIn = layerConf.getNIn();
        long nOut = layerConf.getNOut();
        INDArray gradientViewReshape = gradientView.reshape(new long[]{gradientView.length()});
        LinkedHashMap<String, INDArray> out = new LinkedHashMap<String, INDArray>();
        if (layerConf.hasBias()) {
            if (layerConf instanceof Convolution1DLayer) {
                INDArray biasGradientView = gradientViewReshape.get(new INDArrayIndex[]{NDArrayIndex.interval((long)0L, (long)nOut)});
                INDArray weightGradientView = gradientViewReshape.get(new INDArrayIndex[]{NDArrayIndex.interval((long)nOut, (long)this.numParams(conf))}).reshape('c', new long[]{nOut, nIn, kernel[0]});
                out.put(BIAS_KEY, biasGradientView);
                out.put(WEIGHT_KEY, weightGradientView);
            } else {
                INDArray biasGradientView = gradientViewReshape.get(new INDArrayIndex[]{NDArrayIndex.interval((long)0L, (long)nOut)});
                INDArray weightGradientView = gradientViewReshape.get(new INDArrayIndex[]{NDArrayIndex.interval((long)nOut, (long)this.numParams(conf))}).reshape('c', new long[]{nOut, nIn, kernel[0], kernel[1]});
                out.put(BIAS_KEY, biasGradientView);
                out.put(WEIGHT_KEY, weightGradientView);
            }
        } else if (layerConf instanceof Convolution1DLayer) {
            INDArray weightGradientView = gradientView.reshape('c', new long[]{nOut, nIn, kernel[0]});
            out.put(WEIGHT_KEY, weightGradientView);
        } else {
            INDArray weightGradientView = gradientView.reshape('c', new long[]{nOut, nIn, kernel[0], kernel[1]});
            out.put(WEIGHT_KEY, weightGradientView);
        }
        return out;
    }

    protected INDArray createBias(NeuralNetConfiguration conf, INDArray biasView, boolean initializeParams) {
        ConvolutionLayer layerConf = (ConvolutionLayer)conf.getLayer();
        if (initializeParams) {
            biasView.assign((Number)layerConf.getBiasInit());
        }
        return biasView;
    }

    protected INDArray createWeightMatrix(NeuralNetConfiguration conf, INDArray weightView, boolean initializeParams) {
        long[] lArray;
        ConvolutionLayer layerConf = (ConvolutionLayer)conf.getLayer();
        if (initializeParams) {
            long[] lArray2;
            int[] kernel = layerConf.getKernelSize();
            int[] stride = layerConf.getStride();
            long inputDepth = layerConf.getNIn();
            long outputDepth = layerConf.getNOut();
            double fanIn = inputDepth * (long)kernel[0] * (long)kernel[1];
            double fanOut = (double)(outputDepth * (long)kernel[0] * (long)kernel[1]) / ((double)stride[0] * (double)stride[1]);
            if (layerConf instanceof Convolution1DLayer) {
                long[] lArray3 = new long[4];
                lArray3[0] = outputDepth;
                lArray3[1] = inputDepth;
                lArray3[2] = kernel[0];
                lArray2 = lArray3;
                lArray3[3] = 1L;
            } else {
                long[] lArray4 = new long[4];
                lArray4[0] = outputDepth;
                lArray4[1] = inputDepth;
                lArray4[2] = kernel[0];
                lArray2 = lArray4;
                lArray4[3] = kernel[1];
            }
            long[] weightsShape = lArray2;
            return layerConf.getWeightInitFn().init(fanIn, fanOut, weightsShape, 'c', weightView);
        }
        int[] kernel = layerConf.getKernelSize();
        if (layerConf instanceof Convolution1DLayer) {
            long[] lArray5 = new long[4];
            lArray5[0] = layerConf.getNOut();
            lArray5[1] = layerConf.getNIn();
            lArray5[2] = kernel[0];
            lArray = lArray5;
            lArray5[3] = 1L;
        } else {
            long[] lArray6 = new long[4];
            lArray6[0] = layerConf.getNOut();
            lArray6[1] = layerConf.getNIn();
            lArray6[2] = kernel[0];
            lArray = lArray6;
            lArray6[3] = kernel[1];
        }
        long[] realWeights = lArray;
        return WeightInitUtil.reshapeWeights(realWeights, weightView, 'c');
    }
}

