123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101 |
- import contextlib
- import warnings
- from torch._C import _len_torch_dispatch_stack, _get_dispatch_stack_at,\
- _pop_torch_dispatch_stack, _push_on_torch_dispatch_stack
- # TODO: Limitations and things about enable_torch_dispatch_mode we should fix before exposing it:
- # - We need a better user-facing api for _DisableTorchDispatch that
- # is able to selectively disable __torch_dispatch__ of a particular class.
- # - It doesn't work with the tensor constructors (torch.tensor, torch.Tensor)
- # - Better name (see https://github.com/pytorch/pytorch/pull/63496#discussion_r694091694)
- class TorchDispatchMode:
- """
- A ``TorchDispatchMode`` allows you to override the meaning of all
- ``__torch_dispatch__`` overrideable functions within a dynamic scope,
- without having to actually create a tensor subclass or manually
- monkey-patch functions in the PyTorch API. Some common situations
- where you should use a mode:
- * You want to override the meaning of factory functions, or other
- functions that do not otherwise take a tensor as an argument
- (these cannot be overridden with tensor subclasses).
- * You want to override the behavior of all functions without needing
- to wrap your inputs in tensor subclasses; e.g., if you are just
- interested in logging intermediate computations.
- * You want to control the order of execution of various tensor
- subclasses explicitly, rather than implicitly via the return of
- ``NotImplemented``.
- Independent subclasses of :class:`TorchDispatchMode` are compositional:
- modes can be pushed onto a stack using ``with MyMode():``.
- When you call functions in the PyTorch API inside your
- ``__torch_dispatch__`` implementation, by default, they will forward on to
- the next mode on the mode stack. If you want recursively call back into
- your current ``__torch_dispatch__`` implementation, either explicitly
- invoke ``self.__torch_dispatch__(...)``, or use the context manager
- ``__torch_dispatch__(self)`` to make PyTorch
- API self-referential (beware of infinite loops, in this case!)
- """
- def __torch_dispatch__(self, func, types, args=(), kwargs=None):
- raise NotImplementedError()
- def __enter__(self):
- _push_mode(self)
- return self
- def __exit__(self, exc_type, exc_val, exc_tb):
- _pop_mode()
- @classmethod
- def push(cls, *args, **kwargs):
- warnings.warn("`Mode.push()` is no longer necessary and can be replaced with just `with Mode()`")
- instance = cls(*args, **kwargs)
- return instance
- def _get_current_dispatch_mode():
- stack_len = _len_torch_dispatch_stack()
- return _get_dispatch_stack_at(stack_len - 1) if stack_len > 0 else None
- def _get_current_dispatch_mode_stack():
- stack_len = _len_torch_dispatch_stack()
- return [_get_dispatch_stack_at(i) for i in range(stack_len)]
- def _push_mode(mode):
- _push_on_torch_dispatch_stack(mode)
- def _pop_mode():
- return _pop_torch_dispatch_stack()
- @contextlib.contextmanager
- def _pop_mode_temporarily():
- old = _pop_mode()
- try:
- yield old
- finally:
- _push_mode(old)
- @contextlib.contextmanager
- def _disable_current_modes():
- mode_len = _len_torch_dispatch_stack()
- old_modes = [_pop_mode() for _ in range(mode_len)]
- try:
- yield old_modes
- finally:
- for mode in reversed(old_modes):
- _push_mode(mode)
- class BaseTorchDispatchMode(TorchDispatchMode):
- def __torch_dispatch__(self, func, types, args=(), kwargs=None):
- if kwargs is None:
- kwargs = {}
- return func(*args, **kwargs)
|