graph.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550
  1. import torch
  2. import contextlib
  3. from typing import Callable, Any, Dict, Tuple, Optional, Sequence, List, Set
  4. from torch.utils.hooks import RemovableHandle
  5. from torch.utils._python_dispatch import TorchDispatchMode
  6. from collections import defaultdict
  7. import weakref
  8. import abc
  9. __all__ = [
  10. "saved_tensors_hooks",
  11. "save_on_cpu",
  12. "disable_saved_tensors_hooks",
  13. "register_multi_grad_hook",
  14. "allow_mutation_on_saved_tensors",
  15. "Node",
  16. ]
  17. class Node(abc.ABC):
  18. @abc.abstractmethod
  19. def name(self) -> str:
  20. r"""Returns the name.
  21. Example::
  22. >>> import torch
  23. >>> a = torch.tensor([0., 0., 0.], requires_grad=True)
  24. >>> b = a.clone()
  25. >>> assert isinstance(b.grad_fn, torch.autograd.graph.Node)
  26. >>> print(b.grad_fn.name())
  27. CloneBackward0
  28. """
  29. ...
  30. @property
  31. @abc.abstractmethod
  32. def next_functions(self) -> Tuple[Tuple[Optional['Node'], int], ...]:
  33. ...
  34. @abc.abstractmethod
  35. def metadata(self) -> dict:
  36. r"""Returns the metadata."""
  37. ...
  38. @abc.abstractmethod
  39. def _register_hook_dict(self, tensor: torch.Tensor) -> None:
  40. ...
  41. @abc.abstractmethod
  42. def register_hook(self, fn: Callable[..., Any]) -> RemovableHandle:
  43. r"""Registers a backward hook.
  44. The hook will be called every time a gradient with respect to the
  45. Node is computed. The hook should have the following signature::
  46. hook(grad_inputs: Tuple[Tensor], grad_outputs: Tuple[Tensor]) -> Tuple[Tensor] or None
  47. The hook should not modify its argument, but it can optionally return
  48. a new gradient which will be used in place of :attr:`grad_outputs`.
  49. This function returns a handle with a method ``handle.remove()``
  50. that removes the hook from the module.
  51. .. note::
  52. See :ref:`backward-hooks-execution` for more information on how when this hook
  53. is executed, and how its execution is ordered relative to other hooks.
  54. Example::
  55. >>> import torch
  56. >>> a = torch.tensor([0., 0., 0.], requires_grad=True)
  57. >>> b = a.clone()
  58. >>> assert isinstance(b.grad_fn, torch.autograd.graph.Node)
  59. >>> handle = b.grad_fn.register_hook(lambda gI, gO: (gO[0] * 2,))
  60. >>> b.sum().backward(retain_graph=True)
  61. >>> print(a.grad)
  62. tensor([2., 2., 2.])
  63. >>> handle.remove() # Removes the hook
  64. >>> a.grad = None
  65. >>> b.sum().backward(retain_graph=True)
  66. >>> print(a.grad)
  67. tensor([1., 1., 1.])
  68. """
  69. ...
  70. @abc.abstractmethod
  71. def register_prehook(self, fn: Callable[..., Any]) -> RemovableHandle:
  72. r"""Registers a backward pre-hook.
  73. The hook will be called every time a gradient with respect to the
  74. Node is computed. The hook should have the following signature::
  75. hook(grad_outputs: Tuple[Tensor]) -> Tuple[Tensor] or None
  76. The hook should not modify its argument, but it can optionally return
  77. a new gradient which will be used in place of :attr:`grad_outputs`.
  78. This function returns a handle with a method ``handle.remove()``
  79. that removes the hook from the module.
  80. .. note::
  81. See :ref:`backward-hooks-execution` for more information on how when this hook
  82. is executed, and how its execution is ordered relative to other hooks.
  83. Example::
  84. >>> a = torch.tensor([0., 0., 0.], requires_grad=True)
  85. >>> b = a.clone()
  86. >>> assert isinstance(b.grad_fn, torch.autograd.graph.Node)
  87. >>> handle = b.grad_fn.register_prehook(lambda gI: (gI[0] * 2,))
  88. >>> b.sum().backward(retain_graph=True)
  89. >>> print(a.grad)
  90. tensor([2., 2., 2.])
  91. >>> handle.remove()
  92. >>> a.grad = None
  93. >>> b.sum().backward(retain_graph=True)
  94. >>> print(a.grad)
  95. tensor([1., 1., 1.])
  96. """
  97. ...
  98. @classmethod
  99. def __subclasshook__(cls, C):
  100. if cls is Node:
  101. if ((C is not None and C is getattr(torch._C._functions, C.__name__, None))
  102. or issubclass(C, torch.autograd.function.BackwardCFunction)):
  103. return True
  104. return NotImplemented
  105. class saved_tensors_hooks():
  106. """Context-manager that sets a pair of pack / unpack hooks for saved tensors.
  107. Use this context-manager to define how intermediary results of an operation
  108. should be packed before saving, and unpacked on retrieval.
  109. In that context, the ``pack_hook`` function will be called everytime an
  110. operation saves a tensor for backward (this includes intermediary results
  111. saved using
  112. :func:`~torch.autograd.function._ContextMethodMixin.save_for_backward` but
  113. also those recorded by a PyTorch-defined operation). The output of
  114. ``pack_hook`` is then stored in the computation graph instead of the
  115. original tensor.
  116. The ``unpack_hook`` is called when the saved tensor needs to be accessed,
  117. namely when executing :func:`torch.Tensor.backward()` or
  118. :func:`torch.autograd.grad()`. It takes as argument the *packed* object
  119. returned by ``pack_hook`` and should return a tensor which has the same
  120. content as the original tensor (passed as input to the corresponding
  121. ``pack_hook``).
  122. The hooks should have the following signatures:
  123. pack_hook(tensor: Tensor) -> Any
  124. unpack_hook(Any) -> Tensor
  125. where the return value of ``pack_hook`` is a valid input to ``unpack_hook``.
  126. In general, you want ``unpack_hook(pack_hook(t))`` to be equal to ``t`` in terms
  127. of value, size, dtype and device.
  128. Example::
  129. >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)
  130. >>> def pack_hook(x):
  131. ... print("Packing", x)
  132. ... return x
  133. >>>
  134. >>> def unpack_hook(x):
  135. ... print("Unpacking", x)
  136. ... return x
  137. >>>
  138. >>> a = torch.ones(5, requires_grad=True)
  139. >>> b = torch.ones(5, requires_grad=True) * 2
  140. >>> with torch.autograd.graph.saved_tensors_hooks(pack_hook, unpack_hook):
  141. ... y = a * b
  142. Packing tensor([1., 1., 1., 1., 1.], requires_grad=True)
  143. Packing tensor([2., 2., 2., 2., 2.], grad_fn=<MulBackward0>)
  144. >>> y.sum().backward()
  145. Unpacking tensor([1., 1., 1., 1., 1.], requires_grad=True)
  146. Unpacking tensor([2., 2., 2., 2., 2.], grad_fn=<MulBackward0>)
  147. .. warning ::
  148. Performing an inplace operation on the input to either hooks may lead
  149. to undefined behavior.
  150. .. warning ::
  151. Only one pair of hooks is allowed at a time. When recursively nesting this
  152. context-manager, only the inner-most pair of hooks will be applied.
  153. """
  154. def __init__(self, pack_hook: Callable[[torch.Tensor], Any], unpack_hook: Callable[[Any], torch.Tensor]):
  155. self.pack_hook = pack_hook
  156. self.unpack_hook = unpack_hook
  157. def __enter__(self):
  158. torch._C._autograd._push_saved_tensors_default_hooks(self.pack_hook, self.unpack_hook)
  159. def __exit__(self, *args: Any):
  160. torch._C._autograd._pop_saved_tensors_default_hooks()
  161. class save_on_cpu(saved_tensors_hooks):
  162. """Context-manager under which tensors saved by the forward pass will be
  163. stored on cpu, then retrieved for backward.
  164. When performing operations within this context manager, intermediary
  165. results saved in the graph during the forward pass will be moved to CPU,
  166. then copied back to the original device when needed for the backward pass.
  167. If the graph was already on CPU, no tensor copy is performed.
  168. Use this context-manager to trade compute for GPU memory usage (e.g.
  169. when your model doesn't fit in GPU memory during training).
  170. Args:
  171. pin_memory (bool): If ``True`` tensors will be saved to CPU pinned memory
  172. during packing and copied to GPU asynchronously during unpacking.
  173. Defaults to ``False``.
  174. Also see :ref:`cuda-memory-pinning`.
  175. Example::
  176. >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
  177. >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)
  178. >>> a = torch.randn(5, requires_grad=True, device="cuda")
  179. >>> b = torch.randn(5, requires_grad=True, device="cuda")
  180. >>> c = torch.randn(5, requires_grad=True, device="cuda")
  181. >>>
  182. >>> def f(a, b, c):
  183. ... prod_1 = a * b # a and b are saved on GPU
  184. ... with torch.autograd.graph.save_on_cpu():
  185. ... prod_2 = prod_1 * c # prod_1 and c are saved on CPU
  186. ... y = prod_2 * a # prod_2 and a are saved on GPU
  187. ... return y
  188. >>>
  189. >>> y = f(a, b, c)
  190. >>> del a, b, c # for illustration only
  191. >>> # the content of a, b, and prod_2 are still alive on GPU
  192. >>> # the content of prod_1 and c only live on CPU
  193. >>> y.sum().backward() # all CPU tensors are moved back to GPU, for backward
  194. >>> # all intermediary tensors are released (deleted) after the call to backward
  195. """
  196. def __init__(self, pin_memory=False):
  197. def pack_to_cpu(tensor):
  198. if not pin_memory:
  199. return (tensor.device, tensor.cpu())
  200. packed = torch.empty(
  201. tensor.size(),
  202. dtype=tensor.dtype,
  203. layout=tensor.layout,
  204. pin_memory=(torch.cuda.is_available() and not tensor.is_sparse))
  205. packed.copy_(tensor)
  206. return (tensor.device, packed)
  207. def unpack_from_cpu(packed):
  208. device, tensor = packed
  209. return tensor.to(device, non_blocking=pin_memory)
  210. super().__init__(pack_to_cpu, unpack_from_cpu)
  211. @contextlib.contextmanager
  212. def disable_saved_tensors_hooks(error_message):
  213. """Context-manager that disables the saved tensors default hooks feature.
  214. Useful for if you are creating a feature that does not work with saved
  215. tensors default hooks.
  216. Args:
  217. error_message (str): When saved tensors default hooks are used when they
  218. have been are disabled, a RuntimeError with this
  219. error message gets raised.
  220. Example::
  221. >>> # xdoctest: +SKIP(failing)
  222. >>> message = "saved tensors default hooks are disabled"
  223. >>> with torch.autograd.graph.disable_saved_tensors_hooks(message):
  224. ... # Raises RuntimeError: saved tensors default hooks are disabled
  225. ... with torch.autograd.graph.save_on_cpu():
  226. ... pass
  227. """
  228. try:
  229. maybe_prev_message = torch._C._autograd._saved_tensors_hooks_get_disabled_error_message()
  230. torch._C._autograd._saved_tensors_hooks_disable(error_message)
  231. yield
  232. finally:
  233. # See NOTE: [disabled_error_message invariant]
  234. if maybe_prev_message is None:
  235. torch._C._autograd._saved_tensors_hooks_enable()
  236. else:
  237. torch._C._autograd._saved_tensors_hooks_disable(maybe_prev_message)
  238. def register_multi_grad_hook(tensors: Sequence[torch.Tensor], fn: Callable[[Sequence[Optional[torch.Tensor]]], None]):
  239. r"""Registers a multi-grad backward hook.
  240. The hook will be called after gradients with respect to every tensor in
  241. :attr:`tensors` have been computed. If a tensor is in :attr:`tensors` but
  242. is not part of the graph, or if a tensor is not needed to compute the gradients
  243. for any ``inputs`` specified for the current ``.backward()`` or ``.grad()`` call,
  244. this tensor will be ignored and the hook will not wait for its gradient to be
  245. computed.
  246. After every non-ignored tensor's gradient has been computed, :attr:`fn` will be
  247. called with those gradients. ``None`` will be passed for tensors that did not
  248. have their gradients computed.
  249. The hook should not modify its arguments.
  250. This function returns a handle with a method ``handle.remove()`` that removes the hook.
  251. .. note::
  252. See :ref:`backward-hooks-execution` for more information on how when this hook
  253. is executed, and how its execution is ordered relative to other hooks.
  254. Example::
  255. >>> import torch
  256. >>>
  257. >>> a = torch.rand(2, 3, requires_grad=True)
  258. >>> b = torch.rand(2, 3, requires_grad=True)
  259. >>> c = a * b
  260. >>> d = a * b
  261. >>>
  262. >>> def fn(grads):
  263. ... print([g is not None for g in grads])
  264. ...
  265. >>> torch.autograd.graph.register_multi_grad_hook((a, b, c, d), fn)
  266. >>>
  267. >>> c.sum().backward(retain_graph=True)
  268. [True, True, True, False]
  269. >>> c.sum().backward(inputs=(a,), retain_graph=True)
  270. [True, False, True, False]
  271. >>>
  272. """
  273. count: Dict[int, int] = dict()
  274. nb_calls = None
  275. buffer: Dict[int, List[Optional[torch.Tensor]]] = dict()
  276. def get_grad_fn(t):
  277. # or grad accumulator
  278. if t.requires_grad and t.grad_fn is None:
  279. return t.clone().grad_fn.next_functions[0][0]
  280. else:
  281. return t.grad_fn
  282. grad_fns = list(map(get_grad_fn, tensors))
  283. def get_inner_hook(idx):
  284. def inner_hook(grad: torch.Tensor):
  285. nonlocal count, nb_calls, buffer
  286. id = torch._C._current_graph_task_id()
  287. assert id != -1, "expected this hook to be called inside a backward call"
  288. count[id] = count.get(id, 0)
  289. buffer[id] = buffer.get(id, [None] * len(tensors))
  290. if count[id] == 0:
  291. # On the first call, compute the actual nb_calls and buffer
  292. nb_calls = sum(torch._C._will_engine_execute_node(g) for g in grad_fns) # type: ignore[attr-defined]
  293. buffer[id][idx] = grad
  294. count[id] += 1
  295. if count[id] == nb_calls:
  296. fn(buffer[id])
  297. del count[id]
  298. del buffer[id]
  299. return inner_hook
  300. class Handle(RemovableHandle):
  301. handles: Tuple[RemovableHandle, ...]
  302. def __init__(self, handles: Tuple[RemovableHandle, ...]):
  303. self.handles = handles
  304. def remove(self):
  305. for handle in self.handles:
  306. handle.remove()
  307. def __getstate__(self):
  308. return self.handles
  309. def __setstate__(self, state):
  310. self.handles = state
  311. handles: List[RemovableHandle] = []
  312. for i, t in enumerate(tensors):
  313. handles.append(t.register_hook(get_inner_hook(i)))
  314. return Handle(tuple(handles))
  315. # NOTE [Allow mutation on tensors saved for backward]
  316. #
  317. # 1. Tensor gets saved for backward
  318. # - remember the python object id and the version of the tensor
  319. # - remember aliasing information (data_ptr of base + version)
  320. # - save the original so we control its lifetime
  321. # 2. Any time a tensor gets in-placed
  322. # - for each tensor aliased to it:
  323. # - check using its object id and version to see if it has been saved
  324. # - if it has been saved, clone it
  325. # - delete the reference to the original
  326. # 3. during backward
  327. # - if the clone exists, the tensor must've been modified in-place
  328. _allow_mutation_on_saved_tensors_enabled = False
  329. def _get_tid(t) -> Tuple[int, int, int]:
  330. return (id(t), t.data_ptr(), t._version)
  331. def _get_sid(t) -> Tuple[int, int]:
  332. return (t.data_ptr(), t._version)
  333. class _Handle():
  334. pass
  335. class _swap_with_cloned(saved_tensors_hooks):
  336. def __init__(self, ctx):
  337. def pack_hook(t):
  338. tid = _get_tid(t)
  339. sid = _get_sid(t)
  340. # Tensors saved for backward have an entry in _tid_to_weakhandle
  341. handle: Optional[_Handle] = None
  342. # Save aliasing information
  343. ctx.sid_to_tid[sid].add(tid)
  344. # NB: The same tensor (of the same version) can be saved multiple times
  345. if tid not in ctx.tid_to_weakhandle:
  346. handle = _Handle()
  347. ctx.tid_to_weakhandle[tid] = handle
  348. ctx.original[handle] = t
  349. else:
  350. # Store an additional strong reference to the handle
  351. handle = ctx.tid_to_weakhandle[tid]
  352. return handle
  353. def unpack_hook(tup):
  354. handle = tup
  355. error_msg = (
  356. "Trying to backward outside of the 'allow_mutation_on_saved_tensors' context"
  357. "in which the graph was originally recorded.")
  358. assert _allow_mutation_on_saved_tensors_enabled, error_msg
  359. if handle in ctx.cloned:
  360. res = ctx.cloned[handle]
  361. else:
  362. assert handle in ctx.original, error_msg
  363. res = ctx.original[handle]
  364. return res
  365. super().__init__(pack_hook, unpack_hook)
  366. class _CloneArgBeforeMutateMode(TorchDispatchMode):
  367. def __init__(self, ctx):
  368. self.ctx = ctx
  369. def __torch_dispatch__(self, func, types, args=(), kwargs=None):
  370. kwargs = kwargs or {}
  371. for idx, arg in enumerate(func._schema.arguments):
  372. if arg.alias_info is not None and arg.alias_info.is_write:
  373. t = kwargs["out"] if arg.is_out else args[idx]
  374. tid = _get_tid(t)
  375. sid = _get_sid(t)
  376. ctx = self.ctx
  377. if sid in ctx.sid_to_tid:
  378. for tid in ctx.sid_to_tid[sid]:
  379. if tid not in ctx.tid_to_weakhandle:
  380. # We know that if tid is in sid_to_tid, then it must also be in
  381. # tid_to_weakhandle. However, it is possible for the tensor to be
  382. # saved at one point, but cleared by backward before it is modified
  383. # in-place. Consider the following example:
  384. #
  385. # >>> a = torch.randn(2, 3, requires_grad=True).clone()
  386. # >>> out = (a**2).sum()
  387. # >>> out.backward()
  388. # >>> a.sin_()
  389. continue
  390. handle = ctx.tid_to_weakhandle[tid]
  391. if handle in ctx.cloned:
  392. # The same exact tensor has been cloned already
  393. continue
  394. ctx.cloned[handle] = ctx.original[handle].clone()
  395. del ctx.original[handle]
  396. rs = func(*args, **kwargs)
  397. return rs
  398. class _AllowMutationOnSavedContext():
  399. def __init__(self):
  400. self.cloned: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary()
  401. self.original: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary()
  402. self.tid_to_weakhandle: weakref.WeakValueDictionary = weakref.WeakValueDictionary()
  403. self.sid_to_tid: Dict[Tuple[int, int], Set[Tuple[int, int, int]]] = defaultdict(set)
  404. def clear(self):
  405. self.cloned.clear()
  406. self.original.clear()
  407. self.tid_to_weakhandle.clear()
  408. self.sid_to_tid.clear()
  409. @contextlib.contextmanager
  410. def allow_mutation_on_saved_tensors():
  411. """Context manager under which mutating tensors saved for backward is allowed
  412. Under this context manager, tensors saved for backward are cloned on mutation,
  413. so the original version can still be used during backward. Normally, mutating a tensor
  414. saved for backward will result in an error raised when it's used during backward.
  415. To ensure the correct behavior, both the forward and backward should be run under
  416. the same context manager.
  417. returns:
  418. An _AllowMutationOnSavedContext object storing the state managed by this
  419. context manager. This object can be useful for debugging purposes. The state
  420. managed by the context manager is automatically cleared upon exiting.
  421. Example::
  422. >>> import torch
  423. >>> with torch.autograd.graph.allow_mutation_on_saved_tensors():
  424. ... # forward
  425. ... a = torch.ones(2, 3, requires_grad=True)
  426. ... b = a.clone()
  427. ... out = (b**2).sum()
  428. ... b.sin_()
  429. ... # backward
  430. ... out.sum().backward()
  431. ...
  432. tensor([[0.8415, 0.8415, 0.8415],
  433. [0.8415, 0.8415, 0.8415]], grad_fn=<SinBackward0>)
  434. """
  435. global _allow_mutation_on_saved_tensors_enabled
  436. ctx = _AllowMutationOnSavedContext()
  437. with _swap_with_cloned(ctx), _CloneArgBeforeMutateMode(ctx):
  438. try:
  439. if _allow_mutation_on_saved_tensors_enabled:
  440. raise RuntimeError("allow_mutation_on_saved_tensors contexts cannot be nested")
  441. _allow_mutation_on_saved_tensors_enabled = True
  442. yield ctx
  443. finally:
  444. ctx.clear()
  445. _allow_mutation_on_saved_tensors_enabled = False