_python_dispatch.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101
  1. import contextlib
  2. import warnings
  3. from torch._C import _len_torch_dispatch_stack, _get_dispatch_stack_at,\
  4. _pop_torch_dispatch_stack, _push_on_torch_dispatch_stack
  5. # TODO: Limitations and things about enable_torch_dispatch_mode we should fix before exposing it:
  6. # - We need a better user-facing api for _DisableTorchDispatch that
  7. # is able to selectively disable __torch_dispatch__ of a particular class.
  8. # - It doesn't work with the tensor constructors (torch.tensor, torch.Tensor)
  9. # - Better name (see https://github.com/pytorch/pytorch/pull/63496#discussion_r694091694)
  10. class TorchDispatchMode:
  11. """
  12. A ``TorchDispatchMode`` allows you to override the meaning of all
  13. ``__torch_dispatch__`` overrideable functions within a dynamic scope,
  14. without having to actually create a tensor subclass or manually
  15. monkey-patch functions in the PyTorch API. Some common situations
  16. where you should use a mode:
  17. * You want to override the meaning of factory functions, or other
  18. functions that do not otherwise take a tensor as an argument
  19. (these cannot be overridden with tensor subclasses).
  20. * You want to override the behavior of all functions without needing
  21. to wrap your inputs in tensor subclasses; e.g., if you are just
  22. interested in logging intermediate computations.
  23. * You want to control the order of execution of various tensor
  24. subclasses explicitly, rather than implicitly via the return of
  25. ``NotImplemented``.
  26. Independent subclasses of :class:`TorchDispatchMode` are compositional:
  27. modes can be pushed onto a stack using ``with MyMode():``.
  28. When you call functions in the PyTorch API inside your
  29. ``__torch_dispatch__`` implementation, by default, they will forward on to
  30. the next mode on the mode stack. If you want recursively call back into
  31. your current ``__torch_dispatch__`` implementation, either explicitly
  32. invoke ``self.__torch_dispatch__(...)``, or use the context manager
  33. ``__torch_dispatch__(self)`` to make PyTorch
  34. API self-referential (beware of infinite loops, in this case!)
  35. """
  36. def __torch_dispatch__(self, func, types, args=(), kwargs=None):
  37. raise NotImplementedError()
  38. def __enter__(self):
  39. _push_mode(self)
  40. return self
  41. def __exit__(self, exc_type, exc_val, exc_tb):
  42. _pop_mode()
  43. @classmethod
  44. def push(cls, *args, **kwargs):
  45. warnings.warn("`Mode.push()` is no longer necessary and can be replaced with just `with Mode()`")
  46. instance = cls(*args, **kwargs)
  47. return instance
  48. def _get_current_dispatch_mode():
  49. stack_len = _len_torch_dispatch_stack()
  50. return _get_dispatch_stack_at(stack_len - 1) if stack_len > 0 else None
  51. def _get_current_dispatch_mode_stack():
  52. stack_len = _len_torch_dispatch_stack()
  53. return [_get_dispatch_stack_at(i) for i in range(stack_len)]
  54. def _push_mode(mode):
  55. _push_on_torch_dispatch_stack(mode)
  56. def _pop_mode():
  57. return _pop_torch_dispatch_stack()
  58. @contextlib.contextmanager
  59. def _pop_mode_temporarily():
  60. old = _pop_mode()
  61. try:
  62. yield old
  63. finally:
  64. _push_mode(old)
  65. @contextlib.contextmanager
  66. def _disable_current_modes():
  67. mode_len = _len_torch_dispatch_stack()
  68. old_modes = [_pop_mode() for _ in range(mode_len)]
  69. try:
  70. yield old_modes
  71. finally:
  72. for mode in reversed(old_modes):
  73. _push_mode(mode)
  74. class BaseTorchDispatchMode(TorchDispatchMode):
  75. def __torch_dispatch__(self, func, types, args=(), kwargs=None):
  76. if kwargs is None:
  77. kwargs = {}
  78. return func(*args, **kwargs)