checkpoint_activation.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277
  1. from contextlib import contextmanager
  2. from functools import partial
  3. from typing import Any, List, Optional, Tuple
  4. from weakref import ref, ReferenceType, WeakKeyDictionary
  5. import torch
  6. import torch.nn as nn
  7. from torch.utils.checkpoint import detach_variable, get_device_states, set_device_states
  8. from .contract import contract
  9. @contextmanager
  10. def _no_hook(module: nn.Module):
  11. r"""
  12. Disable hooks installed by checkpoint to avoid unintentional recursion
  13. during backward recomputation.
  14. """
  15. orig_enable_hook = checkpoint.state(module).enable_hook
  16. checkpoint.state(module).enable_hook = False
  17. try:
  18. yield
  19. except Exception:
  20. raise
  21. finally:
  22. checkpoint.state(module).enable_hook = orig_enable_hook
  23. class _ModuleHookCheckpointFunction(torch.autograd.Function):
  24. @staticmethod
  25. def forward(ctx, module: nn.Module, output: Any, *inputs: Any) -> Any: # type: ignore[override]
  26. ctx.module = module
  27. # Save non-tensor inputs in ctx, keep a placeholder None for tensors
  28. # to be filled out during the backward.
  29. ctx.inputs = []
  30. ctx.tensor_indices = []
  31. tensor_inputs = []
  32. for i, inp in enumerate(inputs):
  33. if torch.is_tensor(inp):
  34. tensor_inputs.append(inp)
  35. ctx.tensor_indices.append(i)
  36. ctx.inputs.append(None)
  37. else:
  38. ctx.inputs.append(inp)
  39. ctx.save_for_backward(*tensor_inputs)
  40. return output
  41. @staticmethod
  42. def backward(ctx, output_grads: Tuple[Optional[torch.Tensor]]) -> Any: # type: ignore[override]
  43. if not torch.autograd._is_checkpoint_valid():
  44. raise RuntimeError(
  45. "Checkpointing is not compatible with .grad() or when an "
  46. "`inputs` parameter is passed to .backward(). Please use "
  47. ".backward() and do not pass its `inputs` argument."
  48. )
  49. # Copy the list to avoid modifying original list.
  50. inputs = list(ctx.inputs)
  51. tensor_indices = ctx.tensor_indices
  52. tensors = ctx.saved_tensors
  53. # Fill in inputs with appropriate saved tensors.
  54. for i, idx in enumerate(tensor_indices):
  55. inputs[idx] = tensors[i]
  56. # Stash the surrounding rng state, and mimic the state that was
  57. # present at this time during forward. Restore the surrounding state
  58. # when we're done.
  59. rng_devices = []
  60. if checkpoint.state(ctx.module).had_cuda_in_fwd:
  61. rng_devices = checkpoint.state(ctx.module).fwd_gpu_devices
  62. with torch.random.fork_rng(devices=rng_devices, enabled=True):
  63. torch.set_rng_state(checkpoint.state(ctx.module).fwd_cpu_state)
  64. if checkpoint.state(ctx.module).had_cuda_in_fwd:
  65. set_device_states(
  66. checkpoint.state(ctx.module).fwd_gpu_devices,
  67. checkpoint.state(ctx.module).fwd_gpu_states,
  68. )
  69. detached_inputs = detach_variable(tuple(inputs))
  70. with torch.enable_grad(), _no_hook(ctx.module):
  71. outputs = ctx.module(*detached_inputs)
  72. if isinstance(outputs, torch.Tensor):
  73. outputs = (outputs,)
  74. if isinstance(output_grads, torch.Tensor):
  75. output_grads = (output_grads,)
  76. # run backward() with only tensor that requires grad
  77. outputs_requires_grad: List[torch.Tensor] = []
  78. output_grad_tensors: List[torch.Tensor] = []
  79. for i in range(len(outputs)):
  80. if torch.is_tensor(outputs[i]) and outputs[i].requires_grad:
  81. outputs_requires_grad.append(outputs[i])
  82. assert (
  83. output_grads[i] is not None
  84. ), f"expecting grad for output at index {i}, but got None."
  85. output_grad_tensors.append(output_grads[i]) # type: ignore[arg-type]
  86. if len(outputs_requires_grad) == 0:
  87. raise RuntimeError(
  88. "none of output has requires_grad=True,"
  89. " this checkpoint() is not necessary"
  90. )
  91. torch.autograd.backward(outputs_requires_grad, output_grad_tensors)
  92. grads = tuple(
  93. inp.grad if isinstance(inp, torch.Tensor) else None
  94. for inp in detached_inputs
  95. )
  96. # The two None is for forward argument module and output respectively.
  97. return (None, None) + grads
  98. class _Holder:
  99. pass
  100. def _pack(
  101. x: torch.Tensor,
  102. *,
  103. weak_holder_list: List[ReferenceType],
  104. ) -> _Holder:
  105. res = _Holder()
  106. weak_holder_list.append(ref(res))
  107. return res
  108. def _unpack(
  109. holder: _Holder,
  110. *,
  111. storage: WeakKeyDictionary,
  112. weak_holder_list: List[ReferenceType],
  113. module: nn.Module,
  114. inputs: Tuple[Any],
  115. ) -> torch.Tensor:
  116. holder_index = 0
  117. if len(storage) == 0:
  118. def inner_pack(inner: torch.Tensor):
  119. nonlocal holder_index
  120. if weak_holder_list[holder_index]() is None:
  121. # If the holder went out of scope, the SavedVariable is dead
  122. # and so the value will never be read from the storage. Skip
  123. # filling it.
  124. pass
  125. else:
  126. # Use detach here to ensure we don't keep the temporary
  127. # autograd graph created during the second forward
  128. storage[weak_holder_list[holder_index]()] = inner.detach()
  129. holder_index += 1
  130. return
  131. def inner_unpack(holder: _Holder):
  132. raise RuntimeError(
  133. "You are calling backwards on a tensor that is never exposed. "
  134. "Please open an issue."
  135. )
  136. with _no_hook(
  137. module
  138. ), torch.enable_grad(), torch.autograd.graph.saved_tensors_hooks(
  139. inner_pack, inner_unpack
  140. ):
  141. _unused = module(*inputs)
  142. if holder not in storage:
  143. raise RuntimeError(
  144. "Attempt to retrieve a tensor saved by autograd multiple times "
  145. "without checkpoint recomputation being triggered in between, this "
  146. "is not currently supported. Please open an issue with details on "
  147. "your use case so that we can prioritize adding this."
  148. )
  149. return storage[holder]
  150. @contract()
  151. def checkpoint(module: nn.Module, *, use_reentrant: bool = True) -> nn.Module:
  152. r"""
  153. This is a composable activation checkpointing API. Unlike functional
  154. activation checkpointing APIs, this one does not require changing model
  155. source code. Unlike ``nn.Module`` wrapper activation checkpointing APIs,
  156. this one does not modify model structure or fully-qualified names either.
  157. Under the hood, it registers activation checkpointing logic as pre- and
  158. post-forward hooks. Hence, this API can be easily applied to any model or
  159. sub-modules in the model.
  160. Args:
  161. module (nn.Module): the target model or sub-module to apply activation
  162. checkpointing.
  163. use_reentrant (bool): Apply activation checkpointing using reentrant
  164. autograd.
  165. Example::
  166. >>> # xdoctest: +SKIP
  167. >>> import torch.nn as nn
  168. >>>
  169. >>> class MyModel(nn.Module):
  170. >>> def __init__(self):
  171. >>> super().__init__()
  172. >>> self.l1 = nn.Linear(10, 10)
  173. >>> self.l2 = nn.Linear(10, 10)
  174. >>>
  175. >>> def forward(self, x):
  176. >>> return self.l2(self.l1(x))
  177. >>>
  178. >>> model = MyModel()
  179. >>> checkpoint(model.l1) # apply activation checkpointing only to l1
  180. >>> model(torch.zeros(2, 10)).sum().backward()
  181. """
  182. def forward_pre_hook(module: nn.Module, inputs: Tuple[Any, ...]) -> None:
  183. if checkpoint.state(module).enable_hook:
  184. checkpoint.state(module).orig_grad_enabled = torch.is_grad_enabled()
  185. if checkpoint.state(module).use_reentrant:
  186. torch.set_grad_enabled(False)
  187. checkpoint.state(module).fwd_cpu_state = torch.get_rng_state()
  188. # Don't eagerly initialize the cuda context by accident.
  189. # (If the user intends that the context is initialized later, within their
  190. # run_function, we SHOULD actually stash the cuda state here. Unfortunately,
  191. # we have no way to anticipate this will happen before we run the function.)
  192. checkpoint.state(module).had_cuda_in_fwd = False
  193. if torch.cuda._initialized:
  194. checkpoint.state(module).had_cuda_in_fwd = True
  195. (
  196. checkpoint.state(module).fwd_gpu_devices,
  197. checkpoint.state(module).fwd_gpu_states,
  198. ) = get_device_states(*inputs)
  199. else:
  200. # The Holder object for each of the saved object is saved
  201. # directly on the SavedVariable and is cleared when reset_data()
  202. # is called on it. We MUST make sure that this is the only
  203. # object having an owning reference to ensure that the Tensor
  204. # stored in storage is deleted as soon as the corresponding
  205. # SavedVariable data is cleared.
  206. storage: WeakKeyDictionary = WeakKeyDictionary()
  207. weak_holder_list: List[ReferenceType] = []
  208. saved_tensor_hooks = torch.autograd.graph.saved_tensors_hooks(
  209. partial(_pack, weak_holder_list=weak_holder_list),
  210. partial(
  211. _unpack,
  212. storage=storage,
  213. weak_holder_list=weak_holder_list,
  214. module=module,
  215. inputs=inputs,
  216. ),
  217. )
  218. saved_tensor_hooks.__enter__()
  219. checkpoint.state(module).saved_tensor_hooks = saved_tensor_hooks
  220. def forward_hook(module: nn.Module, inputs: Tuple[Any, ...], output: Any) -> Any:
  221. if checkpoint.state(module).enable_hook:
  222. torch.set_grad_enabled(checkpoint.state(module).orig_grad_enabled)
  223. if checkpoint.state(module).use_reentrant:
  224. return _ModuleHookCheckpointFunction.apply(module, output, *inputs)
  225. else:
  226. checkpoint.state(module).saved_tensor_hooks.__exit__()
  227. checkpoint.state(module).saved_tensor_hooks = None
  228. return output
  229. # This hook does the following things:
  230. # 1. detach outputs from the autograd graph to discard activations
  231. # 2. insert an autograd.Function after the forward pass to recompute
  232. # activations during the backward pass.
  233. checkpoint.state(module).enable_hook = True
  234. checkpoint.state(module).use_reentrant = use_reentrant
  235. module.register_forward_pre_hook(forward_pre_hook)
  236. # Use prepend to make sure we restore the original grad enabled state right
  237. # after the module forward invocation.
  238. module.register_forward_hook(forward_hook, prepend=True)
  239. return module