/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.preprocessors;

import java.util.Arrays;
import org.deeplearning4j.nn.conf.CNN2DFormat;
import org.deeplearning4j.nn.conf.DataFormat;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.inputs.InvalidInputTypeException;
import org.deeplearning4j.nn.conf.layers.Convolution3D;
import org.deeplearning4j.nn.conf.preprocessor.BaseInputPreProcessor;
import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.nd4j.common.util.ArrayUtil;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.shade.jackson.annotation.JsonIgnoreProperties;
import org.nd4j.shade.jackson.annotation.JsonProperty;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@JsonIgnoreProperties(value={"miniBatchSize", "staticTargetShape"})
public class ReshapePreprocessor
extends BaseInputPreProcessor {
    private static final Logger log = LoggerFactory.getLogger(ReshapePreprocessor.class);
    private final long[] inputShape;
    private final long[] targetShape;
    private boolean hasMiniBatchDimension;
    private DataFormat format;

    public ReshapePreprocessor(long[] inputShape, long[] targetShape, boolean hasMiniBatchDimension) {
        this(inputShape, targetShape, hasMiniBatchDimension, null);
    }

    public ReshapePreprocessor(@JsonProperty(value="inputShape") long[] inputShape, @JsonProperty(value="targetShape") long[] targetShape, @JsonProperty(value="hasMiniBatchDimension") boolean hasMiniBatchDimension, @JsonProperty(value="dataFormat") DataFormat dataFormat) {
        this.inputShape = inputShape;
        this.targetShape = targetShape;
        this.hasMiniBatchDimension = hasMiniBatchDimension;
        this.format = dataFormat;
    }

    private long[] getShape(long[] originalShape, long minibatch) {
        long[] newShape;
        long[] lArray = newShape = this.hasMiniBatchDimension ? originalShape : ReshapePreprocessor.prependMiniBatchSize(originalShape, minibatch);
        if (newShape[0] != minibatch) {
            newShape = (long[])newShape.clone();
            newShape[0] = minibatch;
        }
        return newShape;
    }

    private static long[] prependMiniBatchSize(long[] shape, long miniBatchSize) {
        int shapeLength = shape.length;
        long[] miniBatchShape = new long[shapeLength + 1];
        miniBatchShape[0] = miniBatchSize;
        for (int i = 1; i < miniBatchShape.length; ++i) {
            miniBatchShape[i] = shape[i - 1];
        }
        return miniBatchShape;
    }

    @Override
    public INDArray preProcess(INDArray input, int miniBatchSize, LayerWorkspaceMgr workspaceMgr) {
        long[] targetShape = this.getShape(this.targetShape, miniBatchSize);
        if (ArrayUtil.prodLong((long[])input.shape()) == ArrayUtil.prodLong((long[])targetShape)) {
            if (input.ordering() != 'c' || !Shape.hasDefaultStridesForShape((INDArray)input)) {
                input = workspaceMgr.dup(ArrayType.ACTIVATIONS, input, 'c');
            }
            return workspaceMgr.leverageTo(ArrayType.ACTIVATIONS, input.reshape(targetShape));
        }
        throw new IllegalStateException("Input shape " + Arrays.toString(input.shape()) + " and target shape" + Arrays.toString(targetShape) + " do not match");
    }

    @Override
    public INDArray backprop(INDArray output, int miniBatchSize, LayerWorkspaceMgr workspaceMgr) {
        long[] targetShape = this.getShape(this.targetShape, miniBatchSize);
        long[] inputShape = this.getShape(this.inputShape, miniBatchSize);
        if (!Arrays.equals(targetShape, output.shape())) {
            throw new IllegalStateException("Unexpected output shape" + Arrays.toString(output.shape()) + " (expected to be " + Arrays.toString(targetShape) + ")");
        }
        if (ArrayUtil.prodLong((long[])output.shape()) == ArrayUtil.prodLong((long[])targetShape)) {
            if (output.ordering() != 'c' || !Shape.hasDefaultStridesForShape((INDArray)output)) {
                output = workspaceMgr.dup(ArrayType.ACTIVATIONS, output, 'c');
            }
            return workspaceMgr.leverageTo(ArrayType.ACTIVATION_GRAD, output.reshape(inputShape));
        }
        throw new IllegalStateException("Output shape" + Arrays.toString(output.shape()) + " and input shape" + Arrays.toString(targetShape) + " do not match");
    }

    @Override
    public InputType getOutputType(InputType inputType) throws InvalidInputTypeException {
        InputType ret;
        long[] shape = this.getShape(this.targetShape, 0L);
        switch (shape.length) {
            case 2: {
                ret = InputType.feedForward(shape[1]);
                break;
            }
            case 3: {
                RNNFormat format = RNNFormat.NWC;
                if (this.format != null && this.format instanceof RNNFormat) {
                    format = (RNNFormat)this.format;
                }
                ret = InputType.recurrent(shape[2], shape[1], format);
                break;
            }
            case 4: {
                if (this.inputShape.length == 1 || inputType.getType() == InputType.Type.RNN) {
                    ret = InputType.convolutional(shape[1], shape[2], shape[3], CNN2DFormat.NHWC);
                    break;
                }
                CNN2DFormat cnnFormat = CNN2DFormat.NCHW;
                if (this.format != null && this.format instanceof CNN2DFormat) {
                    cnnFormat = (CNN2DFormat)this.format;
                }
                if (cnnFormat == CNN2DFormat.NCHW) {
                    ret = InputType.convolutional(shape[2], shape[3], shape[1], cnnFormat);
                    break;
                }
                ret = InputType.convolutional(shape[1], shape[2], shape[3], cnnFormat);
                break;
            }
            case 5: {
                if (this.inputShape.length == 1 || inputType.getType() == InputType.Type.RNN) {
                    Convolution3D.DataFormat dataFormat = (Convolution3D.DataFormat)this.format;
                    if (dataFormat == Convolution3D.DataFormat.NCDHW) {
                        ret = InputType.convolutional3D(dataFormat, shape[2], shape[3], shape[4], shape[1]);
                        break;
                    }
                    if (dataFormat == Convolution3D.DataFormat.NDHWC || dataFormat == null) {
                        ret = InputType.convolutional3D(dataFormat, shape[1], shape[2], shape[3], shape[4]);
                        break;
                    }
                    throw new IllegalArgumentException("Illegal format found " + dataFormat);
                }
                CNN2DFormat cnnFormat = CNN2DFormat.NCHW;
                if (this.format != null && this.format instanceof CNN2DFormat) {
                    cnnFormat = (CNN2DFormat)this.format;
                }
                if (cnnFormat == CNN2DFormat.NCHW) {
                    ret = InputType.convolutional(shape[2], shape[3], shape[1], cnnFormat);
                    break;
                }
                ret = InputType.convolutional(shape[1], shape[2], shape[3], cnnFormat);
                break;
            }
            default: {
                throw new UnsupportedOperationException("Cannot infer input type for reshape array " + Arrays.toString(shape));
            }
        }
        return ret;
    }

    public long[] getInputShape() {
        return this.inputShape;
    }

    public long[] getTargetShape() {
        return this.targetShape;
    }

    public boolean isHasMiniBatchDimension() {
        return this.hasMiniBatchDimension;
    }

    public DataFormat getFormat() {
        return this.format;
    }

    public void setHasMiniBatchDimension(boolean hasMiniBatchDimension) {
        this.hasMiniBatchDimension = hasMiniBatchDimension;
    }

    public void setFormat(DataFormat format) {
        this.format = format;
    }

    public String toString() {
        return "ReshapePreprocessor(inputShape=" + Arrays.toString(this.getInputShape()) + ", targetShape=" + Arrays.toString(this.getTargetShape()) + ", hasMiniBatchDimension=" + this.isHasMiniBatchDimension() + ", format=" + this.getFormat() + ")";
    }

    public boolean equals(Object o) {
        if (o == this) {
            return true;
        }
        if (!(o instanceof ReshapePreprocessor)) {
            return false;
        }
        ReshapePreprocessor other = (ReshapePreprocessor)o;
        if (!other.canEqual(this)) {
            return false;
        }
        if (this.isHasMiniBatchDimension() != other.isHasMiniBatchDimension()) {
            return false;
        }
        if (!Arrays.equals(this.getInputShape(), other.getInputShape())) {
            return false;
        }
        if (!Arrays.equals(this.getTargetShape(), other.getTargetShape())) {
            return false;
        }
        DataFormat this$format = this.getFormat();
        DataFormat other$format = other.getFormat();
        return !(this$format == null ? other$format != null : !this$format.equals(other$format));
    }

    protected boolean canEqual(Object other) {
        return other instanceof ReshapePreprocessor;
    }

    public int hashCode() {
        int PRIME = 59;
        int result = 1;
        result = result * 59 + (this.isHasMiniBatchDimension() ? 79 : 97);
        result = result * 59 + Arrays.hashCode(this.getInputShape());
        result = result * 59 + Arrays.hashCode(this.getTargetShape());
        DataFormat $format = this.getFormat();
        result = result * 59 + ($format == null ? 43 : $format.hashCode());
        return result;
    }
}

