_warnings.py 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216
  1. from __future__ import annotations
  2. from contextlib import (
  3. contextmanager,
  4. nullcontext,
  5. )
  6. import re
  7. import sys
  8. from typing import (
  9. Generator,
  10. Literal,
  11. Sequence,
  12. Type,
  13. cast,
  14. )
  15. import warnings
  16. @contextmanager
  17. def assert_produces_warning(
  18. expected_warning: type[Warning] | bool | tuple[type[Warning], ...] | None = Warning,
  19. filter_level: Literal[
  20. "error", "ignore", "always", "default", "module", "once"
  21. ] = "always",
  22. check_stacklevel: bool = True,
  23. raise_on_extra_warnings: bool = True,
  24. match: str | None = None,
  25. ) -> Generator[list[warnings.WarningMessage], None, None]:
  26. """
  27. Context manager for running code expected to either raise a specific warning,
  28. multiple specific warnings, or not raise any warnings. Verifies that the code
  29. raises the expected warning(s), and that it does not raise any other unexpected
  30. warnings. It is basically a wrapper around ``warnings.catch_warnings``.
  31. Parameters
  32. ----------
  33. expected_warning : {Warning, False, tuple[Warning, ...], None}, default Warning
  34. The type of Exception raised. ``exception.Warning`` is the base
  35. class for all warnings. To raise multiple types of exceptions,
  36. pass them as a tuple. To check that no warning is returned,
  37. specify ``False`` or ``None``.
  38. filter_level : str or None, default "always"
  39. Specifies whether warnings are ignored, displayed, or turned
  40. into errors.
  41. Valid values are:
  42. * "error" - turns matching warnings into exceptions
  43. * "ignore" - discard the warning
  44. * "always" - always emit a warning
  45. * "default" - print the warning the first time it is generated
  46. from each location
  47. * "module" - print the warning the first time it is generated
  48. from each module
  49. * "once" - print the warning the first time it is generated
  50. check_stacklevel : bool, default True
  51. If True, displays the line that called the function containing
  52. the warning to show were the function is called. Otherwise, the
  53. line that implements the function is displayed.
  54. raise_on_extra_warnings : bool, default True
  55. Whether extra warnings not of the type `expected_warning` should
  56. cause the test to fail.
  57. match : str, optional
  58. Match warning message.
  59. Examples
  60. --------
  61. >>> import warnings
  62. >>> with assert_produces_warning():
  63. ... warnings.warn(UserWarning())
  64. ...
  65. >>> with assert_produces_warning(False):
  66. ... warnings.warn(RuntimeWarning())
  67. ...
  68. Traceback (most recent call last):
  69. ...
  70. AssertionError: Caused unexpected warning(s): ['RuntimeWarning'].
  71. >>> with assert_produces_warning(UserWarning):
  72. ... warnings.warn(RuntimeWarning())
  73. Traceback (most recent call last):
  74. ...
  75. AssertionError: Did not see expected warning of class 'UserWarning'.
  76. ..warn:: This is *not* thread-safe.
  77. """
  78. __tracebackhide__ = True
  79. with warnings.catch_warnings(record=True) as w:
  80. warnings.simplefilter(filter_level)
  81. try:
  82. yield w
  83. finally:
  84. if expected_warning:
  85. expected_warning = cast(Type[Warning], expected_warning)
  86. _assert_caught_expected_warning(
  87. caught_warnings=w,
  88. expected_warning=expected_warning,
  89. match=match,
  90. check_stacklevel=check_stacklevel,
  91. )
  92. if raise_on_extra_warnings:
  93. _assert_caught_no_extra_warnings(
  94. caught_warnings=w,
  95. expected_warning=expected_warning,
  96. )
  97. def maybe_produces_warning(warning: type[Warning], condition: bool, **kwargs):
  98. """
  99. Return a context manager that possibly checks a warning based on the condition
  100. """
  101. if condition:
  102. return assert_produces_warning(warning, **kwargs)
  103. else:
  104. return nullcontext()
  105. def _assert_caught_expected_warning(
  106. *,
  107. caught_warnings: Sequence[warnings.WarningMessage],
  108. expected_warning: type[Warning],
  109. match: str | None,
  110. check_stacklevel: bool,
  111. ) -> None:
  112. """Assert that there was the expected warning among the caught warnings."""
  113. saw_warning = False
  114. matched_message = False
  115. unmatched_messages = []
  116. for actual_warning in caught_warnings:
  117. if issubclass(actual_warning.category, expected_warning):
  118. saw_warning = True
  119. if check_stacklevel:
  120. _assert_raised_with_correct_stacklevel(actual_warning)
  121. if match is not None:
  122. if re.search(match, str(actual_warning.message)):
  123. matched_message = True
  124. else:
  125. unmatched_messages.append(actual_warning.message)
  126. if not saw_warning:
  127. raise AssertionError(
  128. f"Did not see expected warning of class "
  129. f"{repr(expected_warning.__name__)}"
  130. )
  131. if match and not matched_message:
  132. raise AssertionError(
  133. f"Did not see warning {repr(expected_warning.__name__)} "
  134. f"matching '{match}'. The emitted warning messages are "
  135. f"{unmatched_messages}"
  136. )
  137. def _assert_caught_no_extra_warnings(
  138. *,
  139. caught_warnings: Sequence[warnings.WarningMessage],
  140. expected_warning: type[Warning] | bool | tuple[type[Warning], ...] | None,
  141. ) -> None:
  142. """Assert that no extra warnings apart from the expected ones are caught."""
  143. extra_warnings = []
  144. for actual_warning in caught_warnings:
  145. if _is_unexpected_warning(actual_warning, expected_warning):
  146. # GH#38630 pytest.filterwarnings does not suppress these.
  147. if actual_warning.category == ResourceWarning:
  148. # GH 44732: Don't make the CI flaky by filtering SSL-related
  149. # ResourceWarning from dependencies
  150. if "unclosed <ssl.SSLSocket" in str(actual_warning.message):
  151. continue
  152. # GH 44844: Matplotlib leaves font files open during the entire process
  153. # upon import. Don't make CI flaky if ResourceWarning raised
  154. # due to these open files.
  155. if any("matplotlib" in mod for mod in sys.modules):
  156. continue
  157. extra_warnings.append(
  158. (
  159. actual_warning.category.__name__,
  160. actual_warning.message,
  161. actual_warning.filename,
  162. actual_warning.lineno,
  163. )
  164. )
  165. if extra_warnings:
  166. raise AssertionError(f"Caused unexpected warning(s): {repr(extra_warnings)}")
  167. def _is_unexpected_warning(
  168. actual_warning: warnings.WarningMessage,
  169. expected_warning: type[Warning] | bool | tuple[type[Warning], ...] | None,
  170. ) -> bool:
  171. """Check if the actual warning issued is unexpected."""
  172. if actual_warning and not expected_warning:
  173. return True
  174. expected_warning = cast(Type[Warning], expected_warning)
  175. return bool(not issubclass(actual_warning.category, expected_warning))
  176. def _assert_raised_with_correct_stacklevel(
  177. actual_warning: warnings.WarningMessage,
  178. ) -> None:
  179. from inspect import (
  180. getframeinfo,
  181. stack,
  182. )
  183. caller = getframeinfo(stack()[4][0])
  184. msg = (
  185. "Warning not set with correct stacklevel. "
  186. f"File where warning is raised: {actual_warning.filename} != "
  187. f"{caller.filename}. Warning message: {actual_warning.message}"
  188. )
  189. assert actual_warning.filename == caller.filename, msg