/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.modelimport.keras.preprocessors;

import java.util.Arrays;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.inputs.InvalidInputTypeException;
import org.deeplearning4j.nn.conf.preprocessor.BaseInputPreProcessor;
import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.util.ArrayUtil;
import org.nd4j.shade.jackson.annotation.JsonIgnoreProperties;
import org.nd4j.shade.jackson.annotation.JsonProperty;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@JsonIgnoreProperties(value={"miniBatchSize", "staticTargetShape"})
public class ReshapePreprocessor
extends BaseInputPreProcessor {
    private static final Logger log = LoggerFactory.getLogger(ReshapePreprocessor.class);
    private final long[] inputShape;
    private final long[] targetShape;
    private boolean hasMiniBatchDimension;

    @Deprecated
    public ReshapePreprocessor(long[] inputShape, long[] targetShape) {
        this(inputShape, targetShape, false);
    }

    public ReshapePreprocessor(@JsonProperty(value="inputShape") long[] inputShape, @JsonProperty(value="targetShape") long[] targetShape, @JsonProperty(value="hasMiniBatchDimension") boolean hasMiniBatchDimension) {
        this.inputShape = inputShape;
        this.targetShape = targetShape;
        this.hasMiniBatchDimension = hasMiniBatchDimension;
    }

    private long[] getShape(long[] originalShape, long minibatch) {
        long[] newShape;
        long[] lArray = newShape = this.hasMiniBatchDimension ? originalShape : ReshapePreprocessor.prependMiniBatchSize(originalShape, minibatch);
        if (newShape[0] != minibatch) {
            newShape = (long[])newShape.clone();
            newShape[0] = minibatch;
        }
        return newShape;
    }

    private static long[] prependMiniBatchSize(long[] shape, long miniBatchSize) {
        int shapeLength = shape.length;
        long[] miniBatchShape = new long[shapeLength + 1];
        miniBatchShape[0] = miniBatchSize;
        for (int i = 1; i < miniBatchShape.length; ++i) {
            miniBatchShape[i] = shape[i - 1];
        }
        return miniBatchShape;
    }

    public INDArray preProcess(INDArray input, int miniBatchSize, LayerWorkspaceMgr workspaceMgr) {
        long[] targetShape = this.getShape(this.targetShape, miniBatchSize);
        long[] inputShape = this.getShape(this.inputShape, miniBatchSize);
        if (ArrayUtil.prodLong((long[])input.shape()) == ArrayUtil.prodLong((long[])targetShape)) {
            if (input.ordering() != 'c' || !Shape.hasDefaultStridesForShape((INDArray)input)) {
                input = workspaceMgr.dup((Enum)ArrayType.ACTIVATIONS, input, 'c');
            }
            return workspaceMgr.leverageTo(ArrayType.ACTIVATIONS, input.reshape(targetShape));
        }
        throw new IllegalStateException("Input shape " + Arrays.toString(input.shape()) + " and output shape" + Arrays.toString(inputShape) + " do not match");
    }

    public INDArray backprop(INDArray output, int miniBatchSize, LayerWorkspaceMgr workspaceMgr) {
        long[] targetShape = this.getShape(this.targetShape, miniBatchSize);
        long[] inputShape = this.getShape(this.inputShape, miniBatchSize);
        if (!Arrays.equals(targetShape, output.shape())) {
            throw new IllegalStateException("Unexpected output shape" + Arrays.toString(output.shape()) + " (expected to be " + Arrays.toString(targetShape) + ")");
        }
        if (ArrayUtil.prodLong((long[])output.shape()) == ArrayUtil.prodLong((long[])targetShape)) {
            if (output.ordering() != 'c' || !Shape.hasDefaultStridesForShape((INDArray)output)) {
                output = workspaceMgr.dup((Enum)ArrayType.ACTIVATIONS, output, 'c');
            }
            return workspaceMgr.leverageTo(ArrayType.ACTIVATION_GRAD, output.reshape(inputShape));
        }
        throw new IllegalStateException("Output shape" + Arrays.toString(output.shape()) + " and input shape" + Arrays.toString(targetShape) + " do not match");
    }

    public InputType getOutputType(InputType inputType) throws InvalidInputTypeException {
        InputType ret;
        long[] shape = this.getShape(this.targetShape, 0L);
        switch (shape.length) {
            case 2: {
                ret = InputType.feedForward((long)shape[1]);
                break;
            }
            case 3: {
                ret = InputType.recurrent((long)shape[2], (long)shape[1]);
                break;
            }
            case 4: {
                if (this.inputShape.length == 1 || inputType.getType() == InputType.Type.RNN) {
                    ret = InputType.convolutional((long)shape[1], (long)shape[2], (long)shape[3]);
                    break;
                }
                ret = InputType.convolutional((long)shape[2], (long)shape[3], (long)shape[1]);
                break;
            }
            default: {
                throw new UnsupportedOperationException("Cannot infer input type for reshape array " + Arrays.toString(shape));
            }
        }
        return ret;
    }

    public long[] getInputShape() {
        return this.inputShape;
    }

    public long[] getTargetShape() {
        return this.targetShape;
    }

    public boolean isHasMiniBatchDimension() {
        return this.hasMiniBatchDimension;
    }

    public void setHasMiniBatchDimension(boolean hasMiniBatchDimension) {
        this.hasMiniBatchDimension = hasMiniBatchDimension;
    }

    public String toString() {
        return "ReshapePreprocessor(inputShape=" + Arrays.toString(this.getInputShape()) + ", targetShape=" + Arrays.toString(this.getTargetShape()) + ", hasMiniBatchDimension=" + this.isHasMiniBatchDimension() + ")";
    }

    public boolean equals(Object o) {
        if (o == this) {
            return true;
        }
        if (!(o instanceof ReshapePreprocessor)) {
            return false;
        }
        ReshapePreprocessor other = (ReshapePreprocessor)((Object)o);
        if (!other.canEqual((Object)this)) {
            return false;
        }
        if (!Arrays.equals(this.getInputShape(), other.getInputShape())) {
            return false;
        }
        if (!Arrays.equals(this.getTargetShape(), other.getTargetShape())) {
            return false;
        }
        return this.isHasMiniBatchDimension() == other.isHasMiniBatchDimension();
    }

    protected boolean canEqual(Object other) {
        return other instanceof ReshapePreprocessor;
    }

    public int hashCode() {
        int PRIME = 59;
        int result = 1;
        result = result * 59 + Arrays.hashCode(this.getInputShape());
        result = result * 59 + Arrays.hashCode(this.getTargetShape());
        result = result * 59 + (this.isHasMiniBatchDimension() ? 79 : 97);
        return result;
    }
}

