graphs.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438
  1. import gc
  2. import torch
  3. from ._utils import _dummy_type
  4. from torch.utils._pytree import tree_flatten as _tree_flatten
  5. from torch.utils._pytree import tree_unflatten as _tree_unflatten
  6. if not hasattr(torch._C, '_CudaStreamBase'):
  7. # Define dummy base classes
  8. torch._C.__dict__['_CUDAGraph'] = _dummy_type('_CUDAGraph')
  9. torch._C.__dict__['_graph_pool_handle'] = _dummy_type('_graph_pool_handle')
  10. torch._C.__dict__['_cuda_isCurrentStreamCapturing'] = _dummy_type('_cuda_isCurrentStreamCapturing')
  11. from torch._C import _CUDAGraph # noqa: F401
  12. from torch._C import _graph_pool_handle
  13. from torch._C import _cuda_isCurrentStreamCapturing
  14. def is_current_stream_capturing():
  15. r"""
  16. Returns True if CUDA graph capture is underway on the current CUDA stream, False otherwise.
  17. If a CUDA context does not exist on the current device, returns False without initializing the context.
  18. """
  19. return _cuda_isCurrentStreamCapturing()
  20. # Python shim helps Sphinx process docstrings more reliably.
  21. def graph_pool_handle():
  22. r"""
  23. Returns an opaque token representing the id of a graph memory pool.
  24. See :ref:`Graph memory management<graph-memory-management>`.
  25. .. warning::
  26. This API is in beta and may change in future releases.
  27. """
  28. return _graph_pool_handle()
  29. # Python shim helps Sphinx process docstrings more reliably.
  30. class CUDAGraph(torch._C._CUDAGraph):
  31. r"""
  32. Wrapper around a CUDA graph.
  33. .. warning::
  34. This API is in beta and may change in future releases.
  35. """
  36. def __new__(cls):
  37. return super(CUDAGraph, cls).__new__(cls)
  38. def capture_begin(self, pool=None):
  39. r"""
  40. Begins capturing CUDA work on the current stream.
  41. Typically, you shouldn't call ``capture_begin`` yourself.
  42. Use :class:`~torch.cuda.graph` or :func:`~torch.cuda.make_graphed_callables`,
  43. which call ``capture_begin`` internally.
  44. Arguments:
  45. pool (optional): Token (returned by :func:`~torch.cuda.graph_pool_handle` or
  46. :meth:`other_Graph_instance.pool()<torch.cuda.CUDAGraph.pool>`) that hints this graph may share memory
  47. with the indicated pool. See :ref:`Graph memory management<graph-memory-management>`.
  48. """
  49. # I'm not sure if pybind11 converts a None arg to the default defined on the C++ side,
  50. # so I'm not taking any chances.
  51. if pool is None:
  52. super().capture_begin()
  53. else:
  54. super().capture_begin(pool)
  55. def capture_end(self):
  56. r"""
  57. Ends CUDA graph capture on the current stream.
  58. After ``capture_end``, ``replay`` may be called on this instance.
  59. Typically, you shouldn't call ``capture_end`` yourself.
  60. Use :class:`~torch.cuda.graph` or :func:`~torch.cuda.make_graphed_callables`,
  61. which call ``capture_end`` internally.
  62. """
  63. super().capture_end()
  64. def replay(self):
  65. r"""
  66. Replays the CUDA work captured by this graph.
  67. """
  68. super().replay()
  69. def reset(self):
  70. r"""
  71. Deletes the graph currently held by this instance.
  72. """
  73. super().reset()
  74. def pool(self):
  75. r"""
  76. Returns an opaque token representing the id of this graph's memory pool.
  77. This id can optionally be passed to another graph's ``capture_begin``,
  78. which hints the other graph may share the same memory pool.
  79. """
  80. return super().pool()
  81. def enable_debug_mode(self):
  82. r"""
  83. Enables debugging mode for CUDAGraph.debug_dump.
  84. """
  85. return super().enable_debug_mode()
  86. def debug_dump(self, debug_path):
  87. r"""
  88. Arguments:
  89. debug_path (required): Path to dump the graph to.
  90. Calls a debugging function to dump the graph if the debugging is
  91. enabled via CUDAGraph.enable_debug_mode()
  92. """
  93. return super().debug_dump(debug_path)
  94. class graph:
  95. r"""
  96. Context-manager that captures CUDA work into a :class:`torch.cuda.CUDAGraph`
  97. object for later replay.
  98. See :ref:`CUDA Graphs <cuda-graph-semantics>` for a general introduction,
  99. detailed use, and constraints.
  100. Arguments:
  101. cuda_graph (torch.cuda.CUDAGraph): Graph object used for capture.
  102. pool (optional): Opaque token (returned by a call to :func:`~torch.cuda.graph_pool_handle()` or
  103. :meth:`other_Graph_instance.pool()<torch.cuda.CUDAGraph.pool>`) hinting this graph's capture
  104. may share memory from the specified pool. See :ref:`Graph memory management<graph-memory-management>`.
  105. stream (torch.cuda.Stream, optional): If supplied, will be set as the current stream in the context.
  106. If not supplied, ``graph`` sets its own internal side stream as the current stream in the context.
  107. .. note::
  108. For effective memory sharing, if you pass a ``pool`` used by a previous capture and the previous capture
  109. used an explicit ``stream`` argument, you should pass the same ``stream`` argument to this capture.
  110. .. warning::
  111. This API is in beta and may change in future releases.
  112. """
  113. default_capture_stream = None
  114. def __init__(self,
  115. cuda_graph,
  116. pool=None,
  117. stream=None):
  118. # Lazy-init of default_capture_stream helps avoid circular-import errors.
  119. # Not thread safe, but graphs already have the general (explicitly documented)
  120. # restriction that only one capture may be underway at a time in the process.
  121. if self.__class__.default_capture_stream is None:
  122. self.__class__.default_capture_stream = torch.cuda.Stream()
  123. self.pool = () if pool is None else (pool,)
  124. self.capture_stream = stream if stream is not None else self.__class__.default_capture_stream
  125. assert self.capture_stream is not None
  126. self.stream_ctx = torch.cuda.stream(self.capture_stream)
  127. self.cuda_graph = cuda_graph
  128. def __enter__(self):
  129. # Free as much memory as we can for the graph
  130. torch.cuda.synchronize()
  131. gc.collect()
  132. torch.cuda.empty_cache()
  133. # Stackoverflow seems comfortable with this pattern
  134. # https://stackoverflow.com/questions/26635684/calling-enter-and-exit-manually#39172487
  135. self.stream_ctx.__enter__()
  136. self.cuda_graph.capture_begin(*self.pool)
  137. def __exit__(self, exc_type, exc_value, traceback):
  138. self.cuda_graph.capture_end()
  139. self.stream_ctx.__exit__(exc_type, exc_value, traceback)
  140. # returning None should propagate exceptions from either capture_end or stream_ctx.__exit__()
  141. def make_graphed_callables(callables, sample_args, num_warmup_iters=3, allow_unused_input=False):
  142. r"""
  143. Accepts callables (functions or :class:`nn.Module<torch.nn.Module>`\ s)
  144. and returns graphed versions.
  145. Each graphed callable's forward pass runs its source callable's
  146. forward CUDA work as a CUDA graph inside a single autograd node.
  147. The graphed callable's forward pass also appends
  148. a backward node to the autograd graph. During backward, this node runs the
  149. callable's backward work as a CUDA graph.
  150. Therefore, each graphed callable should be a drop-in replacement for its source callable
  151. in an autograd-enabled training loop.
  152. See :ref:`Partial-network capture<partial-network-capture>` for detailed use and constraints.
  153. If you pass a tuple of several callables, their captures will use the same memory pool.
  154. See :ref:`Graph memory management<graph-memory-management>` for when this is appropriate.
  155. Arguments:
  156. callables (torch.nn.Module or Python function, or tuple of these): Callable or callables to graph.
  157. See :ref:`Graph memory management<graph-memory-management>` for when passing a tuple of callables
  158. is appropriate. If you pass a tuple of callables, their order in the tuple must be the same order
  159. they'll run in the live workload.
  160. sample_args (tuple of Tensors, or tuple of tuples of Tensors): Samples args for each callable.
  161. If a single callable was passed, ``sample_args`` must be a single tuple of argument Tensors.
  162. If a tuple of callables was passed, ``sample_args`` must be tuple of tuples of argument Tensors.
  163. num_warmup_iters (int): The number of warmup iterations. Currently, ``DataDistributedParallel`` needs
  164. 11 iterations for warm up. Default: ``3``.
  165. allow_unused_input (bool): If False, specifying inputs that were not used when computing outputs
  166. (and therefore their grad is always zero) is an error. Defaults to False.
  167. .. note::
  168. The ``requires_grad`` state of each Tensor in ``sample_args`` must match the state
  169. that's expected for the corresponding real input in the training loop.
  170. .. warning::
  171. This API is in beta and may change in future releases.
  172. .. warning::
  173. ``sample_args`` for each callable must contain only Tensors. Other types are not allowed.
  174. .. warning::
  175. Returned callables do not support higher order differentiation (e.g., double backward).
  176. .. warning::
  177. In any :class:`~torch.nn.Module` passed to :func:`~make_graphed_callables`, only parameters
  178. may be trainable. Buffers must have ``requires_grad=False``.
  179. .. warning::
  180. After you pass a :class:`torch.nn.Module` through :func:`~make_graphed_callables`,
  181. you may not add or remove any of that Module's parameters or buffers.
  182. .. warning::
  183. :class:`torch.nn.Module`\s passed to :func:`~torch.cuda.make_graphed_callables` must not have module hooks
  184. registered on them at the time they are passed. However, registering hooks on modules *after* passing them
  185. through :func:`~torch.cuda.make_graphed_callables` is allowed.
  186. .. warning::
  187. When running a graphed callable, you must pass its arguments in the same order and format
  188. they appeared in that callable's ``sample_args``.
  189. .. warning::
  190. The automatic mixed precision is supported in :func:`~torch.cuda.make_graphed_callables` only with disabled
  191. caching. The context manager `torch.cuda.amp.autocast()` must have `cache_enabled=False`.
  192. """
  193. if torch.is_autocast_enabled() and torch.is_autocast_cache_enabled():
  194. raise RuntimeError("make_graphed_callables does not support the autocast caching. Please set `cache_enabled=False`.")
  195. just_one_callable = False
  196. if not isinstance(callables, tuple):
  197. just_one_callable = True
  198. callables = (callables,)
  199. sample_args = (sample_args,)
  200. flatten_sample_args = []
  201. for c, args in zip(callables, sample_args):
  202. if isinstance(c, torch.nn.Module):
  203. assert len(c._backward_hooks) == 0 and len(c._forward_hooks) == 0 and len(c._forward_pre_hooks) == 0, \
  204. "Modules must not have hooks registered at the time they are passed. However, registering hooks " + \
  205. "on modules after passing them through make_graphed_callables is allowed."
  206. assert all(b.requires_grad is False for b in c.buffers()), "In any :class:`~torch.nn.Module` passed to " + \
  207. ":func:`~make_graphed_callables`, only parameters may be trainable. All buffers must have " + \
  208. "``requires_grad=False``."
  209. flatten_arg, _ = _tree_flatten(args)
  210. flatten_sample_args.append(tuple(flatten_arg))
  211. assert all(isinstance(arg, torch.Tensor) for arg in flatten_arg), "In the beta API, sample_args " + \
  212. "for each callable must contain only Tensors. Other types are not allowed."
  213. # If a callable is an nn.Module, its graph's full input surface is the args the user explicitly
  214. # passes to forward (ie, its sample_args) AND the module's parameter attributes.
  215. per_callable_len_user_args = [len(args) for args in flatten_sample_args]
  216. per_callable_module_params = [tuple(c.parameters()) if isinstance(c, torch.nn.Module) else ()
  217. for c in callables]
  218. per_callable_static_input_surfaces = [flatten_sample_args[i] + per_callable_module_params[i]
  219. for i in range(len(callables))]
  220. fwd_graphs = [torch.cuda.CUDAGraph() for _ in range(len(callables))]
  221. bwd_graphs = [torch.cuda.CUDAGraph() for _ in range(len(callables))]
  222. mempool = graph_pool_handle()
  223. # Warmup
  224. # Hopefully prevents cudnn benchmarking and other lazy-initialization cuda work
  225. # from ending up in any captures.
  226. torch.cuda.synchronize()
  227. with torch.cuda.stream(torch.cuda.Stream()):
  228. for func, args, static_input_surface in zip(callables,
  229. sample_args,
  230. per_callable_static_input_surfaces):
  231. for _ in range(num_warmup_iters):
  232. outputs, _ = _tree_flatten(func(*args))
  233. grad_inputs = torch.autograd.grad(outputs=tuple(o for o in outputs if o.requires_grad),
  234. inputs=tuple(i for i in static_input_surface if i.requires_grad),
  235. grad_outputs=tuple(torch.empty_like(o) for o in outputs if o.requires_grad),
  236. only_inputs=True,
  237. allow_unused=allow_unused_input)
  238. del outputs, grad_inputs
  239. torch.cuda.synchronize()
  240. # All captures here share a mempool. To avoid replays corrupting each other's memory,
  241. # the safest approach is to capture all passes in the same order they'll run:
  242. # fwd 1, fwd 2, ... fwd N, then bwd N, bwd N-1, ... bwd 1.
  243. # Capture forward graphs
  244. per_callable_static_outputs = []
  245. per_callable_output_unflatten_spec = []
  246. for func, args, fwd_graph in zip(callables,
  247. sample_args,
  248. fwd_graphs):
  249. with torch.cuda.graph(fwd_graph, pool=mempool):
  250. outputs = func(*args)
  251. flatten_outputs, spec = _tree_flatten(outputs)
  252. per_callable_static_outputs.append(tuple(flatten_outputs))
  253. per_callable_output_unflatten_spec.append(spec)
  254. # Capture backward graphs in reverse order
  255. per_callable_static_grad_outputs = []
  256. per_callable_static_grad_inputs = []
  257. for static_input_surface, static_outputs, bwd_graph, module_params in \
  258. zip(reversed(per_callable_static_input_surfaces),
  259. reversed(per_callable_static_outputs),
  260. reversed(bwd_graphs),
  261. reversed(per_callable_module_params)):
  262. # For now, assumes all static_outputs require grad
  263. # assert all(o.requires_grad for o in static_outputs), "Outputs of graphed callables must require grad."
  264. static_grad_outputs = tuple(torch.empty_like(o) if o.requires_grad else None for o in static_outputs)
  265. with torch.cuda.graph(bwd_graph, pool=mempool):
  266. grad_inputs = torch.autograd.grad(outputs=tuple(o for o in static_outputs if o.requires_grad),
  267. inputs=tuple(i for i in static_input_surface if i.requires_grad),
  268. grad_outputs=tuple(o for o in static_grad_outputs if o is not None),
  269. only_inputs=True,
  270. allow_unused=allow_unused_input)
  271. # Constructs a tuple suitable for returning from Graphed.backward:
  272. # Pads out the actually-needed grads with Nones in gradient slots for inputs that don't require grad.
  273. # I couldn't think of a slick one-liner for this pattern.
  274. static_grad_inputs = []
  275. grad_idx = 0
  276. for arg in static_input_surface:
  277. if arg.requires_grad:
  278. static_grad_inputs.append(grad_inputs[grad_idx])
  279. grad_idx += 1
  280. else:
  281. static_grad_inputs.append(None) # type: ignore[arg-type]
  282. static_grad_inputs = tuple(static_grad_inputs) # type: ignore[assignment]
  283. per_callable_static_grad_outputs.append(static_grad_outputs)
  284. per_callable_static_grad_inputs.append(static_grad_inputs)
  285. # Reverses the most recent two lists
  286. per_callable_static_grad_outputs = list(reversed(per_callable_static_grad_outputs))
  287. per_callable_static_grad_inputs = list(reversed(per_callable_static_grad_inputs))
  288. # Now for every per_callable list, per_callable_*[i] holds the stuff for the ith callable.
  289. def make_graphed_autograd_function(fwd_graph,
  290. bwd_graph,
  291. module_params,
  292. len_user_args,
  293. output_unflatten_spec,
  294. static_input_surface,
  295. static_outputs,
  296. static_grad_outputs,
  297. static_grad_inputs):
  298. class Graphed(torch.autograd.Function):
  299. @staticmethod
  300. def forward(ctx, *inputs):
  301. # At this stage, only the user args may (potentially) be new tensors.
  302. for i in range(len_user_args):
  303. if static_input_surface[i].data_ptr() != inputs[i].data_ptr():
  304. static_input_surface[i].copy_(inputs[i])
  305. fwd_graph.replay()
  306. assert isinstance(static_outputs, tuple)
  307. return tuple(o.detach() for o in static_outputs)
  308. @staticmethod
  309. @torch.autograd.function.once_differentiable
  310. def backward(ctx, *grads):
  311. assert len(grads) == len(static_grad_outputs)
  312. for g, grad in zip(static_grad_outputs, grads):
  313. if g is not None:
  314. # don't copy if autograd gods have been kind and the
  315. # incoming grad is already in the right place
  316. if g.data_ptr() != grad.data_ptr():
  317. g.copy_(grad)
  318. bwd_graph.replay()
  319. # Input args that didn't require grad expect a None gradient.
  320. assert isinstance(static_grad_inputs, tuple)
  321. return tuple(b.detach() if b is not None else b for b in static_grad_inputs)
  322. def functionalized(*user_args):
  323. # Runs the autograd function with inputs == all inputs to the graph that might require grad
  324. # (explicit user args + module parameters)
  325. # Assumes module params didn't change since capture.
  326. flatten_user_args, _ = _tree_flatten(user_args)
  327. out = Graphed.apply(*(tuple(flatten_user_args) + module_params))
  328. return _tree_unflatten(out, output_unflatten_spec)
  329. return functionalized
  330. # Put together the final graphed callables
  331. ret = []
  332. for i, func in enumerate(callables):
  333. graphed = make_graphed_autograd_function(fwd_graphs[i],
  334. bwd_graphs[i],
  335. per_callable_module_params[i],
  336. per_callable_len_user_args[i],
  337. per_callable_output_unflatten_spec[i],
  338. per_callable_static_input_surfaces[i],
  339. per_callable_static_outputs[i],
  340. per_callable_static_grad_outputs[i],
  341. per_callable_static_grad_inputs[i])
  342. if isinstance(func, torch.nn.Module):
  343. def make_graphed_forward(func, graph_training_state, graphed, orig_fwd):
  344. def new_fwd(*user_args):
  345. # If the module's training-or-eval state matches what we graphed,
  346. # run the graph, otherwise run the original forward method
  347. if func.training == graph_training_state:
  348. return graphed(*user_args)
  349. else:
  350. return orig_fwd(*user_args)
  351. return new_fwd
  352. func.forward = make_graphed_forward(func, func.training, graphed, func.forward) # type: ignore[assignment]
  353. ret.append(func)
  354. else:
  355. ret.append(graphed)
  356. if just_one_callable:
  357. return ret[0]
  358. return tuple(ret)