distributed.py 84 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921
  1. import copy
  2. import inspect
  3. import itertools
  4. import logging
  5. import os
  6. import sys
  7. import warnings
  8. import weakref
  9. from contextlib import contextmanager
  10. from dataclasses import dataclass, fields, is_dataclass
  11. from enum import Enum, auto
  12. from typing import Callable, Any, Type
  13. import torch
  14. import torch.distributed as dist
  15. from torch.autograd import Function, Variable
  16. from torch.distributed.algorithms.join import (
  17. Join,
  18. Joinable,
  19. JoinHook,
  20. )
  21. from torch.utils._pytree import tree_flatten, tree_unflatten
  22. RPC_AVAILABLE = False
  23. if dist.is_available():
  24. from torch.distributed.utils import (
  25. _verify_param_shape_across_processes,
  26. _sync_module_states,
  27. _to_kwargs,
  28. )
  29. from torch.distributed.distributed_c10d import ReduceOp, _get_default_group
  30. if torch.distributed.rpc.is_available():
  31. RPC_AVAILABLE = True
  32. from torch.distributed.rpc import RRef
  33. from torch._utils import _get_device_index
  34. from ..modules import Module
  35. from ._replicated_tensor_ddp_utils import _ddp_with_replicated_tensor_enabled
  36. from .scatter_gather import gather, scatter_kwargs # noqa: F401
  37. __all__ = ["DistributedDataParallel"]
  38. logger = logging.getLogger(__name__)
  39. def _tree_flatten_with_rref(output):
  40. output_is_rref = RPC_AVAILABLE and isinstance(output, RRef)
  41. if output_is_rref:
  42. output_tensor_list, treespec = tree_flatten(output.local_value())
  43. else:
  44. output_tensor_list, treespec = tree_flatten(output)
  45. # Need to return flattened tensors, spec to re-pack them, as well
  46. # as if the return type was actually an RRef to reconstruct.
  47. return output_tensor_list, treespec, output_is_rref
  48. def _tree_unflatten_with_rref(output, treespec, output_is_rref):
  49. output = tree_unflatten(output, treespec)
  50. if output_is_rref:
  51. output = RRef(output)
  52. return output
  53. def _find_tensors(obj):
  54. r"""
  55. Recursively find all tensors contained in the specified object.
  56. """
  57. if RPC_AVAILABLE and isinstance(obj, RRef):
  58. # If the current node is the owner of the RRef, unwrap it and try to
  59. # find Tensors.
  60. # TODO: Expand to remote RRefs.
  61. if obj.is_owner():
  62. return _find_tensors(obj.local_value())
  63. if isinstance(obj, torch.Tensor):
  64. return [obj]
  65. if isinstance(obj, (list, tuple)):
  66. return itertools.chain(*map(_find_tensors, obj))
  67. if isinstance(obj, dict):
  68. return itertools.chain(*map(_find_tensors, obj.values()))
  69. if is_dataclass(obj):
  70. return itertools.chain(
  71. *map(_find_tensors, (getattr(obj, f.name) for f in fields(obj)))
  72. )
  73. return []
  74. def _dump_DDP_relevant_env_vars():
  75. relevant_env_vars = [
  76. "RANK",
  77. "LOCAL_RANK",
  78. "WORLD_SIZE",
  79. "MASTER_PORT",
  80. "MASTER_ADDR",
  81. "CUDA_VISIBLE_DEVICES",
  82. "GLOO_SOCKET_IFNAME",
  83. "GLOO_DEVICE_TRANSPORT",
  84. "NCCL_SOCKET_IFNAME",
  85. "NCCL_BLOCKING_WAIT",
  86. "NCCL_DEBUG",
  87. "NCCL_DEBUG_SUBSYS",
  88. "NCCL_IB_DISABLE",
  89. # More NCCL env vars:
  90. "NCCL_P2P_DISABLE",
  91. "NCCL_P2P_LEVEL",
  92. "NCCL_SHM_DISABLE",
  93. "NCCL_SOCKET_NTHREADS",
  94. "NCCL_NSOCKS_PERTHREAD",
  95. "NCCL_BUFFSIZE",
  96. "NCCL_NTHREADS",
  97. "NCCL_RINGS",
  98. "NCCL_MAX_NCHANNELS",
  99. "NCCL_MIN_NCHANNELS",
  100. "NCCL_CHECKS_DISABLE",
  101. "NCCL_CHECK_POINTERS",
  102. "NCCL_LAUNCH_MODE",
  103. "NCCL_IB_HCA",
  104. "NCCL_IB_TIMEOUT",
  105. "NCCL_IB_RETRY_CNT",
  106. "NCCL_IB_GID_INDEX",
  107. "NCCL_IB_SL",
  108. "NCCL_IB_TC",
  109. "NCCL_IB_AR_THRESHOLD",
  110. "NCCL_IB_CUDA_SUPPORT",
  111. "NCCL_NET_GDR_LEVEL",
  112. "NCCL_NET_GDR_READ",
  113. "NCCL_SINGLE_RING_THRESHOLD",
  114. "NCCL_LL_THRESHOLD",
  115. "NCCL_TREE_THRESHOLD",
  116. "NCCL_ALGO",
  117. "NCCL_PROTO",
  118. "NCCL_IGNORE_CPU_AFFINITY",
  119. "NCCL_DEBUG_FILE",
  120. "NCCL_COLLNET_ENABLE",
  121. "NCCL_TOPO_FILE",
  122. "NCCL_TOPO_DUMP_FILE",
  123. "NCCL_ASYNC_ERROR_HANDLING",
  124. ]
  125. formatted_output = ""
  126. for var in relevant_env_vars:
  127. value = os.environ[var] if var in os.environ else "N/A"
  128. formatted_output += "env:%s=%s\n" % (var, value)
  129. print(formatted_output)
  130. class _BufferCommHookLocation(Enum):
  131. PRE_FORWARD = auto()
  132. POST_FORWARD = auto()
  133. @dataclass
  134. class _BufferCommHook:
  135. buffer_comm_hook: Callable
  136. buffer_comm_hook_state: Any
  137. buffer_comm_hook_location: _BufferCommHookLocation
  138. # Add a DDPSink to run various functions when backwards starts, such as
  139. # queueing call back of out-most backward/graph task,
  140. # this helps call back is fired after all gradients' calculation
  141. # is completed.
  142. class _DDPSink(Function):
  143. @staticmethod
  144. def forward(ctx, reducer, state_dict, *inputs):
  145. # set_materialize_grads(False) will ensure that None gradients stay as
  146. # None and are not filled with zeros.
  147. ctx.set_materialize_grads(False)
  148. ctx.reducer = reducer
  149. ctx.state_dict = state_dict
  150. ret = tuple(
  151. inp.clone() if isinstance(inp, torch.Tensor) else inp
  152. for inp in inputs
  153. )
  154. return ret
  155. @staticmethod
  156. def backward(ctx, *grad_outputs):
  157. state_dict = ctx.state_dict
  158. # Enqueue delay allreduce for static graph training on the first
  159. # iteration.
  160. if (
  161. ctx.state_dict["static_graph"]
  162. and ctx.state_dict["num_iterations"] == 1
  163. ):
  164. Variable._execution_engine.queue_callback( # type: ignore[call-arg,misc]
  165. ctx.reducer._delay_all_reduce
  166. )
  167. return (None, None, *grad_outputs)
  168. class _DDPJoinHook(JoinHook):
  169. def __init__(self, ddp, divide_by_initial_world_size):
  170. """
  171. Sets config variables for internal usage.
  172. """
  173. assert isinstance(ddp, DistributedDataParallel), (
  174. "DDP join hook requires passing in a DistributedDataParallel "
  175. "instance as the state"
  176. )
  177. assert ddp.logger is not None
  178. ddp.logger._set_uneven_input_join()
  179. self.ddp = ddp
  180. self.ddp._divide_by_initial_world_size = divide_by_initial_world_size
  181. super().__init__()
  182. def main_hook(self):
  183. """
  184. Shadows the DDP collective communication operations in the forward and
  185. backward passes.
  186. """
  187. ddp = self.ddp
  188. # Buckets are rebuilt only once during a training period
  189. ddp.reducer._rebuild_buckets()
  190. # Schedule a broadcast if we are syncing module buffers in the
  191. # forward pass
  192. # TODO: make DDP uneven inputs context manager support buffer
  193. # comm hook (https://github.com/pytorch/pytorch/issues/65436)
  194. ddp._check_and_sync_module_buffers()
  195. # Check if need to sync in the backward pass
  196. work = ddp._check_global_requires_backward_grad_sync(
  197. is_joined_rank=True
  198. )
  199. work.wait()
  200. should_sync_backwards = work.result()[0].item() != 0
  201. # Forward parameter sync is disabled in the next iteration if we
  202. # are skipping gradient sync this iteration, so set
  203. # `require_forward_param_sync` accordingly
  204. ddp.require_forward_param_sync = should_sync_backwards
  205. if not should_sync_backwards:
  206. return
  207. # Schedule one allreduce per gradient bucket to match the backward
  208. # pass allreduce
  209. ddp._match_all_reduce_for_bwd_pass()
  210. # Check if we need to allreduce locally unused parameters
  211. if ddp.find_unused_parameters:
  212. ddp._match_unused_params_allreduce()
  213. # Rebuilt parameters are pushed only once during a training period
  214. ddp.reducer._push_all_rebuilt_params()
  215. def post_hook(self, is_last_joiner: bool):
  216. """
  217. Syncs the final model to ensure that the model is the same across all
  218. processes.
  219. """
  220. self.ddp._sync_final_model(is_last_joiner)
  221. class DistributedDataParallel(Module, Joinable):
  222. r"""Implements distributed data parallelism that is based on
  223. ``torch.distributed`` package at the module level.
  224. This container provides data parallelism by synchronizing gradients
  225. across each model replica. The devices to synchronize across are
  226. specified by the input ``process_group``, which is the entire world
  227. by default. Note that ``DistributedDataParallel`` does not chunk or
  228. otherwise shard the input across participating GPUs; the user is
  229. responsible for defining how to do so, for example through the use
  230. of a :class:`DistributedSampler`.
  231. See also: :ref:`distributed-basics` and :ref:`cuda-nn-ddp-instead`.
  232. The same constraints on input as in :class:`torch.nn.DataParallel` apply.
  233. Creation of this class requires that ``torch.distributed`` to be already
  234. initialized, by calling :func:`torch.distributed.init_process_group`.
  235. ``DistributedDataParallel`` is proven to be significantly faster than
  236. :class:`torch.nn.DataParallel` for single-node multi-GPU data
  237. parallel training.
  238. To use ``DistributedDataParallel`` on a host with N GPUs, you should spawn
  239. up ``N`` processes, ensuring that each process exclusively works on a single
  240. GPU from 0 to N-1. This can be done by either setting
  241. ``CUDA_VISIBLE_DEVICES`` for every process or by calling:
  242. >>> # xdoctest: +SKIP("undefined variables")
  243. >>> torch.cuda.set_device(i)
  244. where i is from 0 to N-1. In each process, you should refer the following
  245. to construct this module:
  246. >>> # xdoctest: +SKIP("undefined variables")
  247. >>> torch.distributed.init_process_group(
  248. >>> backend='nccl', world_size=N, init_method='...'
  249. >>> )
  250. >>> model = DistributedDataParallel(model, device_ids=[i], output_device=i)
  251. In order to spawn up multiple processes per node, you can use either
  252. ``torch.distributed.launch`` or ``torch.multiprocessing.spawn``.
  253. .. note::
  254. Please refer to `PyTorch Distributed Overview <https://pytorch.org/tutorials/beginner/dist_overview.html>`__
  255. for a brief introduction to all features related to distributed training.
  256. .. note::
  257. ``DistributedDataParallel`` can be used in conjunction with
  258. :class:`torch.distributed.optim.ZeroRedundancyOptimizer` to reduce
  259. per-rank optimizer states memory footprint. Please refer to
  260. `ZeroRedundancyOptimizer recipe <https://pytorch.org/tutorials/recipes/zero_redundancy_optimizer.html>`__
  261. for more details.
  262. .. note:: ``nccl`` backend is currently the fastest and highly recommended
  263. backend when using GPUs. This applies to both single-node and
  264. multi-node distributed training.
  265. .. note:: This module also supports mixed-precision distributed training.
  266. This means that your model can have different types of parameters such
  267. as mixed types of ``fp16`` and ``fp32``, the gradient reduction on these
  268. mixed types of parameters will just work fine.
  269. .. note:: If you use ``torch.save`` on one process to checkpoint the module,
  270. and ``torch.load`` on some other processes to recover it, make sure that
  271. ``map_location`` is configured properly for every process. Without
  272. ``map_location``, ``torch.load`` would recover the module to devices
  273. where the module was saved from.
  274. .. note:: When a model is trained on ``M`` nodes with ``batch=N``, the
  275. gradient will be ``M`` times smaller when compared to the same model
  276. trained on a single node with ``batch=M*N`` if the loss is summed (NOT
  277. averaged as usual) across instances in a batch (because the gradients
  278. between different nodes are averaged). You should take this into
  279. consideration when you want to obtain a mathematically equivalent
  280. training process compared to the local training counterpart. But in most
  281. cases, you can just treat a DistributedDataParallel wrapped model, a
  282. DataParallel wrapped model and an ordinary model on a single GPU as the
  283. same (E.g. using the same learning rate for equivalent batch size).
  284. .. note::
  285. Parameters are never broadcast between processes. The module performs
  286. an all-reduce step on gradients and assumes that they will be modified
  287. by the optimizer in all processes in the same way. Buffers
  288. (e.g. BatchNorm stats) are broadcast from the module in process of rank
  289. 0, to all other replicas in the system in every iteration.
  290. .. note::
  291. If you are using DistributedDataParallel in conjunction with the
  292. :ref:`distributed-rpc-framework`, you should always use
  293. :meth:`torch.distributed.autograd.backward` to compute gradients and
  294. :class:`torch.distributed.optim.DistributedOptimizer` for optimizing
  295. parameters.
  296. Example::
  297. >>> # xdoctest: +SKIP("undefined variables")
  298. >>> import torch.distributed.autograd as dist_autograd
  299. >>> from torch.nn.parallel import DistributedDataParallel as DDP
  300. >>> import torch
  301. >>> from torch import optim
  302. >>> from torch.distributed.optim import DistributedOptimizer
  303. >>> import torch.distributed.rpc as rpc
  304. >>> from torch.distributed.rpc import RRef
  305. >>>
  306. >>> t1 = torch.rand((3, 3), requires_grad=True)
  307. >>> t2 = torch.rand((3, 3), requires_grad=True)
  308. >>> rref = rpc.remote("worker1", torch.add, args=(t1, t2))
  309. >>> ddp_model = DDP(my_model)
  310. >>>
  311. >>> # Setup optimizer
  312. >>> optimizer_params = [rref]
  313. >>> for param in ddp_model.parameters():
  314. >>> optimizer_params.append(RRef(param))
  315. >>>
  316. >>> dist_optim = DistributedOptimizer(
  317. >>> optim.SGD,
  318. >>> optimizer_params,
  319. >>> lr=0.05,
  320. >>> )
  321. >>>
  322. >>> with dist_autograd.context() as context_id:
  323. >>> pred = ddp_model(rref.to_here())
  324. >>> loss = loss_func(pred, target)
  325. >>> dist_autograd.backward(context_id, [loss])
  326. >>> dist_optim.step(context_id)
  327. .. note::
  328. DistributedDataParallel currently offers limited support for gradient
  329. checkpointing with :meth:`torch.utils.checkpoint`. DDP will work as
  330. expected when there are no unused parameters in the model and each layer
  331. is checkpointed at most once (make sure you are not passing
  332. `find_unused_parameters=True` to DDP). We currently do not support the
  333. case where a layer is checkpointed multiple times, or when there unused
  334. parameters in the checkpointed model.
  335. .. note::
  336. To let a non-DDP model load a state dict from a DDP model,
  337. :meth:`~torch.nn.modules.utils.consume_prefix_in_state_dict_if_present`
  338. needs to be applied to strip the prefix "module." in the DDP state dict before loading.
  339. .. warning::
  340. Constructor, forward method, and differentiation of the output (or a
  341. function of the output of this module) are distributed synchronization
  342. points. Take that into account in case different processes might be
  343. executing different code.
  344. .. warning::
  345. This module assumes all parameters are registered in the model by the
  346. time it is created. No parameters should be added nor removed later.
  347. Same applies to buffers.
  348. .. warning::
  349. This module assumes all parameters are registered in the model of each
  350. distributed processes are in the same order. The module itself will
  351. conduct gradient ``allreduce`` following the reverse order of the
  352. registered parameters of the model. In other words, it is users'
  353. responsibility to ensure that each distributed process has the exact
  354. same model and thus the exact same parameter registration order.
  355. .. warning::
  356. This module allows parameters with non-rowmajor-contiguous strides.
  357. For example, your model may contain some parameters whose
  358. :class:`torch.memory_format` is ``torch.contiguous_format``
  359. and others whose format is ``torch.channels_last``. However,
  360. corresponding parameters in different processes must have the
  361. same strides.
  362. .. warning::
  363. This module doesn't work with :func:`torch.autograd.grad` (i.e. it will
  364. only work if gradients are to be accumulated in ``.grad`` attributes of
  365. parameters).
  366. .. warning::
  367. If you plan on using this module with a ``nccl`` backend or a ``gloo``
  368. backend (that uses Infiniband), together with a DataLoader that uses
  369. multiple workers, please change the multiprocessing start method to
  370. ``forkserver`` (Python 3 only) or ``spawn``. Unfortunately
  371. Gloo (that uses Infiniband) and NCCL2 are not fork safe, and you will
  372. likely experience deadlocks if you don't change this setting.
  373. .. warning::
  374. You should never try to change your model's parameters after wrapping
  375. up your model with ``DistributedDataParallel``. Because, when
  376. wrapping up your model with ``DistributedDataParallel``, the constructor
  377. of ``DistributedDataParallel`` will register the additional gradient
  378. reduction functions on all the parameters of the model itself at the
  379. time of construction. If you change the model's parameters afterwards,
  380. gradient reduction functions no longer match the correct set of
  381. parameters.
  382. .. warning::
  383. Using ``DistributedDataParallel`` in conjunction with the
  384. :ref:`distributed-rpc-framework` is experimental and subject to change.
  385. Args:
  386. module (Module): module to be parallelized
  387. device_ids (list of int or torch.device): CUDA devices.
  388. 1) For single-device modules, ``device_ids`` can
  389. contain exactly one device id, which represents the only
  390. CUDA device where the input module corresponding to this process resides.
  391. Alternatively, ``device_ids`` can also be ``None``.
  392. 2) For multi-device modules and CPU modules,
  393. ``device_ids`` must be ``None``.
  394. When ``device_ids`` is ``None`` for both cases,
  395. both the input data for the forward pass and the actual module
  396. must be placed on the correct device.
  397. (default: ``None``)
  398. output_device (int or torch.device): Device location of output for
  399. single-device CUDA modules. For multi-device modules and
  400. CPU modules, it must be ``None``, and the module itself
  401. dictates the output location. (default: ``device_ids[0]``
  402. for single-device modules)
  403. broadcast_buffers (bool): Flag that enables syncing (broadcasting)
  404. buffers of the module at beginning of the ``forward``
  405. function. (default: ``True``)
  406. process_group: The process group to be used for distributed data
  407. all-reduction. If ``None``, the default process group, which
  408. is created by :func:`torch.distributed.init_process_group`,
  409. will be used. (default: ``None``)
  410. bucket_cap_mb: ``DistributedDataParallel`` will bucket parameters into
  411. multiple buckets so that gradient reduction of each
  412. bucket can potentially overlap with backward computation.
  413. :attr:`bucket_cap_mb` controls the bucket size in
  414. MegaBytes (MB). (default: 25)
  415. find_unused_parameters (bool): Traverse the autograd graph from all
  416. tensors contained in the return value of the
  417. wrapped module's ``forward`` function. Parameters
  418. that don't receive gradients as part of this
  419. graph are preemptively marked as being ready to
  420. be reduced. In addition, parameters that may have
  421. been used in the wrapped module's ``forward``
  422. function but were not part of loss computation and
  423. thus would also not receive gradients are
  424. preemptively marked as ready to be reduced.
  425. (default: ``False``)
  426. check_reduction: This argument is deprecated.
  427. gradient_as_bucket_view (bool): When set to ``True``, gradients will be views
  428. pointing to different offsets of ``allreduce`` communication
  429. buckets. This can reduce peak memory usage, where the
  430. saved memory size will be equal to the total gradients
  431. size. Moreover, it avoids the overhead of copying between
  432. gradients and ``allreduce`` communication buckets. When
  433. gradients are views, ``detach_()`` cannot be called on the
  434. gradients. If hitting such errors, please fix it by
  435. referring to the :meth:`~torch.optim.Optimizer.zero_grad`
  436. function in ``torch/optim/optimizer.py`` as a solution.
  437. Note that gradients will be views after first iteration, so
  438. the peak memory saving should be checked after first iteration.
  439. static_graph (bool): When set to ``True``, DDP knows the trained graph is
  440. static. Static graph means 1) The set of used and unused
  441. parameters will not change during the whole training loop; in
  442. this case, it does not matter whether users set
  443. ``find_unused_parameters = True`` or not. 2) How the graph is trained
  444. will not change during the whole training loop (meaning there is
  445. no control flow depending on iterations).
  446. When static_graph is set to be ``True``, DDP will support cases that
  447. can not be supported in the past:
  448. 1) Reentrant backwards.
  449. 2) Activation checkpointing multiple times.
  450. 3) Activation checkpointing when model has unused parameters.
  451. 4) There are model parameters that are outside of forward function.
  452. 5) Potentially improve performance when there are unused parameters,
  453. as DDP will not search graph in each iteration to detect unused
  454. parameters when static_graph is set to be ``True``.
  455. To check whether you can set static_graph to be ``True``, one way is to
  456. check ddp logging data at the end of your previous model training,
  457. if ``ddp_logging_data.get("can_set_static_graph") == True``, mostly you
  458. can set ``static_graph = True`` as well.
  459. Example::
  460. >>> # xdoctest: +SKIP("undefined variables")
  461. >>> model_DDP = torch.nn.parallel.DistributedDataParallel(model)
  462. >>> # Training loop
  463. >>> ...
  464. >>> ddp_logging_data = model_DDP._get_ddp_logging_data()
  465. >>> static_graph = ddp_logging_data.get("can_set_static_graph")
  466. Attributes:
  467. module (Module): the module to be parallelized.
  468. Example::
  469. >>> # xdoctest: +SKIP("undefined variables")
  470. >>> torch.distributed.init_process_group(backend='nccl', world_size=4, init_method='...')
  471. >>> net = torch.nn.parallel.DistributedDataParallel(model)
  472. """
  473. # used to track whether the given thread is inside ddp forward for torchdynamo purposes
  474. _active_ddp_module = None
  475. def __init__(
  476. self,
  477. module,
  478. device_ids=None,
  479. output_device=None,
  480. dim=0,
  481. broadcast_buffers=True,
  482. process_group=None,
  483. bucket_cap_mb=25,
  484. find_unused_parameters=False,
  485. check_reduction=False,
  486. gradient_as_bucket_view=False,
  487. static_graph=False,
  488. ):
  489. super().__init__()
  490. Joinable.__init__(self)
  491. self.logger = None
  492. if hasattr(module, "_ddp_params_and_buffers_to_ignore"):
  493. self.parameters_to_ignore = set(module._ddp_params_and_buffers_to_ignore)
  494. else:
  495. self.parameters_to_ignore = set()
  496. self._module_parameters = [p for n, p in module.named_parameters() if n not in self.parameters_to_ignore]
  497. if not any((p.requires_grad for p in self._module_parameters)):
  498. self._log_and_throw(
  499. RuntimeError,
  500. "DistributedDataParallel is not needed when a module "
  501. "doesn't have any parameter that requires a gradient.",
  502. )
  503. if device_ids is not None and len(device_ids) > 1:
  504. self._log_and_throw(
  505. ValueError,
  506. "device_ids can only be None or contain a single element.",
  507. )
  508. self.is_multi_device_module = len({p.device for p in self._module_parameters}) > 1
  509. distinct_device_types = {p.device.type for p in self._module_parameters if p.device is not None}
  510. if len(distinct_device_types) != 1:
  511. self._log_and_throw(
  512. ValueError,
  513. "DistributedDataParallel's input module must be on "
  514. "the same type of devices, but input module parameters locate in {}.".format(
  515. distinct_device_types
  516. ),
  517. )
  518. self.device_type = list(distinct_device_types)[0]
  519. if (
  520. device_ids is None
  521. or len(device_ids) == 0 # For backward compatibility.
  522. or self.device_type == "cpu"
  523. or self.is_multi_device_module
  524. ):
  525. if device_ids or output_device:
  526. self._log_and_throw(
  527. ValueError,
  528. "DistributedDataParallel device_ids and output_device arguments "
  529. "only work with single-device/multiple-device GPU modules or CPU modules, "
  530. "but got device_ids {}, output_device {}, and module parameters {}.".format(
  531. device_ids,
  532. output_device,
  533. {p.device for p in self._module_parameters},
  534. ),
  535. )
  536. self.device_ids = None
  537. self.output_device = None
  538. else:
  539. self.device_ids = [_get_device_index(x, True) for x in device_ids]
  540. if output_device is None:
  541. output_device = device_ids[0]
  542. self.output_device = _get_device_index(output_device, True)
  543. if process_group is None:
  544. self.process_group = _get_default_group()
  545. else:
  546. self.process_group = process_group
  547. self.static_graph = False
  548. self.dim = dim
  549. self.module = module
  550. self.device = list(self._module_parameters)[0].device
  551. self.broadcast_buffers = broadcast_buffers
  552. self.find_unused_parameters = find_unused_parameters
  553. self.require_backward_grad_sync = True
  554. self.require_forward_param_sync = True
  555. self.gradient_as_bucket_view = gradient_as_bucket_view
  556. self._use_replicated_tensor_module = (
  557. _ddp_with_replicated_tensor_enabled()
  558. )
  559. self._build_replicated_tensor_module()
  560. if check_reduction:
  561. # This argument is no longer used since the reducer
  562. # will ensure reduction completes even if some parameters
  563. # do not receive gradients.
  564. warnings.warn(
  565. "The `check_reduction` argument in `DistributedDataParallel` "
  566. "module is deprecated. Please avoid using it."
  567. )
  568. # Check that a module does not have Uninitialized parameters
  569. for param in self._module_parameters:
  570. if isinstance(param, torch.nn.parameter.UninitializedParameter):
  571. self._log_and_throw(
  572. RuntimeError,
  573. "Modules with uninitialized parameters can't be used with `DistributedDataParallel`. "
  574. "Run a dummy forward pass to correctly initialize the modules",
  575. )
  576. # used for intra-node param sync and inter-node sync as well
  577. self.broadcast_bucket_size = int(250 * 1024 * 1024)
  578. # reduction bucket size
  579. self.bucket_bytes_cap = int(bucket_cap_mb * 1024 * 1024)
  580. # Whether to perform input tensor CPU to GPU copies on a side-stream
  581. self.use_side_stream_for_tensor_copies = (
  582. os.environ.get("PYTORCH_DDP_USE_SIDE_STREAM", "1") == "1"
  583. )
  584. # Build parameters for reducer.
  585. parameters, expect_sparse_gradient = self._build_params_for_reducer()
  586. # Verify model equivalence.
  587. _verify_param_shape_across_processes(self.process_group, parameters)
  588. # Sync params and buffers. Ensures all DDP models start off at the same value.
  589. _sync_module_states(
  590. module=self.module,
  591. process_group=self.process_group,
  592. broadcast_bucket_size=self.broadcast_bucket_size,
  593. src=0,
  594. params_and_buffers_to_ignore=self.parameters_to_ignore,
  595. )
  596. # In debug mode, build a mapping of parameter index -> parameter.
  597. param_to_name_mapping = self._build_debug_param_to_name_mapping(
  598. parameters
  599. )
  600. # Builds reducer.
  601. self._ddp_init_helper(
  602. parameters,
  603. expect_sparse_gradient,
  604. param_to_name_mapping,
  605. static_graph,
  606. )
  607. self._has_rebuilt_buckets = False
  608. if static_graph:
  609. self._set_static_graph()
  610. self._setup_in_backward_optimizers()
  611. def _setup_in_backward_optimizers(self):
  612. # Check if user has used apply_optim_in_backward to overlap optimizer
  613. # step + DDP backward. Current constraints:
  614. # 1. Only allreduce is supported at the moment, no custom communication.
  615. # 2. The reducer by default sets all grads for parameters DDP manages to
  616. # None after they have been applied by the optimizer. There is no support
  617. # for setting only some parameter grads to None, this must be done manually
  618. # by user (and DDP_OVERLAPPED_OPTIM_SET_GRADS_TO_NONE=0 needs to be set.)
  619. # If your use case requires some DDP managed parameters to run with
  620. # an in-backward optimizer and some with a traditional optimizer, please
  621. # ping https://github.com/pytorch/pytorch/issues/90052.
  622. # NOTE: we use self._module_parameters instead of .parameters() since
  623. # the former excludes ignored (non-DDP managed) parameters.
  624. if any(
  625. hasattr(p, '_in_backward_optimizers') for p in self._module_parameters
  626. ):
  627. # Remove hooks that apply_optim_in_backward had registered because
  628. # DDP customizes how optimizer is overlapped with backward due to
  629. # the allreduce.
  630. param_to_handle_map = dist.optim.apply_optimizer_in_backward.param_to_optim_hook_handle_map
  631. for p in self._module_parameters:
  632. for handle in param_to_handle_map.get(p, []):
  633. handle.remove()
  634. # Need a weakref to the reducer in order to run all_reduce.
  635. reducer_weakref = weakref.ref(self.reducer)
  636. # Note: importing in function, otherwise this will cause a circular
  637. # import.
  638. from torch.distributed.algorithms.ddp_comm_hooks.optimizer_overlap_hooks import (
  639. _apply_optim_in_backward_hook
  640. )
  641. self.register_comm_hook(
  642. (reducer_weakref, self.process_group),
  643. _apply_optim_in_backward_hook(
  644. gradient_is_bucket_view=self.gradient_as_bucket_view
  645. ),
  646. )
  647. # TODO (rohan-varma): this is a workaround that allows users to
  648. # disable the default behavior of DDP managed parameters with
  649. # optimizer runing in backwards having their gradients all set to None.
  650. # Currently, it is an "all or nothing behavior" where DDP will set
  651. # no grads to None or all of them, relaxing this behavior will be
  652. # done dependent on use cases.
  653. if os.getenv("DDP_OVERLAPPED_OPTIM_SET_GRADS_TO_NONE", "1") != "0":
  654. warnings.warn(
  655. "DDP + apply_optim_in_backward will currently set all "
  656. "parameter gradients to None. If this is not the desired "
  657. "behavior, please set env variable "
  658. "DDP_OVERLAPPED_OPTIM_SET_GRADS_TO_NONE=0, and manually set"
  659. "gradients to None/zero as desired."
  660. )
  661. self.reducer._set_grads_to_none() # type: ignore[attr-defined]
  662. def _build_replicated_tensor_module(self):
  663. if self._use_replicated_tensor_module:
  664. # Create a module with ReplicatedTensor without copying tensors. Avoid
  665. # registering '_replicated_tensor_module' as a submodule by directly
  666. # adding to self.__dict__.
  667. from ._replicated_tensor_ddp_interop import _replicate_module
  668. self.__dict__["_replicated_tensor_module"] = _replicate_module(
  669. self.module, self.process_group
  670. )
  671. def _log_and_throw(self, err_type, err_msg):
  672. if self.logger is not None:
  673. self.logger.set_error_and_log(f"{str(err_type)}: {err_msg}")
  674. raise err_type(err_msg)
  675. def _ddp_init_helper(
  676. self,
  677. parameters,
  678. expect_sparse_gradient,
  679. param_to_name_mapping,
  680. static_graph,
  681. ):
  682. """
  683. Initialization helper function that does the following:
  684. (1) bucketing the parameters for reductions
  685. (2) resetting the bucketing states
  686. (3) registering the grad hooks
  687. (4) Logging construction-time DDP logging data
  688. (5) passing a handle of DDP to SyncBatchNorm Layer
  689. """
  690. self.num_iterations = 0
  691. # Notice, the parameters order is not in the order in which they are used,
  692. # especially in models with control flow.
  693. #
  694. # Alongside parameters are not presented in the real execution order,
  695. # if a certain model happens to also
  696. # 1) have other collectives comm ops in its backward graph.
  697. # 2) have unused parameter in subset ranks of the whole world.
  698. # bucketing could insert ALL-REDUCE comm op too early on the rank with unused parameter,
  699. # matching up with other collectives comm ops on other ranks unexpectedly.
  700. #
  701. # In order to handle this corner case, when the parameters are not in the real execution order,
  702. # we don't do bucketing, thus only one ALL-REDUCE is inserted after all the gradients
  703. # of the whole graph are computed.
  704. #
  705. # Notice, here we only disable bucketing for the first iteration.
  706. # After the first iteration, it's OK to rebuild buckets,
  707. # because "bucket rebuild" bucketizes parameters based on its real execution order in backward graph.
  708. # Can remove this branching once #73732 is landed.
  709. if static_graph is True or self.find_unused_parameters is False:
  710. bucket_size_limits = [sys.maxsize]
  711. else:
  712. bucket_size_limits = [
  713. dist._DEFAULT_FIRST_BUCKET_BYTES,
  714. self.bucket_bytes_cap,
  715. ]
  716. (
  717. bucket_indices,
  718. per_bucket_size_limits,
  719. ) = dist._compute_bucket_assignment_by_size(
  720. parameters,
  721. bucket_size_limits,
  722. expect_sparse_gradient,
  723. )
  724. # Note: reverse list of buckets because we want to approximate the
  725. # order in which their gradients are produced, and assume they
  726. # are used in the forward pass in the order they are defined.
  727. self.reducer = dist.Reducer(
  728. parameters,
  729. list(reversed(bucket_indices)),
  730. list(reversed(per_bucket_size_limits)),
  731. self.process_group,
  732. expect_sparse_gradient,
  733. # The bucket size limit is specified in the constructor.
  734. # Additionally, we allow for a single small bucket for parameters
  735. # that are defined first, such that their gradients don't spill into
  736. # a much larger bucket, adding unnecessary latency after gradient
  737. # computation finishes. Experiments showed 1MB is a reasonable value.
  738. self.bucket_bytes_cap,
  739. self.find_unused_parameters,
  740. self.gradient_as_bucket_view,
  741. param_to_name_mapping,
  742. # User can set dist._DEFAULT_FIRST_BUCKET_BYTES to tune DDP first
  743. # bucket.
  744. dist._DEFAULT_FIRST_BUCKET_BYTES,
  745. )
  746. self.logger = dist.Logger(self.reducer)
  747. # Set as a weak reference to avoid reference cycle between
  748. # logger and reducer.
  749. self.reducer.set_logger(self.logger)
  750. has_sync_bn = False
  751. for submodule in self.module.modules():
  752. if isinstance(submodule, torch.nn.SyncBatchNorm):
  753. has_sync_bn = True
  754. break
  755. # Set logging data that can be got during construction time.
  756. self.logger.set_construction_data_and_log(
  757. self.module.__class__.__name__,
  758. [] if self.device_ids is None else self.device_ids,
  759. -1 if self.output_device is None else self.output_device,
  760. self.broadcast_buffers,
  761. has_sync_bn,
  762. static_graph,
  763. )
  764. # passing a handle to torch.nn.SyncBatchNorm layer
  765. self._passing_sync_batchnorm_handle(self.module)
  766. def __getstate__(self):
  767. self._check_default_group()
  768. attrs = copy.copy(self.__dict__)
  769. del attrs["process_group"]
  770. del attrs["reducer"]
  771. del attrs["logger"]
  772. if self._use_replicated_tensor_module:
  773. del attrs["_replicated_tensor_module"]
  774. return attrs
  775. def __setstate__(self, state):
  776. # If serializable, then the process group should be the default one
  777. self.process_group = _get_default_group()
  778. super().__setstate__(state)
  779. self._build_replicated_tensor_module()
  780. self.__dict__.setdefault("require_forward_param_sync", True)
  781. self.__dict__.setdefault("require_backward_grad_sync", True)
  782. parameters, expect_sparse_gradient = self._build_params_for_reducer()
  783. # In debug mode, build a mapping of parameter index -> parameter.
  784. param_to_name_mapping = self._build_debug_param_to_name_mapping(
  785. parameters
  786. )
  787. # Builds reducer.
  788. self._ddp_init_helper(
  789. parameters,
  790. expect_sparse_gradient,
  791. param_to_name_mapping,
  792. self.static_graph,
  793. )
  794. if self.static_graph:
  795. self.reducer._set_static_graph()
  796. assert self.logger is not None
  797. self.logger._set_static_graph()
  798. def _build_params_for_reducer(self):
  799. # Build tuple of (module, parameter) for all parameters that require grads.
  800. modules_and_parameters = [
  801. (module, parameter)
  802. for module_name, module in self.module.named_modules()
  803. for parameter in [
  804. param
  805. # Note that we access module.named_parameters instead of
  806. # parameters(module). parameters(module) is only needed in the
  807. # single-process multi device case, where it accesses replicated
  808. # parameters through _former_parameters.
  809. for param_name, param in module.named_parameters(recurse=False)
  810. if param.requires_grad
  811. and f"{module_name}.{param_name}"
  812. not in self.parameters_to_ignore
  813. ]
  814. ]
  815. # Deduplicate any parameters that might be shared across child modules.
  816. memo = set()
  817. modules_and_parameters = [
  818. # "p not in memo" is the deduplication check.
  819. # "not memo.add(p)" is always True, and it's only there to cause "add(p)" if needed.
  820. (m, p)
  821. for m, p in modules_and_parameters
  822. if p not in memo and not memo.add(p) # type: ignore[func-returns-value]
  823. ]
  824. # Build list of parameters.
  825. parameters = [parameter for _, parameter in modules_and_parameters]
  826. # Checks if a module will produce a sparse gradient.
  827. def produces_sparse_gradient(module):
  828. if isinstance(module, (torch.nn.Embedding, torch.nn.EmbeddingBag)):
  829. return module.sparse
  830. return False
  831. # Build list of booleans indicating whether or not to expect sparse
  832. # gradients for the corresponding parameters.
  833. expect_sparse_gradient = [
  834. produces_sparse_gradient(module)
  835. for module, _ in modules_and_parameters
  836. ]
  837. self._assign_modules_buffers()
  838. return parameters, expect_sparse_gradient
  839. def _assign_modules_buffers(self):
  840. """
  841. Assigns module buffers to self.modules_buffers which are then used to
  842. broadcast across ranks when broadcast_buffers=True. Note that this
  843. must be called every time buffers need to be synced because buffers can
  844. be reassigned by user module,
  845. see https://github.com/pytorch/pytorch/issues/63916.
  846. """
  847. # Collect buffers for modules, filtering out buffers that should be ignored.
  848. named_module_buffers = [
  849. (buffer, buffer_name)
  850. for buffer_name, buffer in self.module.named_buffers()
  851. if buffer_name not in self.parameters_to_ignore
  852. ]
  853. self.modules_buffers = [
  854. buffer for (buffer, buffer_name) in named_module_buffers
  855. ]
  856. # Dict[str, tensor] representing module buffers not ignored by DDP.
  857. self.named_module_buffers = {
  858. buffer_name: buffer
  859. for (buffer, buffer_name) in named_module_buffers
  860. }
  861. def _build_debug_param_to_name_mapping(self, parameters):
  862. if dist.get_debug_level() == dist.DebugLevel.OFF:
  863. return {}
  864. param_to_param_index = {
  865. parameters[i]: i for i in range(len(parameters))
  866. }
  867. param_set = set(parameters)
  868. param_index_to_param_fqn = {}
  869. for module_name, module in self.module.named_modules():
  870. for param_name, param in module.named_parameters(recurse=False):
  871. fqn = f"{module_name}.{param_name}"
  872. # Bypass ignored parameters since those are not reduced by DDP
  873. # to begin with.
  874. if fqn not in self.parameters_to_ignore and param.requires_grad:
  875. if param not in param_set:
  876. self._log_and_throw(
  877. ValueError,
  878. f"Param with name {fqn} found in module parameters, but not DDP parameters."
  879. " This indicates a bug in DDP, please report an issue to PyTorch.",
  880. )
  881. param_index = param_to_param_index[param]
  882. param_index_to_param_fqn[param_index] = fqn
  883. # Ensure we covered all parameters
  884. if len(param_set) != len(param_index_to_param_fqn):
  885. self._log_and_throw(
  886. ValueError,
  887. (
  888. "Expected param to name mapping to cover all parameters, but"
  889. f" got conflicting lengths: {len(param_set)} vs "
  890. f"{len(param_index_to_param_fqn)}. This indicates a bug in DDP"
  891. ", please report an issue to PyTorch."
  892. ),
  893. )
  894. return param_index_to_param_fqn
  895. def _get_parameters(self, m, recurse=True):
  896. """
  897. Returns a generator of module parameters
  898. """
  899. def model_parameters(m):
  900. ps = (
  901. m._former_parameters.values()
  902. if hasattr(m, "_former_parameters")
  903. else m.parameters(recurse=False)
  904. )
  905. yield from ps
  906. for m in m.modules() if recurse else [m]:
  907. for p in model_parameters(m):
  908. yield p
  909. def _check_default_group(self):
  910. pickle_not_supported = False
  911. try:
  912. if self.process_group != _get_default_group():
  913. pickle_not_supported = True
  914. except RuntimeError:
  915. pickle_not_supported = True
  916. if pickle_not_supported:
  917. self._log_and_throw(
  918. RuntimeError,
  919. "DDP Pickling/Unpickling are only supported "
  920. "when using DDP with the default process "
  921. "group. That is, when you have called "
  922. "init_process_group and have not passed "
  923. "process_group argument to DDP constructor",
  924. )
  925. @contextmanager
  926. def no_sync(self):
  927. r"""
  928. A context manager to disable gradient synchronizations across DDP
  929. processes. Within this context, gradients will be accumulated on module
  930. variables, which will later be synchronized in the first
  931. forward-backward pass exiting the context.
  932. Example::
  933. >>> # xdoctest: +SKIP("undefined variables")
  934. >>> ddp = torch.nn.parallel.DistributedDataParallel(model, pg)
  935. >>> with ddp.no_sync():
  936. >>> for input in inputs:
  937. >>> ddp(input).backward() # no synchronization, accumulate grads
  938. >>> ddp(another_input).backward() # synchronize grads
  939. .. warning::
  940. The forward pass should be included inside the context manager, or
  941. else gradients will still be synchronized.
  942. """
  943. old_require_backward_grad_sync = self.require_backward_grad_sync
  944. self.require_backward_grad_sync = False
  945. try:
  946. yield
  947. finally:
  948. self.require_backward_grad_sync = old_require_backward_grad_sync
  949. @classmethod
  950. def _get_active_ddp_module(cls):
  951. """
  952. TorchDynamo needs to know whether DDP is currently active, and access the DDP module in order to cooperatively optimize it.
  953. """
  954. return cls._active_ddp_module
  955. # note, this ctxmgr function is marked 'skip' in torchdynamo, so dynamo only kicks in
  956. # for the 'module_to_run' underneath
  957. # see torch._dynamo/eval_frame.py TorchPatcher.patch for more details
  958. @contextmanager
  959. def _inside_ddp_forward(self):
  960. DistributedDataParallel._active_ddp_module = self
  961. try:
  962. yield
  963. except Exception:
  964. raise
  965. finally:
  966. DistributedDataParallel._active_ddp_module = None
  967. def _run_ddp_forward(self, *inputs, **kwargs):
  968. module_to_run = (
  969. self._replicated_tensor_module
  970. if self._use_replicated_tensor_module
  971. else self.module
  972. )
  973. if self.device_ids:
  974. inputs, kwargs = _to_kwargs(
  975. inputs,
  976. kwargs,
  977. self.device_ids[0],
  978. self.use_side_stream_for_tensor_copies,
  979. )
  980. with self._inside_ddp_forward():
  981. return module_to_run(*inputs[0], **kwargs[0]) # type: ignore[index]
  982. else:
  983. with self._inside_ddp_forward():
  984. return module_to_run(*inputs, **kwargs)
  985. def forward(self, *inputs, **kwargs):
  986. with torch.autograd.profiler.record_function(
  987. "DistributedDataParallel.forward"
  988. ):
  989. if torch.is_grad_enabled() and self.require_backward_grad_sync:
  990. assert self.logger is not None
  991. self.logger.set_runtime_stats_and_log()
  992. self.num_iterations += 1
  993. self.reducer.prepare_for_forward()
  994. # Notify the join context that this process has not joined, if
  995. # needed
  996. work = Join.notify_join_context(self)
  997. if work:
  998. self.reducer._set_forward_pass_work_handle(
  999. work, self._divide_by_initial_world_size # type: ignore[arg-type]
  1000. )
  1001. # Calling _rebuild_buckets before forward compuation,
  1002. # It may allocate new buckets before deallocating old buckets
  1003. # inside _rebuild_buckets. To save peak memory usage,
  1004. # call _rebuild_buckets before the peak memory usage increases
  1005. # during forward computation.
  1006. # This should be called only once during whole training period.
  1007. if torch.is_grad_enabled() and self.reducer._rebuild_buckets():
  1008. logger.info(
  1009. "Reducer buckets have been rebuilt in this iteration."
  1010. )
  1011. self._has_rebuilt_buckets = True
  1012. # sync params according to location (before/after forward) user
  1013. # specified as part of hook, if hook was specified.
  1014. if self._check_sync_bufs_pre_fwd():
  1015. self._sync_buffers()
  1016. if self._join_config.enable:
  1017. # Notify joined ranks whether they should sync in backwards pass or not.
  1018. self._check_global_requires_backward_grad_sync(
  1019. is_joined_rank=False
  1020. )
  1021. output = self._run_ddp_forward(*inputs, **kwargs)
  1022. # sync params according to location (before/after forward) user
  1023. # specified as part of hook, if hook was specified.
  1024. if self._check_sync_bufs_post_fwd():
  1025. self._sync_buffers()
  1026. if torch.is_grad_enabled() and self.require_backward_grad_sync:
  1027. self.require_forward_param_sync = True
  1028. # We'll return the output object verbatim since it is a freeform
  1029. # object. We need to find any tensors in this object, though,
  1030. # because we need to figure out which parameters were used during
  1031. # this forward pass, to ensure we short circuit reduction for any
  1032. # unused parameters. Only if `find_unused_parameters` is set.
  1033. if self.find_unused_parameters and not self.static_graph:
  1034. # Do not need to populate this for static graph.
  1035. self.reducer.prepare_for_backward(
  1036. list(_find_tensors(output))
  1037. )
  1038. else:
  1039. self.reducer.prepare_for_backward([])
  1040. else:
  1041. self.require_forward_param_sync = False
  1042. # TODO: DDPSink is currently enabled for unused parameter detection and
  1043. # static graph training for first iteration.
  1044. if (self.find_unused_parameters and not self.static_graph) or (
  1045. self.static_graph and self.num_iterations == 1
  1046. ):
  1047. state_dict = {
  1048. "static_graph": self.static_graph,
  1049. "num_iterations": self.num_iterations,
  1050. }
  1051. (
  1052. output_tensor_list,
  1053. treespec,
  1054. output_is_rref,
  1055. ) = _tree_flatten_with_rref(output)
  1056. output_placeholders = [None for _ in range(len(output_tensor_list))]
  1057. # Do not touch tensors that have no grad_fn, which can cause issues
  1058. # such as https://github.com/pytorch/pytorch/issues/60733
  1059. for i, output in enumerate(output_tensor_list):
  1060. if torch.is_tensor(output) and output.grad_fn is None:
  1061. output_placeholders[i] = output
  1062. # When find_unused_parameters=True, makes tensors which require grad
  1063. # run through the DDPSink backward pass. When not all outputs are
  1064. # used in loss, this makes those corresponding tensors receive
  1065. # undefined gradient which the reducer then handles to ensure
  1066. # param.grad field is not touched and we don't error out.
  1067. passthrough_tensor_list = _DDPSink.apply(
  1068. self.reducer,
  1069. state_dict,
  1070. *output_tensor_list,
  1071. )
  1072. for i in range(len(output_placeholders)):
  1073. if output_placeholders[i] is None:
  1074. output_placeholders[i] = passthrough_tensor_list[i]
  1075. # Reconstruct output data structure.
  1076. output = _tree_unflatten_with_rref(
  1077. output_placeholders, treespec, output_is_rref
  1078. )
  1079. return output
  1080. def scatter(self, inputs, kwargs, device_ids):
  1081. return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
  1082. def to_kwargs(self, inputs, kwargs, device_id):
  1083. # Kept for BC
  1084. return _to_kwargs(
  1085. inputs, kwargs, device_id, self.use_side_stream_for_tensor_copies
  1086. )
  1087. def gather(self, outputs, output_device):
  1088. return gather(outputs, output_device, dim=self.dim)
  1089. def train(self, mode=True):
  1090. super().train(mode)
  1091. if self._use_replicated_tensor_module:
  1092. self._replicated_tensor_module.train(mode) # type: ignore[union-attr]
  1093. return self
  1094. # When running in join mode, schedules an allreduce to notify joined ranks
  1095. # of whether backwards pass synchronization will run this iteration or not.
  1096. def _check_global_requires_backward_grad_sync(self, is_joined_rank):
  1097. if not is_joined_rank and self.require_backward_grad_sync:
  1098. requires_sync_tensor = torch.ones(1, device=self.device)
  1099. else:
  1100. requires_sync_tensor = torch.zeros(1, device=self.device)
  1101. work = dist.all_reduce(
  1102. requires_sync_tensor, group=self.process_group, async_op=True
  1103. )
  1104. return work
  1105. # When running in join mode, checks and performs sync of module buffers if
  1106. # the models have buffers that should be synchronized in the forward pass.
  1107. def _check_and_sync_module_buffers(self):
  1108. if self._check_sync_bufs_pre_fwd():
  1109. authoritative_rank = self._find_common_rank(
  1110. self._distributed_rank, False
  1111. )
  1112. self._sync_module_buffers(authoritative_rank)
  1113. # When running in join model, agrees upon a common rank and broadcast model
  1114. # parameters to all other ranks.
  1115. def _sync_final_model(self, is_last_joiner):
  1116. # Agree upon the process that will be the authoritative model copy.
  1117. # The current rank is a candidate for being the authoritative copy if
  1118. # is_last_joiner=True. We break ties via picking the larger rank.
  1119. self._authoritative_rank = self._find_common_rank(
  1120. self._distributed_rank, is_last_joiner
  1121. )
  1122. _sync_module_states(
  1123. module=self.module,
  1124. process_group=self.process_group,
  1125. broadcast_bucket_size=self.broadcast_bucket_size,
  1126. src=self._authoritative_rank,
  1127. params_and_buffers_to_ignore=self.parameters_to_ignore,
  1128. )
  1129. # Schedule comm ops to match those scheduled in the reducer's backward
  1130. # pass.
  1131. def _match_all_reduce_for_bwd_pass(self):
  1132. comm_work = []
  1133. # Schedule comm in the same order as Reducer schedules them, i.e.
  1134. # the order of the buckets. Retrieving the bucket order from the reducer
  1135. # ensures that we keep the same order in join mode, such as when bucket
  1136. # order is rebuilt dynamically.
  1137. # Returns grad_buckets in order, but real tensors are substituted with
  1138. # zero tensors of the same shape.
  1139. grad_buckets = self.reducer._get_zeros_like_grad_buckets()
  1140. for grad_bucket in grad_buckets:
  1141. # Joined processes contribute zero gradient. In the case that
  1142. # divide_by_initial_world_size=True, we divide grads by the static
  1143. # world size, if not, the dividing factor is reduced by the number
  1144. # of joined processes.
  1145. work = self.reducer._run_comm_hook(grad_bucket)
  1146. comm_work.append(work)
  1147. for work in comm_work:
  1148. work.wait()
  1149. # Allreduces the used parameter mapping across ranks.
  1150. def _match_unused_params_allreduce(self):
  1151. locally_used_param_map = self.reducer._get_local_used_map()
  1152. self.process_group.allreduce(locally_used_param_map)
  1153. def join(
  1154. self,
  1155. divide_by_initial_world_size: bool = True,
  1156. enable: bool = True,
  1157. throw_on_early_termination: bool = False,
  1158. ):
  1159. r"""
  1160. A context manager to be used in conjunction with an instance of
  1161. :class:`torch.nn.parallel.DistributedDataParallel` to be
  1162. able to train with uneven inputs across participating processes.
  1163. This context manager will keep track of already-joined DDP processes,
  1164. and "shadow" the forward and backward passes by inserting collective
  1165. communication operations to match with the ones created by non-joined
  1166. DDP processes. This will ensure each collective call has a corresponding
  1167. call by already-joined DDP processes, preventing hangs or errors that
  1168. would otherwise happen when training with uneven inputs across
  1169. processes. Alternatively, if the flag ``throw_on_early_termination`` is
  1170. specified to be ``True``, all trainers will throw an error once one rank
  1171. runs out of inputs, allowing these errors to be caught and handled
  1172. according to application logic.
  1173. Once all DDP processes have joined, the context manager will broadcast
  1174. the model corresponding to the last joined process to all processes to
  1175. ensure the model is the same across all processes
  1176. (which is guaranteed by DDP).
  1177. To use this to enable training with uneven inputs across processes,
  1178. simply wrap this context manager around your training loop. No further
  1179. modifications to the model or data loading is required.
  1180. .. warning::
  1181. If the model or training loop this context manager is wrapped around
  1182. has additional distributed collective operations, such as
  1183. ``SyncBatchNorm`` in the model's forward pass, then the flag
  1184. ``throw_on_early_termination`` must be enabled. This is because this
  1185. context manager is not aware of non-DDP collective communication.
  1186. This flag will cause all ranks to throw when any one rank
  1187. exhausts inputs, allowing these errors to be caught and recovered
  1188. from across all ranks.
  1189. Args:
  1190. divide_by_initial_world_size (bool): If ``True``, will divide
  1191. gradients by the initial ``world_size`` DDP training was launched
  1192. with. If ``False``, will compute the effective world size
  1193. (number of ranks that have not depleted their inputs yet) and
  1194. divide gradients by that during allreduce. Set
  1195. ``divide_by_initial_world_size=True`` to ensure every input
  1196. sample including the uneven inputs have equal weight in terms of
  1197. how much they contribute to the global gradient. This is
  1198. achieved by always dividing the gradient by the initial
  1199. ``world_size`` even when we encounter uneven inputs. If you set
  1200. this to ``False``, we divide the gradient by the remaining
  1201. number of nodes. This ensures parity with training on a smaller
  1202. ``world_size`` although it also means the uneven inputs would
  1203. contribute more towards the global gradient. Typically, you
  1204. would want to set this to ``True`` for cases where the last few
  1205. inputs of your training job are uneven. In extreme cases, where
  1206. there is a large discrepancy in the number of inputs, setting
  1207. this to ``False`` might provide better results.
  1208. enable (bool): Whether to enable uneven input detection or not. Pass
  1209. in ``enable=False`` to disable in cases where you know that
  1210. inputs are even across participating processes. Default is
  1211. ``True``.
  1212. throw_on_early_termination (bool): Whether to throw an error
  1213. or continue training when at least one rank has exhausted
  1214. inputs. If ``True``, will throw upon the first rank reaching end
  1215. of data. If ``False``, will continue training with a smaller
  1216. effective world size until all ranks are joined. Note that if
  1217. this flag is specified, then the flag
  1218. ``divide_by_initial_world_size`` would be ignored. Default
  1219. is ``False``.
  1220. Example::
  1221. >>> # xdoctest: +SKIP("Distributed")
  1222. >>> import torch
  1223. >>> import torch.distributed as dist
  1224. >>> import os
  1225. >>> import torch.multiprocessing as mp
  1226. >>> import torch.nn as nn
  1227. >>> # On each spawned worker
  1228. >>> def worker(rank):
  1229. >>> dist.init_process_group("nccl", rank=rank, world_size=2)
  1230. >>> torch.cuda.set_device(rank)
  1231. >>> model = nn.Linear(1, 1, bias=False).to(rank)
  1232. >>> model = torch.nn.parallel.DistributedDataParallel(
  1233. >>> model, device_ids=[rank], output_device=rank
  1234. >>> )
  1235. >>> # Rank 1 gets one more input than rank 0.
  1236. >>> inputs = [torch.tensor([1]).float() for _ in range(10 + rank)]
  1237. >>> with model.join():
  1238. >>> for _ in range(5):
  1239. >>> for inp in inputs:
  1240. >>> loss = model(inp).sum()
  1241. >>> loss.backward()
  1242. >>> # Without the join() API, the below synchronization will hang
  1243. >>> # blocking for rank 1's allreduce to complete.
  1244. >>> torch.cuda.synchronize(device=rank)
  1245. """
  1246. return Join(
  1247. [self],
  1248. enable,
  1249. throw_on_early_termination,
  1250. divide_by_initial_world_size=divide_by_initial_world_size,
  1251. )
  1252. def join_hook(
  1253. self,
  1254. **kwargs,
  1255. ):
  1256. r"""
  1257. Returns the DDP join hook, which enables training on uneven inputs by
  1258. shadowing the collective communications in the forward and backward
  1259. passes.
  1260. Arguments:
  1261. kwargs (dict): a :class:`dict` containing any keyword arguments
  1262. to modify the behavior of the join hook at run time; all
  1263. :class:`Joinable` instances sharing the same join context
  1264. manager are forwarded the same value for ``kwargs``.
  1265. The hook supports the following keyword arguments:
  1266. divide_by_initial_world_size (bool, optional):
  1267. If ``True``, then gradients are divided by the initial world
  1268. size that DDP was launched with.
  1269. If ``False``, then gradients are divided by the effective world
  1270. size (i.e. the number of non-joined processes), meaning that
  1271. the uneven inputs contribute more toward the global gradient.
  1272. Typically, this should be set to ``True`` if the degree of
  1273. unevenness is small but can be set to ``False`` in extreme
  1274. cases for possibly better results.
  1275. Default is ``True``.
  1276. """
  1277. divide_by_initial_world_size = kwargs.get(
  1278. "divide_by_initial_world_size", True
  1279. )
  1280. return _DDPJoinHook(
  1281. self, divide_by_initial_world_size=divide_by_initial_world_size
  1282. )
  1283. @property
  1284. def join_device(self):
  1285. return self.device
  1286. @property
  1287. def join_process_group(self):
  1288. return self.process_group
  1289. def _register_buffer_comm_hook(
  1290. self,
  1291. state,
  1292. hook: Callable,
  1293. comm_hook_location=_BufferCommHookLocation.POST_FORWARD,
  1294. ):
  1295. r"""
  1296. Allows custom registration of hooks that define how buffer are
  1297. synchronized across ranks. The hook takes in an optional state
  1298. and is passed in a Dict[str, Tensor] corresponding to buffer names
  1299. and the buffers, and can run arbitrary reductions on buffers as
  1300. opposed to DDP's default broadcast from rank 0. This is useful for
  1301. example if a counter needs to be summed or averaged across ranks
  1302. every iteration.
  1303. Args:
  1304. state (Any): Optional state that is passed to the hook.
  1305. hook (Callable): Callable with the following signature:
  1306. ``hook(state: object, bucket: dist.GradBucket) -> torch.futures.Future[torch.Tensor]``
  1307. comm_hook_location (_BufferCommHookLocation): Enum value indicating
  1308. where to run the hook.
  1309. _BufferCommHookLocation.PRE_FORWARD means that the
  1310. hook will run _before_ the forward pass, and
  1311. _BufferCommHookLocation.POST_FORWARD means that the
  1312. hook will run _after_ the forward pass.
  1313. NOTE: To maximize performance, users can return a
  1314. List[torch.futures.Future] from their hook, and DDP will
  1315. install and await these hooks appropriately at the end of
  1316. the backward pass. This will ensure all buffers are
  1317. synchronized by the end of the backward pass. If this
  1318. setting is used, it is recommended to pass
  1319. comm_hook_location=_BufferCommHookLocation.POST_FORWARD,
  1320. which will trigger the hook after the forward pass.
  1321. If _BufferCommHookLocation.PRE_FORWARD is used, users must
  1322. ensure appropriate synchronization when manipulating GPU
  1323. buffers in the forward pass.
  1324. """
  1325. assert callable(hook)
  1326. self.buffer_hook = _BufferCommHook(
  1327. buffer_comm_hook=hook,
  1328. buffer_comm_hook_state=state,
  1329. buffer_comm_hook_location=comm_hook_location,
  1330. )
  1331. def register_comm_hook(self, state: object, hook: Callable):
  1332. r"""
  1333. Registers a communication hook which is an enhancement that provides a
  1334. flexible hook to users where they can specify how DDP aggregates gradients
  1335. across multiple workers.
  1336. This hook would be very useful for researchers to try out new ideas. For
  1337. example, this hook can be used to implement several algorithms like GossipGrad
  1338. and gradient compression which involve different communication strategies for
  1339. parameter syncs while running Distributed DataParallel training.
  1340. Args:
  1341. state (object): Passed to the hook to maintain any state information during the training process.
  1342. Examples include error feedback in gradient compression,
  1343. peers to communicate with next in GossipGrad, etc.
  1344. It is locally stored by each worker
  1345. and shared by all the gradient tensors on the worker.
  1346. hook (Callable): Callable with the following signature:
  1347. ``hook(state: object, bucket: dist.GradBucket) -> torch.futures.Future[torch.Tensor]``:
  1348. This function is called once the bucket is ready. The
  1349. hook can perform whatever processing is needed and return
  1350. a Future indicating completion of any async work (ex: allreduce).
  1351. If the hook doesn't perform any communication, it still
  1352. must return a completed Future. The Future should hold the
  1353. new value of grad bucket's tensors. Once a bucket is ready,
  1354. c10d reducer would call this hook and use the tensors returned
  1355. by the Future and copy grads to individual parameters.
  1356. Note that the future's return type must be a single tensor.
  1357. We also provide an API called ``get_future`` to retrieve a
  1358. Future associated with the completion of ``c10d.ProcessGroup.Work``.
  1359. ``get_future`` is currently supported for NCCL and also supported for most
  1360. operations on GLOO and MPI, except for peer to peer operations (send/recv).
  1361. .. warning ::
  1362. Grad bucket's tensors will not be predivided by world_size. User is responsible
  1363. to divide by the world_size in case of operations like allreduce.
  1364. .. warning ::
  1365. DDP communication hook can only be registered once and should be registered
  1366. before calling backward.
  1367. .. warning ::
  1368. The Future object that hook returns should contain a single tensor
  1369. that has the same shape with the tensors inside grad bucket.
  1370. .. warning ::
  1371. ``get_future`` API supports NCCL, and partially GLOO and MPI backends (no support
  1372. for peer-to-peer operations like send/recv) and will return a ``torch.futures.Future``.
  1373. Example::
  1374. Below is an example of a noop hook that returns the same tensor.
  1375. >>> # xdoctest: +SKIP('undefined name')
  1376. >>> def noop(state: object, bucket: dist.GradBucket) -> torch.futures.Future[torch.Tensor]:
  1377. >>> fut = torch.futures.Future()
  1378. >>> fut.set_result(bucket.buffer())
  1379. >>> return fut
  1380. >>> ddp.register_comm_hook(state=None, hook=noop)
  1381. Example::
  1382. Below is an example of a Parallel SGD algorithm where gradients are encoded before
  1383. allreduce, and then decoded after allreduce.
  1384. >>> # xdoctest: +SKIP('undefined name')
  1385. >>> def encode_and_decode(state: object, bucket: dist.GradBucket) -> torch.futures.Future[torch.Tensor]:
  1386. >>> encoded_tensor = encode(bucket.buffer()) # encode gradients
  1387. >>> fut = torch.distributed.all_reduce(encoded_tensor).get_future()
  1388. >>> # Define the then callback to decode.
  1389. >>> def decode(fut):
  1390. >>> decoded_tensor = decode(fut.value()[0]) # decode gradients
  1391. >>> return decoded_tensor
  1392. >>> return fut.then(decode)
  1393. >>> ddp.register_comm_hook(state=None, hook=encode_and_decode)
  1394. """
  1395. self._check_comm_hook(hook)
  1396. assert self.logger is not None
  1397. self.logger._set_comm_hook_name(hook.__qualname__)
  1398. dist._register_comm_hook(self.reducer, state, hook)
  1399. def _register_builtin_comm_hook(self, comm_hook_type):
  1400. r"""
  1401. Registers a built-in communication hook that specifies how DDP
  1402. aggregates gradients across multiple workers.
  1403. The built-in hooks aim to provide efficient C++ implementations for certain hooks,
  1404. which might not be as efficient if implemented in Python using a Python communication hook.
  1405. Args:
  1406. comm_hook_type (dist.BuiltinCommHookType): type of communication hook, such as ALLREDUCE, FP16_COMPRESS, etc.
  1407. .. warning ::
  1408. DDP communication hook can only be registered once and should be registered
  1409. before calling backward.
  1410. Example::
  1411. Below is an example of a FP16 compression where gradients are
  1412. compressed into 16-bit floating-point numbers before allreduce, and
  1413. then decompressed after allreduce.
  1414. >>> # xdoctest: +SKIP('undefined name')
  1415. >>> ddp._register_builtin_comm_hook(dist.BuiltinCommHookType.FP16_COMPRESS)
  1416. """
  1417. assert self.logger is not None
  1418. self.logger._set_comm_hook_name(str(comm_hook_type))
  1419. dist._register_builtin_comm_hook(self.reducer, comm_hook_type)
  1420. def _register_fused_optim(
  1421. self, optim: Type, *args, optim_params=None, **kwargs
  1422. ):
  1423. r"""
  1424. Registers an optimizer with DDP such that the optimization for a
  1425. parameter will run immediately when that parameter's gradient is
  1426. finished with reduction, instead of waiting for all parameters'
  1427. gradients to finish reduction. This can result in a training speedup
  1428. depending on your workload since the optimizer can run while gradient
  1429. reduction for other parameters are still ongoing. In addition, this has
  1430. the potential to reduce peak memory consumption during training, as it
  1431. only needs to load the per-parameter optimizer states of a single
  1432. parameter at a time, instead of loading all per-parameter optimizer
  1433. states at once.
  1434. Args:
  1435. optim (Type): a ``torch.optim.Optimizer`` class to be registered
  1436. as a fused optimizer.
  1437. *args (Sequence[Any]): Arguments to forward to `optim`.
  1438. optim_params (Optional[Iterable[torch.Tensor]]): Set of parameters
  1439. to optimize, similar to `params` argument of traditional `torch.optim`
  1440. Optimizers. If this is omitted, all DDP model parameters will be
  1441. optimized.
  1442. **kwargs: (Dict[str, Any]): Keyword arguments to forward to `optim`.
  1443. .. warning ::
  1444. _register_fused_optim should only be called once on a DDP instance,
  1445. and registering multiple fused optimizers for the same DDP model
  1446. is not currently supported. Please ping
  1447. https://github.com/pytorch/pytorch/issues/71595 if this is necessary
  1448. for your use case.
  1449. .. warning ::
  1450. _register_fused_optim and register_comm_hook currently do not
  1451. compose together, meaning that custom DDP communication hooks are
  1452. not supported with overlapped optimizers. Please ping
  1453. https://github.com/pytorch/pytorch/issues/71595 if this is necessary
  1454. for your use case.
  1455. .. warning ::
  1456. Gradient accumulation and DDP `no_sync` are currently not supported
  1457. with overlapped optimizer. Please ping
  1458. https://github.com/pytorch/pytorch/issues/71595 if this is necessary
  1459. for your use case.
  1460. Example::
  1461. >>> # xdoctest: +SKIP("No rendezvous handler")
  1462. >>> torch.distributed.init_process_group(backend='nccl', world_size=4, init_method='...')
  1463. >>> net = torch.nn.parallel.DistributedDataParallel(model, pg)
  1464. >>> lr = 1e-2
  1465. >>> betas = (0.9, 0.99)
  1466. >>> eps = 1e-6
  1467. >>> net._register_fused_optim(torch.optim.Adam, lr, betas=betas, eps=eps)
  1468. >>> # Example with subset of parameters
  1469. >>> params_to_opt = [list(net.parameters())[0]]
  1470. >>> net._register_fused_optim(
  1471. ... torch.optim.Adam, lr, optim_params=params_to_opt, betas=betas, eps=eps
  1472. ... )
  1473. """
  1474. # Note: importing in function, otherwise this will cause a circular
  1475. # import as optimizer_overlap module needs to import DistributedDataParallel.
  1476. from torch.distributed.algorithms._optimizer_overlap import (
  1477. _as_overlapped_optim,
  1478. )
  1479. overlapped_optim = _as_overlapped_optim(
  1480. optim, optim_params, *args, **kwargs
  1481. )
  1482. try:
  1483. overlapped_optim.register_ddp(self)
  1484. except NotImplementedError as e:
  1485. raise RuntimeError(
  1486. f"{optim} does not support overlapped DDP. Please file an issue to PyTorch or the respective owner of {optim}."
  1487. ) from e
  1488. def _distributed_broadcast_coalesced(
  1489. self, tensors, buffer_size, authoritative_rank=0
  1490. ):
  1491. dist._broadcast_coalesced(
  1492. self.process_group, tensors, buffer_size, authoritative_rank
  1493. )
  1494. def _check_sync_bufs_post_fwd(self):
  1495. return (
  1496. self.will_sync_module_buffers()
  1497. and hasattr(self, "buffer_hook")
  1498. and self.buffer_hook.buffer_comm_hook_location
  1499. == _BufferCommHookLocation.POST_FORWARD
  1500. )
  1501. def _check_sync_bufs_pre_fwd(self):
  1502. return self.will_sync_module_buffers() and (
  1503. not hasattr(self, "buffer_hook")
  1504. or self.buffer_hook.buffer_comm_hook_location
  1505. == _BufferCommHookLocation.PRE_FORWARD
  1506. )
  1507. def will_sync_module_buffers(self):
  1508. return (
  1509. self.require_forward_param_sync
  1510. and self.broadcast_buffers
  1511. and len(self.modules_buffers) > 0
  1512. )
  1513. def _find_common_rank(self, input_rank, rank_cond):
  1514. # -1 indicates that this rank is not under consideration to be the
  1515. # common_rank
  1516. rank_to_use = torch.tensor(
  1517. [input_rank if rank_cond else -1],
  1518. device=self.device,
  1519. )
  1520. dist.all_reduce(rank_to_use, op=ReduceOp.MAX, group=self.process_group)
  1521. if rank_to_use.item() == -1:
  1522. self._log_and_throw(
  1523. ValueError,
  1524. "BUG! Expected rank_cond to be true for at least one process."
  1525. " This indicates a bug in PyTorch, please report an issue.",
  1526. )
  1527. return rank_to_use.item()
  1528. def _sync_buffers(self):
  1529. with torch.no_grad():
  1530. # module buffer sync
  1531. # Synchronize buffers across processes.
  1532. # If we are running DDP with the join manager, we have to agree
  1533. # upon a rank to sync module buffers from, since rank 0 may
  1534. # already have been joined and have stale module buffers.
  1535. if self._join_config.enable:
  1536. authoritative_rank = self._find_common_rank(
  1537. self._distributed_rank, True
  1538. )
  1539. else:
  1540. # The process with rank 0 is considered the authoritative copy.
  1541. authoritative_rank = 0
  1542. # Update self.modules_buffers incase any buffers were
  1543. # reassigned.
  1544. self._assign_modules_buffers()
  1545. self._sync_module_buffers(authoritative_rank)
  1546. def _sync_module_buffers(self, authoritative_rank):
  1547. if not hasattr(self, "buffer_hook"):
  1548. self._default_broadcast_coalesced(
  1549. authoritative_rank=authoritative_rank
  1550. )
  1551. else:
  1552. hook = self.buffer_hook.buffer_comm_hook
  1553. state = self.buffer_hook.buffer_comm_hook_state
  1554. futs = hook(state, self.named_module_buffers)
  1555. if futs is not None:
  1556. self.reducer._install_post_backward_futures(futs)
  1557. def _default_broadcast_coalesced(
  1558. self, bufs=None, bucket_size=None, authoritative_rank=0
  1559. ):
  1560. """
  1561. Broadcasts buffers from rank 0 to rest of workers. If bufs, bucket_size
  1562. are None, default values self.modules_buffers and
  1563. self.broadcast_bucket_size are used instead.
  1564. """
  1565. if bufs is None:
  1566. bufs = self.modules_buffers
  1567. if bucket_size is None:
  1568. bucket_size = self.broadcast_bucket_size
  1569. self._distributed_broadcast_coalesced(
  1570. bufs, bucket_size, authoritative_rank
  1571. )
  1572. def _passing_sync_batchnorm_handle(self, module):
  1573. for layer in module.modules():
  1574. if isinstance(layer, torch.nn.modules.SyncBatchNorm):
  1575. if self.device_type == "cpu":
  1576. self._log_and_throw(
  1577. ValueError,
  1578. "SyncBatchNorm layers only work with GPU modules",
  1579. )
  1580. def _check_comm_hook(self, hook):
  1581. if not callable(hook):
  1582. self._log_and_throw(
  1583. TypeError, "Communication hook must be callable."
  1584. )
  1585. sig = inspect.signature(hook)
  1586. if (
  1587. sig.parameters["bucket"].annotation != inspect._empty
  1588. and sig.parameters["bucket"].annotation != dist.GradBucket
  1589. ):
  1590. self._log_and_throw(
  1591. ValueError,
  1592. "Communication hook: bucket annotation should be dist.GradBucket.",
  1593. )
  1594. if (
  1595. sig.return_annotation != inspect._empty
  1596. and sig.return_annotation != torch.futures.Future[torch.Tensor]
  1597. ):
  1598. self._log_and_throw(
  1599. ValueError,
  1600. "Communication hook: return annotation should be torch.futures.Future[torch.Tensor].",
  1601. )
  1602. if hook.__name__ in [
  1603. "bf16_compress_hook",
  1604. "bf16_compress_wrapper_hook",
  1605. ] and (
  1606. (torch.version.cuda is None and torch.version.hip is None)
  1607. or (
  1608. torch.version.cuda is not None
  1609. and int(torch.version.cuda.split(".")[0]) < 11
  1610. )
  1611. or not dist.is_available()
  1612. or not dist.is_nccl_available()
  1613. or torch.cuda.nccl.version() < (2, 10)
  1614. ):
  1615. self._log_and_throw(
  1616. TypeError,
  1617. "BF16 all reduce communication hook required CUDA 11+ and NCCL 2.10+.",
  1618. )
  1619. @property
  1620. def _distributed_rank(self):
  1621. return dist.get_rank(self.process_group)
  1622. @staticmethod
  1623. def _set_params_and_buffers_to_ignore_for_model(
  1624. module, params_and_buffers_to_ignore
  1625. ):
  1626. """
  1627. Sets parameters and buffers to be ignored by DDP. Expected format for
  1628. parameters is the fully qualified name: {module_name}.{param_name}, and
  1629. similarly, {module_name}.{buffer_name} for buffers. For example:
  1630. params_to_ignore = []
  1631. # NB: model here is vanilla PyTorch module, not yet wrapped with DDP.
  1632. for module_name, module in model.named_modules():
  1633. for param_name, param in module.named_parameters(recurse=False):
  1634. if should_ignore(param):
  1635. # Create expected format
  1636. fqn = f"{module_name}.{param_name}"
  1637. params_to_ignore.append(fqn)
  1638. torch.nn.parallel.DistributedDataParallel._set_params_and_buffers_to_ignore_for_model(
  1639. model,
  1640. params_to_ignore
  1641. )
  1642. """
  1643. # This is a workaround to set parameters and buffers DDP should ignore
  1644. # during synchronization. It will be removed when the API is finalized
  1645. # as part of addressing https://github.com/pytorch/pytorch/issues/43690.
  1646. module._ddp_params_and_buffers_to_ignore = params_and_buffers_to_ignore
  1647. for name, param in module.named_parameters():
  1648. if name in params_and_buffers_to_ignore:
  1649. param._ddp_ignored = True
  1650. for name, buffer in module.named_buffers():
  1651. if name in params_and_buffers_to_ignore:
  1652. buffer._ddp_ignored = True
  1653. def _get_ddp_logging_data(self):
  1654. r"""
  1655. This interface can be called after DistributedDataParallel() is
  1656. constructed. It returns a dictionary of logging data. It could help
  1657. for debugging and analysis. The loggind data includes DistributedDataParallel
  1658. constructor input parameters, some internal states of DistributedDataParallel
  1659. and performance metrics. Simply print the dictorinary and see what
  1660. these metrics are.
  1661. This is a prototype interface and subject to change in the future.
  1662. """
  1663. assert self.logger is not None
  1664. ddp_logging_data = self.logger._get_ddp_logging_data()
  1665. return {**ddp_logging_data.strs_map, **ddp_logging_data.ints_map}
  1666. def _set_ddp_runtime_logging_sample_rate(self, sample_rate):
  1667. r"""
  1668. This interface allows users to set sample_rate of collecting
  1669. runtime stats. The runtime stats will be recorded for the
  1670. first 10 iterations, after 10 iterations runtime stats will be
  1671. recorded once every "sample_rate" training iterations. In
  1672. default, runtime stats are recorded for the first 10 iterations,
  1673. after 10 iterations runtime stats are recorded once every
  1674. "kDDPRuntimeLoggingSampleRate=100" training iterations.
  1675. This is a prototype interface and subject to change in the future.
  1676. """
  1677. if sample_rate < 1:
  1678. self._log_and_throw(
  1679. ValueError,
  1680. "DDP runtime logging sample rate should be equal or greater than 1",
  1681. )
  1682. self.reducer._set_ddp_runtime_logging_sample_rate(sample_rate)
  1683. def _set_static_graph(self):
  1684. """
  1685. It is recommended to set static graph in the DDP constructor, which will
  1686. call this private API internally.
  1687. """
  1688. # If self.static_graph has been set, no need to set it again
  1689. if self.static_graph:
  1690. warnings.warn(
  1691. "You've set static_graph to be True, no need to set it again."
  1692. )
  1693. return
  1694. self.static_graph = True
  1695. self.reducer._set_static_graph()
  1696. assert self.logger is not None
  1697. self.logger._set_static_graph()
  1698. if self.find_unused_parameters:
  1699. warnings.warn(
  1700. "You passed find_unused_parameters=true to DistributedDataParallel, "
  1701. "`_set_static_graph` will detect unused parameters automatically, so "
  1702. "you do not need to set find_unused_parameters=true, just be sure these "
  1703. "unused parameters will not change during training loop while calling "
  1704. "`_set_static_graph`."
  1705. )