import contextlib import warnings from torch._C import _len_torch_dispatch_stack, _get_dispatch_stack_at,\ _pop_torch_dispatch_stack, _push_on_torch_dispatch_stack # TODO: Limitations and things about enable_torch_dispatch_mode we should fix before exposing it: # - We need a better user-facing api for _DisableTorchDispatch that # is able to selectively disable __torch_dispatch__ of a particular class. # - It doesn't work with the tensor constructors (torch.tensor, torch.Tensor) # - Better name (see https://github.com/pytorch/pytorch/pull/63496#discussion_r694091694) class TorchDispatchMode: """ A ``TorchDispatchMode`` allows you to override the meaning of all ``__torch_dispatch__`` overrideable functions within a dynamic scope, without having to actually create a tensor subclass or manually monkey-patch functions in the PyTorch API. Some common situations where you should use a mode: * You want to override the meaning of factory functions, or other functions that do not otherwise take a tensor as an argument (these cannot be overridden with tensor subclasses). * You want to override the behavior of all functions without needing to wrap your inputs in tensor subclasses; e.g., if you are just interested in logging intermediate computations. * You want to control the order of execution of various tensor subclasses explicitly, rather than implicitly via the return of ``NotImplemented``. Independent subclasses of :class:`TorchDispatchMode` are compositional: modes can be pushed onto a stack using ``with MyMode():``. When you call functions in the PyTorch API inside your ``__torch_dispatch__`` implementation, by default, they will forward on to the next mode on the mode stack. If you want recursively call back into your current ``__torch_dispatch__`` implementation, either explicitly invoke ``self.__torch_dispatch__(...)``, or use the context manager ``__torch_dispatch__(self)`` to make PyTorch API self-referential (beware of infinite loops, in this case!) """ def __torch_dispatch__(self, func, types, args=(), kwargs=None): raise NotImplementedError() def __enter__(self): _push_mode(self) return self def __exit__(self, exc_type, exc_val, exc_tb): _pop_mode() @classmethod def push(cls, *args, **kwargs): warnings.warn("`Mode.push()` is no longer necessary and can be replaced with just `with Mode()`") instance = cls(*args, **kwargs) return instance def _get_current_dispatch_mode(): stack_len = _len_torch_dispatch_stack() return _get_dispatch_stack_at(stack_len - 1) if stack_len > 0 else None def _get_current_dispatch_mode_stack(): stack_len = _len_torch_dispatch_stack() return [_get_dispatch_stack_at(i) for i in range(stack_len)] def _push_mode(mode): _push_on_torch_dispatch_stack(mode) def _pop_mode(): return _pop_torch_dispatch_stack() @contextlib.contextmanager def _pop_mode_temporarily(): old = _pop_mode() try: yield old finally: _push_mode(old) @contextlib.contextmanager def _disable_current_modes(): mode_len = _len_torch_dispatch_stack() old_modes = [_pop_mode() for _ in range(mode_len)] try: yield old_modes finally: for mode in reversed(old_modes): _push_mode(mode) class BaseTorchDispatchMode(TorchDispatchMode): def __torch_dispatch__(self, func, types, args=(), kwargs=None): if kwargs is None: kwargs = {} return func(*args, **kwargs)