_generic_alias.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245
  1. from __future__ import annotations
  2. import sys
  3. import types
  4. from collections.abc import Generator, Iterable, Iterator
  5. from typing import (
  6. Any,
  7. ClassVar,
  8. NoReturn,
  9. TypeVar,
  10. TYPE_CHECKING,
  11. )
  12. import numpy as np
  13. __all__ = ["_GenericAlias", "NDArray"]
  14. _T = TypeVar("_T", bound="_GenericAlias")
  15. def _to_str(obj: object) -> str:
  16. """Helper function for `_GenericAlias.__repr__`."""
  17. if obj is Ellipsis:
  18. return '...'
  19. elif isinstance(obj, type) and not isinstance(obj, _GENERIC_ALIAS_TYPE):
  20. if obj.__module__ == 'builtins':
  21. return obj.__qualname__
  22. else:
  23. return f'{obj.__module__}.{obj.__qualname__}'
  24. else:
  25. return repr(obj)
  26. def _parse_parameters(args: Iterable[Any]) -> Generator[TypeVar, None, None]:
  27. """Search for all typevars and typevar-containing objects in `args`.
  28. Helper function for `_GenericAlias.__init__`.
  29. """
  30. for i in args:
  31. if hasattr(i, "__parameters__"):
  32. yield from i.__parameters__
  33. elif isinstance(i, TypeVar):
  34. yield i
  35. def _reconstruct_alias(alias: _T, parameters: Iterator[TypeVar]) -> _T:
  36. """Recursively replace all typevars with those from `parameters`.
  37. Helper function for `_GenericAlias.__getitem__`.
  38. """
  39. args = []
  40. for i in alias.__args__:
  41. if isinstance(i, TypeVar):
  42. value: Any = next(parameters)
  43. elif isinstance(i, _GenericAlias):
  44. value = _reconstruct_alias(i, parameters)
  45. elif hasattr(i, "__parameters__"):
  46. prm_tup = tuple(next(parameters) for _ in i.__parameters__)
  47. value = i[prm_tup]
  48. else:
  49. value = i
  50. args.append(value)
  51. cls = type(alias)
  52. return cls(alias.__origin__, tuple(args), alias.__unpacked__)
  53. class _GenericAlias:
  54. """A python-based backport of the `types.GenericAlias` class.
  55. E.g. for ``t = list[int]``, ``t.__origin__`` is ``list`` and
  56. ``t.__args__`` is ``(int,)``.
  57. See Also
  58. --------
  59. :pep:`585`
  60. The PEP responsible for introducing `types.GenericAlias`.
  61. """
  62. __slots__ = (
  63. "__weakref__",
  64. "_origin",
  65. "_args",
  66. "_parameters",
  67. "_hash",
  68. "_starred",
  69. )
  70. @property
  71. def __origin__(self) -> type:
  72. return super().__getattribute__("_origin")
  73. @property
  74. def __args__(self) -> tuple[object, ...]:
  75. return super().__getattribute__("_args")
  76. @property
  77. def __parameters__(self) -> tuple[TypeVar, ...]:
  78. """Type variables in the ``GenericAlias``."""
  79. return super().__getattribute__("_parameters")
  80. @property
  81. def __unpacked__(self) -> bool:
  82. return super().__getattribute__("_starred")
  83. @property
  84. def __typing_unpacked_tuple_args__(self) -> tuple[object, ...] | None:
  85. # NOTE: This should return `__args__` if `__origin__` is a tuple,
  86. # which should never be the case with how `_GenericAlias` is used
  87. # within numpy
  88. return None
  89. def __init__(
  90. self,
  91. origin: type,
  92. args: object | tuple[object, ...],
  93. starred: bool = False,
  94. ) -> None:
  95. self._origin = origin
  96. self._args = args if isinstance(args, tuple) else (args,)
  97. self._parameters = tuple(_parse_parameters(self.__args__))
  98. self._starred = starred
  99. @property
  100. def __call__(self) -> type[Any]:
  101. return self.__origin__
  102. def __reduce__(self: _T) -> tuple[
  103. type[_T],
  104. tuple[type[Any], tuple[object, ...], bool],
  105. ]:
  106. cls = type(self)
  107. return cls, (self.__origin__, self.__args__, self.__unpacked__)
  108. def __mro_entries__(self, bases: Iterable[object]) -> tuple[type[Any]]:
  109. return (self.__origin__,)
  110. def __dir__(self) -> list[str]:
  111. """Implement ``dir(self)``."""
  112. cls = type(self)
  113. dir_origin = set(dir(self.__origin__))
  114. return sorted(cls._ATTR_EXCEPTIONS | dir_origin)
  115. def __hash__(self) -> int:
  116. """Return ``hash(self)``."""
  117. # Attempt to use the cached hash
  118. try:
  119. return super().__getattribute__("_hash")
  120. except AttributeError:
  121. self._hash: int = (
  122. hash(self.__origin__) ^
  123. hash(self.__args__) ^
  124. hash(self.__unpacked__)
  125. )
  126. return super().__getattribute__("_hash")
  127. def __instancecheck__(self, obj: object) -> NoReturn:
  128. """Check if an `obj` is an instance."""
  129. raise TypeError("isinstance() argument 2 cannot be a "
  130. "parameterized generic")
  131. def __subclasscheck__(self, cls: type) -> NoReturn:
  132. """Check if a `cls` is a subclass."""
  133. raise TypeError("issubclass() argument 2 cannot be a "
  134. "parameterized generic")
  135. def __repr__(self) -> str:
  136. """Return ``repr(self)``."""
  137. args = ", ".join(_to_str(i) for i in self.__args__)
  138. origin = _to_str(self.__origin__)
  139. prefix = "*" if self.__unpacked__ else ""
  140. return f"{prefix}{origin}[{args}]"
  141. def __getitem__(self: _T, key: object | tuple[object, ...]) -> _T:
  142. """Return ``self[key]``."""
  143. key_tup = key if isinstance(key, tuple) else (key,)
  144. if len(self.__parameters__) == 0:
  145. raise TypeError(f"There are no type variables left in {self}")
  146. elif len(key_tup) > len(self.__parameters__):
  147. raise TypeError(f"Too many arguments for {self}")
  148. elif len(key_tup) < len(self.__parameters__):
  149. raise TypeError(f"Too few arguments for {self}")
  150. key_iter = iter(key_tup)
  151. return _reconstruct_alias(self, key_iter)
  152. def __eq__(self, value: object) -> bool:
  153. """Return ``self == value``."""
  154. if not isinstance(value, _GENERIC_ALIAS_TYPE):
  155. return NotImplemented
  156. return (
  157. self.__origin__ == value.__origin__ and
  158. self.__args__ == value.__args__ and
  159. self.__unpacked__ == getattr(
  160. value, "__unpacked__", self.__unpacked__
  161. )
  162. )
  163. def __iter__(self: _T) -> Generator[_T, None, None]:
  164. """Return ``iter(self)``."""
  165. cls = type(self)
  166. yield cls(self.__origin__, self.__args__, True)
  167. _ATTR_EXCEPTIONS: ClassVar[frozenset[str]] = frozenset({
  168. "__origin__",
  169. "__args__",
  170. "__parameters__",
  171. "__mro_entries__",
  172. "__reduce__",
  173. "__reduce_ex__",
  174. "__copy__",
  175. "__deepcopy__",
  176. "__unpacked__",
  177. "__typing_unpacked_tuple_args__",
  178. "__class__",
  179. })
  180. def __getattribute__(self, name: str) -> Any:
  181. """Return ``getattr(self, name)``."""
  182. # Pull the attribute from `__origin__` unless its
  183. # name is in `_ATTR_EXCEPTIONS`
  184. cls = type(self)
  185. if name in cls._ATTR_EXCEPTIONS:
  186. return super().__getattribute__(name)
  187. return getattr(self.__origin__, name)
  188. # See `_GenericAlias.__eq__`
  189. if sys.version_info >= (3, 9):
  190. _GENERIC_ALIAS_TYPE = (_GenericAlias, types.GenericAlias)
  191. else:
  192. _GENERIC_ALIAS_TYPE = (_GenericAlias,)
  193. ScalarType = TypeVar("ScalarType", bound=np.generic, covariant=True)
  194. if TYPE_CHECKING or sys.version_info >= (3, 9):
  195. _DType = np.dtype[ScalarType]
  196. NDArray = np.ndarray[Any, np.dtype[ScalarType]]
  197. else:
  198. _DType = _GenericAlias(np.dtype, (ScalarType,))
  199. NDArray = _GenericAlias(np.ndarray, (Any, _DType))