distributed.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353
  1. import logging
  2. import traceback
  3. from dataclasses import dataclass, field
  4. from typing import Any, List, Optional
  5. import torch
  6. from torch import fx
  7. from torch._dynamo.output_graph import GraphCompileReason
  8. from torch._dynamo.utils import deepcopy_to_fake_tensor, fake_mode_from_tensors
  9. from torch.fx.node import Node
  10. log = logging.getLogger(__name__)
  11. def args_str(args):
  12. # a debug helper
  13. if torch.is_tensor(args):
  14. return f"T[{args.shape}]"
  15. elif isinstance(args, tuple):
  16. return f"tuple({', '.join([args_str(x) for x in args])})"
  17. elif isinstance(args, list):
  18. return f"list({', '.join([args_str(x) for x in args])})"
  19. else:
  20. return str(args)
  21. @dataclass
  22. class Bucket:
  23. size: int = 0
  24. params: List[str] = field(default_factory=list)
  25. nodes: List[fx.Node] = field(default_factory=list)
  26. # param_ids is just used for unit testing
  27. param_ids: List = field(default_factory=list)
  28. def pretty_print_buckets(buckets: List[Bucket]):
  29. headers = ("Index", "Size (b)", "Param Names")
  30. rows = []
  31. for idx, bucket in enumerate(reversed(buckets)):
  32. if len(bucket.params) > 0:
  33. rows.append((idx, bucket.size, bucket.params[0]))
  34. for param in bucket.params[1:]:
  35. rows.append((None, None, param))
  36. try:
  37. from tabulate import tabulate
  38. log.info(
  39. "\nDDPOptimizer bucket assignments\n"
  40. + tabulate(rows, headers=headers, tablefmt="simple_grid")
  41. )
  42. except ImportError:
  43. log.info(
  44. "Please `pip install tabulate` in order to pretty-print ddp bucket sizes"
  45. )
  46. class DDPOptimizer:
  47. """Note [DDPOptimizer]
  48. DDPOptimizer applies when dynamo compiles models wrapped in DistributedDataParallel (DDP),
  49. breaking the dynamo graph into chunks to compile separately, with the breaks aligning to
  50. the boundaries of gradient-allreduce buckets chosen by DDP.
  51. Background/Motivation
  52. - DDP uses allreduce collectives to synchronize partial gradients computed on different workers
  53. - DDP groups gradient allreduces into 'buckets' to optimize communication efficiency of all-reduce
  54. - Parameters grouped into buckets are assumed to be adjacent in time, so they become ready
  55. at around the same time during backward and thus can share the same allreduce efficently
  56. - Allreduces must overlap with backward compute for optimal training performance
  57. - DDP schedules allreduces using 'hooks' fired from the c++ autograd engine in pytorch, which
  58. operates when individual grads become 'ready'
  59. - Dynamo+AOTAutograd produces a single fused graph that runs 'atomically' from the perspective of the
  60. autograd engine, such that all gradients become 'ready' at the same time. Hooks fire after the whole
  61. fused backward function executes, preventing any overlap of compute and communication
  62. Algorithm
  63. - DDPOptimizer starts off with an FX graph traced by dynamo which represents forward. It can traverse
  64. this graph in reverse order to determine the true order that gradients will become ready during backward.
  65. - Parameter sizes are counted in reverse order, up to a bucket size limit, at which point a new bucket is started
  66. and a graph break introduced
  67. - Each of the subgraphs is compiled by the compiler provided to dynamo by the user, and then fused back together
  68. into an outer module that is returned to the user
  69. Notes
  70. - It would be better to enforce (by adding an API to DDP) that the bucket splits chosen here are used by DDP,
  71. and that DDP does not need to detect or optimize bucket order by observing execution at runtime, as it does
  72. in eager.
  73. - If Dynamo can't capture a whole graph for the portion of the model wrapped by DDP, this algorithm will currently
  74. produce splits that do not necessarily align with the buckets used by DDP. This should result in performance
  75. degradation approaching the baseline case where graph-splits are not used, but not worse.
  76. - If the backend compiler fails to compile a single subgraph, it will execute eagerly despite the rest of the
  77. subgraphs being compiled
  78. - DDP has a 'parameters_and_buffers_to_ignore' field, which DDPOptimizer attempts to honor by reading markers
  79. left by DDP on individual parameters. In cases where other transformations, such as reparameterization, are
  80. also used, the ignore markers could be lost. If DDPOptimizer fails to ignore a parameter ignored by DDP,
  81. it is not catastrophic but could impact performance by choosing sub-optimal bucket splits.
  82. - DDPOptimizer always ignores all buffers, regardless of their ignore flag, since buffers do not require gradients,
  83. and therefore aren't allreduced by DDP. (They are broadcast during forward, but this is not covered by
  84. DDPOptimizer)
  85. Debugging
  86. - Generally, it is easiest to debug DDPOptimizer in a single process program, using pdb.
  87. - In many cases, the log messages are helpful (they show bucket size assignments)-
  88. just configure torch._dynamo.config.log_level to info or debug.
  89. - See `benchmarks/dynamo/distributed.py` for a simple harness that will run a toy model or a torchbench model
  90. in a single process (or with torchrun, in multiple processes)
  91. Args:
  92. bucket_bytes_cap (int): Controls the size of buckets, in bytes, used to determine graphbreaks. Should be
  93. set to match the equivalent parameter on the original DDP module.
  94. backend_compile_fn (callable): A dynamo compiler function, to be invoked to compile each subgraph.
  95. first_bucket_cap (int): Controls the size of the first bucket. Should match DDP's first bucket cap. DDP
  96. special-cases the first bucket size since it is sometimes optimal to start a small allreduce early.
  97. """
  98. def __init__(
  99. self,
  100. bucket_bytes_cap: int,
  101. backend_compile_fn,
  102. first_bucket_cap: Optional[int] = None,
  103. ):
  104. if first_bucket_cap is not None:
  105. self.first_bucket_cap = first_bucket_cap
  106. elif torch.distributed.is_available():
  107. # this constant comes from C10D lib which is not always built
  108. self.first_bucket_cap = torch.distributed._DEFAULT_FIRST_BUCKET_BYTES
  109. else:
  110. self.first_bucket_cap = bucket_bytes_cap
  111. self.bucket_bytes_cap = bucket_bytes_cap
  112. assert (
  113. self.first_bucket_cap <= self.bucket_bytes_cap
  114. ), "First bucket should be smaller/equal to other buckets to get comms warmed up ASAP"
  115. self.backend_compile_fn = backend_compile_fn
  116. def _ignore_parameter(self, parameter):
  117. return hasattr(parameter, "_ddp_ignored") and parameter._ddp_ignored
  118. def compile_fn(self, gm: fx.GraphModule, example_inputs: List[torch.Tensor]):
  119. """
  120. Implements graph splitting, first determining a set of of buckets by counting
  121. parameter sizes in reverse graph order, then invoking the user/backend compiler
  122. to compile each subgraph. Finally, stiches compiled graphs into one graphmodule
  123. and returns its callable.
  124. """
  125. fake_mode = fake_mode_from_tensors(example_inputs)
  126. if fake_mode is None:
  127. fake_mode = torch._subclasses.fake_tensor.FakeTensorMode()
  128. # 1: compute the partition map according to DDP bucket logic
  129. buckets = [Bucket()] # (size, param_names)
  130. for node in reversed(gm.graph.nodes):
  131. if node.op in ("output", "placeholder"):
  132. continue
  133. if (
  134. buckets[0].size >= self.bucket_bytes_cap
  135. or len(buckets) == 1
  136. and buckets[0].size >= self.first_bucket_cap
  137. ):
  138. buckets.insert(0, Bucket())
  139. if node.op == "call_module":
  140. target = gm.get_submodule(node.target)
  141. for name, p in target.named_parameters():
  142. param = target.get_parameter(name)
  143. if p.requires_grad and not self._ignore_parameter(param):
  144. buckets[0].size += p.untyped_storage().nbytes()
  145. buckets[0].params.append(f"{node.target}_{name}")
  146. buckets[0].param_ids.append(id(param))
  147. elif node.op == "get_attr":
  148. maybe_param = getattr(gm, node.target)
  149. if maybe_param.requires_grad and not self._ignore_parameter(
  150. maybe_param
  151. ):
  152. buckets[0].size += maybe_param.untyped_storage().nbytes()
  153. buckets[0].params.append(node.target)
  154. buckets[0].param_ids.append(id(maybe_param))
  155. # All nodes have to be mapped to a bucket, even if they don't have their own params
  156. # Ignored params still end up in buckets, we just don't count them towards the capacity
  157. buckets[0].nodes.append(node)
  158. if len(buckets) > 1 and buckets[0].size == 0:
  159. # we collected a small preamble graph with ops that don't include parameters, fuse it back
  160. buckets[1].nodes.extend(buckets[0].nodes)
  161. assert len(buckets[0].params) == 0, "Params should be empty if size is 0"
  162. del buckets[0]
  163. # stash buckets for testing/debugging purposes
  164. self.buckets = buckets
  165. log.info(
  166. f"DDPOptimizer used bucket cap {self.bucket_bytes_cap} and produced the following buckets:"
  167. )
  168. pretty_print_buckets(buckets)
  169. if len(buckets) == 1:
  170. # bypass split/fuse logic if there is only one bucket
  171. return self.backend_compile_fn(gm, example_inputs)
  172. # 2: partition the graphmodule according to bucket capacity
  173. partition_map = {}
  174. for idx, b in enumerate(buckets):
  175. for node in b.nodes:
  176. partition_map[node] = idx
  177. split_gm = fx.passes.split_module.split_module(
  178. gm, None, lambda node: partition_map[node]
  179. )
  180. debug_str = (
  181. f"\n---orig graph---\n{gm.graph}\n"
  182. + f"\n---split graph---\n{split_gm.graph}\n"
  183. )
  184. for name, module in split_gm.named_modules():
  185. if "." not in name and len(name):
  186. # only print the submod graphs, not their children
  187. debug_str += f"\n---{name} graph---\n{module.graph}\n"
  188. debug_str += "\n---------------\n"
  189. log.debug(debug_str)
  190. # 3: compile each of the partitioned submodules using the user-provided compiler
  191. class SubmodCompiler(torch.fx.interpreter.Interpreter):
  192. def __init__(self, module, compiler):
  193. super().__init__(module)
  194. self.compiler = compiler
  195. def compile_submod(self, input_mod, args, kwargs):
  196. """
  197. Compile the submodule,
  198. using a wrapper to make sure its output is always a tuple,
  199. which is required by AotAutograd based compilers
  200. """
  201. assert len(kwargs) == 0, "We assume only args for these modules"
  202. class WrapperModule(torch.nn.Module):
  203. def __init__(self, submod, unwrap_singleton_tuple):
  204. super().__init__()
  205. self.submod = submod
  206. self.unwrap_singleton_tuple = unwrap_singleton_tuple
  207. def forward(self, *args):
  208. x = self.submod(*args)
  209. # TODO(whc)
  210. # for some reason the isinstance check is necessary if I split one node per submod
  211. # - even though I supposedly wrapped the output in a tuple in those cases, the real
  212. # compiled module was still returning a tensor
  213. if self.unwrap_singleton_tuple and isinstance(x, (tuple, list)):
  214. return x[0]
  215. return x
  216. unwrap_singleton_tuple = False
  217. for sn in input_mod.graph.nodes:
  218. if sn.op == "output":
  219. if not isinstance(sn.args[0], tuple):
  220. unwrap_singleton_tuple = True
  221. sn.args = (sn.args,)
  222. input_mod.recompile()
  223. input_mod.compile_subgraph_reason = GraphCompileReason(
  224. "DDPOptimizer intentional graph-break (See Note [DDPOptimizer])."
  225. " Set `torch._dynamo.config.optimize_ddp = False` to disable.",
  226. [
  227. # it's close to useless to get a real stacktrace here, and quite verbose.
  228. traceback.FrameSummary(__file__, 0, DDPOptimizer),
  229. ],
  230. )
  231. wrapper = WrapperModule(
  232. self.compiler(input_mod, args),
  233. unwrap_singleton_tuple,
  234. )
  235. return wrapper
  236. # Note:
  237. #
  238. # The way distributed works today around fake tensors can be somehwat confusing.
  239. # Some of these codepaths are shared in both runtime, and compile time. The presence
  240. # of a fake_mode, read off of fake tensor inputs, dictates how we will operate.
  241. #
  242. # A few things to keep in mind:
  243. #
  244. # 1) We invoke `compile_submod` with a real module. The output of that gets stored
  245. # on the graph via `self.module.add_submodule(n.target, compiled_submod_real)`.
  246. #
  247. # 2) When running a call_module targeted node, if we have a fake_mode, we fakify the
  248. # module we got from self.fetch_attr(n.target). Regardless of fake_mode, we then execute it.
  249. #
  250. # 3) Fake tensors should always be around during compile time.
  251. #
  252. # 4) Fake tensors should never be around at runtime.
  253. #
  254. # 5) We end up with a compilation mode that takes a real submodule and fake tensors,
  255. # to match what aot_autograd exepcts. See Note: [Fake Modules and AOTAutograd]
  256. def run_node(self, n: Node) -> Any:
  257. with self._set_current_node(n):
  258. args, kwargs = self.fetch_args_kwargs_from_env(n)
  259. new_args = []
  260. assert fake_mode
  261. for arg in args:
  262. if isinstance(arg, torch.Tensor) and not isinstance(
  263. arg, torch._subclasses.FakeTensor
  264. ):
  265. new_args.append(fake_mode.from_tensor(arg))
  266. else:
  267. new_args.append(arg)
  268. log.debug(f"run_node {n.op}, {n.target} got args {args_str(args)}")
  269. assert isinstance(args, tuple)
  270. assert isinstance(kwargs, dict)
  271. if n.op == "call_module":
  272. real_mod = self.fetch_attr(n.target)
  273. if fake_mode:
  274. curr_submod = deepcopy_to_fake_tensor(real_mod, fake_mode)
  275. else:
  276. curr_submod = real_mod
  277. log.debug(
  278. f"\n---{n.target} graph---\n" + str(curr_submod.graph)
  279. )
  280. # When calling the compiler on the submod, inputs (new_args) are expected to
  281. # be FakeTensors already since Dynamo would have made them FakeTensors in the
  282. # non-DDP flow. However, the parameters are _not_ expected to be FakeTensors,
  283. # since this wrapping happens during compilation
  284. compiled_submod_real = self.compile_submod(
  285. real_mod, new_args, kwargs
  286. )
  287. # We update the original (outer) graph with a call into the compiled module
  288. # instead of the uncompiled one.
  289. self.module.delete_submodule(n.target)
  290. n.target = "compiled_" + n.target
  291. self.module.add_submodule(n.target, compiled_submod_real)
  292. # Finally, we have to produce inputs for use compiling the next submodule,
  293. # and these need to be FakeTensors, so we execute the module under fake_mode
  294. with fake_mode:
  295. return curr_submod(*new_args, **kwargs)
  296. else:
  297. # placeholder or output nodes don't need to get compiled, just executed
  298. return getattr(self, n.op)(n.target, new_args, kwargs)
  299. submod_compiler = SubmodCompiler(split_gm, self.backend_compile_fn)
  300. submod_compiler.run(*example_inputs)
  301. split_gm.recompile()
  302. log.debug("\n---final graph---\n" + str(split_gm.graph) + "\n---------------\n")
  303. return split_gm