/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.layers.recurrent;

import java.util.Arrays;
import lombok.NonNull;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.MaskState;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.recurrent.BaseRecurrentLayer;
import org.deeplearning4j.nn.layers.wrapper.BaseWrapperLayer;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.nd4j.common.primitives.Pair;
import org.nd4j.linalg.api.ndarray.INDArray;

public class MaskZeroLayer
extends BaseWrapperLayer {
    private static final long serialVersionUID = -7369482676002469854L;
    private double maskingValue;

    public MaskZeroLayer(@NonNull Layer underlying, double maskingValue) {
        super(underlying);
        if (underlying == null) {
            throw new NullPointerException("underlying is marked non-null but is null");
        }
        this.maskingValue = maskingValue;
    }

    @Override
    public Layer.Type type() {
        return Layer.Type.RECURRENT;
    }

    @Override
    public Pair<Gradient, INDArray> backpropGradient(INDArray epsilon, LayerWorkspaceMgr workspaceMgr) {
        return this.underlying.backpropGradient(epsilon, workspaceMgr);
    }

    @Override
    public INDArray activate(boolean training, LayerWorkspaceMgr workspaceMgr) {
        INDArray input = this.input();
        this.setMaskFromInput(input);
        return this.underlying.activate(training, workspaceMgr);
    }

    @Override
    public INDArray activate(INDArray input, boolean training, LayerWorkspaceMgr workspaceMgr) {
        this.setMaskFromInput(input);
        return this.underlying.activate(input, training, workspaceMgr);
    }

    private void setMaskFromInput(INDArray input) {
        if (input.rank() != 3) {
            throw new IllegalArgumentException("Expected input of shape [batch_size, timestep_input_size, timestep], got shape " + Arrays.toString(input.shape()) + " instead");
        }
        if (this.underlying instanceof BaseRecurrentLayer && ((BaseRecurrentLayer)this.underlying).getDataFormat() == RNNFormat.NWC) {
            input = input.permute(new int[]{0, 2, 1});
        }
        INDArray mask = input.eq((Number)this.maskingValue).castTo(input.dataType()).sum(new int[]{1}).neq((Number)input.shape()[1]).castTo(input.dataType());
        this.underlying.setMaskArray(mask.detach());
    }

    @Override
    public long numParams() {
        return this.underlying.numParams();
    }

    @Override
    public Pair<INDArray, MaskState> feedForwardMaskArray(INDArray maskArray, MaskState currentMaskState, int minibatchSize) {
        this.underlying.feedForwardMaskArray(maskArray, currentMaskState, minibatchSize);
        return new Pair(null, (Object)currentMaskState);
    }
}

