pyfunctorch.py 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218
  1. from abc import ABC, abstractmethod
  2. import contextlib
  3. from typing import Any
  4. import torch
  5. import torch.utils._pytree as pytree
  6. from torch._C._functorch import (
  7. TransformType,
  8. RandomnessType,
  9. CInterpreter,
  10. CGradInterpreterPtr,
  11. CFunctionalizeInterpreterPtr,
  12. CVmapInterpreterPtr,
  13. CJvpInterpreterPtr,
  14. pop_dynamic_layer_stack,
  15. push_dynamic_layer_stack,
  16. )
  17. from torch.autograd.forward_ad import _set_fwd_grad_enabled
  18. """
  19. This file contains the functorch integration with PyDispatcher.
  20. PyDispatcher does not understand functorch's DynamicLayerStack dispatching
  21. logic because it is entirely implemented in C++ in the fallbacks for two
  22. dispatch keys, FuncTorchDynamicLayer{Front, Back}Mode (PyDispatcher is unable
  23. to directly reuse C++ boxed fallbacks).
  24. Instead of trying to hammer PyDispatcher into understanding those fallbacks,
  25. we re-implement the logic of peeking the top of the stack for an interpreter,
  26. selecting the interpreter to dispatch on, etc, in Python. This leads to a
  27. simpler design.
  28. The main difference between C++ functorch and PyDispatcher's functorch logic
  29. is that:
  30. - C++ functorch needs to manually tweak dispatch keys to ping-pong between
  31. DynamicLayerFrontMode and DynamicLayerBackMode.
  32. - PyDispatcher's functorch logic pops an Interpreter from the top of the stack
  33. and asks it to execute the rule associated with the Interpreter.
  34. In C++ we do the ping-pong because e.g. vmap rules are associated with the
  35. batched DispatchKey, but in PyDispatcher we are able to avoid this by asking
  36. the user to register a batching rule directly to a transform that an
  37. interpreter then invokes.
  38. """
  39. # FuncTorchInterpreter is the Python version of Interpreter (recall that
  40. # the DynamicLayerStack is a stack of interpreters).
  41. # It is a wrapper around the actual C++ Interpreter object.
  42. #
  43. # Keep the methods in sync with aten/src/ATen/functorch/Interpreter.h
  44. class FuncTorchInterpreter(ABC):
  45. def __init__(self, cptr: Any):
  46. self._cptr = cptr
  47. # Process an operation. eg for vmap, this is invoking a batching rule.
  48. # Conceptually this is analogous to Interpreter::process in C++
  49. @abstractmethod
  50. def process(self, op, args, kwargs):
  51. pass
  52. # lower an operation from this Interpreter to the next Interpreter on the stack.
  53. # Concretely, this involves temporarily popping the current Interpreter.
  54. # Conceptually this is analogous to Interpreter::sendToNextInterpreter in C++
  55. def lower(self):
  56. return temporarily_pop_interpreter_stack()
  57. def level(self):
  58. return self._cptr.level()
  59. def key(self):
  60. return self._cptr.key()
  61. @contextlib.contextmanager
  62. def temporarily_pop_interpreter_stack():
  63. try:
  64. saved = pop_dynamic_layer_stack()
  65. yield
  66. finally:
  67. push_dynamic_layer_stack(saved)
  68. class VmapInterpreter(FuncTorchInterpreter):
  69. def __init__(self, cdata: CInterpreter):
  70. assert cdata.key() == TransformType.Vmap
  71. # NOTE: [Interpreter cdata vs cptr]
  72. # cdata is a generic CInterpreter. We wrap it in a CVmapInterpreterPtr
  73. # so that we can access methods specific to the vmap interpreter
  74. self._cdata = cdata
  75. self._cptr = CVmapInterpreterPtr(cdata)
  76. def process(self, op, args, kwargs):
  77. kernel = op.functorch_table[TransformType.Vmap]
  78. return kernel(self, *args, **kwargs)
  79. def batch_size(self):
  80. return self._cptr.batchSize()
  81. def randomness(self):
  82. typ = self._cptr.randomness()
  83. if typ == RandomnessType.Error:
  84. return "error"
  85. elif typ == RandomnessType.Same:
  86. return "same"
  87. elif typ == RandomnessType.Different:
  88. return "different"
  89. raise RuntimeError(f"Unknown RandomnessType: {typ}")
  90. @contextlib.contextmanager
  91. def nested(*contexts):
  92. with contextlib.ExitStack() as stack:
  93. for ctx in contexts:
  94. stack.enter_context(ctx)
  95. yield contexts
  96. class GradInterpreter(FuncTorchInterpreter):
  97. def __init__(self, cdata: CInterpreter):
  98. assert cdata.key() == TransformType.Grad
  99. # See NOTE: [Interpreter cdata vs cptr]
  100. self._cdata = cdata
  101. self._cptr = CGradInterpreterPtr(cdata)
  102. def lift(self, args, kwargs):
  103. args, kwargs = pytree.tree_map_only(torch.Tensor, self._cptr.lift, [args, kwargs])
  104. return args, kwargs
  105. def process(self, op, args, kwargs):
  106. kernel = op.functorch_table[TransformType.Grad]
  107. args, kwargs = self.lift(args, kwargs)
  108. return kernel(self, *args, **kwargs)
  109. # GradInterpreter has custom lower because of the no_grad interaction
  110. # See NOTE [grad and vjp interaction with no_grad]
  111. # This logic is mirrored from C++ GradInterpreterPtr::sendToNextInterpreter
  112. def lower(self):
  113. prev_grad_mode = self.prev_grad_mode()
  114. if not self.prev_grad_mode:
  115. return nested(torch.no_grad(), super().lower())
  116. return super().lower()
  117. def prev_grad_mode(self):
  118. return self._cptr.prevGradMode()
  119. class JvpInterpreter(FuncTorchInterpreter):
  120. def __init__(self, cdata: CInterpreter):
  121. assert cdata.key() == TransformType.Jvp
  122. # See NOTE: [Interpreter cdata vs cptr]
  123. self._cdata = cdata
  124. self._cptr = CJvpInterpreterPtr(cdata)
  125. def lift(self, args, kwargs):
  126. args, kwargs = pytree.tree_map_only(torch.Tensor, self._cptr.lift, [args, kwargs])
  127. return args, kwargs
  128. def process(self, op, args, kwargs):
  129. kernel = op.functorch_table[TransformType.Jvp]
  130. args, kwargs = self.lift(args, kwargs)
  131. return kernel(self, *args, **kwargs)
  132. # Jvp has custom lower because of the no_fwd_grad interaction
  133. # See NOTE [grad and vjp interaction with no_grad] for related info.
  134. # This logic is mirrored from C++ JvpInterpreterPtr::sendToNextInterpreter
  135. def lower(self):
  136. prev_fwd_grad_mode = self.prev_fwd_grad_mode()
  137. if not self.prev_fwd_grad_mode:
  138. return nested(_set_fwd_grad_enabled(False), super().lower())
  139. return super().lower()
  140. def prev_fwd_grad_mode(self):
  141. return self._cptr.prevFwdGradMode()
  142. class FunctionalizeInterpreter(FuncTorchInterpreter):
  143. def __init__(self, cdata: CInterpreter):
  144. assert cdata.key() == TransformType.Functionalize
  145. self._cdata = cdata
  146. self._cptr = CFunctionalizeInterpreterPtr(cdata)
  147. def process(self, op, args, kwargs):
  148. kernel = op.functorch_table[TransformType.Functionalize]
  149. return kernel(self, *args, **kwargs)
  150. def functionalize_add_back_views(self):
  151. return self._cptr.functionalizeAddBackViews()
  152. def coerce_cinterpreter(cinterpreter: CInterpreter) -> FuncTorchInterpreter:
  153. key = cinterpreter.key()
  154. if key == TransformType.Grad:
  155. return GradInterpreter(cinterpreter)
  156. if key == TransformType.Vmap:
  157. return VmapInterpreter(cinterpreter)
  158. if key == TransformType.Jvp:
  159. return JvpInterpreter(cinterpreter)
  160. if key == TransformType.Functionalize:
  161. return FunctionalizeInterpreter(cinterpreter)
  162. raise RuntimeError(f"NYI: PyDispatcher has not implemented support for {key}")
  163. def retrieve_current_functorch_interpreter():
  164. interpreter = torch._C._functorch.peek_interpreter_stack()
  165. assert interpreter is not None
  166. return coerce_cinterpreter(interpreter)
  167. def dispatch_functorch(op, args, kwargs):
  168. interpreter = retrieve_current_functorch_interpreter()
  169. # In traditional PyTorch operators, DispatchKey::FuncTorchTensorWrapper's
  170. # unwrap_dead_tensors fallback handles unwrapping dead tensor wrappers.
  171. # PyDispatcher sidesteps the PyTorch dispatcher when dealing with functorch
  172. # transforms, so we manually unwrap the dead tensors here.
  173. # This logic won't need to exist when we have mode-only functorch.
  174. args, kwargs = pytree.tree_map_only(
  175. torch.Tensor, torch._C._functorch.unwrap_if_dead, (args, kwargs))
  176. return interpreter.process(op, args, kwargs)