_array_like.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165
  1. from __future__ import annotations
  2. # NOTE: Import `Sequence` from `typing` as we it is needed for a type-alias,
  3. # not an annotation
  4. import sys
  5. from collections.abc import Collection, Callable
  6. from typing import Any, Sequence, Protocol, Union, TypeVar, runtime_checkable
  7. from numpy import (
  8. ndarray,
  9. dtype,
  10. generic,
  11. bool_,
  12. unsignedinteger,
  13. integer,
  14. floating,
  15. complexfloating,
  16. number,
  17. timedelta64,
  18. datetime64,
  19. object_,
  20. void,
  21. str_,
  22. bytes_,
  23. )
  24. from ._nested_sequence import _NestedSequence
  25. _T = TypeVar("_T")
  26. _ScalarType = TypeVar("_ScalarType", bound=generic)
  27. _DType = TypeVar("_DType", bound="dtype[Any]")
  28. _DType_co = TypeVar("_DType_co", covariant=True, bound="dtype[Any]")
  29. # The `_SupportsArray` protocol only cares about the default dtype
  30. # (i.e. `dtype=None` or no `dtype` parameter at all) of the to-be returned
  31. # array.
  32. # Concrete implementations of the protocol are responsible for adding
  33. # any and all remaining overloads
  34. @runtime_checkable
  35. class _SupportsArray(Protocol[_DType_co]):
  36. def __array__(self) -> ndarray[Any, _DType_co]: ...
  37. @runtime_checkable
  38. class _SupportsArrayFunc(Protocol):
  39. """A protocol class representing `~class.__array_function__`."""
  40. def __array_function__(
  41. self,
  42. func: Callable[..., Any],
  43. types: Collection[type[Any]],
  44. args: tuple[Any, ...],
  45. kwargs: dict[str, Any],
  46. ) -> object: ...
  47. # TODO: Wait until mypy supports recursive objects in combination with typevars
  48. _FiniteNestedSequence = Union[
  49. _T,
  50. Sequence[_T],
  51. Sequence[Sequence[_T]],
  52. Sequence[Sequence[Sequence[_T]]],
  53. Sequence[Sequence[Sequence[Sequence[_T]]]],
  54. ]
  55. # A subset of `npt.ArrayLike` that can be parametrized w.r.t. `np.generic`
  56. _ArrayLike = Union[
  57. _SupportsArray["dtype[_ScalarType]"],
  58. _NestedSequence[_SupportsArray["dtype[_ScalarType]"]],
  59. ]
  60. # A union representing array-like objects; consists of two typevars:
  61. # One representing types that can be parametrized w.r.t. `np.dtype`
  62. # and another one for the rest
  63. _DualArrayLike = Union[
  64. _SupportsArray[_DType],
  65. _NestedSequence[_SupportsArray[_DType]],
  66. _T,
  67. _NestedSequence[_T],
  68. ]
  69. # TODO: support buffer protocols once
  70. #
  71. # https://bugs.python.org/issue27501
  72. #
  73. # is resolved. See also the mypy issue:
  74. #
  75. # https://github.com/python/typing/issues/593
  76. if sys.version_info[:2] < (3, 9):
  77. ArrayLike = _DualArrayLike[
  78. dtype,
  79. Union[bool, int, float, complex, str, bytes],
  80. ]
  81. else:
  82. ArrayLike = _DualArrayLike[
  83. dtype[Any],
  84. Union[bool, int, float, complex, str, bytes],
  85. ]
  86. # `ArrayLike<X>_co`: array-like objects that can be coerced into `X`
  87. # given the casting rules `same_kind`
  88. _ArrayLikeBool_co = _DualArrayLike[
  89. "dtype[bool_]",
  90. bool,
  91. ]
  92. _ArrayLikeUInt_co = _DualArrayLike[
  93. "dtype[Union[bool_, unsignedinteger[Any]]]",
  94. bool,
  95. ]
  96. _ArrayLikeInt_co = _DualArrayLike[
  97. "dtype[Union[bool_, integer[Any]]]",
  98. Union[bool, int],
  99. ]
  100. _ArrayLikeFloat_co = _DualArrayLike[
  101. "dtype[Union[bool_, integer[Any], floating[Any]]]",
  102. Union[bool, int, float],
  103. ]
  104. _ArrayLikeComplex_co = _DualArrayLike[
  105. "dtype[Union[bool_, integer[Any], floating[Any], complexfloating[Any, Any]]]",
  106. Union[bool, int, float, complex],
  107. ]
  108. _ArrayLikeNumber_co = _DualArrayLike[
  109. "dtype[Union[bool_, number[Any]]]",
  110. Union[bool, int, float, complex],
  111. ]
  112. _ArrayLikeTD64_co = _DualArrayLike[
  113. "dtype[Union[bool_, integer[Any], timedelta64]]",
  114. Union[bool, int],
  115. ]
  116. _ArrayLikeDT64_co = Union[
  117. _SupportsArray["dtype[datetime64]"],
  118. _NestedSequence[_SupportsArray["dtype[datetime64]"]],
  119. ]
  120. _ArrayLikeObject_co = Union[
  121. _SupportsArray["dtype[object_]"],
  122. _NestedSequence[_SupportsArray["dtype[object_]"]],
  123. ]
  124. _ArrayLikeVoid_co = Union[
  125. _SupportsArray["dtype[void]"],
  126. _NestedSequence[_SupportsArray["dtype[void]"]],
  127. ]
  128. _ArrayLikeStr_co = _DualArrayLike[
  129. "dtype[str_]",
  130. str,
  131. ]
  132. _ArrayLikeBytes_co = _DualArrayLike[
  133. "dtype[bytes_]",
  134. bytes,
  135. ]
  136. _ArrayLikeInt = _DualArrayLike[
  137. "dtype[integer[Any]]",
  138. int,
  139. ]
  140. # Extra ArrayLike type so that pyright can deal with NDArray[Any]
  141. # Used as the first overload, should only match NDArray[Any],
  142. # not any actual types.
  143. # https://github.com/numpy/numpy/pull/22193
  144. class _UnknownType:
  145. ...
  146. _ArrayLikeUnknown = _DualArrayLike[
  147. "dtype[_UnknownType]",
  148. _UnknownType,
  149. ]