/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.api.ops.impl.layers.recurrent.outputs;

import java.util.Arrays;
import java.util.List;
import org.nd4j.autodiff.samediff.SDIndex;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.common.base.Preconditions;

public class SRULayerOutputs {
    private SDVariable h;
    private SDVariable c;
    private SDVariable lastOutput = null;
    private SDVariable lastState = null;

    public SRULayerOutputs(SDVariable[] outputs) {
        Preconditions.checkArgument(outputs.length == 2, "Must have 2 SRU cell outputs, got %s", outputs.length);
        this.h = outputs[0];
        this.c = outputs[1];
    }

    public List<SDVariable> getAllOutputs() {
        return Arrays.asList(this.h, this.c);
    }

    public SDVariable getOutput() {
        return this.h;
    }

    public SDVariable getState() {
        return this.c;
    }

    public SDVariable getLastOutput() {
        if (this.lastOutput != null) {
            return this.lastOutput;
        }
        this.lastOutput = this.getOutput().get(SDIndex.all(), SDIndex.all(), SDIndex.point(-1L));
        return this.lastOutput;
    }

    public SDVariable getLastState() {
        if (this.lastState != null) {
            return this.lastState;
        }
        this.lastOutput = this.getOutput().get(SDIndex.all(), SDIndex.all(), SDIndex.point(-1L));
        return this.lastState;
    }

    public SDVariable getH() {
        return this.h;
    }

    public SDVariable getC() {
        return this.c;
    }
}

