checkpoint.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441
  1. import torch
  2. import warnings
  3. import weakref
  4. from typing import Any, Iterable, List, Tuple
  5. __all__ = [
  6. "checkpoint", "checkpoint_sequential", "CheckpointFunction",
  7. "check_backward_validity", "detach_variable", "get_device_states",
  8. "set_device_states",
  9. ]
  10. def detach_variable(inputs: Tuple[Any, ...]) -> Tuple[torch.Tensor, ...]:
  11. if isinstance(inputs, tuple):
  12. out = []
  13. for inp in inputs:
  14. if not isinstance(inp, torch.Tensor):
  15. out.append(inp)
  16. continue
  17. x = inp.detach()
  18. x.requires_grad = inp.requires_grad
  19. out.append(x)
  20. return tuple(out)
  21. else:
  22. raise RuntimeError(
  23. "Only tuple of tensors is supported. Got Unsupported input type: ", type(inputs).__name__)
  24. def check_backward_validity(inputs: Iterable[Any]) -> None:
  25. if not any(inp.requires_grad for inp in inputs if isinstance(inp, torch.Tensor)):
  26. warnings.warn("None of the inputs have requires_grad=True. Gradients will be None")
  27. # We can't know if the run_fn will internally move some args to different devices,
  28. # which would require logic to preserve rng states for those devices as well.
  29. # We could paranoically stash and restore ALL the rng states for all visible devices,
  30. # but that seems very wasteful for most cases. Compromise: Stash the RNG state for
  31. # the device of all Tensor args.
  32. #
  33. # To consider: maybe get_device_states and set_device_states should reside in torch/random.py?
  34. def get_device_states(*args) -> Tuple[List[int], List[torch.Tensor]]:
  35. # This will not error out if "arg" is a CPU tensor or a non-tensor type because
  36. # the conditionals short-circuit.
  37. fwd_gpu_devices = list({arg.get_device() for arg in args
  38. if isinstance(arg, torch.Tensor) and arg.is_cuda})
  39. fwd_gpu_states = []
  40. for device in fwd_gpu_devices:
  41. with torch.cuda.device(device):
  42. fwd_gpu_states.append(torch.cuda.get_rng_state())
  43. return fwd_gpu_devices, fwd_gpu_states
  44. def set_device_states(devices, states) -> None:
  45. for device, state in zip(devices, states):
  46. with torch.cuda.device(device):
  47. torch.cuda.set_rng_state(state)
  48. def _get_autocast_kwargs():
  49. gpu_autocast_kwargs = {"enabled": torch.is_autocast_enabled(),
  50. "dtype": torch.get_autocast_gpu_dtype(),
  51. "cache_enabled": torch.is_autocast_cache_enabled()}
  52. cpu_autocast_kwargs = {"enabled": torch.is_autocast_cpu_enabled(),
  53. "dtype": torch.get_autocast_cpu_dtype(),
  54. "cache_enabled": torch.is_autocast_cache_enabled()}
  55. return gpu_autocast_kwargs, cpu_autocast_kwargs
  56. class CheckpointFunction(torch.autograd.Function):
  57. @staticmethod
  58. def forward(ctx, run_function, preserve_rng_state, *args):
  59. check_backward_validity(args)
  60. ctx.run_function = run_function
  61. ctx.preserve_rng_state = preserve_rng_state
  62. # Accommodates the (remote) possibility that autocast is enabled for cpu AND gpu.
  63. ctx.gpu_autocast_kwargs, ctx.cpu_autocast_kwargs = _get_autocast_kwargs()
  64. if preserve_rng_state:
  65. ctx.fwd_cpu_state = torch.get_rng_state()
  66. # Don't eagerly initialize the cuda context by accident.
  67. # (If the user intends that the context is initialized later, within their
  68. # run_function, we SHOULD actually stash the cuda state here. Unfortunately,
  69. # we have no way to anticipate this will happen before we run the function.)
  70. ctx.had_cuda_in_fwd = False
  71. if torch.cuda._initialized:
  72. ctx.had_cuda_in_fwd = True
  73. ctx.fwd_gpu_devices, ctx.fwd_gpu_states = get_device_states(*args)
  74. # Save non-tensor inputs in ctx, keep a placeholder None for tensors
  75. # to be filled out during the backward.
  76. ctx.inputs = []
  77. ctx.tensor_indices = []
  78. tensor_inputs = []
  79. for i, arg in enumerate(args):
  80. if torch.is_tensor(arg):
  81. tensor_inputs.append(arg)
  82. ctx.tensor_indices.append(i)
  83. ctx.inputs.append(None)
  84. else:
  85. ctx.inputs.append(arg)
  86. ctx.save_for_backward(*tensor_inputs)
  87. with torch.no_grad():
  88. outputs = run_function(*args)
  89. return outputs
  90. @staticmethod
  91. def backward(ctx, *args):
  92. if not torch.autograd._is_checkpoint_valid():
  93. raise RuntimeError(
  94. "Checkpointing is not compatible with .grad() or when an `inputs` parameter"
  95. " is passed to .backward(). Please use .backward() and do not pass its `inputs`"
  96. " argument.")
  97. # Copy the list to avoid modifying original list.
  98. inputs = list(ctx.inputs)
  99. tensor_indices = ctx.tensor_indices
  100. tensors = ctx.saved_tensors
  101. # Fill in inputs with appropriate saved tensors.
  102. for i, idx in enumerate(tensor_indices):
  103. inputs[idx] = tensors[i]
  104. # Stash the surrounding rng state, and mimic the state that was
  105. # present at this time during forward. Restore the surrounding state
  106. # when we're done.
  107. rng_devices = []
  108. if ctx.preserve_rng_state and ctx.had_cuda_in_fwd:
  109. rng_devices = ctx.fwd_gpu_devices
  110. with torch.random.fork_rng(devices=rng_devices, enabled=ctx.preserve_rng_state):
  111. if ctx.preserve_rng_state:
  112. torch.set_rng_state(ctx.fwd_cpu_state)
  113. if ctx.had_cuda_in_fwd:
  114. set_device_states(ctx.fwd_gpu_devices, ctx.fwd_gpu_states)
  115. detached_inputs = detach_variable(tuple(inputs))
  116. with torch.enable_grad(), \
  117. torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs), \
  118. torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):
  119. outputs = ctx.run_function(*detached_inputs)
  120. if isinstance(outputs, torch.Tensor):
  121. outputs = (outputs,)
  122. # run backward() with only tensor that requires grad
  123. outputs_with_grad = []
  124. args_with_grad = []
  125. for i in range(len(outputs)):
  126. if torch.is_tensor(outputs[i]) and outputs[i].requires_grad:
  127. outputs_with_grad.append(outputs[i])
  128. args_with_grad.append(args[i])
  129. if len(outputs_with_grad) == 0:
  130. raise RuntimeError(
  131. "none of output has requires_grad=True,"
  132. " this checkpoint() is not necessary")
  133. torch.autograd.backward(outputs_with_grad, args_with_grad)
  134. grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else None
  135. for inp in detached_inputs)
  136. return (None, None) + grads
  137. def checkpoint(function, *args, use_reentrant: bool = True, **kwargs):
  138. r"""Checkpoint a model or part of the model
  139. Checkpointing works by trading compute for memory. Rather than storing all
  140. intermediate activations of the entire computation graph for computing
  141. backward, the checkpointed part does **not** save intermediate activations,
  142. and instead recomputes them in backward pass. It can be applied on any part
  143. of a model.
  144. Specifically, in the forward pass, :attr:`function` will run in
  145. :func:`torch.no_grad` manner, i.e., not storing the intermediate
  146. activations. Instead, the forward pass saves the inputs tuple and the
  147. :attr:`function` parameter. In the backwards pass, the saved inputs and
  148. :attr:`function` is retrieved, and the forward pass is computed on
  149. :attr:`function` again, now tracking the intermediate activations, and then
  150. the gradients are calculated using these activation values.
  151. The output of :attr:`function` can contain non-Tensor values and gradient
  152. recording is only performed for the Tensor values. Note that if the output
  153. consists of nested structures (ex: custom objects, lists, dicts etc.)
  154. consisting of Tensors, these Tensors nested in custom structures will not
  155. be considered as part of autograd.
  156. .. warning::
  157. If :attr:`function` invocation during backward does anything different
  158. than the one during forward, e.g., due to some global variable, the
  159. checkpointed version won't be equivalent, and unfortunately it can't be
  160. detected.
  161. .. warning::
  162. If ``use_reentrant=True`` is specified, then if the checkpointed segment
  163. contains tensors detached from the computational graph by `detach()` or
  164. `torch.no_grad()`, the backward pass will raise an error. This is
  165. because `checkpoint` makes all the outputs require gradients which
  166. causes issues when a tensor is defined to have no gradient in the model.
  167. To circumvent this, detach the tensors outside of the `checkpoint`
  168. function. Note that the checkpointed segment can contain tensors
  169. detached from the computational graph if ``use_reentrant=False`` is
  170. specified.
  171. .. warning::
  172. If ``use_reentrant=True`` is specified, at least one of the inputs needs
  173. to have :code:`requires_grad=True` if grads are needed for model inputs,
  174. otherwise the checkpointed part of the model won't have gradients. At
  175. least one of the outputs needs to have :code:`requires_grad=True` as
  176. well. Note that this does not apply if ``use_reentrant=False`` is
  177. specified.
  178. .. warning::
  179. If ``use_reentrant=True`` is specified, checkpointing currently only
  180. supports :func:`torch.autograd.backward` and only if its `inputs`
  181. argument is not passed. :func:`torch.autograd.grad`
  182. is not supported. If ``use_reentrant=False`` is specified, checkpointing
  183. will work with :func:`torch.autograd.grad`.
  184. Args:
  185. function: describes what to run in the forward pass of the model or
  186. part of the model. It should also know how to handle the inputs
  187. passed as the tuple. For example, in LSTM, if user passes
  188. ``(activation, hidden)``, :attr:`function` should correctly use the
  189. first input as ``activation`` and the second input as ``hidden``
  190. preserve_rng_state(bool, optional): Omit stashing and restoring
  191. the RNG state during each checkpoint.
  192. Default: ``True``
  193. use_reentrant(bool, optional): Use checkpointing
  194. implementation that requires re-entrant autograd.
  195. If ``use_reentrant=False`` is specified, ``checkpoint`` will use an
  196. implementation that does not require re-entrant autograd. This
  197. allows ``checkpoint`` to support additional functionality, such as
  198. working as expected with ``torch.autograd.grad`` and support for
  199. keyword arguments input into the checkpointed function. Note that future
  200. versions of PyTorch will default to ``use_reentrant=False``.
  201. Default: ``True``
  202. args: tuple containing inputs to the :attr:`function`
  203. Returns:
  204. Output of running :attr:`function` on :attr:`*args`
  205. """
  206. # Hack to mix *args with **kwargs in a python 2.7-compliant way
  207. preserve = kwargs.pop('preserve_rng_state', True)
  208. if kwargs and use_reentrant:
  209. raise ValueError("Unexpected keyword arguments: " + ",".join(arg for arg in kwargs))
  210. if use_reentrant:
  211. return CheckpointFunction.apply(function, preserve, *args)
  212. else:
  213. return _checkpoint_without_reentrant(
  214. function,
  215. preserve,
  216. *args,
  217. **kwargs,
  218. )
  219. def checkpoint_sequential(functions, segments, input, use_reentrant=True, **kwargs):
  220. r"""A helper function for checkpointing sequential models.
  221. Sequential models execute a list of modules/functions in order
  222. (sequentially). Therefore, we can divide such a model in various segments
  223. and checkpoint each segment. All segments except the last will run in
  224. :func:`torch.no_grad` manner, i.e., not storing the intermediate
  225. activations. The inputs of each checkpointed segment will be saved for
  226. re-running the segment in the backward pass.
  227. See :func:`~torch.utils.checkpoint.checkpoint` on how checkpointing works.
  228. .. warning::
  229. Checkpointing currently only supports :func:`torch.autograd.backward`
  230. and only if its `inputs` argument is not passed. :func:`torch.autograd.grad`
  231. is not supported.
  232. .. warning:
  233. At least one of the inputs needs to have :code:`requires_grad=True` if
  234. grads are needed for model inputs, otherwise the checkpointed part of the
  235. model won't have gradients.
  236. .. warning:
  237. Since PyTorch 1.4, it allows only one Tensor as the input and
  238. intermediate outputs, just like :class:`torch.nn.Sequential`.
  239. Args:
  240. functions: A :class:`torch.nn.Sequential` or the list of modules or
  241. functions (comprising the model) to run sequentially.
  242. segments: Number of chunks to create in the model
  243. input: A Tensor that is input to :attr:`functions`
  244. preserve_rng_state(bool, optional): Omit stashing and restoring
  245. the RNG state during each checkpoint.
  246. Default: ``True``
  247. use_reentrant(bool, optional): Use checkpointing
  248. implementation that requires re-entrant autograd.
  249. If ``use_reentrant=False`` is specified, ``checkpoint`` will use an
  250. implementation that does not require re-entrant autograd. This
  251. allows ``checkpoint`` to support additional functionality, such as
  252. working as expected with ``torch.autograd.grad`` and support for
  253. keyword arguments input into the checkpointed function.
  254. Default: ``True``
  255. Returns:
  256. Output of running :attr:`functions` sequentially on :attr:`*inputs`
  257. Example:
  258. >>> # xdoctest: +SKIP("stub")
  259. >>> model = nn.Sequential(...)
  260. >>> input_var = checkpoint_sequential(model, chunks, input_var)
  261. """
  262. # Hack for keyword-only parameter in a python 2.7-compliant way
  263. preserve = kwargs.pop('preserve_rng_state', True)
  264. if kwargs:
  265. raise ValueError("Unexpected keyword arguments: " + ",".join(arg for arg in kwargs))
  266. def run_function(start, end, functions):
  267. def forward(input):
  268. for j in range(start, end + 1):
  269. input = functions[j](input)
  270. return input
  271. return forward
  272. if isinstance(functions, torch.nn.Sequential):
  273. functions = list(functions.children())
  274. segment_size = len(functions) // segments
  275. # the last chunk has to be non-volatile
  276. end = -1
  277. for start in range(0, segment_size * (segments - 1), segment_size):
  278. end = start + segment_size - 1
  279. input = checkpoint(
  280. run_function(start, end, functions),
  281. input,
  282. use_reentrant=use_reentrant,
  283. preserve_rng_state=preserve
  284. )
  285. return run_function(end + 1, len(functions) - 1, functions)(input)
  286. def _checkpoint_without_reentrant(function, preserve_rng_state=True, *args, **kwargs):
  287. """Checkpointining without re-entrant autograd
  288. Args:
  289. function: describes what to run in the forward pass of the model or
  290. part of the model. It should also know how to handle the inputs
  291. passed as the tuple. For example, in LSTM, if user passes
  292. ``(activation, hidden)``, :attr:`function` should correctly use the
  293. first input as ``activation`` and the second input as ``hidden``
  294. preserve_rng_state(bool, optional): Omit stashing and restoring
  295. the RNG state during each checkpoint.
  296. Default: ``True``
  297. *args: Arguments to pass in to the given ``function``.
  298. **kwargs: Keyword arguments to pass into the given ``function``.
  299. """
  300. # Accommodates the (remote) possibility that autocast is enabled for cpu AND gpu.
  301. gpu_autocast_kwargs, cpu_autocast_kwargs = _get_autocast_kwargs()
  302. if preserve_rng_state:
  303. fwd_cpu_state = torch.get_rng_state()
  304. # Don't eagerly initialize the cuda context by accident.
  305. # (If the user intends that the context is initialized later, within their
  306. # run_function, we SHOULD actually stash the cuda state here. Unfortunately,
  307. # we have no way to anticipate this will happen before we run the function.
  308. # If they do so, we raise an error.)
  309. had_cuda_in_fwd = False
  310. if torch.cuda._initialized:
  311. had_cuda_in_fwd = True
  312. fwd_gpu_devices, fwd_gpu_states = get_device_states(*args)
  313. # Custom class to be able to take weak references
  314. class Holder():
  315. pass
  316. # The Holder object for each of the saved object is saved directly on the
  317. # SavedVariable and is cleared when reset_data() is called on it. We MUST make
  318. # sure that this is the only object having an owning reference to ensure that
  319. # the Tensor stored in storage is deleted as soon as the corresponding SavedVariable
  320. # data is cleared.
  321. storage: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary()
  322. weak_holder_list = []
  323. def pack(x):
  324. # TODO(varal7): Instead of returning abstract object, we can return things metadata (such as
  325. # size, device, ...) to catch certain cases of undeterministic behavior of the forward
  326. res = Holder()
  327. weak_holder_list.append(weakref.ref(res))
  328. return res
  329. def unpack(x):
  330. unpack_counter = 0
  331. if len(storage) == 0:
  332. def inner_pack(inner):
  333. nonlocal unpack_counter
  334. unpack_counter += 1
  335. # If the holder went out of scope, the SavedVariable is dead and so
  336. # the value will never be read from the storage. Skip filling it.
  337. if weak_holder_list[unpack_counter - 1]() is None:
  338. return
  339. # Use detach here to ensure we don't keep the temporary autograd
  340. # graph created during the second forward
  341. storage[weak_holder_list[unpack_counter - 1]()] = inner.detach()
  342. return
  343. def inner_unpack(packed):
  344. raise RuntimeError("You are calling backwards on a tensor that is never exposed. Please open an issue.")
  345. # Stash the surrounding rng state, and mimic the state that was
  346. # present at this time during forward. Restore the surrounding state
  347. # when we're done.
  348. rng_devices = []
  349. if preserve_rng_state and had_cuda_in_fwd:
  350. rng_devices = fwd_gpu_devices
  351. with torch.random.fork_rng(devices=rng_devices, enabled=preserve_rng_state):
  352. if preserve_rng_state:
  353. torch.set_rng_state(fwd_cpu_state)
  354. if had_cuda_in_fwd:
  355. set_device_states(fwd_gpu_devices, fwd_gpu_states)
  356. with torch.enable_grad(), \
  357. torch.cuda.amp.autocast(**gpu_autocast_kwargs), \
  358. torch.cpu.amp.autocast(**cpu_autocast_kwargs), \
  359. torch.autograd.graph.saved_tensors_hooks(inner_pack, inner_unpack):
  360. _unused = function(*args, **kwargs)
  361. if x not in storage:
  362. raise RuntimeError(
  363. "Attempt to retrieve a tensor saved by autograd multiple times without checkpoint"
  364. " recomputation being triggered in between, this is not currently supported. Please"
  365. " open an issue with details on your use case so that we can prioritize adding this."
  366. )
  367. return storage[x]
  368. with torch.autograd.graph.saved_tensors_hooks(pack, unpack):
  369. output = function(*args, **kwargs)
  370. if torch.cuda._initialized and preserve_rng_state and not had_cuda_in_fwd:
  371. # Cuda was not initialized before running the forward, so we didn't
  372. # stash the CUDA state.
  373. raise RuntimeError(
  374. "PyTorch's CUDA state was initialized in the forward pass "
  375. "of a Checkpoint, which is not allowed. Please open an issue "
  376. "if you need this feature.")
  377. return output