join.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356
  1. import warnings
  2. from abc import ABC, abstractmethod
  3. from types import TracebackType
  4. from typing import Any, List, NamedTuple, Optional, Type
  5. import torch
  6. import torch.distributed as dist
  7. __all__ = ['JoinHook', 'Joinable', 'Join']
  8. class JoinHook():
  9. r"""
  10. This defines a join hook, which provides two entry points in the join
  11. context manager: a main hook, which is called repeatedly while there exists
  12. a non-joined process, and a post-hook, which is called once all processes
  13. have joined.
  14. To implement a join hook for the generic join context manager, define a
  15. class that inherits from :class:`JoinHook` and override ``main_hook()`` and
  16. ``post_hook()`` as appropriate.
  17. """
  18. def main_hook(self) -> None:
  19. r"""
  20. This hook is called repeatedly while there exists a non-joined process
  21. to shadow collective communications in one training iteration (i.e. in
  22. one forward pass, backward pass, and optimizer step).
  23. """
  24. ...
  25. def post_hook(self, is_last_joiner: bool) -> None:
  26. r"""
  27. This hook is called after all processes have joined. It is passed an
  28. additional ``bool`` argument ``is_last_joiner``, which indicates if the
  29. rank is one of the last to join.
  30. Arguments:
  31. is_last_joiner (bool): ``True`` if the rank is one of the last to
  32. join; ``False`` otherwise.
  33. """
  34. ...
  35. class Joinable(ABC):
  36. r"""
  37. This defines an abstract base class for joinable classes. A joinable class
  38. (inheriting from :class:`Joinable`) should implement :meth:`join_hook`,
  39. which returns a :class:`JoinHook` instance, in addition to
  40. :meth:`join_device` and :meth:`join_process_group` that return device and
  41. process group information, respectively.
  42. """
  43. @abstractmethod
  44. def __init__(self):
  45. super().__init__()
  46. self._join_config = _JoinConfig.construct_disabled_join_config()
  47. @abstractmethod
  48. def join_hook(self, **kwargs) -> JoinHook:
  49. r"""
  50. Returns a :class:`JoinHook` instance for the given :class:`Joinable`.
  51. Arguments:
  52. kwargs (dict): a :class:`dict` containing any keyword arguments
  53. to modify the behavior of the join hook at run time; all
  54. :class:`Joinable` instances sharing the same join context
  55. manager are forwarded the same value for ``kwargs``.
  56. """
  57. ...
  58. @property
  59. @abstractmethod
  60. def join_device(self) -> torch.device:
  61. r"""
  62. Returns the device from which to perform collective communications
  63. needed by the join context manager implementation itself.
  64. """
  65. ...
  66. @property
  67. @abstractmethod
  68. def join_process_group(self) -> Any:
  69. r"""
  70. Returns the process group for the collective communications needed by
  71. the join context manager itself.
  72. """
  73. ...
  74. class _JoinConfig(NamedTuple):
  75. r"""
  76. This includes all fields needed from a :class:`Joinable` instance for the
  77. join context manager side.
  78. """
  79. enable: bool
  80. throw_on_early_termination: bool
  81. is_first_joinable: bool
  82. @staticmethod
  83. def construct_disabled_join_config():
  84. r"""
  85. Returns a :class:`_JoinConfig` instance indicating that join-related
  86. logic should be disabled, e.g. if the caller is not in a join context
  87. manager.
  88. """
  89. return _JoinConfig(
  90. enable=False,
  91. throw_on_early_termination=False,
  92. is_first_joinable=False
  93. )
  94. class Join():
  95. r"""
  96. This class defines the generic join context manager, which allows custom
  97. hooks to be called after a process joins. These hooks should shadow the
  98. collective communications of non-joined processes to prevent hanging and
  99. erroring and to ensure algorithmic correctness. Refer to :class:`JoinHook`
  100. for details about the hook definition.
  101. .. warning::
  102. The context manager requires each participating :class:`Joinable` to
  103. call the method :meth:`notify_join_context()` before its own per-
  104. iteration collective communications to ensure correctness.
  105. .. warning::
  106. The context manager requires that all ``process_group`` attributes in
  107. the :class:`JoinHook` objects are the same. If there are multiple
  108. :class:`JoinHook` objects, then the ``device`` of the first is used.
  109. The process group and device information is used for checking for non-
  110. joined processes and for notifying processes to throw an exception if
  111. ``throw_on_early_termination`` is enabled, both of which using an all-
  112. reduce.
  113. Arguments:
  114. joinables (List[Joinable]): a list of the participating
  115. :class:`Joinable` s; their hooks are iterated over in the given
  116. order.
  117. enable (bool): a flag enabling uneven input detection; setting to
  118. ``False`` disables the context manager's functionality and should
  119. only be set when the user knows the inputs will not be uneven
  120. (default: ``True``).
  121. throw_on_early_termination (bool): a flag controlling whether to throw an
  122. exception upon detecting uneven inputs (default: ``False``).
  123. Example::
  124. >>> import os
  125. >>> import torch
  126. >>> import torch.distributed as dist
  127. >>> import torch.multiprocessing as mp
  128. >>> # xdoctest: +SKIP
  129. >>> import torch.nn.parallel.DistributedDataParallel as DDP
  130. >>> import torch.distributed.optim.ZeroRedundancyOptimizer as ZeRO
  131. >>> from torch.distributed.algorithms.join import Join
  132. >>>
  133. >>> # On each spawned worker
  134. >>> def worker(rank):
  135. >>> dist.init_process_group("nccl", rank=rank, world_size=2)
  136. >>> model = DDP(torch.nn.Linear(1, 1).to(rank), device_ids=[rank])
  137. >>> optim = ZeRO(model.parameters(), torch.optim.Adam, lr=0.01)
  138. >>> # Rank 1 gets one more input than rank 0
  139. >>> inputs = [torch.tensor([1.]).to(rank) for _ in range(10 + rank)]
  140. >>> with Join([model, optim]):
  141. >>> for input in inputs:
  142. >>> loss = model(input).sum()
  143. >>> loss.backward()
  144. >>> optim.step()
  145. >>> # All ranks reach here without hanging/erroring
  146. """
  147. def __init__(
  148. self,
  149. joinables: List[Joinable],
  150. enable: bool = True,
  151. throw_on_early_termination: bool = False,
  152. **kwargs,
  153. ):
  154. if len(joinables) == 0:
  155. raise ValueError("The join context manager requires at least one joinable")
  156. self._joinables = joinables
  157. self._join_hooks = [joinable.join_hook(**kwargs) for joinable in self._joinables]
  158. self._enable = enable
  159. self._throw_on_early_termination = throw_on_early_termination
  160. self._set_joinable_configs()
  161. self._extract_dist_info()
  162. def _set_joinable_configs(self) -> None:
  163. r"""
  164. Sets the :class:`_JoinConfig` of each participating :class:`Joinable`.
  165. """
  166. assert len(self._joinables) > 0
  167. is_first_joinable = True
  168. for joinable in self._joinables:
  169. joinable._join_config = _JoinConfig(
  170. enable=self._enable,
  171. throw_on_early_termination=self._throw_on_early_termination,
  172. is_first_joinable=is_first_joinable
  173. )
  174. is_first_joinable = False
  175. def _extract_dist_info(self) -> None:
  176. r"""
  177. Extracts the process group and device information from the joinables.
  178. If there are multiple joinables, then the context manager uses the
  179. first specified device.
  180. Preconditions:
  181. ``self._joinables`` is not ``None`` and is non-empty.
  182. Raises:
  183. ValueError
  184. If there are multiple conflicting ``process_group`` attributes
  185. among the ``Joinable`` objects.
  186. """
  187. process_group = None
  188. device = None
  189. for joinable in self._joinables:
  190. if process_group is None:
  191. process_group = joinable.join_process_group
  192. elif process_group != joinable.join_process_group:
  193. raise ValueError("Using join context manager with multiple process groups")
  194. if device is None:
  195. device = joinable.join_device
  196. self._process_group = process_group
  197. self._rank = dist.get_rank(self._process_group)
  198. self._device = device
  199. def __enter__(self):
  200. ...
  201. def __exit__(
  202. self,
  203. type: Optional[Type[BaseException]],
  204. value: Optional[BaseException],
  205. traceback: Optional[TracebackType]
  206. ):
  207. r"""
  208. Repeatedly runs the main hooks until all processes join; then, runs
  209. the post-hooks.
  210. Raises:
  211. RuntimeError
  212. If ``throw_on_early_termination=True``.
  213. """
  214. if not self._enable or type:
  215. return # propagate the exception directly if one was raised
  216. all_procs_joined = False
  217. is_last_joiner = True
  218. i = 0
  219. WARN_THRESHOLD = 1000
  220. warnings.simplefilter("once")
  221. while not all_procs_joined:
  222. if i > WARN_THRESHOLD:
  223. warnings.warn(
  224. "Detected uneven input skew of greater than "
  225. f"{WARN_THRESHOLD}. This means that rank "
  226. f"{self._rank} has at least {WARN_THRESHOLD} "
  227. f"fewer inputs than other currently-active ranks. "
  228. "This level of skew could lead to performance "
  229. "degradation during training."
  230. )
  231. # Shadow the all-reduce in non-joined processes
  232. num_nonjoined_procs = self._get_num_nonjoined_procs()
  233. if num_nonjoined_procs == 0:
  234. all_procs_joined = True
  235. else:
  236. if self._throw_on_early_termination:
  237. self._notify_procs_to_terminate()
  238. # Run main hooks
  239. for join_hook in self._join_hooks:
  240. join_hook.main_hook()
  241. is_last_joiner = False
  242. i += 1
  243. # Run post-hooks
  244. for join_hook in self._join_hooks:
  245. join_hook.post_hook(is_last_joiner)
  246. def _get_num_nonjoined_procs(self):
  247. r"""
  248. Returns the number of non-joined processes by shadowing an all-reduce
  249. in the non-joined processes.
  250. """
  251. num_nonjoined_procs = torch.zeros(1, device=self._device)
  252. dist.all_reduce(num_nonjoined_procs, group=self._process_group)
  253. return num_nonjoined_procs.item()
  254. def _notify_procs_to_terminate(self):
  255. r"""
  256. Schedules an all-reduce to notify non-joined processes to terminate
  257. and raises a ``RuntimeError`` indicating that the current process has
  258. exhausted its inputs.
  259. """
  260. ones = torch.ones(1, device=self._device)
  261. dist.all_reduce(ones, group=self._process_group)
  262. raise RuntimeError(f"Rank {self._rank} exhausted all inputs.")
  263. @staticmethod
  264. def notify_join_context(joinable: Joinable):
  265. r"""
  266. Notifies the join context manager that the calling process has not yet
  267. joined; then, if ``throw_on_early_termination=True``, checks if uneven
  268. inputs have been detected (i.e. if one process has already joined) and
  269. throws an exception if so.
  270. This method should be called from a :class:`Joinable` object before
  271. its per-iteration collective communications. For example, this should
  272. be called at the beginning of the forward pass in
  273. :class:`DistributedDataParallel`.
  274. Only the first :class:`Joinable` object passed into the context
  275. manager performs the collective communications in this method, and
  276. for the others, this method is vacuous.
  277. Arguments:
  278. joinable (Joinable): the :class:`Joinable` object calling this
  279. method.
  280. Returns:
  281. An async work handle for the all-reduce meant to notify the context
  282. manager that the process has not yet joined if ``joinable`` is the
  283. first one passed into the context manager; ``None`` otherwise.
  284. """
  285. assert hasattr(joinable, "_join_config"), \
  286. f"Check that the {type(joinable)} constructor calls the " \
  287. "``Joinable`` constructor"
  288. join_config = joinable._join_config
  289. # First joinable is responsible for the collective communications
  290. if not join_config.is_first_joinable or not join_config.enable:
  291. return None
  292. device = joinable.join_device
  293. process_group = joinable.join_process_group
  294. # Schedule an all-reduce to indicate that the caller has not yet joined
  295. ones = torch.ones(1, device=device)
  296. work = dist.all_reduce(ones, group=process_group, async_op=True)
  297. if join_config.throw_on_early_termination:
  298. # Check if uneven inputs have been detected
  299. zeros = torch.zeros(1, device=device)
  300. dist.all_reduce(zeros, group=process_group)
  301. should_throw = zeros.item()
  302. if should_throw:
  303. raise RuntimeError(
  304. "Detected at least one rank that exhausted inputs. "
  305. "Throwing across all ranks."
  306. )
  307. return work