_manipulation_functions.py 2.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. from __future__ import annotations
  2. from ._array_object import Array
  3. from ._data_type_functions import result_type
  4. from typing import List, Optional, Tuple, Union
  5. import numpy as np
  6. # Note: the function name is different here
  7. def concat(
  8. arrays: Union[Tuple[Array, ...], List[Array]], /, *, axis: Optional[int] = 0
  9. ) -> Array:
  10. """
  11. Array API compatible wrapper for :py:func:`np.concatenate <numpy.concatenate>`.
  12. See its docstring for more information.
  13. """
  14. # Note: Casting rules here are different from the np.concatenate default
  15. # (no for scalars with axis=None, no cross-kind casting)
  16. dtype = result_type(*arrays)
  17. arrays = tuple(a._array for a in arrays)
  18. return Array._new(np.concatenate(arrays, axis=axis, dtype=dtype))
  19. def expand_dims(x: Array, /, *, axis: int) -> Array:
  20. """
  21. Array API compatible wrapper for :py:func:`np.expand_dims <numpy.expand_dims>`.
  22. See its docstring for more information.
  23. """
  24. return Array._new(np.expand_dims(x._array, axis))
  25. def flip(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Array:
  26. """
  27. Array API compatible wrapper for :py:func:`np.flip <numpy.flip>`.
  28. See its docstring for more information.
  29. """
  30. return Array._new(np.flip(x._array, axis=axis))
  31. # Note: The function name is different here (see also matrix_transpose).
  32. # Unlike transpose(), the axes argument is required.
  33. def permute_dims(x: Array, /, axes: Tuple[int, ...]) -> Array:
  34. """
  35. Array API compatible wrapper for :py:func:`np.transpose <numpy.transpose>`.
  36. See its docstring for more information.
  37. """
  38. return Array._new(np.transpose(x._array, axes))
  39. # Note: the optional argument is called 'shape', not 'newshape'
  40. def reshape(x: Array, /, shape: Tuple[int, ...]) -> Array:
  41. """
  42. Array API compatible wrapper for :py:func:`np.reshape <numpy.reshape>`.
  43. See its docstring for more information.
  44. """
  45. return Array._new(np.reshape(x._array, shape))
  46. def roll(
  47. x: Array,
  48. /,
  49. shift: Union[int, Tuple[int, ...]],
  50. *,
  51. axis: Optional[Union[int, Tuple[int, ...]]] = None,
  52. ) -> Array:
  53. """
  54. Array API compatible wrapper for :py:func:`np.roll <numpy.roll>`.
  55. See its docstring for more information.
  56. """
  57. return Array._new(np.roll(x._array, shift, axis=axis))
  58. def squeeze(x: Array, /, axis: Union[int, Tuple[int, ...]]) -> Array:
  59. """
  60. Array API compatible wrapper for :py:func:`np.squeeze <numpy.squeeze>`.
  61. See its docstring for more information.
  62. """
  63. return Array._new(np.squeeze(x._array, axis=axis))
  64. def stack(arrays: Union[Tuple[Array, ...], List[Array]], /, *, axis: int = 0) -> Array:
  65. """
  66. Array API compatible wrapper for :py:func:`np.stack <numpy.stack>`.
  67. See its docstring for more information.
  68. """
  69. # Call result type here just to raise on disallowed type combinations
  70. result_type(*arrays)
  71. arrays = tuple(a._array for a in arrays)
  72. return Array._new(np.stack(arrays, axis=axis))