/*
 Copyright 2019 The TensorFlow Authors. All Rights Reserved.

 Licensed under the Apache License, Version 2.0 (the "License");
 you may not use this file except in compliance with the License.
 You may obtain a copy of the License at

     http://www.apache.org/licenses/LICENSE-2.0

 Unless required by applicable law or agreed to in writing, software
 distributed under the License is distributed on an "AS IS" BASIS,
 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 See the License for the specific language governing permissions and
 limitations under the License.
 =======================================================================
 */
package org.tensorflow.ndarray.index;

import org.tensorflow.ndarray.IllegalRankException;
import org.tensorflow.ndarray.NdArray;
import org.tensorflow.ndarray.NdArrays;
import org.tensorflow.ndarray.Shape;
import org.tensorflow.ndarray.buffer.DataBuffers;

/**
 * Helper class for instantiating {@link Index} objects.
 */
public final class Indices {

  /**
   * A coordinate that selects a specific element on a given dimension.
   *
   * <p>When this index is applied to a given dimension, the dimension is resolved as a
   * single element and therefore is excluded from the computation of the rank.
   *
   * <p>For example, given a 3D matrix on the axis [x, y, z], if
   * {@code matrix.slice(all(), at(0), at(0)}, then the rank of the returned slice is 1 and its number of elements is
   * {@code x.numElements()}
   *
   * @param coord coordinate of the element on the indexed axis
   * @return index
   */
  public static Index at(long coord) {
    return new At(coord, false);
  }

  /**
   * A coordinate that selects a specific element on a given dimension.
   *
   * <p>This is equivalent to call {@link #at(long)} but where the value of the coordinate is
   * provided by an N-dimensional array.
   *
   * @param coord scalar indicating the coordinate of the element on the indexed axis
   * @return index
   * @throws IllegalRankException if {@code coord} is not a scalar (rank 0)
   */
  public static Index at(NdArray<? extends Number> coord) {
    if (coord.rank() > 0) {
      throw new IllegalRankException("Only scalars are accepted as a value index");
    }
    return new At(coord.getObject().longValue(), false);
  }

  /**
   * A coordinate that selects a specific element on a given dimension.
   *
   * <p>When this index is applied to a given dimension, the dimension is resolved as a
   * single element and therefore, if {@code keepDim} is false, is excluded from the computation of the rank. If {@code}
   * keepDim is true, the dimension is collapsed down to one element.
   *
   * <p>For example, given a 3D matrix on the axis [x, y, z], if
   * {@code matrix.slice(all(), at(0), at(0)}, then the rank of the returned slice is 1 and its number of elements is
   * {@code x.numElements()}
   *
   * @param coord coordinate of the element on the indexed axis
   * @param keepDim whether to remove the dimension.
   * @return index
   */
  public static Index at(long coord, boolean keepDim) {
    return new At(coord, keepDim);
  }

  /**
   * A coordinate that selects a specific element on a given dimension.
   *
   * <p>This is equivalent to call {@link #at(long, boolean)} but where the value of the coordinate is
   * provided by an N-dimensional array.
   * <p>
   * If {@code} keepDim is true, the dimension is collapsed down to one element instead of being removed.
   *
   * @param coord scalar indicating the coordinate of the element on the indexed axis
   * @param keepDim whether to remove the dimension.
   * @return index
   * @throws IllegalRankException if {@code coord} is not a scalar (rank 0)
   */
  public static Index at(NdArray<? extends Number> coord, boolean keepDim) {
    if (coord.rank() > 0) {
      throw new IllegalRankException("Only scalars are accepted as a value index");
    }
    return new At(coord.getObject().longValue(), keepDim);
  }

  /**
   * An index that returns all elements of a dimension in the original order.
   *
   * <p>Applying this index to a given dimension will return the original dimension
   * directly.
   *
   * <p>For example, given a vector with {@code n} elements, {@code all()} returns
   * x<sub>0</sub>, x<sub>1</sub>, ..., x<sub>n-1</sub>
   *
   * @return index
   */
  public static Index all() {
    return All.INSTANCE;
  }

  /**
   * An index that returns only specific elements on a given dimension.
   *
   * <p>For example, given a vector with {@code n} elements on the {@code x} axis, and {@code n > 10},
   * {@code seq(8, 0, 3)} returns x<sub>8</sub>, x<sub>0</sub>, x<sub>3</sub>
   *
   * @param coords coordinates of the elements in the sequence
   * @return index
   */
  public static Index seq(long... coords) {
    if (coords == null) {
      throw new IllegalArgumentException();
    }
    return new Sequence(NdArrays.wrap(Shape.of(coords.length), DataBuffers.of(coords, true, false)));
  }

  /**
   * An index that returns only specific elements on a given dimension.
   *
   * <p>This is equivalent to {@link #seq(long...)} but where the coordinates of the elements in
   * the sequence are provided by an N-dimensional array.
   *
   * @param coords vector of coordinates of the elements in the sequence
   * @return index
   * @throws IllegalRankException if {@code coords} is not a vector (rank 1)
   */
  public static Index seq(NdArray<? extends Number> coords) {
    if (coords.rank() != 1) {
      throw new IllegalRankException("Only vectors are accepted as an element index");
    }
    return new Sequence(coords);
  }

  /**
   * An index that returns only elements found at an even position in the original dimension.
   *
   * <p>For example, given a vector with {@code n} elements on the {@code x} axis, and n is even,
   * {@code even()} returns x<sub>0</sub>, x<sub>2</sub>, ..., x<sub>n-2</sub>
   *
   * @return index
   */
  public static Index even() {
    return step(2);
  }

  /**
   * An index that returns only elements found at an odd position in the original dimension.
   *
   * <p>For example, given a vector with {@code n} elements on the {@code x} axis, and n is even,
   * {@code odd()} returns x<sub>1</sub>, x<sub>3</sub>, ..., x<sub>n-1</sub>
   *
   * @return index
   */
  public static Index odd() {
    return sliceFrom(1, 2);
  }

  /**
   * An index that skips a fixed amount of coordinates between each values returned.
   *
   * <p>For example, given a vector with {@code n} elements on the {@code x} axis,
   * {@code step(k)} returns x<sub>0</sub>, x<sub>k</sub>, x<sub>k*2</sub>, ...
   *
   * @param stride the number of elements between each steps
   * @return index
   */
  public static Index step(long stride) {
    return new Step(stride);
  }

  /**
   * An index that returns only elements on a given dimension starting at a specific coordinate.
   *
   * <p>For example, given a vector with {@code n} elements on the {@code x} axis, and {@code n > k},
   * {@code from(k)} returns x<sub>k</sub>, x<sub>k+1</sub>, ..., x<sub>n-1</sub>
   *
   * @param start coordinate of the first element of the sequence
   * @return index
   */
  public static Index sliceFrom(long start) {
    return sliceFrom(start, 1);
  }

  /**
   * An index that returns only elements on a given dimension starting at a specific coordinate, using the given
   * stride.
   *
   * <p>For example, given a vector with {@code n} elements on the {@code x} axis, and {@code n > k},
   * {@code from(k)} returns x<sub>k</sub>, x<sub>k+1</sub>, ..., x<sub>n-1</sub>
   *
   * @param start coordinate of the first element of the sequence
   * @param stride the stride to use
   * @return index
   * @see #slice(long, long, long)
   */
  public static Index sliceFrom(long start, long stride) {
    return new SliceFrom(start, stride);
  }

  /**
   * An index that returns only elements on a given dimension up to a specific coordinate.
   *
   * <p>For example, given a vector with {@code n} elements on the {@code x} axis, and {@code n > k},
   * {@code to(k)} returns x<sub>0</sub>, x<sub>1</sub>, ..., x<sub>k</sub>
   *
   * @param end coordinate of the last element of the sequence (exclusive)
   * @return index
   */
  public static Index sliceTo(long end) {
    return sliceTo(end, 1);
  }

  /**
   * An index that returns only elements on a given dimension up to a specific coordinate, using the given stride.
   *
   * <p>For example, given a vector with {@code n} elements on the {@code x} axis, and {@code n > k},
   * {@code to(k)} returns x<sub>0</sub>, x<sub>1</sub>, ..., x<sub>k</sub>
   *
   * @param end coordinate of the last element of the sequence (exclusive)
   * @param stride the stride to use
   * @return index
   * @see #slice(long, long, long)
   */
  public static Index sliceTo(long end, long stride) {
    return new SliceTo(end, stride);
  }

  /**
   * An index that returns only elements on a given dimension between two coordinates.
   *
   * <p>For example, given a vector with {@code n} elements on the {@code x} axis, and {@code n > k > j},
   * {@code range(j, k)} returns x<sub>j</sub>, x<sub>j+1</sub>, ..., x<sub>k</sub>
   *
   * @param start coordinate of the first element of the sequence
   * @param end coordinate of the last element of the sequence (exclusive)
   * @return index
   */
  public static Index range(long start, long end) {
    return slice(start, end);
  }

  /**
   * An index that returns only elements on a given dimension between two coordinates.
   *
   * <p>For example, given a vector with {@code n} elements on the {@code x} axis, and {@code n > k > j},
   * {@code range(j, k)} returns x<sub>j</sub>, x<sub>j+1</sub>, ..., x<sub>k</sub>
   *
   * @return index
   */
  public static Index flip() {
    return slice(null, null, -1);
  }

  /**
   * An index that returns elements according to an hyperslab defined by {@code start}, {@code stride}, {@code count},
   * {@code block}. See {@link Hyperslab}.
   *
   * @param start Starting location for the hyperslab.
   * @param stride The number of elements to separate each element or block to be selected.
   * @param count The number of elements or blocks to select along the dimension.
   * @param block The size of the block selected from the dimension.
   * @return index
   */
  public static Index hyperslab(long start, long stride, long count, long block) {
    return new Hyperslab(start, stride, count, block);
  }

  /**
   * An index that inserts a new dimension of size 1 into the resulting array.
   *
   * @return index
   */
  public static Index newAxis() {
    return NewAxis.INSTANCE;
  }

  /**
   * An index that expands to fill all available source dimensions. Works the same as Python's {@code ...}.
   *
   * @return index
   */
  public static Index ellipsis() {
    return Ellipsis.INSTANCE;
  }

  /**
   * An index that returns elements between {@code start} and {@code end}. If {@code start} or {@code end} is {@code
   * null}, starts or ends at the beginning or the end, respectively.
   * <p>
   * Analogous to Python's {@code :} slice syntax.
   *
   * @return index
   */
  public static Index slice(long start, long end) {
    return slice(start, end, 1);
  }

  /**
   * An index that returns every {@code stride}-th element between {@code start} and {@code end}. If {@code start} or
   * {@code end} is {@code null}, starts or ends at the beginning or the end, respectively.
   * <p>
   * Analogous to Python's {@code :} slice syntax.
   *
   * @return index
   */
  public static Index slice(long start, long end, long stride) {
    return new Slice(start, end, stride);
  }

  /**
   * An index that returns elements between {@code start} and {@code end}. If {@code start} or {@code end} is {@code
   * null}, starts or ends at the beginning or the end, respectively.
   * <p>
   * Analogous to Python's {@code :} slice syntax.
   *
   * @return index
   */
  public static Index slice(Long start, Long end) {
    return slice(start, end, 1);
  }

  /**
   * An index that returns every {@code stride}-th element between {@code start} and {@code end}. If {@code start} or
   * {@code end} is {@code null}, starts or ends at the beginning or the end, respectively.
   * <p>
   * Analogous to Python's {@code :} slice syntax.
   *
   * @return index
   */
  public static Index slice(Long start, Long end, long stride) {
    if (start == null && end == null) {
      if (stride == 1) {
        return Indices.all();
      } else {
        return Indices.step(stride);
      }
    } else if (start == null) {
      return Indices.sliceTo(end, stride);
    } else if (end == null) {
      return Indices.sliceFrom(start, stride);
    }

    return slice(start.longValue(), end.longValue(), stride);
  }

}
