123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143 |
- # Extra utilities for working with context managers that should have been
- # in the standard library but are not
- import functools
- import inspect
- import warnings
- import sys
- from typing import Any, Callable, TypeVar, cast
- # Used for annotating the decorator usage of _DecoratorContextManager (e.g.,
- # 'no_grad' and 'enable_grad').
- # See https://mypy.readthedocs.io/en/latest/generics.html#declaring-decorators
- FuncType = Callable[..., Any]
- F = TypeVar('F', bound=FuncType)
- def _wrap_generator(ctx_factory, func):
- """
- Wrap each generator invocation with the context manager factory.
- The input should be a function that returns a context manager,
- not a context manager itself, to handle one-shot context managers.
- """
- @functools.wraps(func)
- def generator_context(*args, **kwargs):
- gen = func(*args, **kwargs)
- # Generators are suspended and unsuspended at `yield`, hence we
- # make sure the grad mode is properly set every time the execution
- # flow returns into the wrapped generator and restored when it
- # returns through our `yield` to our caller (see PR #49017).
- try:
- # Issuing `None` to a generator fires it up
- with ctx_factory():
- response = gen.send(None)
- while True:
- try:
- # Forward the response to our caller and get its next request
- request = yield response
- except GeneratorExit:
- # Inform the still active generator about its imminent closure
- with ctx_factory():
- gen.close()
- raise
- except BaseException:
- # Propagate the exception thrown at us by the caller
- with ctx_factory():
- response = gen.throw(*sys.exc_info())
- else:
- # Pass the last request to the generator and get its response
- with ctx_factory():
- response = gen.send(request)
- # We let the exceptions raised above by the generator's `.throw` or
- # `.send` methods bubble up to our caller, except for StopIteration
- except StopIteration as e:
- # The generator informed us that it is done: take whatever its
- # returned value (if any) was and indicate that we're done too
- # by returning it (see docs for python's return-statement).
- return e.value
- return generator_context
- def context_decorator(ctx, func):
- """
- Like contextlib.ContextDecorator, but:
- 1. Is done by wrapping, rather than inheritance, so it works with context
- managers that are implemented from C and thus cannot easily inherit from
- Python classes
- 2. Wraps generators in the intuitive way (c.f. https://bugs.python.org/issue37743)
- 3. Errors out if you try to wrap a class, because it is ambiguous whether
- or not you intended to wrap only the constructor
- The input argument can either be a context manager (in which case it must
- be a multi-shot context manager that can be directly invoked multiple times)
- or a callable that produces a context manager.
- """
- assert not (callable(ctx) and hasattr(ctx, '__enter__')), (
- f"Passed in {ctx} is both callable and also a valid context manager "
- "(has __enter__), making it ambiguous which interface to use. If you "
- "intended to pass a context manager factory, rewrite your call as "
- "context_decorator(lambda: ctx()); if you intended to pass a context "
- "manager directly, rewrite your call as context_decorator(lambda: ctx)"
- )
- if not callable(ctx):
- def ctx_factory():
- return ctx
- else:
- ctx_factory = ctx
- if inspect.isclass(func):
- raise RuntimeError(
- "Cannot decorate classes; it is ambiguous whether or not only the "
- "constructor or all methods should have the context manager applied; "
- "additionally, decorating a class at definition-site will prevent "
- "use of the identifier as a conventional type. "
- "To specify which methods to decorate, decorate each of them "
- "individually."
- )
- if inspect.isgeneratorfunction(func):
- return _wrap_generator(ctx_factory, func)
- @functools.wraps(func)
- def decorate_context(*args, **kwargs):
- with ctx_factory():
- return func(*args, **kwargs)
- return decorate_context
- class _DecoratorContextManager:
- """Allow a context manager to be used as a decorator"""
- def __call__(self, orig_func: F) -> F:
- if inspect.isclass(orig_func):
- warnings.warn("Decorating classes is deprecated and will be disabled in "
- "future versions. You should only decorate functions or methods. "
- "To preserve the current behavior of class decoration, you can "
- "directly decorate the `__init__` method and nothing else.")
- func = cast(F, lambda *args, **kwargs: orig_func(*args, **kwargs))
- else:
- func = orig_func
- return cast(F, context_decorator(self.clone, func))
- def __enter__(self) -> None:
- raise NotImplementedError
- def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
- raise NotImplementedError
- def clone(self):
- # override this method if your children class takes __init__ parameters
- return self.__class__()
|