_contextlib.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143
  1. # Extra utilities for working with context managers that should have been
  2. # in the standard library but are not
  3. import functools
  4. import inspect
  5. import warnings
  6. import sys
  7. from typing import Any, Callable, TypeVar, cast
  8. # Used for annotating the decorator usage of _DecoratorContextManager (e.g.,
  9. # 'no_grad' and 'enable_grad').
  10. # See https://mypy.readthedocs.io/en/latest/generics.html#declaring-decorators
  11. FuncType = Callable[..., Any]
  12. F = TypeVar('F', bound=FuncType)
  13. def _wrap_generator(ctx_factory, func):
  14. """
  15. Wrap each generator invocation with the context manager factory.
  16. The input should be a function that returns a context manager,
  17. not a context manager itself, to handle one-shot context managers.
  18. """
  19. @functools.wraps(func)
  20. def generator_context(*args, **kwargs):
  21. gen = func(*args, **kwargs)
  22. # Generators are suspended and unsuspended at `yield`, hence we
  23. # make sure the grad mode is properly set every time the execution
  24. # flow returns into the wrapped generator and restored when it
  25. # returns through our `yield` to our caller (see PR #49017).
  26. try:
  27. # Issuing `None` to a generator fires it up
  28. with ctx_factory():
  29. response = gen.send(None)
  30. while True:
  31. try:
  32. # Forward the response to our caller and get its next request
  33. request = yield response
  34. except GeneratorExit:
  35. # Inform the still active generator about its imminent closure
  36. with ctx_factory():
  37. gen.close()
  38. raise
  39. except BaseException:
  40. # Propagate the exception thrown at us by the caller
  41. with ctx_factory():
  42. response = gen.throw(*sys.exc_info())
  43. else:
  44. # Pass the last request to the generator and get its response
  45. with ctx_factory():
  46. response = gen.send(request)
  47. # We let the exceptions raised above by the generator's `.throw` or
  48. # `.send` methods bubble up to our caller, except for StopIteration
  49. except StopIteration as e:
  50. # The generator informed us that it is done: take whatever its
  51. # returned value (if any) was and indicate that we're done too
  52. # by returning it (see docs for python's return-statement).
  53. return e.value
  54. return generator_context
  55. def context_decorator(ctx, func):
  56. """
  57. Like contextlib.ContextDecorator, but:
  58. 1. Is done by wrapping, rather than inheritance, so it works with context
  59. managers that are implemented from C and thus cannot easily inherit from
  60. Python classes
  61. 2. Wraps generators in the intuitive way (c.f. https://bugs.python.org/issue37743)
  62. 3. Errors out if you try to wrap a class, because it is ambiguous whether
  63. or not you intended to wrap only the constructor
  64. The input argument can either be a context manager (in which case it must
  65. be a multi-shot context manager that can be directly invoked multiple times)
  66. or a callable that produces a context manager.
  67. """
  68. assert not (callable(ctx) and hasattr(ctx, '__enter__')), (
  69. f"Passed in {ctx} is both callable and also a valid context manager "
  70. "(has __enter__), making it ambiguous which interface to use. If you "
  71. "intended to pass a context manager factory, rewrite your call as "
  72. "context_decorator(lambda: ctx()); if you intended to pass a context "
  73. "manager directly, rewrite your call as context_decorator(lambda: ctx)"
  74. )
  75. if not callable(ctx):
  76. def ctx_factory():
  77. return ctx
  78. else:
  79. ctx_factory = ctx
  80. if inspect.isclass(func):
  81. raise RuntimeError(
  82. "Cannot decorate classes; it is ambiguous whether or not only the "
  83. "constructor or all methods should have the context manager applied; "
  84. "additionally, decorating a class at definition-site will prevent "
  85. "use of the identifier as a conventional type. "
  86. "To specify which methods to decorate, decorate each of them "
  87. "individually."
  88. )
  89. if inspect.isgeneratorfunction(func):
  90. return _wrap_generator(ctx_factory, func)
  91. @functools.wraps(func)
  92. def decorate_context(*args, **kwargs):
  93. with ctx_factory():
  94. return func(*args, **kwargs)
  95. return decorate_context
  96. class _DecoratorContextManager:
  97. """Allow a context manager to be used as a decorator"""
  98. def __call__(self, orig_func: F) -> F:
  99. if inspect.isclass(orig_func):
  100. warnings.warn("Decorating classes is deprecated and will be disabled in "
  101. "future versions. You should only decorate functions or methods. "
  102. "To preserve the current behavior of class decoration, you can "
  103. "directly decorate the `__init__` method and nothing else.")
  104. func = cast(F, lambda *args, **kwargs: orig_func(*args, **kwargs))
  105. else:
  106. func = orig_func
  107. return cast(F, context_decorator(self.clone, func))
  108. def __enter__(self) -> None:
  109. raise NotImplementedError
  110. def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
  111. raise NotImplementedError
  112. def clone(self):
  113. # override this method if your children class takes __init__ parameters
  114. return self.__class__()