cudagraphs.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180
  1. import logging
  2. import operator
  3. from collections import defaultdict
  4. from typing import Set
  5. import torch
  6. from torch.fx import GraphModule
  7. from torch.fx.passes.backends.cudagraphs import partition_cudagraphs
  8. from torch.multiprocessing.reductions import StorageWeakRef
  9. from torch.nn import Module
  10. from torch.utils._pytree import tree_map
  11. from .common import aot_autograd
  12. from .registry import register_backend
  13. log = logging.getLogger(__name__)
  14. def cloner(t):
  15. if isinstance(t, torch.Tensor):
  16. return t.clone()
  17. else:
  18. return t
  19. class CudaGraphModule(Module):
  20. gm: GraphModule
  21. mutated_inputs: Set[int]
  22. def __init__(self, gm, mutated_inputs):
  23. super().__init__()
  24. self.gm = gm
  25. self.mutated_inputs = mutated_inputs
  26. warmed_up = False
  27. # these are all None or all filled
  28. graph = None
  29. static_inputs = None
  30. static_outputs = None
  31. # NB: we override __call__ as we don't need any nn.Module machinery
  32. # and to reduce overhead
  33. def __call__(self, *args):
  34. # TODO: once we've recorded here, we'd like to replace the __call__
  35. # implementation with compiled bytecode that copies into static, replays
  36. # the cuda graph, then copies out. First condition is the hotpath,
  37. # needs optimizing
  38. if self.graph is not None:
  39. assert len(args) == len(self.static_inputs)
  40. for dst, src in zip(self.static_inputs, args):
  41. dst.copy_(src)
  42. self.graph.replay()
  43. for i in self.mutated_inputs:
  44. args[i].copy_(self.static_inputs[i])
  45. return tree_map(cloner, self.static_outputs)
  46. elif self.warmed_up:
  47. # record
  48. self.static_inputs = [x.clone() for x in args]
  49. self.graph = torch.cuda.CUDAGraph()
  50. with torch.cuda.graph(self.graph):
  51. self.static_outputs = self.gm(*self.static_inputs)
  52. # NB: recording doesn't actually run the operations, so
  53. # now we immediately replay the graph to serve up the result
  54. self.graph.replay()
  55. for i in self.mutated_inputs:
  56. args[i].copy_(self.static_inputs[i])
  57. return tree_map(cloner, self.static_outputs)
  58. else:
  59. # warmup
  60. stream = torch.cuda.Stream()
  61. stream.wait_stream(torch.cuda.current_stream())
  62. with torch.cuda.stream(stream):
  63. r = self.gm(*args)
  64. torch.cuda.current_stream().wait_stream(stream)
  65. self.warmed_up = True
  66. return r
  67. # Interpreter versions of these passes can be found at
  68. # https://gist.github.com/ezyang/df2d746cac3b2c7d55c181e37c57ef23
  69. def find_input_mutations(g):
  70. def meta_fk(meta):
  71. return meta["val"] if "val" in meta else meta["fake_result"]
  72. inputs = defaultdict(set)
  73. input_idx = 0
  74. mutated_inputs = set()
  75. for n in g.nodes:
  76. if n.op == "placeholder":
  77. inputs[StorageWeakRef(meta_fk(n.meta)._typed_storage())].add(input_idx)
  78. input_idx += 1
  79. elif n.op == "call_function":
  80. if n.target is operator.getitem:
  81. continue
  82. schema = n.target._schema
  83. for i, arg in enumerate(schema.arguments):
  84. if i < len(n.args):
  85. argument = n.args[i]
  86. else:
  87. if arg.name not in n.kwargs:
  88. continue
  89. argument = n.kwargs[arg.name]
  90. mut_arg = False
  91. if arg.alias_info:
  92. if arg.alias_info.is_write:
  93. mut_arg = True
  94. if mut_arg:
  95. # TODO: not correct for args that contain tensors in a struct
  96. # like list
  97. mutated_inputs |= inputs[
  98. StorageWeakRef(meta_fk(argument.meta)._typed_storage())
  99. ]
  100. # TODO: error on unrecognized nodes
  101. return mutated_inputs
  102. # Mutates input graph
  103. def apply_cuda_graphs(gm):
  104. for n in gm.graph.nodes:
  105. if n.op == "call_module":
  106. assert not n.kwargs
  107. submod = gm.get_submodule(n.target)
  108. gm.delete_submodule(n.target)
  109. mutated_inputs = find_input_mutations(submod.graph)
  110. gm.add_submodule(n.target, CudaGraphModule(submod, mutated_inputs))
  111. # NB: we didn't actually change the graph, no need for recompile
  112. def cudagraphs(model, inputs):
  113. model = partition_cudagraphs(model, inputs)
  114. apply_cuda_graphs(model)
  115. return model
  116. aot_cudagraphs = aot_autograd(fw_compiler=cudagraphs, bw_compiler=cudagraphs)
  117. # aot_cudagraphs only applies CUDA graphs to the graph. It is also helpful
  118. # for debugging and can serve as a perf baseline.
  119. # TODO(jansel): rename to just "cudagraphs"?
  120. register_backend(name="cudagraphs", compiler_fn=aot_cudagraphs)
  121. def cudagraphs_inner(model, inputs, copy_outputs=True):
  122. """This isn't registered as a backend, but is used in some benchmarks"""
  123. assert isinstance(inputs, (list, tuple))
  124. static_inputs = [torch.zeros_like(x) for x in inputs]
  125. # warmup
  126. torch.cuda.synchronize()
  127. stream = torch.cuda.Stream()
  128. stream.wait_stream(torch.cuda.current_stream())
  129. with torch.cuda.stream(stream):
  130. model(*inputs)
  131. stream.synchronize()
  132. torch.cuda.current_stream().wait_stream(stream)
  133. torch.cuda.synchronize()
  134. # record
  135. graph = torch.cuda.CUDAGraph()
  136. with torch.cuda.graph(graph, stream=stream):
  137. static_outputs = model(*static_inputs)
  138. if not isinstance(static_outputs, (list, tuple)):
  139. static_outputs = (static_outputs,)
  140. def run(*new_inputs):
  141. assert len(static_inputs) == len(new_inputs)
  142. for dst, src in zip(static_inputs, new_inputs):
  143. dst.copy_(src)
  144. graph.replay()
  145. if copy_outputs:
  146. return [x.clone() for x in static_outputs]
  147. else:
  148. return static_outputs
  149. return run