test_generic_alias.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188
  1. from __future__ import annotations
  2. import sys
  3. import copy
  4. import types
  5. import pickle
  6. import weakref
  7. from typing import TypeVar, Any, Union, Callable
  8. import pytest
  9. import numpy as np
  10. from numpy._typing._generic_alias import _GenericAlias
  11. from typing_extensions import Unpack
  12. ScalarType = TypeVar("ScalarType", bound=np.generic, covariant=True)
  13. T1 = TypeVar("T1")
  14. T2 = TypeVar("T2")
  15. DType = _GenericAlias(np.dtype, (ScalarType,))
  16. NDArray = _GenericAlias(np.ndarray, (Any, DType))
  17. # NOTE: The `npt._GenericAlias` *class* isn't quite stable on python >=3.11.
  18. # This is not a problem during runtime (as it's 3.8-exclusive), but we still
  19. # need it for the >=3.9 in order to verify its semantics match
  20. # `types.GenericAlias` replacement. xref numpy/numpy#21526
  21. if sys.version_info >= (3, 9):
  22. DType_ref = types.GenericAlias(np.dtype, (ScalarType,))
  23. NDArray_ref = types.GenericAlias(np.ndarray, (Any, DType_ref))
  24. FuncType = Callable[["_GenericAlias | types.GenericAlias"], Any]
  25. else:
  26. DType_ref = Any
  27. NDArray_ref = Any
  28. FuncType = Callable[["_GenericAlias"], Any]
  29. GETATTR_NAMES = sorted(set(dir(np.ndarray)) - _GenericAlias._ATTR_EXCEPTIONS)
  30. BUFFER = np.array([1], dtype=np.int64)
  31. BUFFER.setflags(write=False)
  32. def _get_subclass_mro(base: type) -> tuple[type, ...]:
  33. class Subclass(base): # type: ignore[misc,valid-type]
  34. pass
  35. return Subclass.__mro__[1:]
  36. class TestGenericAlias:
  37. """Tests for `numpy._typing._generic_alias._GenericAlias`."""
  38. @pytest.mark.parametrize("name,func", [
  39. ("__init__", lambda n: n),
  40. ("__init__", lambda n: _GenericAlias(np.ndarray, Any)),
  41. ("__init__", lambda n: _GenericAlias(np.ndarray, (Any,))),
  42. ("__init__", lambda n: _GenericAlias(np.ndarray, (Any, Any))),
  43. ("__init__", lambda n: _GenericAlias(np.ndarray, T1)),
  44. ("__init__", lambda n: _GenericAlias(np.ndarray, (T1,))),
  45. ("__init__", lambda n: _GenericAlias(np.ndarray, (T1, T2))),
  46. ("__origin__", lambda n: n.__origin__),
  47. ("__args__", lambda n: n.__args__),
  48. ("__parameters__", lambda n: n.__parameters__),
  49. ("__mro_entries__", lambda n: n.__mro_entries__([object])),
  50. ("__hash__", lambda n: hash(n)),
  51. ("__repr__", lambda n: repr(n)),
  52. ("__getitem__", lambda n: n[np.float64]),
  53. ("__getitem__", lambda n: n[ScalarType][np.float64]),
  54. ("__getitem__", lambda n: n[Union[np.int64, ScalarType]][np.float64]),
  55. ("__getitem__", lambda n: n[Union[T1, T2]][np.float32, np.float64]),
  56. ("__eq__", lambda n: n == n),
  57. ("__ne__", lambda n: n != np.ndarray),
  58. ("__call__", lambda n: n((1,), np.int64, BUFFER)),
  59. ("__call__", lambda n: n(shape=(1,), dtype=np.int64, buffer=BUFFER)),
  60. ("subclassing", lambda n: _get_subclass_mro(n)),
  61. ("pickle", lambda n: n == pickle.loads(pickle.dumps(n))),
  62. ])
  63. def test_pass(self, name: str, func: FuncType) -> None:
  64. """Compare `types.GenericAlias` with its numpy-based backport.
  65. Checker whether ``func`` runs as intended and that both `GenericAlias`
  66. and `_GenericAlias` return the same result.
  67. """
  68. value = func(NDArray)
  69. if sys.version_info >= (3, 9):
  70. value_ref = func(NDArray_ref)
  71. assert value == value_ref
  72. @pytest.mark.parametrize("name,func", [
  73. ("__copy__", lambda n: n == copy.copy(n)),
  74. ("__deepcopy__", lambda n: n == copy.deepcopy(n)),
  75. ])
  76. def test_copy(self, name: str, func: FuncType) -> None:
  77. value = func(NDArray)
  78. # xref bpo-45167
  79. GE_398 = (
  80. sys.version_info[:2] == (3, 9) and sys.version_info >= (3, 9, 8)
  81. )
  82. if GE_398 or sys.version_info >= (3, 10, 1):
  83. value_ref = func(NDArray_ref)
  84. assert value == value_ref
  85. def test_dir(self) -> None:
  86. value = dir(NDArray)
  87. if sys.version_info < (3, 9):
  88. return
  89. # A number attributes only exist in `types.GenericAlias` in >= 3.11
  90. if sys.version_info < (3, 11, 0, "beta", 3):
  91. value.remove("__typing_unpacked_tuple_args__")
  92. if sys.version_info < (3, 11, 0, "beta", 1):
  93. value.remove("__unpacked__")
  94. assert value == dir(NDArray_ref)
  95. @pytest.mark.parametrize("name,func,dev_version", [
  96. ("__iter__", lambda n: len(list(n)), ("beta", 1)),
  97. ("__iter__", lambda n: next(iter(n)), ("beta", 1)),
  98. ("__unpacked__", lambda n: n.__unpacked__, ("beta", 1)),
  99. ("Unpack", lambda n: Unpack[n], ("beta", 1)),
  100. # The right operand should now have `__unpacked__ = True`,
  101. # and they are thus now longer equivalent
  102. ("__ne__", lambda n: n != next(iter(n)), ("beta", 1)),
  103. # >= beta3
  104. ("__typing_unpacked_tuple_args__",
  105. lambda n: n.__typing_unpacked_tuple_args__, ("beta", 3)),
  106. # >= beta4
  107. ("__class__", lambda n: n.__class__ == type(n), ("beta", 4)),
  108. ])
  109. def test_py311_features(
  110. self,
  111. name: str,
  112. func: FuncType,
  113. dev_version: tuple[str, int],
  114. ) -> None:
  115. """Test Python 3.11 features."""
  116. value = func(NDArray)
  117. if sys.version_info >= (3, 11, 0, *dev_version):
  118. value_ref = func(NDArray_ref)
  119. assert value == value_ref
  120. def test_weakref(self) -> None:
  121. """Test ``__weakref__``."""
  122. value = weakref.ref(NDArray)()
  123. if sys.version_info >= (3, 9, 1): # xref bpo-42332
  124. value_ref = weakref.ref(NDArray_ref)()
  125. assert value == value_ref
  126. @pytest.mark.parametrize("name", GETATTR_NAMES)
  127. def test_getattr(self, name: str) -> None:
  128. """Test that `getattr` wraps around the underlying type,
  129. aka ``__origin__``.
  130. """
  131. value = getattr(NDArray, name)
  132. value_ref1 = getattr(np.ndarray, name)
  133. if sys.version_info >= (3, 9):
  134. value_ref2 = getattr(NDArray_ref, name)
  135. assert value == value_ref1 == value_ref2
  136. else:
  137. assert value == value_ref1
  138. @pytest.mark.parametrize("name,exc_type,func", [
  139. ("__getitem__", TypeError, lambda n: n[()]),
  140. ("__getitem__", TypeError, lambda n: n[Any, Any]),
  141. ("__getitem__", TypeError, lambda n: n[Any][Any]),
  142. ("isinstance", TypeError, lambda n: isinstance(np.array(1), n)),
  143. ("issublass", TypeError, lambda n: issubclass(np.ndarray, n)),
  144. ("setattr", AttributeError, lambda n: setattr(n, "__origin__", int)),
  145. ("setattr", AttributeError, lambda n: setattr(n, "test", int)),
  146. ("getattr", AttributeError, lambda n: getattr(n, "test")),
  147. ])
  148. def test_raise(
  149. self,
  150. name: str,
  151. exc_type: type[BaseException],
  152. func: FuncType,
  153. ) -> None:
  154. """Test operations that are supposed to raise."""
  155. with pytest.raises(exc_type):
  156. func(NDArray)
  157. if sys.version_info >= (3, 9):
  158. with pytest.raises(exc_type):
  159. func(NDArray_ref)