123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218 |
- from abc import ABC, abstractmethod
- import contextlib
- from typing import Any
- import torch
- import torch.utils._pytree as pytree
- from torch._C._functorch import (
- TransformType,
- RandomnessType,
- CInterpreter,
- CGradInterpreterPtr,
- CFunctionalizeInterpreterPtr,
- CVmapInterpreterPtr,
- CJvpInterpreterPtr,
- pop_dynamic_layer_stack,
- push_dynamic_layer_stack,
- )
- from torch.autograd.forward_ad import _set_fwd_grad_enabled
- """
- This file contains the functorch integration with PyDispatcher.
- PyDispatcher does not understand functorch's DynamicLayerStack dispatching
- logic because it is entirely implemented in C++ in the fallbacks for two
- dispatch keys, FuncTorchDynamicLayer{Front, Back}Mode (PyDispatcher is unable
- to directly reuse C++ boxed fallbacks).
- Instead of trying to hammer PyDispatcher into understanding those fallbacks,
- we re-implement the logic of peeking the top of the stack for an interpreter,
- selecting the interpreter to dispatch on, etc, in Python. This leads to a
- simpler design.
- The main difference between C++ functorch and PyDispatcher's functorch logic
- is that:
- - C++ functorch needs to manually tweak dispatch keys to ping-pong between
- DynamicLayerFrontMode and DynamicLayerBackMode.
- - PyDispatcher's functorch logic pops an Interpreter from the top of the stack
- and asks it to execute the rule associated with the Interpreter.
- In C++ we do the ping-pong because e.g. vmap rules are associated with the
- batched DispatchKey, but in PyDispatcher we are able to avoid this by asking
- the user to register a batching rule directly to a transform that an
- interpreter then invokes.
- """
- # FuncTorchInterpreter is the Python version of Interpreter (recall that
- # the DynamicLayerStack is a stack of interpreters).
- # It is a wrapper around the actual C++ Interpreter object.
- #
- # Keep the methods in sync with aten/src/ATen/functorch/Interpreter.h
- class FuncTorchInterpreter(ABC):
- def __init__(self, cptr: Any):
- self._cptr = cptr
- # Process an operation. eg for vmap, this is invoking a batching rule.
- # Conceptually this is analogous to Interpreter::process in C++
- @abstractmethod
- def process(self, op, args, kwargs):
- pass
- # lower an operation from this Interpreter to the next Interpreter on the stack.
- # Concretely, this involves temporarily popping the current Interpreter.
- # Conceptually this is analogous to Interpreter::sendToNextInterpreter in C++
- def lower(self):
- return temporarily_pop_interpreter_stack()
- def level(self):
- return self._cptr.level()
- def key(self):
- return self._cptr.key()
- @contextlib.contextmanager
- def temporarily_pop_interpreter_stack():
- try:
- saved = pop_dynamic_layer_stack()
- yield
- finally:
- push_dynamic_layer_stack(saved)
- class VmapInterpreter(FuncTorchInterpreter):
- def __init__(self, cdata: CInterpreter):
- assert cdata.key() == TransformType.Vmap
- # NOTE: [Interpreter cdata vs cptr]
- # cdata is a generic CInterpreter. We wrap it in a CVmapInterpreterPtr
- # so that we can access methods specific to the vmap interpreter
- self._cdata = cdata
- self._cptr = CVmapInterpreterPtr(cdata)
- def process(self, op, args, kwargs):
- kernel = op.functorch_table[TransformType.Vmap]
- return kernel(self, *args, **kwargs)
- def batch_size(self):
- return self._cptr.batchSize()
- def randomness(self):
- typ = self._cptr.randomness()
- if typ == RandomnessType.Error:
- return "error"
- elif typ == RandomnessType.Same:
- return "same"
- elif typ == RandomnessType.Different:
- return "different"
- raise RuntimeError(f"Unknown RandomnessType: {typ}")
- @contextlib.contextmanager
- def nested(*contexts):
- with contextlib.ExitStack() as stack:
- for ctx in contexts:
- stack.enter_context(ctx)
- yield contexts
- class GradInterpreter(FuncTorchInterpreter):
- def __init__(self, cdata: CInterpreter):
- assert cdata.key() == TransformType.Grad
- # See NOTE: [Interpreter cdata vs cptr]
- self._cdata = cdata
- self._cptr = CGradInterpreterPtr(cdata)
- def lift(self, args, kwargs):
- args, kwargs = pytree.tree_map_only(torch.Tensor, self._cptr.lift, [args, kwargs])
- return args, kwargs
- def process(self, op, args, kwargs):
- kernel = op.functorch_table[TransformType.Grad]
- args, kwargs = self.lift(args, kwargs)
- return kernel(self, *args, **kwargs)
- # GradInterpreter has custom lower because of the no_grad interaction
- # See NOTE [grad and vjp interaction with no_grad]
- # This logic is mirrored from C++ GradInterpreterPtr::sendToNextInterpreter
- def lower(self):
- prev_grad_mode = self.prev_grad_mode()
- if not self.prev_grad_mode:
- return nested(torch.no_grad(), super().lower())
- return super().lower()
- def prev_grad_mode(self):
- return self._cptr.prevGradMode()
- class JvpInterpreter(FuncTorchInterpreter):
- def __init__(self, cdata: CInterpreter):
- assert cdata.key() == TransformType.Jvp
- # See NOTE: [Interpreter cdata vs cptr]
- self._cdata = cdata
- self._cptr = CJvpInterpreterPtr(cdata)
- def lift(self, args, kwargs):
- args, kwargs = pytree.tree_map_only(torch.Tensor, self._cptr.lift, [args, kwargs])
- return args, kwargs
- def process(self, op, args, kwargs):
- kernel = op.functorch_table[TransformType.Jvp]
- args, kwargs = self.lift(args, kwargs)
- return kernel(self, *args, **kwargs)
- # Jvp has custom lower because of the no_fwd_grad interaction
- # See NOTE [grad and vjp interaction with no_grad] for related info.
- # This logic is mirrored from C++ JvpInterpreterPtr::sendToNextInterpreter
- def lower(self):
- prev_fwd_grad_mode = self.prev_fwd_grad_mode()
- if not self.prev_fwd_grad_mode:
- return nested(_set_fwd_grad_enabled(False), super().lower())
- return super().lower()
- def prev_fwd_grad_mode(self):
- return self._cptr.prevFwdGradMode()
- class FunctionalizeInterpreter(FuncTorchInterpreter):
- def __init__(self, cdata: CInterpreter):
- assert cdata.key() == TransformType.Functionalize
- self._cdata = cdata
- self._cptr = CFunctionalizeInterpreterPtr(cdata)
- def process(self, op, args, kwargs):
- kernel = op.functorch_table[TransformType.Functionalize]
- return kernel(self, *args, **kwargs)
- def functionalize_add_back_views(self):
- return self._cptr.functionalizeAddBackViews()
- def coerce_cinterpreter(cinterpreter: CInterpreter) -> FuncTorchInterpreter:
- key = cinterpreter.key()
- if key == TransformType.Grad:
- return GradInterpreter(cinterpreter)
- if key == TransformType.Vmap:
- return VmapInterpreter(cinterpreter)
- if key == TransformType.Jvp:
- return JvpInterpreter(cinterpreter)
- if key == TransformType.Functionalize:
- return FunctionalizeInterpreter(cinterpreter)
- raise RuntimeError(f"NYI: PyDispatcher has not implemented support for {key}")
- def retrieve_current_functorch_interpreter():
- interpreter = torch._C._functorch.peek_interpreter_stack()
- assert interpreter is not None
- return coerce_cinterpreter(interpreter)
- def dispatch_functorch(op, args, kwargs):
- interpreter = retrieve_current_functorch_interpreter()
- # In traditional PyTorch operators, DispatchKey::FuncTorchTensorWrapper's
- # unwrap_dead_tensors fallback handles unwrapping dead tensor wrappers.
- # PyDispatcher sidesteps the PyTorch dispatcher when dealing with functorch
- # transforms, so we manually unwrap the dead tensors here.
- # This logic won't need to exist when we have mode-only functorch.
- args, kwargs = pytree.tree_map_only(
- torch.Tensor, torch._C._functorch.unwrap_if_dead, (args, kwargs))
- return interpreter.process(op, args, kwargs)
|