/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.util;

import java.util.Arrays;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;

public class TimeSeriesUtils {
    private TimeSeriesUtils() {
    }

    public static INDArray movingAverage(INDArray toAvg, int n) {
        INDArray ret = Nd4j.cumsum((INDArray)toAvg);
        INDArrayIndex[] ends = new INDArrayIndex[]{NDArrayIndex.interval((int)n, (int)toAvg.columns())};
        INDArrayIndex[] begins = new INDArrayIndex[]{NDArrayIndex.interval((int)0, (int)(toAvg.columns() - n), (boolean)false)};
        INDArrayIndex[] nMinusOne = new INDArrayIndex[]{NDArrayIndex.interval((int)(n - 1), (int)toAvg.columns())};
        ret.put(ends, ret.get(ends).sub(ret.get(begins)));
        return ret.get(nMinusOne).divi((Number)n);
    }

    public static INDArray reshapeTimeSeriesMaskToVector(INDArray timeSeriesMask) {
        if (timeSeriesMask.rank() != 2) {
            throw new IllegalArgumentException("Cannot reshape mask: rank is not 2");
        }
        if (timeSeriesMask.ordering() != 'f') {
            timeSeriesMask = timeSeriesMask.dup('f');
        }
        return timeSeriesMask.reshape('f', new int[]{timeSeriesMask.length(), 1});
    }

    public static INDArray reshapeVectorToTimeSeriesMask(INDArray timeSeriesMaskAsVector, int minibatchSize) {
        if (!timeSeriesMaskAsVector.isVector()) {
            throw new IllegalArgumentException("Cannot reshape mask: expected vector");
        }
        int timeSeriesLength = timeSeriesMaskAsVector.length() / minibatchSize;
        return timeSeriesMaskAsVector.reshape('f', new int[]{minibatchSize, timeSeriesLength});
    }

    public static INDArray reshapePerOutputTimeSeriesMaskTo2d(INDArray perOutputTimeSeriesMask) {
        if (perOutputTimeSeriesMask.rank() != 3) {
            throw new IllegalArgumentException("Cannot reshape per output mask: rank is not 3 (is: " + perOutputTimeSeriesMask.rank() + ", shape = " + Arrays.toString(perOutputTimeSeriesMask.shape()) + ")");
        }
        return TimeSeriesUtils.reshape3dTo2d(perOutputTimeSeriesMask);
    }

    public static INDArray reshape3dTo2d(INDArray in) {
        if (in.rank() != 3) {
            throw new IllegalArgumentException("Invalid input: expect NDArray with rank 3");
        }
        int[] shape = in.shape();
        if (shape[0] == 1) {
            return in.tensorAlongDimension(0, new int[]{1, 2}).permutei(new int[]{1, 0});
        }
        if (shape[2] == 1) {
            return in.tensorAlongDimension(0, new int[]{1, 0});
        }
        INDArray permuted = in.permute(new int[]{0, 2, 1});
        return permuted.reshape('f', shape[0] * shape[2], shape[1]);
    }

    public static INDArray reshape2dTo3d(INDArray in, int miniBatchSize) {
        if (in.rank() != 2) {
            throw new IllegalArgumentException("Invalid input: expect NDArray with rank 2");
        }
        int[] shape = in.shape();
        if (in.ordering() != 'f') {
            in = Shape.toOffsetZeroCopy((INDArray)in, (char)'f');
        }
        INDArray reshaped = in.reshape('f', new int[]{miniBatchSize, shape[0] / miniBatchSize, shape[1]});
        return reshaped.permute(new int[]{0, 2, 1});
    }
}

