checkpoint.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359
  1. # Copyright 2019 Kakao Brain
  2. #
  3. # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
  4. #
  5. # This source code is licensed under the BSD license found in the
  6. # LICENSE file in the root directory of this source tree.
  7. """Checkpointing with preceding recomputation.
  8. PyTorch already provides the official checkpointing utilities in
  9. :mod:`torch.utils.checkpoint`. The official checkpointing combines
  10. recomputation and recursive backpropagation into one autograd function named
  11. ``CheckpointFunction``. Hence, the recomputation can be started only when the
  12. gradients arrive to the function. In Pipe, the recomputation needs to precede
  13. the gradient arrival to minimize the GPU idle time.
  14. We solve this problem by introducing separate autograd functions named
  15. :class:`Recompute` and :class:`Checkpoint`. Each function represents
  16. recomputation and recursive backpropagation, respectively. We can manipulate
  17. the control flow in aspect of both the autograd engine and CUDA with a pair of
  18. the functions.
  19. Specifically, we place CUDA stream synchronization between :class:`Recompute`
  20. and :class:`Checkpoint` to delay only :class:`Checkpoint` until the gradient is
  21. copied entirely.
  22. """
  23. from collections import deque
  24. from contextlib import contextmanager
  25. import threading
  26. from typing import (
  27. Any,
  28. Deque,
  29. Generator,
  30. List,
  31. Optional,
  32. Protocol,
  33. Union,
  34. Sequence,
  35. Tuple
  36. )
  37. import torch
  38. from torch import Tensor
  39. import torch.autograd
  40. from .dependency import fork, join
  41. from .microbatch import Batch
  42. from .phony import get_phony
  43. __all__ = ["Function", "checkpoint", "Checkpointing", "ThreadLocal", "enable_checkpointing",
  44. "enable_recomputing", "is_checkpointing", "is_recomputing", "Context", "save_rng_states",
  45. "restore_rng_states", "Checkpoint", "Recompute"]
  46. Tensors = Sequence[Tensor]
  47. TensorOrTensors = Union[Tensor, Tensors]
  48. # Types for shared memory between Checkpoint and Recompute.
  49. Recomputed = Tuple[TensorOrTensors, Tensors] # (output, input_leaf)
  50. RNGStates = Tuple[Tensor, Optional[Tensor]] # (cpu_rng_state, gpu_rng_state)
  51. # Protocol with __call__ instead of Callable can be used as an attribute type.
  52. # See: https://github.com/python/mypy/issues/708#issuecomment-561735949
  53. class Function(Protocol):
  54. def __call__(self, input: TensorOrTensors) -> TensorOrTensors:
  55. ...
  56. def checkpoint(function: Function, input):
  57. """Makes a checkpoint with a simple interface like
  58. :func:`torch.utils.checkpoint.checkpoint`. It's only used to test or debug
  59. :class:`Checkpoint` and :class:`Recompute` without boilerplate.
  60. """
  61. batch = Batch(input)
  62. chk = Checkpointing(function, batch)
  63. batch = chk.checkpoint()
  64. chk.recompute(batch)
  65. return batch.values
  66. class Checkpointing:
  67. """Generates a pair of :class:`Checkpoint` and :class:`Recompute`."""
  68. def __init__(self, function: Function, batch: Batch) -> None:
  69. self.function = function
  70. self.batch = batch
  71. # Shared memory between Checkpoint and Recompute. 1-length deque is
  72. # used for mutability and length limitation.
  73. self.recomputed: Deque[Recomputed] = deque(maxlen=1)
  74. self.rng_states: Deque[RNGStates] = deque(maxlen=1)
  75. def checkpoint(self) -> Batch:
  76. """Returns a batch applied by :class:`Checkpoint`."""
  77. input_atomic = self.batch.atomic
  78. inputs = tuple(self.batch)
  79. # Use a phony which requires grad to ensure that Checkpoint can be
  80. # tracked by the autograd engine even when none of the input tensors
  81. # require grad.
  82. phony = get_phony(self.batch.get_device(), requires_grad=True)
  83. output = Checkpoint.apply(phony, self.recomputed, self.rng_states, self.function, input_atomic, *inputs)
  84. # Gradients are only supported for float Tensors.
  85. if isinstance(output, tuple):
  86. output = tuple([x.detach() if torch.is_tensor(x) and not x.is_floating_point() else x for x in output])
  87. return Batch(output)
  88. def recompute(self, batch: Batch) -> None:
  89. """Applies :class:`Recompute` to the batch in place."""
  90. input_atomic = self.batch.atomic
  91. inputs = tuple(self.batch)
  92. # Use a tensor in the batch to tie together fork-join
  93. tensor_idx = batch.find_tensor_idx()
  94. # batch[tensor_idx] is always requiring grad, because it has been passed
  95. # checkpoint with a phony requiring grad.
  96. batch[tensor_idx], phony = fork(batch[tensor_idx])
  97. phony = Recompute.apply(phony, self.recomputed, self.rng_states, self.function, input_atomic, *inputs)
  98. batch[tensor_idx] = join(batch[tensor_idx], phony)
  99. class ThreadLocal(threading.local):
  100. def __init__(self) -> None:
  101. self.is_checkpointing = False
  102. self.is_recomputing = False
  103. thread_local = ThreadLocal()
  104. @contextmanager
  105. def enable_checkpointing() -> Generator[None, None, None]:
  106. """Makes :func:`is_checkpointing` return :data:`True` within a context."""
  107. orig = thread_local.is_checkpointing
  108. thread_local.is_checkpointing = True
  109. try:
  110. yield
  111. finally:
  112. thread_local.is_checkpointing = orig
  113. @contextmanager
  114. def enable_recomputing() -> Generator[None, None, None]:
  115. """Makes :func:`is_recomputing` return :data:`True` within a context."""
  116. orig = thread_local.is_recomputing
  117. thread_local.is_recomputing = True
  118. try:
  119. yield
  120. finally:
  121. thread_local.is_recomputing = orig
  122. def is_checkpointing() -> bool:
  123. """Whether the current forward propagation is under checkpointing.
  124. Returns:
  125. bool: :data:`True` if it's under checkpointing.
  126. """
  127. return thread_local.is_checkpointing
  128. def is_recomputing() -> bool:
  129. """Whether the current forward propagation is under checkpoint
  130. recomputation. Use this to prevent duplicated side-effects at forward
  131. propagation::
  132. class Counter(nn.Module):
  133. def __init__(self):
  134. super().__init__()
  135. self.counter = 0
  136. def forward(self, input):
  137. if not is_recomputing():
  138. self.counter += 1
  139. return input
  140. Returns:
  141. bool: :data:`True` if it's under checkpoint recomputation.
  142. .. seealso:: :ref:`Detecting Recomputation`
  143. """
  144. return thread_local.is_recomputing
  145. class Context:
  146. """The common interface between the :class:`Checkpoint` and
  147. :class:`Recompute` context.
  148. """
  149. recomputed: Deque[Recomputed]
  150. rng_states: Deque[RNGStates]
  151. function: Function
  152. input_atomic: bool
  153. inputs: Sequence[Any]
  154. saved_tensors: Tuple[Tensor, ...]
  155. def save_for_backward(self, *tensors: Tensor) -> None: # pragma: no cover
  156. pass
  157. def save_rng_states(device: torch.device, rng_states: Deque[RNGStates],) -> None:
  158. """:meth:`Checkpoint.forward` captures the current PyTorch's random number
  159. generator states at CPU and GPU to reuse in :meth:`Recompute.backward`.
  160. .. seealso:: :ref:`Referential Transparency`
  161. """
  162. cpu_rng_state = torch.get_rng_state()
  163. gpu_rng_state: Optional[Tensor]
  164. if device.type == "cuda":
  165. gpu_rng_state = torch.cuda.get_rng_state(device)
  166. else:
  167. gpu_rng_state = None
  168. rng_states.append((cpu_rng_state, gpu_rng_state))
  169. @contextmanager
  170. def restore_rng_states(device: torch.device, rng_states: Deque[RNGStates],) -> Generator[None, None, None]:
  171. """:meth:`Recompute.backward` restores the random number generator states
  172. captured by :func:`save_rng_states` within its context.
  173. .. seealso:: :ref:`Referential Transparency`
  174. """
  175. cpu_rng_state, gpu_rng_state = rng_states.pop()
  176. gpu_devices: List[torch.device] = []
  177. if device.type == "cuda":
  178. gpu_devices.append(device)
  179. with torch.random.fork_rng(gpu_devices):
  180. torch.set_rng_state(cpu_rng_state)
  181. if gpu_rng_state is not None:
  182. torch.cuda.set_rng_state(gpu_rng_state, device)
  183. yield
  184. class Checkpoint(torch.autograd.Function):
  185. @staticmethod
  186. # type: ignore[override]
  187. def forward(
  188. ctx: Context,
  189. phony: Tensor,
  190. recomputed: Deque[Recomputed],
  191. rng_states: Deque[RNGStates],
  192. function: Function,
  193. input_atomic: bool,
  194. *inputs,
  195. ):
  196. ctx.recomputed = recomputed
  197. ctx.rng_states = rng_states
  198. save_rng_states(phony.device, ctx.rng_states)
  199. ctx.function = function
  200. ctx.input_atomic = input_atomic
  201. if input_atomic:
  202. tensors = [inputs[0]]
  203. else:
  204. tensors = []
  205. for input in inputs:
  206. if torch.is_tensor(input):
  207. tensors.append(input)
  208. ctx.save_for_backward(*tensors)
  209. with torch.no_grad(), enable_checkpointing():
  210. if input_atomic:
  211. assert len(inputs) == 1
  212. output = function(inputs[0])
  213. else:
  214. output = function(*inputs)
  215. return output
  216. @staticmethod
  217. def backward(ctx: Context, *grad_output: Tensor,) -> Tuple[Optional[Tensor], ...]: # pragma: no cover
  218. output, input_leaf = ctx.recomputed.pop()
  219. if isinstance(output, tuple):
  220. outputs = output
  221. else:
  222. outputs = (output,)
  223. if any(torch.is_tensor(y) and y.requires_grad for y in outputs):
  224. tensors = tuple([x for x in outputs if torch.is_tensor(x) and x.requires_grad])
  225. torch.autograd.backward(tensors, grad_output)
  226. grad_input: List[Optional[Tensor]] = [None, None, None, None, None]
  227. grad_input.extend(x.grad if torch.is_tensor(x) else None for x in input_leaf)
  228. return tuple(grad_input)
  229. class Recompute(torch.autograd.Function):
  230. @staticmethod
  231. # type: ignore[override]
  232. def forward(
  233. ctx: Context,
  234. phony: Tensor,
  235. recomputed: Deque[Recomputed],
  236. rng_states: Deque[RNGStates],
  237. function: Function,
  238. input_atomic: bool,
  239. *inputs,
  240. ) -> Tensor:
  241. ctx.recomputed = recomputed
  242. ctx.rng_states = rng_states
  243. ctx.function = function
  244. ctx.input_atomic = input_atomic
  245. ctx.inputs = inputs
  246. if input_atomic:
  247. tensors = [inputs[0]]
  248. else:
  249. tensors = []
  250. for input in inputs:
  251. if torch.is_tensor(input):
  252. tensors.append(input)
  253. ctx.save_for_backward(*tensors)
  254. return phony
  255. @staticmethod
  256. def backward(ctx: Context, *grad_output: Tensor) -> Tuple[None, ...]: # pragma: no cover
  257. inputs = ctx.inputs
  258. inputs_leaf = tuple(x.detach().requires_grad_(x.requires_grad) if torch.is_tensor(x) else x for x in inputs)
  259. # Get the device for the inputs from a tensor
  260. device = None
  261. for input in inputs:
  262. if torch.is_tensor(input):
  263. device = input.device
  264. break
  265. if device is None:
  266. raise RuntimeError(f'No tensors found in {inputs}')
  267. with restore_rng_states(device, ctx.rng_states):
  268. with torch.enable_grad(), enable_recomputing():
  269. if ctx.input_atomic:
  270. assert len(inputs_leaf) == 1
  271. output = ctx.function(inputs_leaf[0])
  272. else:
  273. output = ctx.function(*inputs_leaf)
  274. ctx.recomputed.append((output, inputs_leaf))
  275. grad_input: List[None] = [None, None, None, None, None]
  276. grad_input.extend(None for _ in ctx.inputs)
  277. return tuple(grad_input)