/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.updater;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.api.Updater;
import org.deeplearning4j.nn.conf.GradientNormalization;
import org.deeplearning4j.nn.conf.layers.BaseLayer;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.updater.UpdaterBlock;
import org.deeplearning4j.nn.updater.UpdaterUtils;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.Accumulation;
import org.nd4j.linalg.api.ops.impl.accum.Norm2;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.BooleanIndexing;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.indexing.conditions.Condition;
import org.nd4j.linalg.indexing.conditions.Conditions;

public abstract class BaseMultiLayerUpdater<T extends Model>
implements Updater {
    protected final T network;
    protected Map<String, Layer> layersByName;
    protected final List<UpdaterBlock> updaterBlocks;
    protected INDArray updaterStateViewArray;

    public BaseMultiLayerUpdater(T network) {
        this(network, null);
    }

    public BaseMultiLayerUpdater(T network, INDArray updaterState) {
        this.network = network;
        Layer[] layers = this.getOrderedLayers();
        int updaterStateSize = 0;
        Layer lastLayer = null;
        String lastVariable = null;
        UpdaterBlock currentBlock = null;
        this.updaterBlocks = new ArrayList<UpdaterBlock>();
        INDArray paramsView = network.params();
        INDArray gradientView = this.getFlattenedGradientsView();
        int paramsViewSoFar = 0;
        int currentUpdaterOffset = 0;
        for (int i = 0; i < layers.length; ++i) {
            Map<String, INDArray> layerParamTable = layers[i].paramTable();
            if (layerParamTable == null) continue;
            ArrayList<String> variables = new ArrayList<String>(layerParamTable.keySet());
            for (int j = 0; j < variables.size(); ++j) {
                String var = (String)variables.get(j);
                int paramSizeThisVariable = layerParamTable.get(var).length();
                int updaterStateSizeThisVariable = (int)layers[i].conf().getLayer().getIUpdaterByParam(var).stateSize((long)paramSizeThisVariable);
                INDArray gradientViewSubset = null;
                INDArray paramsViewSubset = null;
                if (paramSizeThisVariable > 0) {
                    paramsViewSubset = paramsView.get(new INDArrayIndex[]{NDArrayIndex.point((int)0), NDArrayIndex.interval((int)paramsViewSoFar, (int)(paramsViewSoFar + paramSizeThisVariable))});
                    gradientViewSubset = gradientView.get(new INDArrayIndex[]{NDArrayIndex.point((int)0), NDArrayIndex.interval((int)paramsViewSoFar, (int)(paramsViewSoFar + paramSizeThisVariable))});
                }
                if (currentBlock == null || !UpdaterUtils.updaterConfigurationsEquals(lastLayer, lastVariable, layers[i], var)) {
                    ArrayList<UpdaterBlock.ParamState> list = new ArrayList<UpdaterBlock.ParamState>();
                    list.add(new UpdaterBlock.ParamState(layers[i], var, paramsViewSoFar, paramsViewSoFar + paramSizeThisVariable, paramsViewSubset, gradientViewSubset));
                    currentBlock = new UpdaterBlock(paramsViewSoFar, paramsViewSoFar + paramSizeThisVariable, currentUpdaterOffset, currentUpdaterOffset + updaterStateSizeThisVariable, list);
                    this.updaterBlocks.add(currentBlock);
                } else {
                    currentBlock.setParamOffsetEnd(currentBlock.getParamOffsetEnd() + paramSizeThisVariable);
                    currentBlock.setUpdaterViewOffsetEnd(currentBlock.getUpdaterViewOffsetEnd() + updaterStateSizeThisVariable);
                    currentBlock.getLayersAndVariablesInBlock().add(new UpdaterBlock.ParamState(layers[i], var, paramsViewSoFar, paramsViewSoFar + paramSizeThisVariable, paramsViewSubset, gradientViewSubset));
                }
                lastLayer = layers[i];
                lastVariable = (String)variables.get(j);
                updaterStateSize += updaterStateSizeThisVariable;
                paramsViewSoFar += paramSizeThisVariable;
                currentUpdaterOffset += updaterStateSizeThisVariable;
            }
        }
        boolean updaterRequiresInit = false;
        if (updaterState != null) {
            this.updaterStateViewArray = updaterState;
            updaterRequiresInit = false;
        } else if (updaterStateSize > 0) {
            this.updaterStateViewArray = Nd4j.createUninitialized((int[])new int[]{1, updaterStateSize}, (char)Nd4j.order().charValue());
            updaterRequiresInit = true;
        }
        int updaterViewSoFar = 0;
        paramsViewSoFar = 0;
        for (int i = 0; i < this.updaterBlocks.size(); ++i) {
            UpdaterBlock ub = this.updaterBlocks.get(i);
            int viewStateSize = ub.getUpdaterViewOffsetEnd() - ub.getUpdaterViewOffsetStart();
            int gradSize = ub.getParamOffsetEnd() - ub.getParamOffsetStart();
            if (viewStateSize > 0) {
                INDArray updaterViewSubset = this.updaterStateViewArray.get(new INDArrayIndex[]{NDArrayIndex.point((int)0), NDArrayIndex.interval((int)updaterViewSoFar, (int)(updaterViewSoFar + viewStateSize))});
                ub.setUpdaterView(updaterViewSubset);
                ub.setUpdaterViewRequiresInitialization(updaterRequiresInit);
            }
            if (gradSize > 0) {
                INDArray gradientViewSubset = gradientView.get(new INDArrayIndex[]{NDArrayIndex.point((int)0), NDArrayIndex.interval((int)paramsViewSoFar, (int)(paramsViewSoFar + gradSize))});
                ub.setGradientView(gradientViewSubset);
            }
            ub.init();
            updaterViewSoFar += viewStateSize;
            paramsViewSoFar += gradSize;
        }
    }

    protected abstract Layer[] getOrderedLayers();

    protected abstract INDArray getFlattenedGradientsView();

    protected abstract INDArray getParams();

    protected abstract boolean isMiniBatch();

    public void setStateViewArray(INDArray viewArray) {
        if (this.updaterStateViewArray.length() != viewArray.length()) {
            throw new IllegalStateException("Invalid input: view arrays differ in length. Expected length " + this.updaterStateViewArray.length() + ", got length " + viewArray.length());
        }
        this.updaterStateViewArray.assign(viewArray);
    }

    @Override
    public void setStateViewArray(Layer layer, INDArray viewArray, boolean initialize) {
        this.setStateViewArray(viewArray);
    }

    @Override
    public INDArray getStateViewArray() {
        return this.updaterStateViewArray;
    }

    @Override
    public void update(Layer layer, Gradient gradient, int iteration, int batchSize) {
        this.update(gradient, iteration, batchSize);
    }

    public void update(Gradient gradient, int iteration, int batchSize) {
        boolean isExternal = gradient.gradient() != this.getFlattenedGradientsView();
        HashMap<String, Gradient> layerGradients = new HashMap<String, Gradient>();
        Layer[] layers = this.getOrderedLayers();
        if (layers.length == 1 && this.isSingleLayerUpdater()) {
            layerGradients.put(layers[0].conf().getLayer().getLayerName(), gradient);
        } else {
            for (Map.Entry<String, INDArray> entry : gradient.gradientForVariable().entrySet()) {
                String key = entry.getKey();
                int idx = key.lastIndexOf(95);
                if (idx == -1) {
                    throw new IllegalStateException("Invalid key: Gradient key does not have layer separator: \"" + key + "\"");
                }
                String layerName = key.substring(0, idx);
                Gradient g = (Gradient)layerGradients.get(layerName);
                if (g == null) {
                    g = new DefaultGradient();
                    layerGradients.put(layerName, g);
                }
                String newKey = key.substring(idx + 1);
                g.setGradientFor(newKey, entry.getValue());
            }
        }
        for (Map.Entry<String, Object> entry : layerGradients.entrySet()) {
            String layerName = entry.getKey();
            Layer layer = this.layersByName.get(layerName);
            this.preApply(layer, (Gradient)layerGradients.get(layerName), iteration);
        }
        for (UpdaterBlock updaterBlock : this.updaterBlocks) {
            if (updaterBlock.skipDueToPretrainConfig()) continue;
            if (Nd4j.getWorkspaceManager().checkIfWorkspaceExists("LOOP_FF")) {
                MemoryWorkspace workspace = Nd4j.getWorkspaceManager().getAndActivateWorkspace("LOOP_FF");
                Throwable throwable = null;
                try {
                    if (isExternal) {
                        updaterBlock.updateExternalGradient(iteration, gradient.gradient(), this.getParams());
                        continue;
                    }
                    updaterBlock.update(iteration);
                    continue;
                }
                catch (Throwable throwable2) {
                    throwable = throwable2;
                    throw throwable2;
                }
                finally {
                    if (workspace == null) continue;
                    if (throwable != null) {
                        try {
                            workspace.close();
                        }
                        catch (Throwable throwable3) {
                            throwable.addSuppressed(throwable3);
                        }
                        continue;
                    }
                    workspace.close();
                    continue;
                }
            }
            if (isExternal) {
                updaterBlock.updateExternalGradient(iteration, gradient.gradient(), this.getParams());
                continue;
            }
            updaterBlock.update(iteration);
        }
        if (this.isMiniBatch()) {
            if (isExternal) {
                gradient.gradient().divi((Number)batchSize);
            } else {
                this.getFlattenedGradientsView().divi((Number)batchSize);
            }
        }
    }

    protected boolean isSingleLayerUpdater() {
        return false;
    }

    public void preApply(Layer layer, Gradient gradient, int iteration) {
        if (!(layer.conf().getLayer() instanceof BaseLayer)) {
            return;
        }
        BaseLayer bLayer = (BaseLayer)layer.conf().getLayer();
        GradientNormalization normalization = bLayer.getGradientNormalization();
        if (normalization == null || normalization == GradientNormalization.None || layer.conf().isPretrain()) {
            return;
        }
        double threshold = bLayer.getGradientNormalizationThreshold();
        INDArray layerGradientView = layer.getGradientsViewArray();
        switch (normalization) {
            case RenormalizeL2PerLayer: {
                if (layerGradientView == null) break;
                double l2 = layerGradientView.norm2Number().doubleValue();
                layerGradientView.divi((Number)l2);
                break;
            }
            case RenormalizeL2PerParamType: {
                for (INDArray g : gradient.gradientForVariable().values()) {
                    double l2 = Nd4j.getExecutioner().execAndReturn((Accumulation)new Norm2(g)).getFinalResult().doubleValue();
                    g.divi((Number)l2);
                }
                break;
            }
            case ClipElementWiseAbsoluteValue: {
                if (layerGradientView == null) break;
                BooleanIndexing.replaceWhere((INDArray)layerGradientView, (Number)threshold, (Condition)Conditions.greaterThan((Number)threshold));
                BooleanIndexing.replaceWhere((INDArray)layerGradientView, (Number)(-threshold), (Condition)Conditions.lessThan((Number)(-threshold)));
                break;
            }
            case ClipL2PerLayer: {
                double layerL2;
                if (layerGradientView == null || !((layerL2 = layerGradientView.norm2Number().doubleValue()) > threshold)) break;
                double scalingFactor = threshold / layerL2;
                layerGradientView.muli((Number)scalingFactor);
                break;
            }
            case ClipL2PerParamType: {
                for (INDArray g : gradient.gradientForVariable().values()) {
                    double l2 = g.norm2Number().doubleValue();
                    if (!(l2 > threshold)) continue;
                    double scalingFactor = l2 / threshold;
                    g.divi((Number)scalingFactor);
                }
                break;
            }
            default: {
                throw new RuntimeException("Unknown (or not implemented) gradient normalization strategy: " + (Object)((Object)normalization));
            }
        }
    }

    public boolean equals(Object o) {
        if (this == o) {
            return true;
        }
        if (o == null || this.getClass() != o.getClass()) {
            return false;
        }
        BaseMultiLayerUpdater that = (BaseMultiLayerUpdater)o;
        return this.updaterStateViewArray != null ? this.updaterStateViewArray.equals(that.updaterStateViewArray) : that.updaterStateViewArray == null;
    }

    public int hashCode() {
        int result = this.layersByName != null ? this.layersByName.hashCode() : 0;
        result = 31 * result + (this.updaterBlocks != null ? this.updaterBlocks.hashCode() : 0);
        result = 31 * result + (this.updaterStateViewArray != null ? this.updaterStateViewArray.hashCode() : 0);
        return result;
    }

    public T getNetwork() {
        return this.network;
    }

    public Map<String, Layer> getLayersByName() {
        return this.layersByName;
    }

    public List<UpdaterBlock> getUpdaterBlocks() {
        return this.updaterBlocks;
    }

    public INDArray getUpdaterStateViewArray() {
        return this.updaterStateViewArray;
    }
}

