test_typing.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455
  1. from __future__ import annotations
  2. import importlib.util
  3. import itertools
  4. import os
  5. import re
  6. import shutil
  7. from collections import defaultdict
  8. from collections.abc import Iterator
  9. from typing import IO, TYPE_CHECKING
  10. import pytest
  11. import numpy as np
  12. import numpy.typing as npt
  13. from numpy.typing.mypy_plugin import (
  14. _PRECISION_DICT,
  15. _EXTENDED_PRECISION_LIST,
  16. _C_INTP,
  17. )
  18. try:
  19. from mypy import api
  20. except ImportError:
  21. NO_MYPY = True
  22. else:
  23. NO_MYPY = False
  24. if TYPE_CHECKING:
  25. # We need this as annotation, but it's located in a private namespace.
  26. # As a compromise, do *not* import it during runtime
  27. from _pytest.mark.structures import ParameterSet
  28. DATA_DIR = os.path.join(os.path.dirname(__file__), "data")
  29. PASS_DIR = os.path.join(DATA_DIR, "pass")
  30. FAIL_DIR = os.path.join(DATA_DIR, "fail")
  31. REVEAL_DIR = os.path.join(DATA_DIR, "reveal")
  32. MISC_DIR = os.path.join(DATA_DIR, "misc")
  33. MYPY_INI = os.path.join(DATA_DIR, "mypy.ini")
  34. CACHE_DIR = os.path.join(DATA_DIR, ".mypy_cache")
  35. #: A dictionary with file names as keys and lists of the mypy stdout as values.
  36. #: To-be populated by `run_mypy`.
  37. OUTPUT_MYPY: dict[str, list[str]] = {}
  38. def _key_func(key: str) -> str:
  39. """Split at the first occurrence of the ``:`` character.
  40. Windows drive-letters (*e.g.* ``C:``) are ignored herein.
  41. """
  42. drive, tail = os.path.splitdrive(key)
  43. return os.path.join(drive, tail.split(":", 1)[0])
  44. def _strip_filename(msg: str) -> str:
  45. """Strip the filename from a mypy message."""
  46. _, tail = os.path.splitdrive(msg)
  47. return tail.split(":", 1)[-1]
  48. def strip_func(match: re.Match[str]) -> str:
  49. """`re.sub` helper function for stripping module names."""
  50. return match.groups()[1]
  51. @pytest.mark.slow
  52. @pytest.mark.skipif(NO_MYPY, reason="Mypy is not installed")
  53. @pytest.fixture(scope="module", autouse=True)
  54. def run_mypy() -> None:
  55. """Clears the cache and run mypy before running any of the typing tests.
  56. The mypy results are cached in `OUTPUT_MYPY` for further use.
  57. The cache refresh can be skipped using
  58. NUMPY_TYPING_TEST_CLEAR_CACHE=0 pytest numpy/typing/tests
  59. """
  60. if (
  61. os.path.isdir(CACHE_DIR)
  62. and bool(os.environ.get("NUMPY_TYPING_TEST_CLEAR_CACHE", True))
  63. ):
  64. shutil.rmtree(CACHE_DIR)
  65. for directory in (PASS_DIR, REVEAL_DIR, FAIL_DIR, MISC_DIR):
  66. # Run mypy
  67. stdout, stderr, exit_code = api.run([
  68. "--config-file",
  69. MYPY_INI,
  70. "--cache-dir",
  71. CACHE_DIR,
  72. directory,
  73. ])
  74. if stderr:
  75. pytest.fail(f"Unexpected mypy standard error\n\n{stderr}")
  76. elif exit_code not in {0, 1}:
  77. pytest.fail(f"Unexpected mypy exit code: {exit_code}\n\n{stdout}")
  78. stdout = stdout.replace('*', '')
  79. # Parse the output
  80. iterator = itertools.groupby(stdout.split("\n"), key=_key_func)
  81. OUTPUT_MYPY.update((k, list(v)) for k, v in iterator if k)
  82. def get_test_cases(directory: str) -> Iterator[ParameterSet]:
  83. for root, _, files in os.walk(directory):
  84. for fname in files:
  85. short_fname, ext = os.path.splitext(fname)
  86. if ext in (".pyi", ".py"):
  87. fullpath = os.path.join(root, fname)
  88. yield pytest.param(fullpath, id=short_fname)
  89. @pytest.mark.slow
  90. @pytest.mark.skipif(NO_MYPY, reason="Mypy is not installed")
  91. @pytest.mark.parametrize("path", get_test_cases(PASS_DIR))
  92. def test_success(path) -> None:
  93. # Alias `OUTPUT_MYPY` so that it appears in the local namespace
  94. output_mypy = OUTPUT_MYPY
  95. if path in output_mypy:
  96. msg = "Unexpected mypy output\n\n"
  97. msg += "\n".join(_strip_filename(v) for v in output_mypy[path])
  98. raise AssertionError(msg)
  99. @pytest.mark.slow
  100. @pytest.mark.skipif(NO_MYPY, reason="Mypy is not installed")
  101. @pytest.mark.parametrize("path", get_test_cases(FAIL_DIR))
  102. def test_fail(path: str) -> None:
  103. __tracebackhide__ = True
  104. with open(path) as fin:
  105. lines = fin.readlines()
  106. errors = defaultdict(lambda: "")
  107. output_mypy = OUTPUT_MYPY
  108. assert path in output_mypy
  109. for error_line in output_mypy[path]:
  110. error_line = _strip_filename(error_line).split("\n", 1)[0]
  111. match = re.match(
  112. r"(?P<lineno>\d+): (error|note): .+$",
  113. error_line,
  114. )
  115. if match is None:
  116. raise ValueError(f"Unexpected error line format: {error_line}")
  117. lineno = int(match.group('lineno'))
  118. errors[lineno] += f'{error_line}\n'
  119. for i, line in enumerate(lines):
  120. lineno = i + 1
  121. if (
  122. line.startswith('#')
  123. or (" E:" not in line and lineno not in errors)
  124. ):
  125. continue
  126. target_line = lines[lineno - 1]
  127. if "# E:" in target_line:
  128. expression, _, marker = target_line.partition(" # E: ")
  129. expected_error = errors[lineno].strip()
  130. marker = marker.strip()
  131. _test_fail(path, expression, marker, expected_error, lineno)
  132. else:
  133. pytest.fail(
  134. f"Unexpected mypy output at line {lineno}\n\n{errors[lineno]}"
  135. )
  136. _FAIL_MSG1 = """Extra error at line {}
  137. Expression: {}
  138. Extra error: {!r}
  139. """
  140. _FAIL_MSG2 = """Error mismatch at line {}
  141. Expression: {}
  142. Expected error: {!r}
  143. Observed error: {!r}
  144. """
  145. def _test_fail(
  146. path: str,
  147. expression: str,
  148. error: str,
  149. expected_error: None | str,
  150. lineno: int,
  151. ) -> None:
  152. if expected_error is None:
  153. raise AssertionError(_FAIL_MSG1.format(lineno, expression, error))
  154. elif error not in expected_error:
  155. raise AssertionError(_FAIL_MSG2.format(
  156. lineno, expression, expected_error, error
  157. ))
  158. def _construct_ctypes_dict() -> dict[str, str]:
  159. dct = {
  160. "ubyte": "c_ubyte",
  161. "ushort": "c_ushort",
  162. "uintc": "c_uint",
  163. "uint": "c_ulong",
  164. "ulonglong": "c_ulonglong",
  165. "byte": "c_byte",
  166. "short": "c_short",
  167. "intc": "c_int",
  168. "int_": "c_long",
  169. "longlong": "c_longlong",
  170. "single": "c_float",
  171. "double": "c_double",
  172. "longdouble": "c_longdouble",
  173. }
  174. # Match `ctypes` names to the first ctypes type with a given kind and
  175. # precision, e.g. {"c_double": "c_double", "c_longdouble": "c_double"}
  176. # if both types represent 64-bit floats.
  177. # In this context "first" is defined by the order of `dct`
  178. ret = {}
  179. visited: dict[tuple[str, int], str] = {}
  180. for np_name, ct_name in dct.items():
  181. np_scalar = getattr(np, np_name)()
  182. # Find the first `ctypes` type for a given `kind`/`itemsize` combo
  183. key = (np_scalar.dtype.kind, np_scalar.dtype.itemsize)
  184. ret[ct_name] = visited.setdefault(key, f"ctypes.{ct_name}")
  185. return ret
  186. def _construct_format_dict() -> dict[str, str]:
  187. dct = {k.split(".")[-1]: v.replace("numpy", "numpy._typing") for
  188. k, v in _PRECISION_DICT.items()}
  189. return {
  190. "uint8": "numpy.unsignedinteger[numpy._typing._8Bit]",
  191. "uint16": "numpy.unsignedinteger[numpy._typing._16Bit]",
  192. "uint32": "numpy.unsignedinteger[numpy._typing._32Bit]",
  193. "uint64": "numpy.unsignedinteger[numpy._typing._64Bit]",
  194. "uint128": "numpy.unsignedinteger[numpy._typing._128Bit]",
  195. "uint256": "numpy.unsignedinteger[numpy._typing._256Bit]",
  196. "int8": "numpy.signedinteger[numpy._typing._8Bit]",
  197. "int16": "numpy.signedinteger[numpy._typing._16Bit]",
  198. "int32": "numpy.signedinteger[numpy._typing._32Bit]",
  199. "int64": "numpy.signedinteger[numpy._typing._64Bit]",
  200. "int128": "numpy.signedinteger[numpy._typing._128Bit]",
  201. "int256": "numpy.signedinteger[numpy._typing._256Bit]",
  202. "float16": "numpy.floating[numpy._typing._16Bit]",
  203. "float32": "numpy.floating[numpy._typing._32Bit]",
  204. "float64": "numpy.floating[numpy._typing._64Bit]",
  205. "float80": "numpy.floating[numpy._typing._80Bit]",
  206. "float96": "numpy.floating[numpy._typing._96Bit]",
  207. "float128": "numpy.floating[numpy._typing._128Bit]",
  208. "float256": "numpy.floating[numpy._typing._256Bit]",
  209. "complex64": ("numpy.complexfloating"
  210. "[numpy._typing._32Bit, numpy._typing._32Bit]"),
  211. "complex128": ("numpy.complexfloating"
  212. "[numpy._typing._64Bit, numpy._typing._64Bit]"),
  213. "complex160": ("numpy.complexfloating"
  214. "[numpy._typing._80Bit, numpy._typing._80Bit]"),
  215. "complex192": ("numpy.complexfloating"
  216. "[numpy._typing._96Bit, numpy._typing._96Bit]"),
  217. "complex256": ("numpy.complexfloating"
  218. "[numpy._typing._128Bit, numpy._typing._128Bit]"),
  219. "complex512": ("numpy.complexfloating"
  220. "[numpy._typing._256Bit, numpy._typing._256Bit]"),
  221. "ubyte": f"numpy.unsignedinteger[{dct['_NBitByte']}]",
  222. "ushort": f"numpy.unsignedinteger[{dct['_NBitShort']}]",
  223. "uintc": f"numpy.unsignedinteger[{dct['_NBitIntC']}]",
  224. "uintp": f"numpy.unsignedinteger[{dct['_NBitIntP']}]",
  225. "uint": f"numpy.unsignedinteger[{dct['_NBitInt']}]",
  226. "ulonglong": f"numpy.unsignedinteger[{dct['_NBitLongLong']}]",
  227. "byte": f"numpy.signedinteger[{dct['_NBitByte']}]",
  228. "short": f"numpy.signedinteger[{dct['_NBitShort']}]",
  229. "intc": f"numpy.signedinteger[{dct['_NBitIntC']}]",
  230. "intp": f"numpy.signedinteger[{dct['_NBitIntP']}]",
  231. "int_": f"numpy.signedinteger[{dct['_NBitInt']}]",
  232. "longlong": f"numpy.signedinteger[{dct['_NBitLongLong']}]",
  233. "half": f"numpy.floating[{dct['_NBitHalf']}]",
  234. "single": f"numpy.floating[{dct['_NBitSingle']}]",
  235. "double": f"numpy.floating[{dct['_NBitDouble']}]",
  236. "longdouble": f"numpy.floating[{dct['_NBitLongDouble']}]",
  237. "csingle": ("numpy.complexfloating"
  238. f"[{dct['_NBitSingle']}, {dct['_NBitSingle']}]"),
  239. "cdouble": ("numpy.complexfloating"
  240. f"[{dct['_NBitDouble']}, {dct['_NBitDouble']}]"),
  241. "clongdouble": (
  242. "numpy.complexfloating"
  243. f"[{dct['_NBitLongDouble']}, {dct['_NBitLongDouble']}]"
  244. ),
  245. # numpy.typing
  246. "_NBitInt": dct['_NBitInt'],
  247. # numpy.ctypeslib
  248. "c_intp": f"ctypes.{_C_INTP}"
  249. }
  250. #: A dictionary with all supported format keys (as keys)
  251. #: and matching values
  252. FORMAT_DICT: dict[str, str] = _construct_format_dict()
  253. FORMAT_DICT.update(_construct_ctypes_dict())
  254. def _parse_reveals(file: IO[str]) -> tuple[npt.NDArray[np.str_], list[str]]:
  255. """Extract and parse all ``" # E: "`` comments from the passed
  256. file-like object.
  257. All format keys will be substituted for their respective value
  258. from `FORMAT_DICT`, *e.g.* ``"{float64}"`` becomes
  259. ``"numpy.floating[numpy._typing._64Bit]"``.
  260. """
  261. string = file.read().replace("*", "")
  262. # Grab all `# E:`-based comments and matching expressions
  263. expression_array, _, comments_array = np.char.partition(
  264. string.split("\n"), sep=" # E: "
  265. ).T
  266. comments = "/n".join(comments_array)
  267. # Only search for the `{*}` pattern within comments, otherwise
  268. # there is the risk of accidentally grabbing dictionaries and sets
  269. key_set = set(re.findall(r"\{(.*?)\}", comments))
  270. kwargs = {
  271. k: FORMAT_DICT.get(k, f"<UNRECOGNIZED FORMAT KEY {k!r}>") for
  272. k in key_set
  273. }
  274. fmt_str = comments.format(**kwargs)
  275. return expression_array, fmt_str.split("/n")
  276. @pytest.mark.slow
  277. @pytest.mark.skipif(NO_MYPY, reason="Mypy is not installed")
  278. @pytest.mark.parametrize("path", get_test_cases(REVEAL_DIR))
  279. def test_reveal(path: str) -> None:
  280. """Validate that mypy correctly infers the return-types of
  281. the expressions in `path`.
  282. """
  283. __tracebackhide__ = True
  284. with open(path) as fin:
  285. expression_array, reveal_list = _parse_reveals(fin)
  286. output_mypy = OUTPUT_MYPY
  287. assert path in output_mypy
  288. for error_line in output_mypy[path]:
  289. error_line = _strip_filename(error_line)
  290. match = re.match(
  291. r"(?P<lineno>\d+): note: .+$",
  292. error_line,
  293. )
  294. if match is None:
  295. raise ValueError(f"Unexpected reveal line format: {error_line}")
  296. lineno = int(match.group('lineno')) - 1
  297. assert "Revealed type is" in error_line
  298. marker = reveal_list[lineno]
  299. expression = expression_array[lineno]
  300. _test_reveal(path, expression, marker, error_line, 1 + lineno)
  301. _REVEAL_MSG = """Reveal mismatch at line {}
  302. Expression: {}
  303. Expected reveal: {!r}
  304. Observed reveal: {!r}
  305. """
  306. _STRIP_PATTERN = re.compile(r"(\w+\.)+(\w+)")
  307. def _test_reveal(
  308. path: str,
  309. expression: str,
  310. reveal: str,
  311. expected_reveal: str,
  312. lineno: int,
  313. ) -> None:
  314. """Error-reporting helper function for `test_reveal`."""
  315. stripped_reveal = _STRIP_PATTERN.sub(strip_func, reveal)
  316. stripped_expected_reveal = _STRIP_PATTERN.sub(strip_func, expected_reveal)
  317. if stripped_reveal not in stripped_expected_reveal:
  318. raise AssertionError(
  319. _REVEAL_MSG.format(lineno,
  320. expression,
  321. stripped_expected_reveal,
  322. stripped_reveal)
  323. )
  324. @pytest.mark.slow
  325. @pytest.mark.skipif(NO_MYPY, reason="Mypy is not installed")
  326. @pytest.mark.parametrize("path", get_test_cases(PASS_DIR))
  327. def test_code_runs(path: str) -> None:
  328. """Validate that the code in `path` properly during runtime."""
  329. path_without_extension, _ = os.path.splitext(path)
  330. dirname, filename = path.split(os.sep)[-2:]
  331. spec = importlib.util.spec_from_file_location(
  332. f"{dirname}.{filename}", path
  333. )
  334. assert spec is not None
  335. assert spec.loader is not None
  336. test_module = importlib.util.module_from_spec(spec)
  337. spec.loader.exec_module(test_module)
  338. LINENO_MAPPING = {
  339. 3: "uint128",
  340. 4: "uint256",
  341. 6: "int128",
  342. 7: "int256",
  343. 9: "float80",
  344. 10: "float96",
  345. 11: "float128",
  346. 12: "float256",
  347. 14: "complex160",
  348. 15: "complex192",
  349. 16: "complex256",
  350. 17: "complex512",
  351. }
  352. @pytest.mark.slow
  353. @pytest.mark.skipif(NO_MYPY, reason="Mypy is not installed")
  354. def test_extended_precision() -> None:
  355. path = os.path.join(MISC_DIR, "extended_precision.pyi")
  356. output_mypy = OUTPUT_MYPY
  357. assert path in output_mypy
  358. with open(path, "r") as f:
  359. expression_list = f.readlines()
  360. for _msg in output_mypy[path]:
  361. *_, _lineno, msg_typ, msg = _msg.split(":")
  362. msg = _strip_filename(msg)
  363. lineno = int(_lineno)
  364. expression = expression_list[lineno - 1].rstrip("\n")
  365. msg_typ = msg_typ.strip()
  366. assert msg_typ in {"error", "note"}
  367. if LINENO_MAPPING[lineno] in _EXTENDED_PRECISION_LIST:
  368. if msg_typ == "error":
  369. raise ValueError(f"Unexpected reveal line format: {lineno}")
  370. else:
  371. marker = FORMAT_DICT[LINENO_MAPPING[lineno]]
  372. _test_reveal(path, expression, marker, msg, lineno)
  373. else:
  374. if msg_typ == "error":
  375. marker = "Module has no attribute"
  376. _test_fail(path, expression, marker, msg, lineno)