shape_base.pyi 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215
  1. from collections.abc import Callable, Sequence
  2. from typing import TypeVar, Any, overload, SupportsIndex, Protocol
  3. from numpy import (
  4. generic,
  5. integer,
  6. ufunc,
  7. bool_,
  8. unsignedinteger,
  9. signedinteger,
  10. floating,
  11. complexfloating,
  12. object_,
  13. )
  14. from numpy._typing import (
  15. ArrayLike,
  16. NDArray,
  17. _ShapeLike,
  18. _ArrayLike,
  19. _ArrayLikeBool_co,
  20. _ArrayLikeUInt_co,
  21. _ArrayLikeInt_co,
  22. _ArrayLikeFloat_co,
  23. _ArrayLikeComplex_co,
  24. _ArrayLikeObject_co,
  25. )
  26. from numpy.core.shape_base import vstack
  27. _SCT = TypeVar("_SCT", bound=generic)
  28. # The signatures of `__array_wrap__` and `__array_prepare__` are the same;
  29. # give them unique names for the sake of clarity
  30. class _ArrayWrap(Protocol):
  31. def __call__(
  32. self,
  33. array: NDArray[Any],
  34. context: None | tuple[ufunc, tuple[Any, ...], int] = ...,
  35. /,
  36. ) -> Any: ...
  37. class _ArrayPrepare(Protocol):
  38. def __call__(
  39. self,
  40. array: NDArray[Any],
  41. context: None | tuple[ufunc, tuple[Any, ...], int] = ...,
  42. /,
  43. ) -> Any: ...
  44. class _SupportsArrayWrap(Protocol):
  45. @property
  46. def __array_wrap__(self) -> _ArrayWrap: ...
  47. class _SupportsArrayPrepare(Protocol):
  48. @property
  49. def __array_prepare__(self) -> _ArrayPrepare: ...
  50. __all__: list[str]
  51. row_stack = vstack
  52. def take_along_axis(
  53. arr: _SCT | NDArray[_SCT],
  54. indices: NDArray[integer[Any]],
  55. axis: None | int,
  56. ) -> NDArray[_SCT]: ...
  57. def put_along_axis(
  58. arr: NDArray[_SCT],
  59. indices: NDArray[integer[Any]],
  60. values: ArrayLike,
  61. axis: None | int,
  62. ) -> None: ...
  63. # TODO: Use PEP 612 `ParamSpec` once mypy supports `Concatenate`
  64. # xref python/mypy#8645
  65. @overload
  66. def apply_along_axis(
  67. func1d: Callable[..., _ArrayLike[_SCT]],
  68. axis: SupportsIndex,
  69. arr: ArrayLike,
  70. *args: Any,
  71. **kwargs: Any,
  72. ) -> NDArray[_SCT]: ...
  73. @overload
  74. def apply_along_axis(
  75. func1d: Callable[..., ArrayLike],
  76. axis: SupportsIndex,
  77. arr: ArrayLike,
  78. *args: Any,
  79. **kwargs: Any,
  80. ) -> NDArray[Any]: ...
  81. def apply_over_axes(
  82. func: Callable[[NDArray[Any], int], NDArray[_SCT]],
  83. a: ArrayLike,
  84. axes: int | Sequence[int],
  85. ) -> NDArray[_SCT]: ...
  86. @overload
  87. def expand_dims(
  88. a: _ArrayLike[_SCT],
  89. axis: _ShapeLike,
  90. ) -> NDArray[_SCT]: ...
  91. @overload
  92. def expand_dims(
  93. a: ArrayLike,
  94. axis: _ShapeLike,
  95. ) -> NDArray[Any]: ...
  96. @overload
  97. def column_stack(tup: Sequence[_ArrayLike[_SCT]]) -> NDArray[_SCT]: ...
  98. @overload
  99. def column_stack(tup: Sequence[ArrayLike]) -> NDArray[Any]: ...
  100. @overload
  101. def dstack(tup: Sequence[_ArrayLike[_SCT]]) -> NDArray[_SCT]: ...
  102. @overload
  103. def dstack(tup: Sequence[ArrayLike]) -> NDArray[Any]: ...
  104. @overload
  105. def array_split(
  106. ary: _ArrayLike[_SCT],
  107. indices_or_sections: _ShapeLike,
  108. axis: SupportsIndex = ...,
  109. ) -> list[NDArray[_SCT]]: ...
  110. @overload
  111. def array_split(
  112. ary: ArrayLike,
  113. indices_or_sections: _ShapeLike,
  114. axis: SupportsIndex = ...,
  115. ) -> list[NDArray[Any]]: ...
  116. @overload
  117. def split(
  118. ary: _ArrayLike[_SCT],
  119. indices_or_sections: _ShapeLike,
  120. axis: SupportsIndex = ...,
  121. ) -> list[NDArray[_SCT]]: ...
  122. @overload
  123. def split(
  124. ary: ArrayLike,
  125. indices_or_sections: _ShapeLike,
  126. axis: SupportsIndex = ...,
  127. ) -> list[NDArray[Any]]: ...
  128. @overload
  129. def hsplit(
  130. ary: _ArrayLike[_SCT],
  131. indices_or_sections: _ShapeLike,
  132. ) -> list[NDArray[_SCT]]: ...
  133. @overload
  134. def hsplit(
  135. ary: ArrayLike,
  136. indices_or_sections: _ShapeLike,
  137. ) -> list[NDArray[Any]]: ...
  138. @overload
  139. def vsplit(
  140. ary: _ArrayLike[_SCT],
  141. indices_or_sections: _ShapeLike,
  142. ) -> list[NDArray[_SCT]]: ...
  143. @overload
  144. def vsplit(
  145. ary: ArrayLike,
  146. indices_or_sections: _ShapeLike,
  147. ) -> list[NDArray[Any]]: ...
  148. @overload
  149. def dsplit(
  150. ary: _ArrayLike[_SCT],
  151. indices_or_sections: _ShapeLike,
  152. ) -> list[NDArray[_SCT]]: ...
  153. @overload
  154. def dsplit(
  155. ary: ArrayLike,
  156. indices_or_sections: _ShapeLike,
  157. ) -> list[NDArray[Any]]: ...
  158. @overload
  159. def get_array_prepare(*args: _SupportsArrayPrepare) -> _ArrayPrepare: ...
  160. @overload
  161. def get_array_prepare(*args: object) -> None | _ArrayPrepare: ...
  162. @overload
  163. def get_array_wrap(*args: _SupportsArrayWrap) -> _ArrayWrap: ...
  164. @overload
  165. def get_array_wrap(*args: object) -> None | _ArrayWrap: ...
  166. @overload
  167. def kron(a: _ArrayLikeBool_co, b: _ArrayLikeBool_co) -> NDArray[bool_]: ... # type: ignore[misc]
  168. @overload
  169. def kron(a: _ArrayLikeUInt_co, b: _ArrayLikeUInt_co) -> NDArray[unsignedinteger[Any]]: ... # type: ignore[misc]
  170. @overload
  171. def kron(a: _ArrayLikeInt_co, b: _ArrayLikeInt_co) -> NDArray[signedinteger[Any]]: ... # type: ignore[misc]
  172. @overload
  173. def kron(a: _ArrayLikeFloat_co, b: _ArrayLikeFloat_co) -> NDArray[floating[Any]]: ... # type: ignore[misc]
  174. @overload
  175. def kron(a: _ArrayLikeComplex_co, b: _ArrayLikeComplex_co) -> NDArray[complexfloating[Any, Any]]: ...
  176. @overload
  177. def kron(a: _ArrayLikeObject_co, b: Any) -> NDArray[object_]: ...
  178. @overload
  179. def kron(a: Any, b: _ArrayLikeObject_co) -> NDArray[object_]: ...
  180. @overload
  181. def tile(
  182. A: _ArrayLike[_SCT],
  183. reps: int | Sequence[int],
  184. ) -> NDArray[_SCT]: ...
  185. @overload
  186. def tile(
  187. A: ArrayLike,
  188. reps: int | Sequence[int],
  189. ) -> NDArray[Any]: ...