test_runtime.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
  1. """Test the runtime usage of `numpy.typing`."""
  2. from __future__ import annotations
  3. import sys
  4. from typing import (
  5. get_type_hints,
  6. Union,
  7. NamedTuple,
  8. get_args,
  9. get_origin,
  10. Any,
  11. )
  12. import pytest
  13. import numpy as np
  14. import numpy.typing as npt
  15. import numpy._typing as _npt
  16. class TypeTup(NamedTuple):
  17. typ: type
  18. args: tuple[type, ...]
  19. origin: None | type
  20. if sys.version_info >= (3, 9):
  21. NDArrayTup = TypeTup(npt.NDArray, npt.NDArray.__args__, np.ndarray)
  22. else:
  23. NDArrayTup = TypeTup(npt.NDArray, (), None)
  24. TYPES = {
  25. "ArrayLike": TypeTup(npt.ArrayLike, npt.ArrayLike.__args__, Union),
  26. "DTypeLike": TypeTup(npt.DTypeLike, npt.DTypeLike.__args__, Union),
  27. "NBitBase": TypeTup(npt.NBitBase, (), None),
  28. "NDArray": NDArrayTup,
  29. }
  30. @pytest.mark.parametrize("name,tup", TYPES.items(), ids=TYPES.keys())
  31. def test_get_args(name: type, tup: TypeTup) -> None:
  32. """Test `typing.get_args`."""
  33. typ, ref = tup.typ, tup.args
  34. out = get_args(typ)
  35. assert out == ref
  36. @pytest.mark.parametrize("name,tup", TYPES.items(), ids=TYPES.keys())
  37. def test_get_origin(name: type, tup: TypeTup) -> None:
  38. """Test `typing.get_origin`."""
  39. typ, ref = tup.typ, tup.origin
  40. out = get_origin(typ)
  41. assert out == ref
  42. @pytest.mark.parametrize("name,tup", TYPES.items(), ids=TYPES.keys())
  43. def test_get_type_hints(name: type, tup: TypeTup) -> None:
  44. """Test `typing.get_type_hints`."""
  45. typ = tup.typ
  46. # Explicitly set `__annotations__` in order to circumvent the
  47. # stringification performed by `from __future__ import annotations`
  48. def func(a): pass
  49. func.__annotations__ = {"a": typ, "return": None}
  50. out = get_type_hints(func)
  51. ref = {"a": typ, "return": type(None)}
  52. assert out == ref
  53. @pytest.mark.parametrize("name,tup", TYPES.items(), ids=TYPES.keys())
  54. def test_get_type_hints_str(name: type, tup: TypeTup) -> None:
  55. """Test `typing.get_type_hints` with string-representation of types."""
  56. typ_str, typ = f"npt.{name}", tup.typ
  57. # Explicitly set `__annotations__` in order to circumvent the
  58. # stringification performed by `from __future__ import annotations`
  59. def func(a): pass
  60. func.__annotations__ = {"a": typ_str, "return": None}
  61. out = get_type_hints(func)
  62. ref = {"a": typ, "return": type(None)}
  63. assert out == ref
  64. def test_keys() -> None:
  65. """Test that ``TYPES.keys()`` and ``numpy.typing.__all__`` are synced."""
  66. keys = TYPES.keys()
  67. ref = set(npt.__all__)
  68. assert keys == ref
  69. PROTOCOLS: dict[str, tuple[type[Any], object]] = {
  70. "_SupportsDType": (_npt._SupportsDType, np.int64(1)),
  71. "_SupportsArray": (_npt._SupportsArray, np.arange(10)),
  72. "_SupportsArrayFunc": (_npt._SupportsArrayFunc, np.arange(10)),
  73. "_NestedSequence": (_npt._NestedSequence, [1]),
  74. }
  75. @pytest.mark.parametrize("cls,obj", PROTOCOLS.values(), ids=PROTOCOLS.keys())
  76. class TestRuntimeProtocol:
  77. def test_isinstance(self, cls: type[Any], obj: object) -> None:
  78. assert isinstance(obj, cls)
  79. assert not isinstance(None, cls)
  80. def test_issubclass(self, cls: type[Any], obj: object) -> None:
  81. if cls is _npt._SupportsDType:
  82. pytest.xfail(
  83. "Protocols with non-method members don't support issubclass()"
  84. )
  85. assert issubclass(type(obj), cls)
  86. assert not issubclass(type(None), cls)