_ddp.py 50 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165
  1. import copy
  2. import inspect
  3. import itertools
  4. import logging
  5. import os
  6. import sys
  7. import warnings
  8. from contextlib import contextmanager
  9. from dataclasses import dataclass
  10. from enum import auto, Enum
  11. from typing import Any, Callable, Optional, Type
  12. import torch
  13. import torch.distributed as dist
  14. from torch.autograd import Function, Variable
  15. from torch.utils._pytree import tree_flatten, tree_unflatten
  16. if dist.is_available():
  17. from torch.distributed.distributed_c10d import _get_default_group, ReduceOp
  18. from torch.distributed.utils import (
  19. _sync_module_states,
  20. _to_kwargs,
  21. _verify_param_shape_across_processes,
  22. )
  23. from torch._utils import _get_device_index
  24. from torch.nn.modules import Module
  25. from torch.nn.parallel.scatter_gather import gather, scatter_kwargs
  26. __all__ = ["DistributedDataParallel"]
  27. logger = logging.getLogger(__name__)
  28. def _find_tensors(obj):
  29. r"""
  30. Recursively find all tensors contained in the specified object.
  31. """
  32. if isinstance(obj, torch.Tensor):
  33. return [obj]
  34. if isinstance(obj, (list, tuple)):
  35. return itertools.chain(*map(_find_tensors, obj))
  36. if isinstance(obj, dict):
  37. return itertools.chain(*map(_find_tensors, obj.values()))
  38. return []
  39. class _BufferCommHookLocation(Enum):
  40. PRE_FORWARD = auto()
  41. POST_FORWARD = auto()
  42. @dataclass
  43. class _BufferCommHook:
  44. buffer_comm_hook: Callable
  45. buffer_comm_hook_state: Any
  46. buffer_comm_hook_location: _BufferCommHookLocation
  47. # Add a DDPSink to run various functions when backwards starts, such as
  48. # queueing call back of out-most backward/graph task,
  49. # this helps call back is fired after all gradients' calculation
  50. # is completed.
  51. class _DDPSink(Function):
  52. @staticmethod
  53. def forward(ctx, reducer, state_dict, *inputs):
  54. # set_materialize_grads(False) will ensure that None gradients stay as
  55. # None and are not filled with zeros.
  56. ctx.set_materialize_grads(False)
  57. ctx.reducer = reducer
  58. ctx.state_dict = state_dict
  59. ret = tuple(
  60. inp.clone() if isinstance(inp, torch.Tensor) else inp for inp in inputs
  61. )
  62. return ret
  63. @staticmethod
  64. def backward(ctx, *grad_outputs):
  65. state_dict = ctx.state_dict
  66. # Enqueue delay allreduce for static graph training on the first
  67. # iteration.
  68. if state_dict["static_graph"] and state_dict["num_iterations"] == 1:
  69. Variable._execution_engine.queue_callback(ctx.reducer._delay_all_reduce) # type: ignore[call-arg,misc]
  70. return (None, None, *grad_outputs)
  71. class DistributedDataParallel(Module):
  72. # used to track whether the given thread is inside ddp forward for torchdynamo purposes
  73. _active_ddp_module = None
  74. def __init__(
  75. self,
  76. module,
  77. device_ids=None,
  78. output_device=None,
  79. dim=0,
  80. broadcast_buffers=True,
  81. process_group=None,
  82. bucket_cap_mb=25,
  83. find_unused_parameters=False,
  84. gradient_as_bucket_view=False,
  85. static_graph=False,
  86. ):
  87. super().__init__()
  88. self.logger: Optional[dist.Logger] = None
  89. if not any((p.requires_grad for p in module.parameters())):
  90. self._log_and_throw(
  91. RuntimeError,
  92. "DistributedDataParallel is not needed when a module "
  93. "doesn't have any parameter that requires a gradient.",
  94. )
  95. if device_ids is not None and len(device_ids) > 1:
  96. self._log_and_throw(
  97. ValueError,
  98. "device_ids can only be None or contain a single element.",
  99. )
  100. self.is_multi_device_module = len({p.device for p in module.parameters()}) > 1
  101. distinct_device_types = {p.device.type for p in module.parameters()}
  102. if len(distinct_device_types) != 1:
  103. self._log_and_throw(
  104. ValueError,
  105. "DistributedDataParallel's input module must be on "
  106. "the same type of devices, but input module parameters locate in {}.".format(
  107. distinct_device_types
  108. ),
  109. )
  110. self.device_type = list(distinct_device_types)[0]
  111. if (
  112. device_ids is None
  113. or len(device_ids) == 0 # For backward compatibility.
  114. or self.device_type == "cpu"
  115. or self.is_multi_device_module
  116. ):
  117. if device_ids or output_device:
  118. self._log_and_throw(
  119. ValueError,
  120. "DistributedDataParallel device_ids and output_device arguments "
  121. "only work with single-device/multiple-device GPU modules or CPU modules, "
  122. "but got device_ids {}, output_device {}, and module parameters {}.".format(
  123. device_ids,
  124. output_device,
  125. {p.device for p in module.parameters()},
  126. ),
  127. )
  128. self.device_ids = None
  129. self.output_device = None
  130. else:
  131. self.device_ids = [_get_device_index(x, True) for x in device_ids]
  132. if output_device is None:
  133. output_device = device_ids[0]
  134. self.output_device = _get_device_index(output_device, True)
  135. if process_group is None:
  136. self.process_group = _get_default_group()
  137. else:
  138. self.process_group = process_group
  139. self.static_graph = False
  140. self.dim = dim
  141. self.module = module
  142. self.device = list(self.module.parameters())[0].device
  143. self.broadcast_buffers = broadcast_buffers
  144. self.find_unused_parameters = find_unused_parameters
  145. self.require_backward_grad_sync = True
  146. self.require_forward_param_sync = True
  147. self.gradient_as_bucket_view = gradient_as_bucket_view
  148. if hasattr(module, "_ddp_params_and_buffers_to_ignore"):
  149. self.parameters_to_ignore = module._ddp_params_and_buffers_to_ignore
  150. else:
  151. self.parameters_to_ignore = []
  152. # Check that a module does not have Uninitialized parameters
  153. for param in module.parameters():
  154. if isinstance(param, torch.nn.parameter.UninitializedParameter):
  155. self._log_and_throw(
  156. RuntimeError,
  157. "Modules with uninitialized parameters can't be used with `DistributedDataParallel`. "
  158. "Run a dummy forward pass to correctly initialize the modules",
  159. )
  160. # used for intra-node param sync and inter-node sync as well
  161. self.broadcast_bucket_size = int(250 * 1024 * 1024)
  162. # reduction bucket size
  163. self.bucket_bytes_cap = int(bucket_cap_mb * 1024 * 1024)
  164. # Whether to perform input tensor CPU to GPU copies on a side-stream
  165. self.use_side_stream_for_tensor_copies = (
  166. os.environ.get("PYTORCH_DDP_USE_SIDE_STREAM", "1") == "1"
  167. )
  168. # Build parameters for reducer.
  169. parameters, expect_sparse_gradient = self._build_params_for_reducer()
  170. # Verify model equivalence.
  171. _verify_param_shape_across_processes(self.process_group, parameters)
  172. # Sync params and buffers. Ensures all DDP models start off at the same value.
  173. _sync_module_states(
  174. module=self.module,
  175. process_group=self.process_group,
  176. broadcast_bucket_size=self.broadcast_bucket_size,
  177. src=0,
  178. params_and_buffers_to_ignore=self.parameters_to_ignore,
  179. )
  180. # In debug mode, build a mapping of parameter index -> parameter.
  181. param_to_name_mapping = self._build_debug_param_to_name_mapping(parameters)
  182. # Builds reducer.
  183. self._ddp_init_helper(
  184. parameters,
  185. expect_sparse_gradient,
  186. param_to_name_mapping,
  187. static_graph,
  188. )
  189. self._has_rebuilt_buckets = False
  190. if static_graph:
  191. self._set_static_graph()
  192. def _log_and_throw(self, err_type, err_msg):
  193. if self.logger is not None:
  194. self.logger.set_error_and_log(f"{str(err_type)}: {err_msg}")
  195. raise err_type(err_msg)
  196. def _ddp_init_helper(
  197. self,
  198. parameters,
  199. expect_sparse_gradient,
  200. param_to_name_mapping,
  201. static_graph,
  202. ):
  203. """
  204. Initialization helper function that does the following:
  205. (1) bucketing the parameters for reductions
  206. (2) resetting the bucketing states
  207. (3) registering the grad hooks
  208. (4) Logging construction-time DDP logging data
  209. (5) passing a handle of DDP to SyncBatchNorm Layer
  210. """
  211. self.num_iterations = 0
  212. # Notice, the parameters order is not in the order in which they are used,
  213. # especially in models with control flow.
  214. #
  215. # Alongside parameters are not presented in the real execution order,
  216. # if a certain model happens to also
  217. # 1) have other collectives comm ops in its backward graph.
  218. # 2) have unused parameter in subset ranks of the whole world.
  219. # bucketing could insert ALL-REDUCE comm op too early on the rank with unused parameter,
  220. # matching up with other collectives comm ops on other ranks unexpectedly.
  221. #
  222. # In order to handle this corner case, when the parameters are not in the real execution order,
  223. # we don't do bucketing, thus only one ALL-REDUCE is inserted after all the gradients
  224. # of the whole graph are computed.
  225. #
  226. # Notice, here we only disable bucketing for the first iteration.
  227. # After the first iteration, it's OK to rebuild buckets,
  228. # because "bucket rebuild" bucketizes parameters based on its real execution order in backward graph.
  229. # Can remove this branching once #73732 is landed.
  230. if static_graph is True or self.find_unused_parameters is False:
  231. bucket_size_limits = [sys.maxsize]
  232. else:
  233. bucket_size_limits = [
  234. dist._DEFAULT_FIRST_BUCKET_BYTES,
  235. self.bucket_bytes_cap,
  236. ]
  237. (
  238. bucket_indices,
  239. per_bucket_size_limits,
  240. ) = dist._compute_bucket_assignment_by_size(
  241. parameters,
  242. bucket_size_limits,
  243. expect_sparse_gradient,
  244. )
  245. # Note: reverse list of buckets because we want to approximate the
  246. # order in which their gradients are produced, and assume they
  247. # are used in the forward pass in the order they are defined.
  248. self.reducer = dist.Reducer(
  249. parameters,
  250. list(reversed(bucket_indices)),
  251. list(reversed(per_bucket_size_limits)),
  252. self.process_group,
  253. expect_sparse_gradient,
  254. # The bucket size limit is specified in the constructor.
  255. # Additionally, we allow for a single small bucket for parameters
  256. # that are defined first, such that their gradients don't spill into
  257. # a much larger bucket, adding unnecessary latency after gradient
  258. # computation finishes. Experiments showed 1MB is a reasonable value.
  259. self.bucket_bytes_cap,
  260. self.find_unused_parameters,
  261. self.gradient_as_bucket_view,
  262. param_to_name_mapping,
  263. # User can set dist._DEFAULT_FIRST_BUCKET_BYTES to tune DDP first
  264. # bucket.
  265. dist._DEFAULT_FIRST_BUCKET_BYTES,
  266. )
  267. self.logger = dist.Logger(self.reducer)
  268. # Set as a weak reference to avoid reference cycle between
  269. # logger and reducer.
  270. self.reducer.set_logger(self.logger)
  271. has_sync_bn = False
  272. for submodule in self.module.modules():
  273. if isinstance(submodule, torch.nn.SyncBatchNorm):
  274. has_sync_bn = True
  275. break
  276. # Set logging data that can be got during construction time.
  277. self.logger.set_construction_data_and_log(
  278. self.module.__class__.__name__,
  279. [] if self.device_ids is None else self.device_ids,
  280. -1 if self.output_device is None else self.output_device,
  281. self.broadcast_buffers,
  282. has_sync_bn,
  283. static_graph,
  284. )
  285. # passing a handle to torch.nn.SyncBatchNorm layer
  286. self._passing_sync_batchnorm_handle(self.module)
  287. def __getstate__(self):
  288. self._check_default_group()
  289. attrs = copy.copy(self.__dict__)
  290. del attrs["process_group"]
  291. del attrs["reducer"]
  292. del attrs["logger"]
  293. return attrs
  294. def __setstate__(self, state):
  295. # If serializable, then the process group should be the default one
  296. self.process_group = _get_default_group()
  297. super().__setstate__(state)
  298. self.__dict__.setdefault("require_forward_param_sync", True)
  299. self.__dict__.setdefault("require_backward_grad_sync", True)
  300. parameters, expect_sparse_gradient = self._build_params_for_reducer()
  301. # In debug mode, build a mapping of parameter index -> parameter.
  302. param_to_name_mapping = self._build_debug_param_to_name_mapping(parameters)
  303. # Builds reducer.
  304. self._ddp_init_helper(
  305. parameters,
  306. expect_sparse_gradient,
  307. param_to_name_mapping,
  308. self.static_graph,
  309. )
  310. if self.static_graph:
  311. self.reducer._set_static_graph()
  312. assert self.logger is not None
  313. self.logger._set_static_graph()
  314. def _build_params_for_reducer(self):
  315. # Build tuple of (module, parameter) for all parameters that require grads.
  316. modules_and_parameters = [
  317. (module, parameter)
  318. for module_name, module in self.module.named_modules()
  319. for parameter in [
  320. param
  321. # Note that we access module.named_parameters instead of
  322. # parameters(module). parameters(module) is only needed in the
  323. # single-process multi device case, where it accesses replicated
  324. # parameters through _former_parameters.
  325. for param_name, param in module.named_parameters(recurse=False)
  326. if param.requires_grad
  327. and f"{module_name}.{param_name}" not in self.parameters_to_ignore
  328. ]
  329. ]
  330. # Deduplicate any parameters that might be shared across child modules.
  331. memo = set()
  332. modules_and_parameters = [
  333. # "p not in memo" is the deduplication check.
  334. # "not memo.add(p)" is always True, and it's only there to cause "add(p)" if needed.
  335. (m, p)
  336. for m, p in modules_and_parameters
  337. if p not in memo and not memo.add(p) # type: ignore[func-returns-value]
  338. ]
  339. # Build list of parameters.
  340. parameters = [parameter for _, parameter in modules_and_parameters]
  341. # Checks if a module will produce a sparse gradient.
  342. def produces_sparse_gradient(module):
  343. if isinstance(module, (torch.nn.Embedding, torch.nn.EmbeddingBag)):
  344. return module.sparse
  345. return False
  346. # Build list of booleans indicating whether or not to expect sparse
  347. # gradients for the corresponding parameters.
  348. expect_sparse_gradient = [
  349. produces_sparse_gradient(module) for module, _ in modules_and_parameters
  350. ]
  351. self._assign_modules_buffers()
  352. return parameters, expect_sparse_gradient
  353. def _assign_modules_buffers(self):
  354. """
  355. Assigns module buffers to self.modules_buffers which are then used to
  356. broadcast across ranks when broadcast_buffers=True. Note that this
  357. must be called every time buffers need to be synced because buffers can
  358. be reassigned by user module,
  359. see https://github.com/pytorch/pytorch/issues/63916.
  360. """
  361. # Collect buffers for modules, filtering out buffers that should be ignored.
  362. named_module_buffers = [
  363. (buffer, buffer_name)
  364. for buffer_name, buffer in self.module.named_buffers()
  365. if buffer_name not in self.parameters_to_ignore
  366. ]
  367. self.modules_buffers = [
  368. buffer for (buffer, buffer_name) in named_module_buffers
  369. ]
  370. # Dict[str, tensor] representing module buffers not ignored by DDP.
  371. self.named_module_buffers = {
  372. buffer_name: buffer for (buffer, buffer_name) in named_module_buffers
  373. }
  374. def _build_debug_param_to_name_mapping(self, parameters):
  375. if dist.get_debug_level() == dist.DebugLevel.OFF:
  376. return {}
  377. param_to_param_index = {parameters[i]: i for i in range(len(parameters))}
  378. param_set = set(parameters)
  379. param_index_to_param_fqn = {}
  380. for module_name, module in self.module.named_modules():
  381. for param_name, param in module.named_parameters(recurse=False):
  382. fqn = f"{module_name}.{param_name}"
  383. # Bypass ignored parameters since those are not reduced by DDP
  384. # to begin with.
  385. if fqn not in self.parameters_to_ignore and param.requires_grad:
  386. if param not in param_set:
  387. self._log_and_throw(
  388. ValueError,
  389. f"Param with name {fqn} found in module parameters, but not DDP parameters."
  390. " This indicates a bug in DDP, please report an issue to PyTorch.",
  391. )
  392. param_index = param_to_param_index[param]
  393. param_index_to_param_fqn[param_index] = fqn
  394. # Ensure we covered all parameters
  395. if len(param_set) != len(param_index_to_param_fqn):
  396. self._log_and_throw(
  397. ValueError,
  398. (
  399. "Expected param to name mapping to cover all parameters, but"
  400. f" got conflicting lengths: {len(param_set)} vs "
  401. f"{len(param_index_to_param_fqn)}. This indicates a bug in DDP"
  402. ", please report an issue to PyTorch."
  403. ),
  404. )
  405. return param_index_to_param_fqn
  406. def _get_parameters(self, m, recurse=True):
  407. """
  408. Returns a generator of module parameters
  409. """
  410. def model_parameters(m):
  411. ps = (
  412. m._former_parameters.values()
  413. if hasattr(m, "_former_parameters")
  414. else m.parameters(recurse=False)
  415. )
  416. yield from ps
  417. for m in m.modules() if recurse else [m]:
  418. for p in model_parameters(m):
  419. yield p
  420. def _check_default_group(self):
  421. pickle_not_supported = False
  422. try:
  423. if self.process_group != _get_default_group():
  424. pickle_not_supported = True
  425. except RuntimeError:
  426. pickle_not_supported = True
  427. if pickle_not_supported:
  428. self._log_and_throw(
  429. RuntimeError,
  430. "DDP Pickling/Unpickling are only supported "
  431. "when using DDP with the default process "
  432. "group. That is, when you have called "
  433. "init_process_group and have not passed "
  434. "process_group argument to DDP constructor",
  435. )
  436. @contextmanager
  437. def no_sync(self):
  438. r"""
  439. A context manager to disable gradient synchronizations across DDP
  440. processes. Within this context, gradients will be accumulated on module
  441. variables, which will later be synchronized in the first
  442. forward-backward pass exiting the context.
  443. Example::
  444. >>> # xdoctest: +SKIP("undefined variables")
  445. >>> ddp = torch.nn.parallel.DistributedDataParallel(model, pg)
  446. >>> with ddp.no_sync():
  447. >>> for input in inputs:
  448. >>> ddp(input).backward() # no synchronization, accumulate grads
  449. >>> ddp(another_input).backward() # synchronize grads
  450. """
  451. old_require_backward_grad_sync = self.require_backward_grad_sync
  452. self.require_backward_grad_sync = False
  453. try:
  454. yield
  455. finally:
  456. self.require_backward_grad_sync = old_require_backward_grad_sync
  457. @classmethod
  458. def _get_active_ddp_module(cls):
  459. """
  460. TorchDynamo needs to know whether DDP is currently active, and access the DDP module in order to cooperatively optimize it.
  461. """
  462. return cls._active_ddp_module
  463. # note, this ctxmgr function is marked 'skip' in torchdynamo, so dynamo only kicks in
  464. # for the 'module_to_run' underneath
  465. # see torchdynamo/eval_frame.py TorchPatcher.patch for more details
  466. @contextmanager
  467. def _inside_ddp_forward(self):
  468. DistributedDataParallel._active_ddp_module = self
  469. try:
  470. yield
  471. except Exception:
  472. raise
  473. finally:
  474. DistributedDataParallel._active_ddp_module = None
  475. def pre_forward(self):
  476. with torch.autograd.profiler.record_function(
  477. "DistributedDataParallel.pre_forward"
  478. ):
  479. if torch.is_grad_enabled() and self.require_backward_grad_sync:
  480. assert self.logger is not None
  481. self.logger.set_runtime_stats_and_log()
  482. self.num_iterations += 1
  483. self.reducer.prepare_for_forward()
  484. # Calling _rebuild_buckets before forward computation,
  485. # It may allocate new buckets before deallocating old buckets
  486. # inside _rebuild_buckets. To save peak memory usage,
  487. # call _rebuild_buckets before the peak memory usage increases
  488. # during forward computation.
  489. # This should be called only once during whole training period.
  490. if torch.is_grad_enabled() and self.reducer._rebuild_buckets():
  491. logger.info("Reducer buckets have been rebuilt in this iteration.")
  492. self._has_rebuilt_buckets = True
  493. # sync params according to location (before/after forward) user
  494. # specified as part of hook, if hook was specified.
  495. if self._check_sync_bufs_pre_fwd():
  496. self._sync_buffers()
  497. def post_forward(self, output):
  498. with torch.autograd.profiler.record_function(
  499. "DistributedDataParallel.post_forward"
  500. ):
  501. # sync params according to location (before/after forward) user
  502. # specified as part of hook, if hook was specified.
  503. if self._check_sync_bufs_post_fwd():
  504. self._sync_buffers()
  505. if torch.is_grad_enabled() and self.require_backward_grad_sync:
  506. self.require_forward_param_sync = True
  507. # We'll return the output object verbatim since it is a freeform
  508. # object. We need to find any tensors in this object, though,
  509. # because we need to figure out which parameters were used during
  510. # this forward pass, to ensure we short circuit reduction for any
  511. # unused parameters. Only if `find_unused_parameters` is set.
  512. if self.find_unused_parameters and not self.static_graph:
  513. # Do not need to populate this for static graph.
  514. self.reducer.prepare_for_backward(list(_find_tensors(output)))
  515. else:
  516. self.reducer.prepare_for_backward([])
  517. else:
  518. self.require_forward_param_sync = False
  519. # TODO: DDPSink is currently enabled for unused parameter detection and
  520. # static graph training for first iteration.
  521. if (self.find_unused_parameters and not self.static_graph) or (
  522. self.static_graph and self.num_iterations == 1
  523. ):
  524. state_dict = {
  525. "static_graph": self.static_graph,
  526. "num_iterations": self.num_iterations,
  527. }
  528. output_tensor_list, treespec = tree_flatten(output)
  529. output_placeholders = [None for _ in range(len(output_tensor_list))]
  530. # Do not touch tensors that have no grad_fn, which can cause issues
  531. # such as https://github.com/pytorch/pytorch/issues/60733
  532. for i, output in enumerate(output_tensor_list):
  533. if torch.is_tensor(output) and output.grad_fn is None:
  534. output_placeholders[i] = output
  535. # When find_unused_parameters=True, makes tensors which require grad
  536. # run through the DDPSink backward pass. When not all outputs are
  537. # used in loss, this makes those corresponding tensors receive
  538. # undefined gradient which the reducer then handles to ensure
  539. # param.grad field is not touched and we don't error out.
  540. passthrough_tensor_list = _DDPSink.apply(
  541. self.reducer,
  542. state_dict,
  543. *output_tensor_list,
  544. )
  545. for i in range(len(output_placeholders)):
  546. if output_placeholders[i] is None:
  547. output_placeholders[i] = passthrough_tensor_list[i]
  548. # Reconstruct output data structure.
  549. output = tree_unflatten(output_placeholders, treespec)
  550. return output
  551. def forward(self, *inputs, **kwargs):
  552. self.pre_forward(*inputs, **kwargs)
  553. with torch.autograd.profiler.record_function("DistributedDataParallel.forward"):
  554. if self.device_ids:
  555. inputs, kwargs = _to_kwargs(
  556. inputs,
  557. kwargs,
  558. self.device_ids[0],
  559. self.use_side_stream_for_tensor_copies,
  560. )
  561. with self._inside_ddp_forward():
  562. output = self.module(*inputs[0], **kwargs[0]) # type: ignore[index]
  563. else:
  564. with self._inside_ddp_forward():
  565. output = self.module(*inputs, **kwargs)
  566. output = self.post_forward(output)
  567. return output
  568. def scatter(self, inputs, kwargs, device_ids):
  569. return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
  570. def to_kwargs(self, inputs, kwargs, device_id):
  571. # Kept for BC
  572. return _to_kwargs(
  573. inputs, kwargs, device_id, self.use_side_stream_for_tensor_copies
  574. )
  575. def gather(self, outputs, output_device):
  576. return gather(outputs, output_device, dim=self.dim)
  577. def train(self, mode=True):
  578. super().train(mode)
  579. return self
  580. # When running in join mode, schedules an allreduce to notify joined ranks
  581. # of whether backwards pass synchronization will run this iteration or not.
  582. def _check_global_requires_backward_grad_sync(self, is_joined_rank):
  583. if not is_joined_rank and self.require_backward_grad_sync:
  584. requires_sync_tensor = torch.ones(1, device=self.device)
  585. else:
  586. requires_sync_tensor = torch.zeros(1, device=self.device)
  587. work = dist.all_reduce(
  588. requires_sync_tensor, group=self.process_group, async_op=True
  589. )
  590. return work
  591. # When running in join mode, checks and performs sync of module buffers if
  592. # the models have buffers that should be synchronized in the forward pass.
  593. def _check_and_sync_module_buffers(self):
  594. if self._check_sync_bufs_pre_fwd():
  595. authoritative_rank = self._find_common_rank(self._distributed_rank, False)
  596. self._sync_module_buffers(authoritative_rank)
  597. # When running in join model, agrees upon a common rank and broadcast model
  598. # parameters to all other ranks.
  599. def _sync_final_model(self, is_last_joiner):
  600. # Agree upon the process that will be the authoritative model copy.
  601. # The current rank is a candidate for being the authoritative copy if
  602. # is_last_joiner=True. We break ties via picking the larger rank.
  603. self._authoritative_rank = self._find_common_rank(
  604. self._distributed_rank, is_last_joiner
  605. )
  606. _sync_module_states(
  607. module=self.module,
  608. process_group=self.process_group,
  609. broadcast_bucket_size=self.broadcast_bucket_size,
  610. src=self._authoritative_rank,
  611. params_and_buffers_to_ignore=self.parameters_to_ignore,
  612. )
  613. # Schedule comm ops to match those scheduled in the reducer's backward
  614. # pass.
  615. def _match_all_reduce_for_bwd_pass(self):
  616. comm_work = []
  617. # Schedule comm in the same order as Reducer schedules them, i.e.
  618. # the order of the buckets. Retrieving the bucket order from the reducer
  619. # ensures that we keep the same order in join mode, such as when bucket
  620. # order is rebuilt dynamically.
  621. # Returns grad_buckets in order, but real tensors are substituted with
  622. # zero tensors of the same shape.
  623. grad_buckets = self.reducer._get_zeros_like_grad_buckets()
  624. for grad_bucket in grad_buckets:
  625. # Joined processes contribute zero gradient. In the case that
  626. # divide_by_initial_world_size=True, we divide grads by the static
  627. # world size, if not, the dividing factor is reduced by the number
  628. # of joined processes.
  629. work = self.reducer._run_comm_hook(grad_bucket)
  630. comm_work.append(work)
  631. for work in comm_work:
  632. work.wait()
  633. # Allreduces the used parameter mapping across ranks.
  634. def _match_unused_params_allreduce(self):
  635. locally_used_param_map = self.reducer._get_local_used_map()
  636. self.process_group.allreduce(locally_used_param_map)
  637. def _register_buffer_comm_hook(
  638. self,
  639. state,
  640. hook: Callable,
  641. comm_hook_location=_BufferCommHookLocation.POST_FORWARD,
  642. ):
  643. r"""
  644. Allows custom registration of hooks that define how buffer are
  645. synchronized across ranks. The hook takes in an optional state
  646. and is passed in a Dict[str, Tensor] corresponding to buffer names
  647. and the buffers, and can run arbitrary reductions on buffers as
  648. opposed to DDP's default broadcast from rank 0. This is useful for
  649. example if a counter needs to be summed or averaged across ranks
  650. every iteration.
  651. Args:
  652. state (Any): Optional state that is passed to the hook.
  653. hook (Callable): Callable with the following signature:
  654. ``hook(state: object, bucket: dist.GradBucket) -> torch.futures.Future[torch.Tensor]``:
  655. comm_hook_location (_BufferCommHookLocation): Enum value indicating
  656. where to run the hook.
  657. _BufferCommHookLocation.PRE_FORWARD means that the
  658. hook will run _before_ the forward pass, and
  659. _BufferCommHookLocation.POST_FORWARD means that the
  660. hook will run _after_ the forward pass.
  661. NOTE: To maximize performance, users can return a
  662. List[torch.futures.Future] from their hook, and DDP will
  663. install and await these hooks appropriately at the end of
  664. the backward pass. This will ensure all buffers are
  665. synchronized by the end of the backward pass. If this
  666. setting is used, it is recommended to pass
  667. comm_hook_location=_BufferCommHookLocation.POST_FORWARD,
  668. which will trigger the hook after the forward pass.
  669. If _BufferCommHookLocation.PRE_FORWARD is used, users must
  670. ensure appropriate synchronization when manipulating GPU
  671. buffers in the forward pass.
  672. """
  673. assert callable(hook)
  674. self.buffer_hook = _BufferCommHook(
  675. buffer_comm_hook=hook,
  676. buffer_comm_hook_state=state,
  677. buffer_comm_hook_location=comm_hook_location,
  678. )
  679. def register_comm_hook(self, state: object, hook: Callable):
  680. r"""
  681. Registers a communication hook which is an enhancement that provides a
  682. flexible hook to users where they can specify how DDP aggregates gradients
  683. across multiple workers.
  684. This hook would be very useful for researchers to try out new ideas. For
  685. example, this hook can be used to implement several algorithms like GossipGrad
  686. and gradient compression which involve different communication strategies for
  687. parameter syncs while running Distributed DataParallel training.
  688. Args:
  689. state (object): Passed to the hook to maintain any state information during the training process.
  690. Examples include error feedback in gradient compression,
  691. peers to communicate with next in GossipGrad, etc.
  692. It is locally stored by each worker
  693. and shared by all the gradient tensors on the worker.
  694. hook (Callable): Callable with the following signature:
  695. ``hook(state: object, bucket: dist.GradBucket) -> torch.futures.Future[torch.Tensor]``:
  696. This function is called once the bucket is ready. The
  697. hook can perform whatever processing is needed and return
  698. a Future indicating completion of any async work (ex: allreduce).
  699. If the hook doesn't perform any communication, it still
  700. must return a completed Future. The Future should hold the
  701. new value of grad bucket's tensors. Once a bucket is ready,
  702. c10d reducer would call this hook and use the tensors returned
  703. by the Future and copy grads to individual parameters.
  704. Note that the future's return type must be a single tensor.
  705. We also provide an API called ``get_future`` to retrieve a
  706. Future associated with the completion of ``c10d.ProcessGroup.Work``.
  707. ``get_future`` is currently supported for NCCL and also supported for most
  708. operations on GLOO and MPI, except for peer to peer operations (send/recv).
  709. .. warning ::
  710. Grad bucket's tensors will not be predivided by world_size. User is responsible
  711. to divide by the world_size in case of operations like allreduce.
  712. .. warning ::
  713. DDP communication hook can only be registered once and should be registered
  714. before calling backward.
  715. .. warning ::
  716. The Future object that hook returns should contain a single tensor
  717. that has the same shape with the tensors inside grad bucket.
  718. .. warning ::
  719. ``get_future`` API supports NCCL, and partially GLOO and MPI backends (no support
  720. for peer-to-peer operations like send/recv) and will return a ``torch.futures.Future``.
  721. Example::
  722. Below is an example of a noop hook that returns the same tensor.
  723. >>> # xdoctest: +REQUIRES(module:torch._C._distributed_c10d)
  724. >>> def noop(state: object, bucket: dist.GradBucket) -> torch.futures.Future[torch.Tensor]:
  725. >>> fut = torch.futures.Future()
  726. >>> fut.set_result(bucket.buffer())
  727. >>> return fut
  728. >>> # xdoctest: +SKIP('undefined name')
  729. >>> ddp.register_comm_hook(state=None, hook=noop)
  730. Example::
  731. Below is an example of a Parallel SGD algorithm where gradients are encoded before
  732. allreduce, and then decoded after allreduce.
  733. >>> # xdoctest: +REQUIRES(module:torch._C._distributed_c10d)
  734. >>> def encode_and_decode(state: object, bucket: dist.GradBucket) -> torch.futures.Future[torch.Tensor]:
  735. >>> encoded_tensor = encode(bucket.buffer()) # encode gradients
  736. >>> fut = torch.distributed.all_reduce(encoded_tensor).get_future()
  737. >>> # Define the then callback to decode.
  738. >>> def decode(fut):
  739. >>> decoded_tensor = decode(fut.value()[0]) # decode gradients
  740. >>> return decoded_tensor
  741. >>> return fut.then(decode)
  742. >>> # xdoctest: +SKIP('undefined name')
  743. >>> ddp.register_comm_hook(state=None, hook=encode_and_decode)
  744. """
  745. self._check_comm_hook(hook)
  746. assert self.logger is not None
  747. self.logger._set_comm_hook_name(hook.__qualname__)
  748. dist._register_comm_hook(self.reducer, state, hook)
  749. def _register_builtin_comm_hook(self, comm_hook_type):
  750. r"""
  751. Registers a built-in communication hook that specifies how DDP
  752. aggregates gradients across multiple workers.
  753. The built-in hooks aim to provide efficient C++ implementations for certain hooks,
  754. which might not be as efficient if implemented in Python using a Python communication hook.
  755. Args:
  756. comm_hook_type (dist.BuiltinCommHookType): type of communication hook, such as ALLREDUCE, FP16_COMPRESS, etc.
  757. .. warning ::
  758. DDP communication hook can only be registered once and should be registered
  759. before calling backward.
  760. Example::
  761. Below is an example of a FP16 compression where gradients are
  762. compressed into 16-bit floating-point numbers before allreduce, and
  763. then decompressed after allreduce.
  764. >>> # xdoctest: +SKIP('undefined name')
  765. >>> ddp._register_builtin_comm_hook(dist.BuiltinCommHookType.FP16_COMPRESS)
  766. """
  767. assert self.logger is not None
  768. self.logger._set_comm_hook_name(str(comm_hook_type))
  769. dist._register_builtin_comm_hook(self.reducer, comm_hook_type)
  770. def _register_fused_optim(self, optim: Type, *args, optim_params=None, **kwargs):
  771. r"""
  772. Registers an optimizer with DDP such that the optimization for a
  773. parameter will run immediately when that parameter's gradient is
  774. finished with reduction, instead of waiting for all parameters'
  775. gradients to finish reduction. This can result in a training speedup
  776. depending on your workload since the optimizer can run while gradient
  777. reduction for other parameters are still ongoing. In addition, this has
  778. the potential to reduce peak memory consumption during training, as it
  779. only needs to load the per-parameter optimizer states of a single
  780. parameter at a time, instead of loading all per-parameter optimizer
  781. states at once.
  782. Args:
  783. optim_cls (Type): a ``torch.optim.Optimizer`` class to be registered
  784. as a fused optimizer.
  785. *args (Sequence[Any]): Arguments to forward to `optim_cls`.
  786. optim_params (Optional[Iterable[torch.Tensor]]): Set of parameters
  787. to optimize, similar to `params` argument of traditional `torch.optim`
  788. Optimizers. If this is omitted, all DDP model parameters will be
  789. optimized.
  790. **kwargs: (Dict[str, Any]): Keyword arguments to forward to `optim_cls`.
  791. .. warning ::
  792. _register_fused_optim should only be called once on a DDP instance,
  793. and registering multiple fused optimizers for the same DDP model
  794. is not currently supported. Please ping
  795. https://github.com/pytorch/pytorch/issues/71595 if this is necessary
  796. for your use case.
  797. .. warning ::
  798. _register_fused_optim and register_comm_hook currently do not
  799. compose together, meaning that custom DDP communication hooks are
  800. not supported with overlapped optimizers. Please ping
  801. https://github.com/pytorch/pytorch/issues/71595 if this is necessary
  802. for your use case.
  803. .. warning ::
  804. Gradient accumulation and DDP `no_sync` are currently not supported
  805. with overlapped optimizer. Please ping
  806. https://github.com/pytorch/pytorch/issues/71595 if this is necessary
  807. for your use case.
  808. Example::
  809. >>> # xdoctest: +SKIP("No rendezvous handler")
  810. >>> torch.distributed.init_process_group(backend='nccl', world_size=4, init_method='...')
  811. >>> net = torch.nn.parallel.DistributedDataParallel(model, pg)
  812. >>> lr = 1e-2
  813. >>> betas = (0.9, 0.99)
  814. >>> eps = 1e-6
  815. >>> net._register_fused_optim(torch.optim.Adam, lr, betas=betas, eps=eps)
  816. >>> # Example with subset of parameters
  817. >>> params_to_opt = [list(net.parameters())[0]]
  818. >>> net._register_fused_optim(
  819. ... torch.optim.Adam, lr, optim_params=params_to_opt, betas=betas, eps=eps
  820. ... )
  821. """
  822. # Note: importing in function, otherwise this will cause a circular
  823. # import as optimizer_overlap module needs to import DistributedDataParallel.
  824. from torch.distributed.algorithms._optimizer_overlap import _as_overlapped_optim
  825. overlapped_optim = _as_overlapped_optim(optim, optim_params, *args, **kwargs)
  826. try:
  827. overlapped_optim.register_ddp(self)
  828. except NotImplementedError as e:
  829. raise RuntimeError(
  830. f"{optim} does not support overlapped DDP. Please file an issue to PyTorch or the respective owner of {optim}."
  831. ) from e
  832. def _distributed_broadcast_coalesced(
  833. self, tensors, buffer_size, authoritative_rank=0
  834. ):
  835. dist._broadcast_coalesced(
  836. self.process_group, tensors, buffer_size, authoritative_rank
  837. )
  838. def _check_sync_bufs_post_fwd(self):
  839. return (
  840. self.will_sync_module_buffers()
  841. and hasattr(self, "buffer_hook")
  842. and self.buffer_hook.buffer_comm_hook_location
  843. == _BufferCommHookLocation.POST_FORWARD
  844. )
  845. def _check_sync_bufs_pre_fwd(self):
  846. return self.will_sync_module_buffers() and (
  847. not hasattr(self, "buffer_hook")
  848. or self.buffer_hook.buffer_comm_hook_location
  849. == _BufferCommHookLocation.PRE_FORWARD
  850. )
  851. def will_sync_module_buffers(self):
  852. return (
  853. self.require_forward_param_sync
  854. and self.broadcast_buffers
  855. and len(self.modules_buffers) > 0
  856. )
  857. def _find_common_rank(self, input_rank, rank_cond):
  858. # -1 indicates that this rank is not under consideration to be the
  859. # common_rank
  860. rank_to_use = torch.tensor(
  861. [input_rank if rank_cond else -1],
  862. device=self.device,
  863. )
  864. dist.all_reduce(rank_to_use, op=ReduceOp.MAX, group=self.process_group)
  865. if rank_to_use.item() == -1:
  866. self._log_and_throw(
  867. ValueError,
  868. "BUG! Expected rank_cond to be true for at least one process."
  869. " This indicates a bug in PyTorch, please report an issue.",
  870. )
  871. return rank_to_use.item()
  872. def _sync_buffers(self):
  873. with torch.no_grad():
  874. # module buffer sync
  875. # Synchronize buffers across processes.
  876. # The process with rank 0 is considered the authoritative copy.
  877. authoritative_rank = 0
  878. # Update self.modules_buffers incase any buffers were
  879. # reassigned.
  880. self._assign_modules_buffers()
  881. self._sync_module_buffers(authoritative_rank)
  882. def _sync_module_buffers(self, authoritative_rank):
  883. if not hasattr(self, "buffer_hook"):
  884. self._default_broadcast_coalesced(authoritative_rank=authoritative_rank)
  885. else:
  886. hook = self.buffer_hook.buffer_comm_hook
  887. state = self.buffer_hook.buffer_comm_hook_state
  888. futs = hook(state, self.named_module_buffers)
  889. if futs is not None:
  890. self.reducer._install_post_backward_futures(futs)
  891. def _default_broadcast_coalesced(
  892. self, bufs=None, bucket_size=None, authoritative_rank=0
  893. ):
  894. """
  895. Broadcasts buffers from rank 0 to rest of workers. If bufs, bucket_size
  896. are None, default values self.modules_buffers and
  897. self.broadcast_bucket_size are used instead.
  898. """
  899. if bufs is None:
  900. bufs = self.modules_buffers
  901. if bucket_size is None:
  902. bucket_size = self.broadcast_bucket_size
  903. self._distributed_broadcast_coalesced(bufs, bucket_size, authoritative_rank)
  904. def _passing_sync_batchnorm_handle(self, module):
  905. for layer in module.modules():
  906. if isinstance(layer, torch.nn.modules.SyncBatchNorm):
  907. if self.device_type == "cpu":
  908. self._log_and_throw(
  909. ValueError,
  910. "SyncBatchNorm layers only work with GPU modules",
  911. )
  912. def _check_comm_hook(self, hook):
  913. if not callable(hook):
  914. self._log_and_throw(TypeError, "Communication hook must be callable.")
  915. sig = inspect.signature(hook)
  916. if (
  917. sig.parameters["bucket"].annotation != inspect._empty
  918. and sig.parameters["bucket"].annotation != dist.GradBucket
  919. ):
  920. self._log_and_throw(
  921. ValueError,
  922. "Communication hook: bucket annotation should be dist.GradBucket.",
  923. )
  924. if (
  925. sig.return_annotation != inspect._empty
  926. and sig.return_annotation != torch.futures.Future[torch.Tensor]
  927. ):
  928. self._log_and_throw(
  929. ValueError,
  930. "Communication hook: return annotation should be torch.futures.Future[torch.Tensor].",
  931. )
  932. if hook.__name__ in ["bf16_compress_hook", "bf16_compress_wrapper_hook"] and (
  933. (torch.version.cuda is None and torch.version.hip is None)
  934. or (
  935. torch.version.cuda is not None
  936. and int(torch.version.cuda.split(".")[0]) < 11
  937. )
  938. or not dist.is_available()
  939. or not dist.is_nccl_available()
  940. or torch.cuda.nccl.version() < (2, 10)
  941. ):
  942. self._log_and_throw(
  943. TypeError,
  944. "BF16 all reduce communication hook required CUDA 11+ and NCCL 2.10+.",
  945. )
  946. @property
  947. def _distributed_rank(self):
  948. return dist.get_rank(self.process_group)
  949. @staticmethod
  950. def _set_params_and_buffers_to_ignore_for_model(
  951. module, params_and_buffers_to_ignore
  952. ):
  953. """
  954. Sets parameters and buffers to be ignored by DDP. Expected format for
  955. parameters is the fully qualified name: {module_name}.{param_name}, and
  956. similarly, {module_name}.{buffer_name} for buffers. For example:
  957. params_to_ignore = []
  958. # NB: model here is vanilla PyTorch module, not yet wrapped with DDP.
  959. for module_name, module in model.named_modules():
  960. for param_name, param in module.named_parameters(recurse=False):
  961. if should_ignore(param):
  962. # Create expected format
  963. fqn = f"{module_name}.{param_name}"
  964. params_to_ignore.append(fqn)
  965. torch.nn.parallel.DistributedDataParallel._set_params_and_buffers_to_ignore_for_model(
  966. model,
  967. params_to_ignore
  968. )
  969. """
  970. # This is a workaround to set parameters and buffers DDP should ignore
  971. # during synchronization. It will be removed when the API is finalized
  972. # as part of addressing https://github.com/pytorch/pytorch/issues/43690.
  973. module._ddp_params_and_buffers_to_ignore = params_and_buffers_to_ignore
  974. def _get_ddp_logging_data(self):
  975. r"""
  976. This interface can be called after DistributedDataParallel() is
  977. constructed. It returns a dictionary of logging data. It could help
  978. for debugging and analysis. The logging data includes DistributedDataParallel
  979. constructor input parameters, some internal states of DistributedDataParallel
  980. and performance metrics. Simply print the dictionary and see what
  981. these metrics are.
  982. This is a prototype interface and subject to change in the future.
  983. """
  984. assert self.logger is not None
  985. ddp_logging_data = self.logger._get_ddp_logging_data()
  986. return {**ddp_logging_data.strs_map, **ddp_logging_data.ints_map}
  987. def _set_ddp_runtime_logging_sample_rate(self, sample_rate):
  988. r"""
  989. This interface allows users to set sample_rate of collecting
  990. runtime stats. The runtime stats will be recorded for the
  991. first 10 iterations, after 10 iterations runtime stats will be
  992. recorded once every "sample_rate" training iterations. In
  993. default, runtime stats are recorded for the first 10 iterations,
  994. after 10 iterations runtime stats are recorded once every
  995. "kDDPRuntimeLoggingSampleRate=100" training iterations.
  996. This is a prototype interface and subject to change in the future.
  997. """
  998. if sample_rate < 1:
  999. self._log_and_throw(
  1000. ValueError,
  1001. "DDP runtime logging sample rate should be equal or greater than 1",
  1002. )
  1003. self.reducer._set_ddp_runtime_logging_sample_rate(sample_rate)
  1004. def _set_static_graph(self):
  1005. """
  1006. It is recommended to set static graph in the DDP constructor, which will
  1007. call this private API internally.
  1008. """
  1009. # If self.static_graph has been set, no need to set it again
  1010. if self.static_graph:
  1011. warnings.warn(
  1012. "You've set static_graph to be True, no need to set it again."
  1013. )
  1014. return
  1015. self.static_graph = True
  1016. self.reducer._set_static_graph()
  1017. assert self.logger is not None
  1018. self.logger._set_static_graph()
  1019. if self.find_unused_parameters:
  1020. warnings.warn(
  1021. "You passed find_unused_parameters=true to DistributedDataParallel, "
  1022. "`_set_static_graph` will detect unused parameters automatically, so "
  1023. "you do not need to set find_unused_parameters=true, just be sure these "
  1024. "unused parameters will not change during training loop while calling "
  1025. "`_set_static_graph`."
  1026. )