distributed_c10d.py 145 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131213221332134213521362137213821392140214121422143214421452146214721482149215021512152215321542155215621572158215921602161216221632164216521662167216821692170217121722173217421752176217721782179218021812182218321842185218621872188218921902191219221932194219521962197219821992200220122022203220422052206220722082209221022112212221322142215221622172218221922202221222222232224222522262227222822292230223122322233223422352236223722382239224022412242224322442245224622472248224922502251225222532254225522562257225822592260226122622263226422652266226722682269227022712272227322742275227622772278227922802281228222832284228522862287228822892290229122922293229422952296229722982299230023012302230323042305230623072308230923102311231223132314231523162317231823192320232123222323232423252326232723282329233023312332233323342335233623372338233923402341234223432344234523462347234823492350235123522353235423552356235723582359236023612362236323642365236623672368236923702371237223732374237523762377237823792380238123822383238423852386238723882389239023912392239323942395239623972398239924002401240224032404240524062407240824092410241124122413241424152416241724182419242024212422242324242425242624272428242924302431243224332434243524362437243824392440244124422443244424452446244724482449245024512452245324542455245624572458245924602461246224632464246524662467246824692470247124722473247424752476247724782479248024812482248324842485248624872488248924902491249224932494249524962497249824992500250125022503250425052506250725082509251025112512251325142515251625172518251925202521252225232524252525262527252825292530253125322533253425352536253725382539254025412542254325442545254625472548254925502551255225532554255525562557255825592560256125622563256425652566256725682569257025712572257325742575257625772578257925802581258225832584258525862587258825892590259125922593259425952596259725982599260026012602260326042605260626072608260926102611261226132614261526162617261826192620262126222623262426252626262726282629263026312632263326342635263626372638263926402641264226432644264526462647264826492650265126522653265426552656265726582659266026612662266326642665266626672668266926702671267226732674267526762677267826792680268126822683268426852686268726882689269026912692269326942695269626972698269927002701270227032704270527062707270827092710271127122713271427152716271727182719272027212722272327242725272627272728272927302731273227332734273527362737273827392740274127422743274427452746274727482749275027512752275327542755275627572758275927602761276227632764276527662767276827692770277127722773277427752776277727782779278027812782278327842785278627872788278927902791279227932794279527962797279827992800280128022803280428052806280728082809281028112812281328142815281628172818281928202821282228232824282528262827282828292830283128322833283428352836283728382839284028412842284328442845284628472848284928502851285228532854285528562857285828592860286128622863286428652866286728682869287028712872287328742875287628772878287928802881288228832884288528862887288828892890289128922893289428952896289728982899290029012902290329042905290629072908290929102911291229132914291529162917291829192920292129222923292429252926292729282929293029312932293329342935293629372938293929402941294229432944294529462947294829492950295129522953295429552956295729582959296029612962296329642965296629672968296929702971297229732974297529762977297829792980298129822983298429852986298729882989299029912992299329942995299629972998299930003001300230033004300530063007300830093010301130123013301430153016301730183019302030213022302330243025302630273028302930303031303230333034303530363037303830393040304130423043304430453046304730483049305030513052305330543055305630573058305930603061306230633064306530663067306830693070307130723073307430753076307730783079308030813082308330843085308630873088308930903091309230933094309530963097309830993100310131023103310431053106310731083109311031113112311331143115311631173118311931203121312231233124312531263127312831293130313131323133313431353136313731383139314031413142314331443145314631473148314931503151315231533154315531563157315831593160316131623163316431653166316731683169317031713172317331743175317631773178317931803181318231833184318531863187318831893190319131923193319431953196319731983199320032013202320332043205320632073208320932103211321232133214321532163217321832193220322132223223322432253226322732283229323032313232323332343235323632373238323932403241324232433244324532463247324832493250325132523253325432553256325732583259326032613262326332643265326632673268326932703271327232733274327532763277327832793280328132823283328432853286328732883289329032913292329332943295329632973298329933003301330233033304330533063307330833093310331133123313331433153316331733183319332033213322332333243325332633273328332933303331333233333334333533363337333833393340334133423343334433453346334733483349335033513352335333543355335633573358335933603361336233633364336533663367336833693370337133723373337433753376337733783379338033813382338333843385338633873388338933903391339233933394339533963397339833993400340134023403340434053406340734083409341034113412341334143415341634173418341934203421342234233424342534263427342834293430343134323433343434353436343734383439344034413442344334443445344634473448344934503451345234533454345534563457345834593460346134623463346434653466346734683469347034713472347334743475347634773478347934803481348234833484348534863487348834893490349134923493349434953496349734983499350035013502350335043505350635073508350935103511351235133514351535163517351835193520352135223523352435253526352735283529353035313532353335343535353635373538353935403541354235433544354535463547354835493550355135523553355435553556355735583559356035613562356335643565356635673568356935703571357235733574357535763577357835793580358135823583358435853586358735883589359035913592359335943595359635973598359936003601360236033604360536063607360836093610361136123613361436153616361736183619362036213622362336243625362636273628362936303631363236333634363536363637363836393640364136423643364436453646364736483649365036513652365336543655365636573658365936603661366236633664366536663667366836693670367136723673367436753676367736783679368036813682368336843685368636873688368936903691369236933694369536963697369836993700370137023703370437053706370737083709371037113712371337143715371637173718371937203721372237233724372537263727372837293730373137323733373437353736373737383739374037413742374337443745374637473748374937503751375237533754375537563757375837593760376137623763376437653766376737683769
  1. import itertools
  2. import collections.abc
  3. import contextlib
  4. import functools
  5. import io
  6. import logging
  7. import os
  8. import pickle
  9. import time
  10. import warnings
  11. from collections import namedtuple
  12. from datetime import timedelta
  13. from typing import Any, Dict, Optional, Tuple, Union
  14. import torch
  15. from torch._C._distributed_c10d import (
  16. AllreduceCoalescedOptions,
  17. AllreduceOptions,
  18. AllToAllOptions,
  19. _DistributedBackendOptions,
  20. BarrierOptions,
  21. BroadcastOptions,
  22. GatherOptions,
  23. PrefixStore,
  24. ProcessGroup,
  25. ReduceOp,
  26. ReduceOptions,
  27. ReduceScatterOptions,
  28. ScatterOptions,
  29. Store,
  30. DebugLevel,
  31. get_debug_level,
  32. Work
  33. )
  34. from torch.autograd.profiler import record_function
  35. from .constants import default_pg_timeout
  36. from .c10d_error_logger import _get_or_create_logger
  37. from .rendezvous import register_rendezvous_handler, rendezvous # noqa: F401
  38. __all__ = [
  39. 'Backend', 'BackendConfig', 'GroupMember', 'P2POp', 'all_gather', 'all_gather_coalesced',
  40. 'all_gather_multigpu', 'all_gather_object', 'all_reduce',
  41. 'all_reduce_coalesced', 'all_reduce_multigpu', 'all_to_all',
  42. 'all_to_all_single', 'barrier', 'batch_isend_irecv', 'broadcast',
  43. 'broadcast_multigpu', 'broadcast_object_list', 'destroy_process_group',
  44. 'dist_backend', 'gather', 'gather_object', 'get_backend_config', 'get_backend', 'get_rank',
  45. 'get_world_size', 'group', 'init_process_group', 'irecv',
  46. 'is_gloo_available', 'is_initialized', 'is_mpi_available',
  47. 'is_nccl_available', 'is_torchelastic_launched', 'is_ucc_available',
  48. 'isend', 'monitored_barrier', 'new_group', 'new_subgroups',
  49. 'new_subgroups_by_enumeration', 'recv', 'reduce', 'reduce_multigpu',
  50. 'reduce_scatter', 'reduce_scatter_multigpu', 'scatter',
  51. 'scatter_object_list', 'send', 'supports_complex',
  52. 'AllreduceCoalescedOptions', 'AllreduceOptions', 'AllToAllOptions',
  53. 'BarrierOptions', 'BroadcastOptions', 'GatherOptions', 'PrefixStore',
  54. 'ProcessGroup', 'ReduceOp', 'ReduceOptions', 'ReduceScatterOptions',
  55. 'ScatterOptions', 'Store', 'DebugLevel', 'get_debug_level', 'Work',
  56. 'default_pg_timeout', 'get_group_rank', 'get_global_rank', 'get_process_group_ranks',
  57. 'reduce_op', 'all_gather_into_tensor', 'reduce_scatter_tensor', 'exception_handler'
  58. ]
  59. _MPI_AVAILABLE = True
  60. _NCCL_AVAILABLE = True
  61. _GLOO_AVAILABLE = True
  62. _UCC_AVAILABLE = True
  63. _pickler = pickle.Pickler
  64. _unpickler = pickle.Unpickler
  65. # Change __module__ of all imported types from torch._C._distributed_c10d that are public
  66. def _export_c_types():
  67. _public_types_to_change_module = [
  68. AllreduceCoalescedOptions,
  69. AllreduceOptions,
  70. AllToAllOptions,
  71. BarrierOptions,
  72. BroadcastOptions,
  73. GatherOptions,
  74. PrefixStore,
  75. ProcessGroup,
  76. ReduceOp,
  77. ReduceOptions,
  78. ReduceScatterOptions,
  79. ScatterOptions,
  80. Store,
  81. DebugLevel,
  82. get_debug_level,
  83. Work
  84. ]
  85. for type in _public_types_to_change_module:
  86. type.__module__ = "torch.distributed.distributed_c10d"
  87. _export_c_types()
  88. try:
  89. from torch._C._distributed_c10d import ProcessGroupMPI
  90. ProcessGroupMPI.__module__ = "torch.distributed.distributed_c10d"
  91. __all__ += ["ProcessGroupMPI"]
  92. except ImportError:
  93. _MPI_AVAILABLE = False
  94. try:
  95. from torch._C._distributed_c10d import ProcessGroupNCCL
  96. ProcessGroupNCCL.__module__ = "torch.distributed.distributed_c10d"
  97. __all__ += ["ProcessGroupNCCL"]
  98. except ImportError:
  99. _NCCL_AVAILABLE = False
  100. try:
  101. from torch._C._distributed_c10d import ProcessGroupGloo
  102. from torch._C._distributed_c10d import _ProcessGroupWrapper
  103. ProcessGroupGloo.__module__ = "torch.distributed.distributed_c10d"
  104. __all__ += ["ProcessGroupGloo"]
  105. except ImportError:
  106. _GLOO_AVAILABLE = False
  107. try:
  108. from torch._C._distributed_c10d import ProcessGroupUCC
  109. ProcessGroupUCC.__module__ = "torch.distributed.distributed_c10d"
  110. __all__ += ["ProcessGroupUCC"]
  111. except ImportError:
  112. _UCC_AVAILABLE = False
  113. logger = logging.getLogger(__name__)
  114. global _c10d_error_logger
  115. _c10d_error_logger = _get_or_create_logger()
  116. PG_WRAPPER_STORE_PREFIX = "pg_wrapper"
  117. # Some reduce ops are not supported by complex numbers and will result in an error.
  118. # We currently provide complex support to the distributed API by viewing
  119. # complex tensors as real (torch.view_as_real), meaning that calling
  120. # these unsupported ops will return garbage values rather than error out.
  121. # (e.g. max(2+3i, 3+2i) = 3+3i)
  122. # We'd like calls to unsupported ops to error out accordingly,
  123. # rather than returning garbage values.
  124. def supports_complex(reduceOp: ReduceOp) -> bool:
  125. denyList = [
  126. ReduceOp.MAX,
  127. ReduceOp.MIN,
  128. ReduceOp.PRODUCT,
  129. ReduceOp.BAND,
  130. ReduceOp.BOR,
  131. ReduceOp.BXOR,
  132. ]
  133. return reduceOp not in denyList
  134. class Backend:
  135. """
  136. An enum-like class of available backends: GLOO, NCCL, UCC, MPI, and other registered
  137. backends.
  138. The values of this class are lowercase strings, e.g., ``"gloo"``. They can
  139. be accessed as attributes, e.g., ``Backend.NCCL``.
  140. This class can be directly called to parse the string, e.g.,
  141. ``Backend(backend_str)`` will check if ``backend_str`` is valid, and
  142. return the parsed lowercase string if so. It also accepts uppercase strings,
  143. e.g., ``Backend("GLOO")`` returns ``"gloo"``.
  144. .. note:: The entry ``Backend.UNDEFINED`` is present but only used as
  145. initial value of some fields. Users should neither use it directly
  146. nor assume its existence.
  147. """
  148. UNDEFINED = "undefined"
  149. GLOO = "gloo"
  150. NCCL = "nccl"
  151. UCC = "ucc"
  152. MPI = "mpi"
  153. _BackendPlugin = namedtuple("_BackendPlugin", ["creator_fn", "extended_api"])
  154. _plugins: Dict[str, _BackendPlugin] = {}
  155. backend_list = [UNDEFINED, GLOO, NCCL, UCC, MPI]
  156. def __new__(cls, name: str):
  157. if not isinstance(name, str):
  158. raise ValueError("Backend name must be a string, but got: {}".format(name))
  159. value = getattr(Backend, name.upper(), Backend.UNDEFINED)
  160. if value != Backend.GLOO and value != Backend.NCCL and value != Backend.UCC and value != Backend.MPI:
  161. value = name.lower()
  162. return value
  163. @classmethod
  164. def register_backend(cls, name, func, extended_api=False):
  165. """
  166. Registers a new backend with the given name and instantiating function.
  167. This class method is used by 3rd party ``ProcessGroup`` extension to
  168. register new backends.
  169. Args:
  170. name (str): Backend name of the ``ProcessGroup`` extension. It
  171. should match the one in ``init_process_group()``.
  172. func (function): Function handler that instantiates the backend.
  173. The function should be implemented in the backend
  174. extension and takes four arguments, including
  175. ``store``, ``rank``, ``world_size``, and ``timeout``.
  176. extended_api (bool, optional): Whether the backend supports extended argument structure.
  177. Default: ``False``. If set to ``True``, the backend
  178. will get an instance of ``c10d::DistributedBackendOptions``, and
  179. a process group options object as defined by the backend implementation.
  180. .. note:: This support of 3rd party backend is experimental and subject to change.
  181. """
  182. # Allow UCC plugin if Pytorch is not built with native support.
  183. # TODO: remove this exception once UCC plugin is fully deprecated.
  184. if (name != Backend.UCC or (name == Backend.UCC and is_ucc_available())):
  185. assert not hasattr(Backend, name.upper()), (
  186. f"{name.upper()} c10d backend already exist"
  187. )
  188. assert name.upper() not in Backend._plugins, (
  189. f"{name.upper()} c10d backend creator function already exist"
  190. )
  191. setattr(Backend, name.upper(), name.upper())
  192. Backend.backend_list.append(name.lower())
  193. Backend._plugins[name.upper()] = Backend._BackendPlugin(func, extended_api)
  194. class BackendConfig:
  195. def __init__(self, backend: Union[str, Backend]):
  196. self.device_backend_map: Dict[torch.device, Backend] = {}
  197. # error check to make sure the config string is valid
  198. # Cases for when backend is a single string (without device types)
  199. if backend == Backend.UNDEFINED:
  200. # default config when backend is not specified
  201. self.device_backend_map = {
  202. "cpu": Backend.GLOO,
  203. "cuda": Backend.NCCL,
  204. }
  205. elif backend.lower() in Backend.backend_list:
  206. # backend applies to all devices (e.g. "NCCL", "GLOO", "UCC", "MPI", "custom_backend")
  207. backend_val = Backend(backend)
  208. self.device_backend_map = {
  209. "cpu": backend_val,
  210. "cuda": backend_val,
  211. }
  212. else:
  213. # custom backend string in format of "{device_type1}:{backend1},{device_type2}:{backend2}"
  214. # TODO
  215. pass
  216. required_devices = ["cpu", "cuda"]
  217. for device in required_devices:
  218. assert device in self.device_backend_map
  219. def __repr__(self):
  220. # string with all the device:backend pairs separared by commas
  221. return ",".join(f"{device}:{backend}" for device, backend in self.device_backend_map.items())
  222. def get_device_backend_map(self):
  223. return self.device_backend_map
  224. # `_backend`, `dist_backend`, and `reduce_op` are here to maintain backward
  225. # compatibility with pre-c10d distributed package.
  226. # TODO: remove them when users are ready to take a hard dependency on PyTorch 1.
  227. _backend: str = Backend.UNDEFINED
  228. dist_backend = Backend
  229. class _reduce_op:
  230. r"""
  231. Deprecated enum-like class for reduction operations: ``SUM``, ``PRODUCT``,
  232. ``MIN``, and ``MAX``.
  233. :class:`~torch.distributed.ReduceOp` is recommended to use instead.
  234. """
  235. def __init__(self):
  236. # __members__ is a dict storing key-value pairs for enum classes
  237. for k, v in ReduceOp.RedOpType.__members__.items():
  238. setattr(self, k, v)
  239. self.__members__ = ReduceOp.RedOpType.__members__
  240. def __getattribute__(self, key):
  241. warnings.warn(
  242. "torch.distributed.reduce_op is deprecated, please use "
  243. "torch.distributed.ReduceOp instead"
  244. )
  245. return object.__getattribute__(self, key)
  246. reduce_op = _reduce_op()
  247. # DO NOT USE THESE FIELDS DIRECTLY.
  248. # Use them through the _world object to make sure the _world override mechanism
  249. _pg_map: Dict[ProcessGroup, Tuple[str, Optional[Store]]] = {}
  250. _pg_names: Dict[ProcessGroup, str] = {}
  251. _pg_group_ranks: Dict[ProcessGroup, Dict[int, int]] = {}
  252. # For a pg, it is a map from ProcessGroup to BackendConfig
  253. _pg_backend_config: Dict[ProcessGroup, str] = {}
  254. _group_count = 0
  255. class _World:
  256. """
  257. Container class for c10d process group state.
  258. This is used during registration and lookup of PG state.
  259. .. warning:: This is an experimental API inteded to expose the inner workings
  260. of c10d and is subject to change..
  261. """
  262. def __init__(self):
  263. self._default_pg = None
  264. @property
  265. def default_pg(self):
  266. """
  267. The default ProcessGroup includes all ranks of the cluster.
  268. This is used by c10d APIs when a ProcessGroup is needed but None is provided.
  269. """
  270. return self._default_pg
  271. @default_pg.setter
  272. def default_pg(self, value):
  273. self._default_pg = value
  274. @property
  275. def pg_map(self) -> Dict[ProcessGroup, Tuple[str, Optional[Store]]]:
  276. """
  277. Cached process groups
  278. For NCCL and GLOO pg, it is a map from ProcessGroup to (Backend, Store)
  279. For MPI pg, it is a map from ProcessGroup to (Backend, None)
  280. TODO don't expose the map, expose fine grained ops
  281. """
  282. global _pg_map
  283. return _pg_map
  284. @property
  285. def pg_names(self) -> Dict[ProcessGroup, str]:
  286. """
  287. Process group's names, map from ProcessGroup to str.
  288. TODO don't expose the map, expose fine grained ops
  289. """
  290. global _pg_names
  291. return _pg_names
  292. @property
  293. def pg_group_ranks(self) -> Dict[ProcessGroup, Dict[int, int]]:
  294. """
  295. Process group's global rank to local rank mapping
  296. TODO don't expose the map, expose fine grained ops
  297. """
  298. global _pg_group_ranks
  299. return _pg_group_ranks
  300. @property
  301. def pg_backend_config(self) -> Dict[ProcessGroup, str]:
  302. """
  303. Process group's backend config
  304. TODO don't expose the map, expose fine grained ops
  305. """
  306. global _pg_backend_config
  307. return _pg_backend_config
  308. @property
  309. def group_count(self) -> int:
  310. """
  311. Process group count for default naming.
  312. TODO don't expose group_count, use something else instead
  313. """
  314. global _group_count
  315. return _group_count
  316. @group_count.setter
  317. def group_count(self, value):
  318. """
  319. Count is used when computing the name of ProcessGroups when using global synchronization.
  320. """
  321. global _group_count
  322. _group_count = value
  323. _world = _World()
  324. """Holds the singleton instance of ``_World`` used by c10. Experimental extension point to override it"""
  325. class _WorldMeta(type):
  326. """
  327. Meta class of ``group`` and ``GroupMember`` so they
  328. can have the class property ``WORLD``.
  329. """
  330. # Points to the default PG once initialized.
  331. @property
  332. def WORLD(cls) -> Optional[ProcessGroup]:
  333. return _world.default_pg
  334. @WORLD.setter
  335. def WORLD(cls, pg: Optional[ProcessGroup]):
  336. _world.default_pg = pg
  337. class group(metaclass=_WorldMeta):
  338. pass
  339. class GroupMember(metaclass=_WorldMeta):
  340. NON_GROUP_MEMBER = object()
  341. # Default process group state
  342. _default_pg_init_method = None
  343. STORE_BASED_BARRIER_PREFIX = "store_based_barrier_key"
  344. def _get_pg_device(group: ProcessGroup):
  345. """
  346. Returns the device to use with ``group``.
  347. This is cuda for NCCL and CPU for everything else
  348. """
  349. if _check_for_nccl_backend(group):
  350. return torch.device("cuda", torch.cuda.current_device())
  351. return torch.device("cpu")
  352. def _store_based_barrier(rank, store, timeout):
  353. """
  354. Barrier based on store which is used for synchronizing processes after
  355. ``init_process_group`` or ``new_group``. Intended to be used only with
  356. those two methods and is not a generic alternative to ``barrier()``.
  357. """
  358. store_key = "{}:{}".format(STORE_BASED_BARRIER_PREFIX, _world.group_count)
  359. store.add(store_key, 1)
  360. logger.info("Added key: {} to store for rank: {}".format(store_key, rank))
  361. # Now wait for all workers to check in with the store.
  362. world_size = get_world_size()
  363. # Use 'add' instead of 'get' since for some store implementations 'add'
  364. # doesn't work well with 'get'. Ideally the store implementations should
  365. # be fixed, but for backward compatiblity reasons it is risky to change
  366. # the store implementations. Once, we completely migrate away from these
  367. # legacy stores, we can use 'get' here instead.
  368. worker_count = store.add(store_key, 0)
  369. start = time.time()
  370. log_time = time.time()
  371. while worker_count != world_size:
  372. time.sleep(0.01)
  373. worker_count = store.add(store_key, 0)
  374. # Print status periodically to keep track.
  375. if timedelta(seconds=(time.time() - log_time)) > timedelta(seconds=10):
  376. logger.info(
  377. "Waiting in store based barrier to initialize process group for "
  378. "rank: {}, key: {} (world_size={}, worker_count={}, timeout={})".format(
  379. rank, store_key, world_size, worker_count, timeout
  380. )
  381. )
  382. log_time = time.time()
  383. if timedelta(seconds=(time.time() - start)) > timeout:
  384. raise RuntimeError(
  385. "Timed out initializing process group in store based barrier on "
  386. "rank: {}, for key: {} (world_size={}, worker_count={}, timeout={})".format(
  387. rank, store_key, world_size, worker_count, timeout
  388. )
  389. )
  390. logger.info(
  391. f"Rank {rank}: Completed store-based barrier for key:{store_key} with {world_size} nodes."
  392. )
  393. def _rank_not_in_group(group: ProcessGroup):
  394. """
  395. Helper that checks if the current process's rank is not in a given group.
  396. """
  397. if group is None:
  398. return False
  399. return group == GroupMember.NON_GROUP_MEMBER
  400. def _warn_not_in_group(op_name):
  401. global_rank = -1 if GroupMember.WORLD is None else GroupMember.WORLD.rank()
  402. warnings.warn(
  403. f"Running {op_name} on global rank {global_rank} which does not "
  404. "belong to the given group."
  405. )
  406. def get_group_rank(group: ProcessGroup, global_rank: int) -> int:
  407. """
  408. Translate a global rank into a group rank.
  409. ``global_rank`` must be part of ``group`` otherwise this raises RuntimeError.
  410. Args:
  411. group (ProcessGroup): ProcessGroup to find the relative rank.
  412. global_rank (int): Global rank to query.
  413. Returns:
  414. Group rank of ``global_rank`` relative to ``group``
  415. N.B. calling this function on the default process group returns identity
  416. """
  417. if group is GroupMember.WORLD:
  418. return global_rank
  419. if group not in _world.pg_group_ranks:
  420. raise RuntimeError(f"Group {group} is not registered, please create group with torch.distributed.new_group API")
  421. group_ranks = _world.pg_group_ranks[group]
  422. if global_rank not in group_ranks:
  423. raise RuntimeError(f"Global rank {global_rank} is not part of group {group}")
  424. return group_ranks[global_rank]
  425. def get_global_rank(group: ProcessGroup, group_rank: int) -> int:
  426. """
  427. Translate a group rank into a global rank.
  428. ``group_rank`` must be part of `group` otherwise this raises RuntimeError.
  429. Args:
  430. group (ProcessGroup): ProcessGroup to find the global rank from.
  431. group_rank (int): Group rank to query.
  432. Returns:
  433. Global rank of ``group_rank`` relative to ``group``
  434. N.B. calling this function on the default process group returns identity
  435. """
  436. if group is GroupMember.WORLD:
  437. return group_rank
  438. if group not in _world.pg_group_ranks:
  439. raise RuntimeError(f"Group {group} is not registered, please create group with torch.distributed.new_group API")
  440. for rank, grp_rank in _world.pg_group_ranks[group].items():
  441. if grp_rank == group_rank:
  442. return rank
  443. raise RuntimeError(f"Group rank {group_rank} is not part of group {group}")
  444. # TODO: remove this once the ecosystem moves away from it.
  445. def _get_global_rank(group, rank):
  446. """
  447. This method is deprecated, please use get_global_rank.
  448. """
  449. warnings.warn(
  450. "torch.distributed.distributed_c10d._get_global_rank is deprecated "
  451. "please use torch.distributed.distributed_c10d.get_global_rank instead"
  452. )
  453. return get_global_rank(group, rank)
  454. def get_process_group_ranks(group: ProcessGroup):
  455. """
  456. Get all ranks associated with ``group``.
  457. Args:
  458. group (ProcessGroup): ProcessGroup to get all ranks from.
  459. Returns:
  460. List of global ranks ordered by group rank.
  461. """
  462. return list(_world.pg_group_ranks[group].keys())
  463. def _get_group_size(group):
  464. """
  465. Helper that gets a given group's world size.
  466. """
  467. if group is GroupMember.WORLD or group is None:
  468. default_pg = _get_default_group()
  469. return default_pg.size()
  470. return group.size()
  471. def _check_single_tensor(param, param_name):
  472. """
  473. Helper to check that the parameter ``param_name`` is a single tensor.
  474. """
  475. if not isinstance(param, torch.Tensor):
  476. raise RuntimeError(
  477. "Invalid function argument. Expected parameter `{}` "
  478. "to be of type torch.Tensor.".format(param_name)
  479. )
  480. def _check_tensor_list(param, param_name):
  481. """
  482. Helper to check that the parameter ``param_name`` is a list of tensors.
  483. """
  484. if not isinstance(param, list) or not all(
  485. isinstance(p, torch.Tensor) for p in param
  486. ):
  487. raise RuntimeError(
  488. "Invalid function argument. Expected parameter `{}` "
  489. "to be of type List[torch.Tensor].".format(param_name)
  490. )
  491. def _as_iterable(obj) -> collections.abc.Iterable:
  492. return obj if isinstance(obj, list) else (obj,)
  493. def _ensure_all_tensors_same_dtype(*tensors) -> None:
  494. last_dtype = None
  495. for tensor in itertools.chain(*map(_as_iterable, tensors)):
  496. tensor_dtype = tensor.dtype
  497. # Mixing complex and its element type is allowed
  498. if tensor_dtype.is_complex:
  499. tensor_dtype = torch.float32 if tensor_dtype == torch.complex64 else torch.complex128
  500. if last_dtype is None:
  501. last_dtype = tensor_dtype
  502. else:
  503. if last_dtype != tensor_dtype:
  504. raise RuntimeError(
  505. "Invalid usage of tensors with different dtypes"
  506. f"Found {last_dtype} and {tensor.dtype}"
  507. )
  508. def _check_op(op):
  509. """
  510. Helper to check that the ``op`` is either isend or irecv.
  511. """
  512. if op not in [isend, irecv]:
  513. raise RuntimeError(
  514. "Invalid ``op``. Expected ``op`` "
  515. "to be of type ``torch.distributed.isend`` or "
  516. "``torch.distributed.irecv``."
  517. )
  518. def _check_p2p_op_list(p2p_op_list):
  519. """
  520. Helper to check that the ``p2p_op_list`` is a list of P2POp instances and
  521. all ops use the same group.
  522. """
  523. if not isinstance(p2p_op_list, list) or not all(
  524. isinstance(p2p_op, P2POp) for p2p_op in p2p_op_list
  525. ):
  526. raise RuntimeError(
  527. "Invalid ``p2p_op_list``. Each op is expected to "
  528. "to be of type ``torch.distributed.P2POp``."
  529. )
  530. group = p2p_op_list[0].group
  531. if not all(group == p2p_op.group for p2p_op in p2p_op_list):
  532. raise RuntimeError("All ops need to use the same group.")
  533. def is_mpi_available() -> bool:
  534. """
  535. Checks if the MPI backend is available.
  536. """
  537. return _MPI_AVAILABLE
  538. def is_nccl_available() -> bool:
  539. """
  540. Checks if the NCCL backend is available.
  541. """
  542. return _NCCL_AVAILABLE
  543. def is_gloo_available() -> bool:
  544. """
  545. Checks if the Gloo backend is available.
  546. """
  547. return _GLOO_AVAILABLE
  548. def is_ucc_available() -> bool:
  549. """
  550. Checks if the UCC backend is available.
  551. """
  552. return _UCC_AVAILABLE
  553. def is_initialized() -> bool:
  554. """
  555. Checking if the default process group has been initialized
  556. """
  557. return GroupMember.WORLD is not None
  558. def is_torchelastic_launched() -> bool:
  559. """
  560. Checks whether this process was launched with ``torch.distributed.elastic``
  561. (aka torchelastic). The existence of ``TORCHELASTIC_RUN_ID`` environment
  562. variable is used as a proxy to determine whether the current process
  563. was launched with torchelastic. This is a reasonable proxy since
  564. ``TORCHELASTIC_RUN_ID`` maps to the rendezvous id which is always a
  565. non-null value indicating the job id for peer discovery purposes..
  566. """
  567. return os.getenv("TORCHELASTIC_RUN_ID") is not None
  568. def _get_default_group():
  569. """
  570. Getting the default process group created by init_process_group
  571. """
  572. if not is_initialized():
  573. raise RuntimeError(
  574. "Default process group has not been initialized, "
  575. "please make sure to call init_process_group."
  576. )
  577. return GroupMember.WORLD
  578. def _get_default_store():
  579. """
  580. Getting the default store created by init_process_group
  581. """
  582. if not is_initialized():
  583. raise RuntimeError(
  584. "Default process group has not been initialized, "
  585. "please make sure to call init_process_group."
  586. )
  587. default_pg = _get_default_group()
  588. _, default_store = _world.pg_map[default_pg]
  589. return default_store
  590. def _update_default_pg(pg):
  591. _world.default_pg = pg
  592. def get_backend_config(group: Optional[ProcessGroup] = None) -> str:
  593. if group is None:
  594. pg = _get_default_group()
  595. else:
  596. pg = group
  597. if _rank_not_in_group(pg):
  598. raise RuntimeError("Invalid process group specified")
  599. backend_config = _world.pg_backend_config.get(pg)
  600. assert backend_config is not None
  601. return str(backend_config)
  602. def get_backend(group: Optional[ProcessGroup] = None) -> str:
  603. """
  604. Returns the backend of the given process group.
  605. Args:
  606. group (ProcessGroup, optional): The process group to work on. The
  607. default is the general main process group. If another specific group
  608. is specified, the calling process must be part of :attr:`group`.
  609. Returns:
  610. The backend of the given process group as a lower case string.
  611. """
  612. if group is None:
  613. pg = _get_default_group()
  614. else:
  615. pg = group
  616. if _rank_not_in_group(pg):
  617. raise RuntimeError("Invalid process group specified")
  618. pg_store = _world.pg_map.get(pg, None)
  619. assert pg_store is not None
  620. return pg_store[0]
  621. def init_process_group(
  622. backend: Union[str, Backend] = None,
  623. init_method: Optional[str] = None,
  624. timeout: timedelta = default_pg_timeout,
  625. world_size: int = -1,
  626. rank: int = -1,
  627. store: Optional[Store] = None,
  628. group_name: str = "",
  629. pg_options: Optional[Any] = None,
  630. ):
  631. """
  632. Initializes the default distributed process group, and this will also
  633. initialize the distributed package.
  634. There are 2 main ways to initialize a process group:
  635. 1. Specify ``store``, ``rank``, and ``world_size`` explicitly.
  636. 2. Specify ``init_method`` (a URL string) which indicates where/how
  637. to discover peers. Optionally specify ``rank`` and ``world_size``,
  638. or encode all required parameters in the URL and omit them.
  639. If neither is specified, ``init_method`` is assumed to be "env://".
  640. Args:
  641. backend (str or Backend, optional): The backend to use. Depending on
  642. build-time configurations, valid values include ``mpi``, ``gloo``,
  643. ``nccl``, and ``ucc``. If the backend is not provied, then both a ``gloo``
  644. and ``nccl`` backend will be created, see notes below for how multiple
  645. backends are managed. This field can be given as a lowercase string
  646. (e.g., ``"gloo"``), which can also be accessed via
  647. :class:`Backend` attributes (e.g., ``Backend.GLOO``). If using
  648. multiple processes per machine with ``nccl`` backend, each process
  649. must have exclusive access to every GPU it uses, as sharing GPUs
  650. between processes can result in deadlocks. ``ucc`` backend is
  651. experimental.
  652. init_method (str, optional): URL specifying how to initialize the
  653. process group. Default is "env://" if no
  654. ``init_method`` or ``store`` is specified.
  655. Mutually exclusive with ``store``.
  656. world_size (int, optional): Number of processes participating in
  657. the job. Required if ``store`` is specified.
  658. rank (int, optional): Rank of the current process (it should be a
  659. number between 0 and ``world_size``-1).
  660. Required if ``store`` is specified.
  661. store(Store, optional): Key/value store accessible to all workers, used
  662. to exchange connection/address information.
  663. Mutually exclusive with ``init_method``.
  664. timeout (timedelta, optional): Timeout for operations executed against
  665. the process group. Default value equals 30 minutes.
  666. This is applicable for the ``gloo`` backend. For ``nccl``, this is
  667. applicable only if the environment variable ``NCCL_BLOCKING_WAIT``
  668. or ``NCCL_ASYNC_ERROR_HANDLING`` is set to 1. When
  669. ``NCCL_BLOCKING_WAIT`` is set, this is the duration for which the
  670. process will block and wait for collectives to complete before
  671. throwing an exception. When ``NCCL_ASYNC_ERROR_HANDLING`` is set,
  672. this is the duration after which collectives will be aborted
  673. asynchronously and the process will crash. ``NCCL_BLOCKING_WAIT``
  674. will provide errors to the user which can be caught and handled,
  675. but due to its blocking nature, it has a performance overhead. On
  676. the other hand, ``NCCL_ASYNC_ERROR_HANDLING`` has very little
  677. performance overhead, but crashes the process on errors. This is
  678. done since CUDA execution is async and it is no longer safe to
  679. continue executing user code since failed async NCCL operations
  680. might result in subsequent CUDA operations running on corrupted
  681. data. Only one of these two environment variables should be set.
  682. For ``ucc``, blocking wait is supported similar to NCCL. However,
  683. async error handling is done differently since with UCC we have
  684. progress thread and not watch-dog thread.
  685. group_name (str, optional, deprecated): Group name.
  686. pg_options (ProcessGroupOptions, optional): process group options
  687. specifying what additional options need to be passed in during
  688. the construction of specific process groups. As of now, the only
  689. options we support is ``ProcessGroupNCCL.Options`` for the ``nccl``
  690. backend, ``is_high_priority_stream`` can be specified so that
  691. the nccl backend can pick up high priority cuda streams when
  692. there're compute kernels waiting.
  693. .. note:: To enable ``backend == Backend.MPI``, PyTorch needs to be built from source
  694. on a system that supports MPI.
  695. .. note:: Support for multiple backends is experimental. Currently when no backend is
  696. specified, both ``gloo`` and ``nccl`` backends will be created. The ``gloo`` backend
  697. will be used for collectives with CPU tensors and the ``nccl`` backend will be used
  698. for collectives with CUDA tensors.
  699. """
  700. global _world
  701. global _backend
  702. global _default_pg_init_method
  703. if not isinstance(timeout, timedelta):
  704. raise RuntimeError(
  705. "Expected timeout argument to be of type" "datetime.timedelta"
  706. )
  707. if GroupMember.WORLD is not None:
  708. raise RuntimeError("trying to initialize the default process group " "twice!")
  709. assert (store is None) or (
  710. init_method is None
  711. ), "Cannot specify both init_method and store."
  712. if store is not None:
  713. assert world_size > 0, "world_size must be positive if using store"
  714. assert rank >= 0, "rank must be non-negative if using store"
  715. elif init_method is None:
  716. init_method = "env://"
  717. if backend:
  718. backend = Backend(backend)
  719. else:
  720. backend = Backend("undefined")
  721. if backend == Backend.MPI:
  722. if world_size != -1 or rank != -1:
  723. warnings.warn(
  724. "For MPI backend, world_size ({}) and rank ({}) "
  725. "are ignored since they are assigned by the "
  726. "MPI runtime.".format(world_size, rank)
  727. )
  728. default_pg = _new_process_group_helper(
  729. -1, -1, [], backend, None, group_name=group_name, timeout=timeout
  730. )
  731. _update_default_pg(default_pg)
  732. else:
  733. # backward compatible API
  734. if store is None:
  735. rendezvous_iterator = rendezvous(
  736. init_method, rank, world_size, timeout=timeout
  737. )
  738. store, rank, world_size = next(rendezvous_iterator)
  739. store.set_timeout(timeout)
  740. # Use a PrefixStore to avoid accidental overrides of keys used by
  741. # different systems (e.g. RPC) in case the store is multi-tenant.
  742. store = PrefixStore("default_pg", store)
  743. default_pg = _new_process_group_helper(
  744. world_size,
  745. rank,
  746. [],
  747. backend,
  748. store,
  749. pg_options=pg_options,
  750. group_name=group_name,
  751. timeout=timeout,
  752. )
  753. _update_default_pg(default_pg)
  754. _world.pg_group_ranks[GroupMember.WORLD] = {i: i for i in range(GroupMember.WORLD.size())} # type: ignore[attr-defined, index]
  755. _backend = _world.pg_map[GroupMember.WORLD][0] # type: ignore[index]
  756. _default_pg_init_method = init_method
  757. # barrier at the end to ensure that once we return from this method, all
  758. # process groups including global variables are updated correctly on all
  759. # ranks.
  760. if backend == Backend.MPI:
  761. # MPI backend doesn't use store.
  762. barrier()
  763. else:
  764. # Use store based barrier here since barrier() used a bunch of
  765. # default devices and messes up NCCL internal state.
  766. _store_based_barrier(rank, store, timeout)
  767. def _new_process_group_helper(
  768. group_size,
  769. group_rank,
  770. global_ranks_in_group,
  771. backend,
  772. store,
  773. pg_options=None,
  774. group_name=None,
  775. timeout=default_pg_timeout,
  776. ):
  777. """
  778. Create a new distributed process group.
  779. This function must be called by ALL processes in the global group, even if
  780. the calling process is not part of the newly created group. In that case,
  781. this function returns GroupMember.NON_GROUP_MEMBER.
  782. This function is called with ``global_ranks_in_group == []`` for the default group.
  783. """
  784. global _world
  785. if not group_name:
  786. group_name = str(_world.group_count)
  787. _world.group_count = _world.group_count + 1
  788. if group_name in _world.pg_names.values():
  789. raise RuntimeError(
  790. "The specified group name has already been "
  791. "created, please use a different group name"
  792. )
  793. if not isinstance(timeout, timedelta):
  794. raise RuntimeError(
  795. "Expected timeout argument to be of type" "datetime.timedelta"
  796. )
  797. # The list of group ranks is empty if we're creating the default group.
  798. is_default_group = len(global_ranks_in_group) == 0
  799. # If this is a subgroup (which means group_ranks is specified),
  800. # we check if the current process is a member of the new group.
  801. if not is_default_group:
  802. global_rank = _get_default_group().rank()
  803. if global_rank not in global_ranks_in_group:
  804. return GroupMember.NON_GROUP_MEMBER
  805. prefix_store = PrefixStore(f"{group_name}/", store)
  806. base_pg_options = ProcessGroup.Options(backend=str(backend))
  807. base_pg_options._timeout = timeout
  808. pg: ProcessGroup = ProcessGroup(prefix_store, group_rank, group_size, base_pg_options)
  809. backend_config = BackendConfig(backend)
  810. for device, backend_str in backend_config.get_device_backend_map().items():
  811. # Use the group name as prefix in the default store, such that
  812. # a single store can be reused by multiple groups.
  813. backend_prefix_store = PrefixStore(f"{device}/", prefix_store)
  814. if backend_str == Backend.MPI:
  815. if not is_mpi_available():
  816. raise RuntimeError(
  817. "Distributed package doesn't have MPI built in."
  818. " MPI is only included if you build PyTorch from"
  819. " source on a host that has MPI installed."
  820. )
  821. backend_class = ProcessGroupMPI.create(global_ranks_in_group)
  822. backend_type = ProcessGroup.BackendType.MPI
  823. if not backend_class:
  824. return GroupMember.NON_GROUP_MEMBER
  825. elif backend_str == Backend.GLOO:
  826. # TODO: remove this check after lazy initialization is supported
  827. # if pg_options is not None:
  828. # raise RuntimeError("GLOO options not supported")
  829. backend_class = ProcessGroupGloo(backend_prefix_store, group_rank, group_size, timeout=timeout)
  830. backend_type = ProcessGroup.BackendType.GLOO
  831. elif backend_str == Backend.NCCL:
  832. if not is_nccl_available():
  833. raise RuntimeError("Distributed package doesn't have NCCL " "built in")
  834. if pg_options is not None:
  835. assert isinstance(
  836. pg_options, ProcessGroupNCCL.Options
  837. ), "Expected pg_options argument to be of type ProcessGroupNCCL.Options"
  838. else:
  839. # default pg_options for NCCL
  840. pg_options = ProcessGroupNCCL.Options()
  841. pg_options.is_high_priority_stream = False
  842. pg_options._timeout = timeout
  843. backend_class = ProcessGroupNCCL(backend_prefix_store, group_rank, group_size, pg_options)
  844. backend_type = ProcessGroup.BackendType.NCCL
  845. elif backend_str == Backend.UCC and is_ucc_available():
  846. # TODO: once UCC plugin is fully deprecated, remove
  847. # is_ucc_available() from above elif-condition and raise
  848. # RuntimeError if is_ucc_available() returns false.
  849. backend_class = ProcessGroupUCC(backend_prefix_store, group_rank, group_size, timeout=timeout)
  850. backend_type = ProcessGroup.BackendType.UCC
  851. else:
  852. assert backend_str.upper() in Backend._plugins, (
  853. f"Unknown c10d backend type {backend_str.upper()}"
  854. )
  855. backend_plugin = Backend._plugins[backend_str.upper()]
  856. creator_fn = backend_plugin.creator_fn
  857. extended_api = backend_plugin.extended_api
  858. backend_type = ProcessGroup.BackendType.CUSTOM
  859. if not extended_api:
  860. backend_class = creator_fn(backend_prefix_store, group_rank, group_size, timeout)
  861. else:
  862. dist_backend_opts = _DistributedBackendOptions()
  863. dist_backend_opts.store = backend_prefix_store
  864. dist_backend_opts.group_rank = group_rank
  865. dist_backend_opts.group_size = group_size
  866. dist_backend_opts.timeout = timeout
  867. dist_backend_opts.group_id = group_name
  868. dist_backend_opts.global_ranks_in_group = global_ranks_in_group
  869. backend_class = creator_fn(dist_backend_opts, pg_options)
  870. # Set sequence numbers for gloo and nccl backends.
  871. if backend_str in [Backend.GLOO, Backend.NCCL]:
  872. backend_class._set_sequence_number_for_group()
  873. # If the type is a sublcass of ProcessGroup then return this process group immediately
  874. # TODO: This defaults to the old behavior for PythonProcessGroups which overwrites the
  875. # ProcessGroup instance
  876. if issubclass(type(backend_class), ProcessGroup):
  877. pg = backend_class
  878. break
  879. # Process group wrapper initialization for supported PGs when TORCH_DISTRIBUTED_DEBUG is set
  880. if backend_str in [Backend.GLOO, Backend.NCCL, Backend.UCC]:
  881. # In debug mode and if GLOO is available, wrap in a wrapper PG that
  882. # enables enhanced collective checking for debuggability.
  883. if get_debug_level() == DebugLevel.DETAIL:
  884. if not _GLOO_AVAILABLE:
  885. logger.info(
  886. """TORCH_DISTRIBUTED_DEBUG was set to DETAIL, but
  887. GLOO is not available. Build with Gloo to
  888. create a wrapper process group in debug mode
  889. to aid collective desynchronization debugging."""
  890. )
  891. else:
  892. backend_class = _create_process_group_wrapper(
  893. wrapped_pg=backend_class,
  894. store_prefix=group_name,
  895. store=backend_prefix_store,
  896. rank=group_rank,
  897. world_size=group_size,
  898. timeout=timeout,
  899. )
  900. # only create single backend pg when backend is set to gloo, nccl, mpi, and ucc
  901. if backend in [Backend.GLOO, Backend.NCCL, Backend.UCC, Backend.MPI]:
  902. for device in backend_config.get_device_backend_map().keys():
  903. pg._register_backend(torch.device(device), backend_type, backend_class)
  904. # break out of outer loop to not create any more backends
  905. break
  906. else:
  907. pg._register_backend(torch.device(device), backend_type, backend_class)
  908. # update global state
  909. _world.pg_map[pg] = (backend, prefix_store)
  910. _world.pg_names[pg] = group_name
  911. _world.pg_backend_config[pg] = str(backend_config)
  912. return pg
  913. def destroy_process_group(group: Optional[ProcessGroup] = None):
  914. """
  915. Destroy a given process group, and deinitialize the distributed package
  916. Args:
  917. group (ProcessGroup, optional): The process group to be destroyed, if
  918. group.WORLD is given, all process
  919. groups including the default one will
  920. be destroyed.
  921. """
  922. global _world
  923. if group == GroupMember.NON_GROUP_MEMBER:
  924. return
  925. if group is None:
  926. pg = GroupMember.WORLD
  927. else:
  928. pg = group
  929. assert pg is not None
  930. if _world.pg_map.get(pg, None) is None:
  931. raise RuntimeError("Invalid process group specified")
  932. if group is None or group == GroupMember.WORLD:
  933. _update_default_pg(None)
  934. _world.pg_map.clear()
  935. _world.pg_names.clear()
  936. _world.pg_group_ranks.clear()
  937. _world.pg_backend_config.clear()
  938. # when process group doesn't have an explicit name (only WORLD (default)
  939. # process group can have an explicit name), we use global _world.group_count
  940. # to generate the name. We need to reset the counter on destruction to
  941. # allow consistent value to be generated when we re-create process
  942. # groups after some trainers recover from failure
  943. #
  944. # We only reset this when WORLD is being destroyed because if this
  945. # process group is in good state, we aren't dealing with failures.
  946. _world.group_count = 0
  947. else:
  948. del _world.pg_map[pg]
  949. del _world.pg_names[pg]
  950. del _world.pg_group_ranks[pg]
  951. del _world.pg_backend_config[pg]
  952. def get_rank(group: Optional[ProcessGroup] = None) -> int:
  953. """
  954. Returns the rank of the current process in the provided ``group`` or the
  955. default group if none was provided.
  956. Rank is a unique identifier assigned to each process within a distributed
  957. process group. They are always consecutive integers ranging from 0 to
  958. ``world_size``.
  959. Args:
  960. group (ProcessGroup, optional): The process group to work on. If None,
  961. the default process group will be used.
  962. Returns:
  963. The rank of the process group
  964. -1, if not part of the group
  965. """
  966. if _rank_not_in_group(group):
  967. return -1
  968. default_pg = _get_default_group()
  969. if group is None or group is GroupMember.WORLD:
  970. return default_pg.rank()
  971. return get_group_rank(group, default_pg.rank())
  972. def get_world_size(group: Optional[ProcessGroup] = None) -> int:
  973. """
  974. Returns the number of processes in the current process group
  975. Args:
  976. group (ProcessGroup, optional): The process group to work on. If None,
  977. the default process group will be used.
  978. Returns:
  979. The world size of the process group
  980. -1, if not part of the group
  981. """
  982. if _rank_not_in_group(group):
  983. return -1
  984. return _get_group_size(group)
  985. def isend(tensor: torch.Tensor, dst: int, group: Optional[ProcessGroup] = None, tag: int = 0) -> Work:
  986. """
  987. Sends a tensor asynchronously.
  988. .. warning::
  989. Modifying ``tensor`` before the request completes causes undefined
  990. behavior.
  991. Args:
  992. tensor (Tensor): Tensor to send.
  993. dst (int): Destination rank.
  994. group (ProcessGroup, optional): The process group to work on. If None,
  995. the default process group will be used.
  996. tag (int, optional): Tag to match send with remote recv
  997. Returns:
  998. A distributed request object.
  999. None, if not part of the group
  1000. """
  1001. _check_single_tensor(tensor, "tensor")
  1002. if _rank_not_in_group(group):
  1003. _warn_not_in_group("isend")
  1004. return
  1005. if group is None or group is GroupMember.WORLD:
  1006. default_pg = _get_default_group()
  1007. return default_pg.send([tensor], dst, tag)
  1008. else:
  1009. group_dst_rank = get_group_rank(group, dst)
  1010. return group.send([tensor], group_dst_rank, tag)
  1011. def irecv(tensor: torch.Tensor, src: Optional[int] = None, group: Optional[ProcessGroup] = None, tag: int = 0) -> Work:
  1012. """
  1013. Receives a tensor asynchronously.
  1014. Args:
  1015. tensor (Tensor): Tensor to fill with received data.
  1016. src (int, optional): Source rank. Will receive from any
  1017. process if unspecified.
  1018. group (ProcessGroup, optional): The process group to work on. If None,
  1019. the default process group will be used.
  1020. tag (int, optional): Tag to match recv with remote send
  1021. Returns:
  1022. A distributed request object.
  1023. None, if not part of the group
  1024. """
  1025. _check_single_tensor(tensor, "tensor")
  1026. if _rank_not_in_group(group):
  1027. _warn_not_in_group("irecv")
  1028. return
  1029. if group is None or group is GroupMember.WORLD:
  1030. pg = _get_default_group()
  1031. else:
  1032. pg = group
  1033. if src is None:
  1034. return pg.recv_anysource([tensor], tag)
  1035. else:
  1036. if pg is GroupMember.WORLD:
  1037. return pg.recv([tensor], src, tag)
  1038. else:
  1039. group_src_rank = get_group_rank(pg, src)
  1040. return pg.recv([tensor], group_src_rank, tag)
  1041. def send(tensor: torch.Tensor, dst: int, group: Optional[ProcessGroup] = None, tag: int = 0) -> None:
  1042. """
  1043. Sends a tensor synchronously.
  1044. Args:
  1045. tensor (Tensor): Tensor to send.
  1046. dst (int): Destination rank. Destination rank should not be the same
  1047. as the rank of the current process.
  1048. group (ProcessGroup, optional): The process group to work on. If None,
  1049. the default process group will be used.
  1050. tag (int, optional): Tag to match send with remote recv
  1051. """
  1052. if get_rank() == dst:
  1053. raise ValueError(
  1054. "Invalid destination rank: destination rank should not be the same as "
  1055. "the rank of the current process."
  1056. )
  1057. _check_single_tensor(tensor, "tensor")
  1058. if _rank_not_in_group(group):
  1059. _warn_not_in_group("send")
  1060. return
  1061. if group is None or group is GroupMember.WORLD:
  1062. default_pg = _get_default_group()
  1063. default_pg.send([tensor], dst, tag).wait()
  1064. else:
  1065. group_dst_rank = get_group_rank(group, dst)
  1066. group.send([tensor], group_dst_rank, tag).wait()
  1067. def recv(tensor: torch.Tensor, src: Optional[int] = None, group: Optional[ProcessGroup] = None, tag: int = 0) -> int:
  1068. """
  1069. Receives a tensor synchronously.
  1070. Args:
  1071. tensor (Tensor): Tensor to fill with received data.
  1072. src (int, optional): Source rank. Will receive from any
  1073. process if unspecified.
  1074. group (ProcessGroup, optional): The process group to work on. If None,
  1075. the default process group will be used.
  1076. tag (int, optional): Tag to match recv with remote send
  1077. Returns:
  1078. Sender rank
  1079. -1, if not part of the group
  1080. """
  1081. _check_single_tensor(tensor, "tensor")
  1082. if _rank_not_in_group(group):
  1083. _warn_not_in_group("recv")
  1084. return -1
  1085. if group is None:
  1086. pg = _get_default_group()
  1087. else:
  1088. pg = group
  1089. if src is None:
  1090. work = pg.recv_anysource([tensor], tag)
  1091. work.wait()
  1092. src_rank = work._source_rank()
  1093. if group is None or group is GroupMember.WORLD:
  1094. return src_rank
  1095. else:
  1096. return get_global_rank(pg, src_rank)
  1097. else:
  1098. if group is None or group is GroupMember.WORLD:
  1099. pg.recv([tensor], src, tag).wait()
  1100. else:
  1101. group_src_rank = get_group_rank(pg, src)
  1102. pg.recv([tensor], group_src_rank, tag).wait()
  1103. return src
  1104. class P2POp:
  1105. """
  1106. A class to build point-to-point operations for ``batch_isend_irecv``.
  1107. This class builds the type of P2P operation, communication buffer, peer rank,
  1108. Process Group group, and tag. Instances of this class will be passed to
  1109. ``batch_isend_irecv`` for point-to-point communications.
  1110. Args:
  1111. op (Callable): A function to send data to or receive data from a peer process.
  1112. The type of ``op`` is either ``torch.distributed.isend`` or
  1113. ``torch.distributed.irecv``.
  1114. tensor (Tensor): Tensor to send or receive.
  1115. peer (int): Destination or source rank.
  1116. group (ProcessGroup, optional): The process group to work on. If None,
  1117. the default process group will be used.
  1118. tag (int, optional): Tag to match send with recv.
  1119. """
  1120. def __init__(self, op, tensor, peer, group=None, tag=0):
  1121. self.op = op
  1122. self.tensor = tensor
  1123. self.peer = peer
  1124. self.group = group
  1125. self.tag = tag
  1126. def __new__(cls, op, tensor, peer, group=None, tag=0):
  1127. _check_op(op)
  1128. _check_single_tensor(tensor, "tensor")
  1129. return object.__new__(cls)
  1130. @contextlib.contextmanager
  1131. def _coalescing_manager(group, device, reqs):
  1132. if group is None:
  1133. group = _get_default_group()
  1134. group._start_coalescing(device)
  1135. try:
  1136. yield
  1137. finally:
  1138. group._end_coalescing(device, reqs)
  1139. def batch_isend_irecv(p2p_op_list):
  1140. """
  1141. Send or Receive a batch of tensors asynchronously and return a list of requests.
  1142. Process each of the operations in ``p2p_op_list`` and return the corresponding
  1143. requests. NCCL, Gloo, and UCC backend are currently supported.
  1144. Args:
  1145. p2p_op_list: A list of point-to-point operations(type of each operator is
  1146. ``torch.distributed.P2POp``). The order of the isend/irecv in the list
  1147. matters and it needs to match with corresponding isend/irecv on the
  1148. remote end.
  1149. Returns:
  1150. A list of distributed request objects returned by calling the corresponding
  1151. op in the op_list.
  1152. Examples:
  1153. >>> # xdoctest: +SKIP("no rank")
  1154. >>> send_tensor = torch.arange(2) + 2 * rank
  1155. >>> recv_tensor = torch.randn(2)
  1156. >>> send_op = dist.P2POp(dist.isend, send_tensor, (rank + 1)%world_size)
  1157. >>> recv_op = dist.P2POp(dist.irecv, recv_tensor, (rank - 1 + world_size)%world_size)
  1158. >>> reqs = batch_isend_irecv([send_op, recv_op])
  1159. >>> for req in reqs:
  1160. >>> req.wait()
  1161. >>> recv_tensor
  1162. tensor([2, 3]) # Rank 0
  1163. tensor([0, 1]) # Rank 1
  1164. .. note:: Note that when this API is used with the NCCL PG backend, users must set
  1165. the current GPU device with `torch.cuda.set_device`, otherwise it will
  1166. lead to unexpected hang issues.
  1167. In addition, if this API is the first collective call in the ``group``
  1168. passed to ``dist.P2POp``, all ranks of the ``group`` must participate in
  1169. this API call; otherwise, the behavior is undefined. If this API call is
  1170. not the first collective call in the ``group``, batched P2P operations
  1171. involving only a subset of ranks of the ``group`` are allowed.
  1172. """
  1173. _check_p2p_op_list(p2p_op_list)
  1174. group = p2p_op_list[0].group
  1175. device = p2p_op_list[0].tensor.device
  1176. reqs = []
  1177. with _coalescing_manager(group, device, reqs):
  1178. for p2p_op in p2p_op_list:
  1179. op = p2p_op.op
  1180. tensor = p2p_op.tensor
  1181. peer = p2p_op.peer
  1182. curr_group = p2p_op.group
  1183. tag = p2p_op.tag
  1184. ret = op(tensor, peer, curr_group, tag)
  1185. if ret is not None:
  1186. reqs.append(ret)
  1187. return reqs
  1188. def exception_handler(func):
  1189. @functools.wraps(func)
  1190. def wrapper(*args, **kwargs):
  1191. try:
  1192. return func(*args, **kwargs)
  1193. except Exception as error:
  1194. if is_initialized():
  1195. error_msg_dict = {
  1196. "func_name": f"{func.__name__}",
  1197. "args": f"{args}, {kwargs}",
  1198. "backend": f"{get_backend(kwargs.get('group'))}",
  1199. "world_size": f"{get_world_size(kwargs.get('group'))}",
  1200. "global_rank": f"{get_rank()}",
  1201. "local_rank": f"{get_rank(kwargs.get('group'))}",
  1202. "error": f"{error}",
  1203. }
  1204. else:
  1205. error_msg_dict = {
  1206. "func_name": f"{func.__name__}",
  1207. "args": f"{args}, {kwargs}",
  1208. "error": f"{error}",
  1209. }
  1210. _c10d_error_logger.debug(error_msg_dict)
  1211. raise
  1212. return wrapper
  1213. @exception_handler
  1214. def broadcast_multigpu(tensor_list, src, group=None, async_op=False, src_tensor=0):
  1215. """
  1216. Broadcasts the tensor to the whole group with multiple GPU tensors
  1217. per node.
  1218. ``tensor`` must have the same number of elements in all the GPUs from
  1219. all processes participating in the collective. each tensor in the list must
  1220. be on a different GPU
  1221. Only nccl and gloo backend are currently supported
  1222. tensors should only be GPU tensors
  1223. Args:
  1224. tensor_list (List[Tensor]): Tensors that participate in the collective
  1225. operation. If ``src`` is the rank, then the specified ``src_tensor``
  1226. element of ``tensor_list`` (``tensor_list[src_tensor]``) will be
  1227. broadcast to all other tensors (on different GPUs) in the src process
  1228. and all tensors in ``tensor_list`` of other non-src processes.
  1229. You also need to make sure that ``len(tensor_list)`` is the same
  1230. for all the distributed processes calling this function.
  1231. src (int): Source rank.
  1232. group (ProcessGroup, optional): The process group to work on. If None,
  1233. the default process group will be used.
  1234. async_op (bool, optional): Whether this op should be an async op
  1235. src_tensor (int, optional): Source tensor rank within ``tensor_list``
  1236. Returns:
  1237. Async work handle, if async_op is set to True.
  1238. None, if not async_op or if not part of the group
  1239. """
  1240. warnings.warn(
  1241. "torch.distributed.broadcast_multigpu will be deprecated. If you must "
  1242. "use it, please revisit our documentation later at "
  1243. "https://pytorch.org/docs/master/distributed.html#multi-gpu-collective-functions"
  1244. )
  1245. if _rank_not_in_group(group):
  1246. _warn_not_in_group("broadcast_multigpu")
  1247. return
  1248. opts = BroadcastOptions()
  1249. opts.rootRank = src
  1250. opts.rootTensor = src_tensor
  1251. if group is None or group is GroupMember.WORLD:
  1252. default_pg = _get_default_group()
  1253. work = default_pg.broadcast(tensor_list, opts)
  1254. else:
  1255. group_src_rank = get_group_rank(group, src)
  1256. opts.rootRank = group_src_rank
  1257. work = group.broadcast(tensor_list, opts)
  1258. if async_op:
  1259. return work
  1260. else:
  1261. work.wait()
  1262. @exception_handler
  1263. def broadcast(tensor, src, group=None, async_op=False):
  1264. """
  1265. Broadcasts the tensor to the whole group.
  1266. ``tensor`` must have the same number of elements in all processes
  1267. participating in the collective.
  1268. Args:
  1269. tensor (Tensor): Data to be sent if ``src`` is the rank of current
  1270. process, and tensor to be used to save received data otherwise.
  1271. src (int): Source rank.
  1272. group (ProcessGroup, optional): The process group to work on. If None,
  1273. the default process group will be used.
  1274. async_op (bool, optional): Whether this op should be an async op
  1275. Returns:
  1276. Async work handle, if async_op is set to True.
  1277. None, if not async_op or if not part of the group
  1278. """
  1279. _check_single_tensor(tensor, "tensor")
  1280. if _rank_not_in_group(group):
  1281. _warn_not_in_group("broadcast")
  1282. return
  1283. opts = BroadcastOptions()
  1284. opts.rootRank = src
  1285. opts.rootTensor = 0
  1286. if group is None or group is GroupMember.WORLD:
  1287. default_pg = _get_default_group()
  1288. work = default_pg.broadcast([tensor], opts)
  1289. else:
  1290. group_src_rank = get_group_rank(group, src)
  1291. opts.rootRank = group_src_rank
  1292. work = group.broadcast([tensor], opts)
  1293. if async_op:
  1294. return work
  1295. else:
  1296. work.wait()
  1297. @exception_handler
  1298. def all_reduce_multigpu(tensor_list, op=ReduceOp.SUM, group=None, async_op=False):
  1299. r"""
  1300. Reduces the tensor data across all machines in such a way that all get
  1301. the final result. This function reduces a number of tensors on every node,
  1302. while each tensor resides on different GPUs.
  1303. Therefore, the input tensor in the tensor list needs to be GPU tensors.
  1304. Also, each tensor in the tensor list needs to reside on a different GPU.
  1305. After the call, all ``tensor`` in ``tensor_list`` is going to be bitwise
  1306. identical in all processes.
  1307. Complex tensors are supported.
  1308. Only nccl and gloo backend is currently supported
  1309. tensors should only be GPU tensors
  1310. Args:
  1311. tensor_list (List[Tensor]): List of input and output tensors of
  1312. the collective. The function operates in-place and requires that
  1313. each tensor to be a GPU tensor on different GPUs.
  1314. You also need to make sure that ``len(tensor_list)`` is the same for
  1315. all the distributed processes calling this function.
  1316. op (optional): One of the values from
  1317. ``torch.distributed.ReduceOp``
  1318. enum. Specifies an operation used for element-wise reductions.
  1319. group (ProcessGroup, optional): The process group to work on. If
  1320. ``None``, the default process group will be used.
  1321. async_op (bool, optional): Whether this op should be an async op
  1322. Returns:
  1323. Async work handle, if async_op is set to True.
  1324. None, if not async_op or if not part of the group
  1325. """
  1326. warnings.warn(
  1327. "torch.distributed.all_reduce_multigpu will be deprecated. If you must "
  1328. "use it, please revisit our documentation later at "
  1329. "https://pytorch.org/docs/master/distributed.html#multi-gpu-collective-functions"
  1330. )
  1331. if _rank_not_in_group(group):
  1332. return
  1333. tensor_list = [
  1334. t if not t.is_complex() else torch.view_as_real(t) for t in tensor_list
  1335. ]
  1336. opts = AllreduceOptions()
  1337. opts.reduceOp = op
  1338. if group is None:
  1339. default_pg = _get_default_group()
  1340. work = default_pg.allreduce(tensor_list, opts)
  1341. else:
  1342. work = group.allreduce(tensor_list, opts)
  1343. if async_op:
  1344. return work
  1345. else:
  1346. work.wait()
  1347. @exception_handler
  1348. def all_reduce(tensor, op=ReduceOp.SUM, group=None, async_op=False):
  1349. """
  1350. Reduces the tensor data across all machines in such a way that all get
  1351. the final result.
  1352. After the call ``tensor`` is going to be bitwise identical in all processes.
  1353. Complex tensors are supported.
  1354. Args:
  1355. tensor (Tensor): Input and output of the collective. The function
  1356. operates in-place.
  1357. op (optional): One of the values from
  1358. ``torch.distributed.ReduceOp``
  1359. enum. Specifies an operation used for element-wise reductions.
  1360. group (ProcessGroup, optional): The process group to work on. If None,
  1361. the default process group will be used.
  1362. async_op (bool, optional): Whether this op should be an async op
  1363. Returns:
  1364. Async work handle, if async_op is set to True.
  1365. None, if not async_op or if not part of the group
  1366. Examples:
  1367. >>> # xdoctest: +SKIP("no rank")
  1368. >>> # All tensors below are of torch.int64 type.
  1369. >>> # We have 2 process groups, 2 ranks.
  1370. >>> tensor = torch.arange(2, dtype=torch.int64) + 1 + 2 * rank
  1371. >>> tensor
  1372. tensor([1, 2]) # Rank 0
  1373. tensor([3, 4]) # Rank 1
  1374. >>> dist.all_reduce(tensor, op=ReduceOp.SUM)
  1375. >>> tensor
  1376. tensor([4, 6]) # Rank 0
  1377. tensor([4, 6]) # Rank 1
  1378. >>> # All tensors below are of torch.cfloat type.
  1379. >>> # We have 2 process groups, 2 ranks.
  1380. >>> tensor = torch.tensor([1+1j, 2+2j], dtype=torch.cfloat) + 2 * rank * (1+1j)
  1381. >>> tensor
  1382. tensor([1.+1.j, 2.+2.j]) # Rank 0
  1383. tensor([3.+3.j, 4.+4.j]) # Rank 1
  1384. >>> dist.all_reduce(tensor, op=ReduceOp.SUM)
  1385. >>> tensor
  1386. tensor([4.+4.j, 6.+6.j]) # Rank 0
  1387. tensor([4.+4.j, 6.+6.j]) # Rank 1
  1388. """
  1389. _check_single_tensor(tensor, "tensor")
  1390. if _rank_not_in_group(group):
  1391. _warn_not_in_group("all_reduce")
  1392. return
  1393. if tensor.is_complex():
  1394. if not supports_complex(op):
  1395. raise RuntimeError(f"all_reduce does not support {op} on complex tensors")
  1396. tensor = torch.view_as_real(tensor)
  1397. opts = AllreduceOptions()
  1398. opts.reduceOp = op
  1399. if group is None:
  1400. default_pg = _get_default_group()
  1401. work = default_pg.allreduce([tensor], opts)
  1402. else:
  1403. work = group.allreduce([tensor], opts)
  1404. if async_op:
  1405. return work
  1406. else:
  1407. work.wait()
  1408. @exception_handler
  1409. def all_reduce_coalesced(tensors, op=ReduceOp.SUM, group=None, async_op=False):
  1410. """
  1411. WARNING: at this time individual shape checking is not implemented across nodes.
  1412. For example, if the rank 0 node passes [torch.rand(4), torch.rand(2)] and the
  1413. rank 1 node passes [torch.rand(2), torch.rand(2), torch.rand(2)], the allreduce
  1414. operation will proceed without complaint and return erroneous outputs. This lack
  1415. of shape checking results in significant performance improvements but users of this
  1416. function should take extra care to ensure that each node passes in tensors whose
  1417. shapes match across nodes.
  1418. Reduces each tensor in tensors (residing on the same device) across all machines
  1419. in such a way that all get the final result.
  1420. After the call each tensor in tensors is going to bitwise identical
  1421. in all processes.
  1422. Complex tensors are supported.
  1423. Args:
  1424. tensors (List[Tensor]): Input and output of the collective. The function
  1425. operates in-place.
  1426. op (Optional[ReduceOp]): One of the values from
  1427. ``torch.distributed.ReduceOp`` enum. Specifies an operation used for
  1428. element-wise reductions.
  1429. group (ProcessGroup, optional): The process group to work on. If None,
  1430. the default process group will be used.
  1431. async_op (Optional[bool]): Whether this op should be an async op.
  1432. Returns:
  1433. Async work handle, if async_op is set to True.
  1434. None, if not async_op or if not part of the group.
  1435. """
  1436. warnings.warn(
  1437. "torch.distributed.all_reduce_coalesced will be deprecated. If you must "
  1438. "use it, please revisit our documentation later at "
  1439. "https://pytorch.org/docs/master/distributed.html#collective-functions"
  1440. )
  1441. _check_tensor_list(tensors, "tensor")
  1442. _ensure_all_tensors_same_dtype(tensors)
  1443. if _rank_not_in_group(group):
  1444. _warn_not_in_group("all_reduce_coalesced")
  1445. return
  1446. if any([t.is_complex() for t in tensors]) and not supports_complex(op):
  1447. raise RuntimeError(f"all_reduce does not support {op} on complex tensors")
  1448. tensors = [t if not t.is_complex() else torch.view_as_real(t) for t in tensors]
  1449. opts = AllreduceCoalescedOptions()
  1450. opts.reduceOp = op
  1451. if group is None:
  1452. default_pg = _get_default_group()
  1453. work = default_pg.allreduce_coalesced(tensors, opts)
  1454. else:
  1455. work = group.allreduce_coalesced(tensors, opts)
  1456. if async_op:
  1457. return work.get_future()
  1458. else:
  1459. work.wait()
  1460. @exception_handler
  1461. def reduce_multigpu(
  1462. tensor_list, dst, op=ReduceOp.SUM, group=None, async_op=False, dst_tensor=0
  1463. ):
  1464. """
  1465. Reduces the tensor data on multiple GPUs across all machines. Each tensor
  1466. in ``tensor_list`` should reside on a separate GPU
  1467. Only the GPU of ``tensor_list[dst_tensor]`` on the process with rank ``dst``
  1468. is going to receive the final result.
  1469. Only nccl backend is currently supported
  1470. tensors should only be GPU tensors
  1471. Args:
  1472. tensor_list (List[Tensor]): Input and output GPU tensors of the
  1473. collective. The function operates in-place.
  1474. You also need to make sure that ``len(tensor_list)`` is the same for
  1475. all the distributed processes calling this function.
  1476. dst (int): Destination rank
  1477. op (optional): One of the values from
  1478. ``torch.distributed.ReduceOp``
  1479. enum. Specifies an operation used for element-wise reductions.
  1480. group (ProcessGroup, optional): The process group to work on. If None,
  1481. the default process group will be used.
  1482. async_op (bool, optional): Whether this op should be an async op
  1483. dst_tensor (int, optional): Destination tensor rank within
  1484. ``tensor_list``
  1485. Returns:
  1486. Async work handle, if async_op is set to True.
  1487. None, otherwise
  1488. """
  1489. warnings.warn(
  1490. "torch.distributed.reduce_multigpu will be deprecated. If you must "
  1491. "use it, please revisit our documentation later at "
  1492. "https://pytorch.org/docs/master/distributed.html#multi-gpu-collective-functions"
  1493. )
  1494. if _rank_not_in_group(group):
  1495. _warn_not_in_group("reduce_multigpu")
  1496. return
  1497. opts = ReduceOptions()
  1498. opts.reduceOp = op
  1499. opts.rootRank = dst
  1500. opts.rootTensor = dst_tensor
  1501. if group is None or group is GroupMember.WORLD:
  1502. default_pg = _get_default_group()
  1503. work = default_pg.reduce(tensor_list, opts)
  1504. else:
  1505. group_dst_rank = get_group_rank(group, dst)
  1506. opts.rootRank = group_dst_rank
  1507. work = group.reduce(tensor_list, opts)
  1508. if async_op:
  1509. return work
  1510. else:
  1511. work.wait()
  1512. @exception_handler
  1513. def reduce(tensor, dst, op=ReduceOp.SUM, group=None, async_op=False):
  1514. """
  1515. Reduces the tensor data across all machines.
  1516. Only the process with rank ``dst`` is going to receive the final result.
  1517. Args:
  1518. tensor (Tensor): Input and output of the collective. The function
  1519. operates in-place.
  1520. dst (int): Destination rank
  1521. op (optional): One of the values from
  1522. ``torch.distributed.ReduceOp``
  1523. enum. Specifies an operation used for element-wise reductions.
  1524. group (ProcessGroup, optional): The process group to work on. If None,
  1525. the default process group will be used.
  1526. async_op (bool, optional): Whether this op should be an async op
  1527. Returns:
  1528. Async work handle, if async_op is set to True.
  1529. None, if not async_op or if not part of the group
  1530. """
  1531. _check_single_tensor(tensor, "tensor")
  1532. if _rank_not_in_group(group):
  1533. _warn_not_in_group("reduce")
  1534. return
  1535. opts = ReduceOptions()
  1536. opts.reduceOp = op
  1537. opts.rootRank = dst
  1538. if group is None or group is GroupMember.WORLD:
  1539. default_pg = _get_default_group()
  1540. work = default_pg.reduce([tensor], opts)
  1541. else:
  1542. group_dst_rank = get_group_rank(group, dst)
  1543. opts.rootRank = group_dst_rank
  1544. work = group.reduce([tensor], opts)
  1545. if async_op:
  1546. return work
  1547. else:
  1548. work.wait()
  1549. @exception_handler
  1550. def all_gather_multigpu(
  1551. output_tensor_lists, input_tensor_list, group=None, async_op=False
  1552. ):
  1553. """
  1554. Gathers tensors from the whole group in a list.
  1555. Each tensor in ``tensor_list`` should reside on a separate GPU
  1556. Only nccl backend is currently supported
  1557. tensors should only be GPU tensors
  1558. Complex tensors are supported.
  1559. Args:
  1560. output_tensor_lists (List[List[Tensor]]): Output lists. It should
  1561. contain correctly-sized tensors on each GPU to be used for output
  1562. of the collective, e.g. ``output_tensor_lists[i]`` contains the
  1563. all_gather result that resides on the GPU of
  1564. ``input_tensor_list[i]``.
  1565. Note that each element of ``output_tensor_lists`` has the size of
  1566. ``world_size * len(input_tensor_list)``, since the function all
  1567. gathers the result from every single GPU in the group. To interpret
  1568. each element of ``output_tensor_lists[i]``, note that
  1569. ``input_tensor_list[j]`` of rank k will be appear in
  1570. ``output_tensor_lists[i][k * world_size + j]``
  1571. Also note that ``len(output_tensor_lists)``, and the size of each
  1572. element in ``output_tensor_lists`` (each element is a list,
  1573. therefore ``len(output_tensor_lists[i])``) need to be the same
  1574. for all the distributed processes calling this function.
  1575. input_tensor_list (List[Tensor]): List of tensors(on different GPUs) to
  1576. be broadcast from current process.
  1577. Note that ``len(input_tensor_list)`` needs to be the same for
  1578. all the distributed processes calling this function.
  1579. group (ProcessGroup, optional): The process group to work on. If None,
  1580. the default process group will be used.
  1581. async_op (bool, optional): Whether this op should be an async op
  1582. Returns:
  1583. Async work handle, if async_op is set to True.
  1584. None, if not async_op or if not part of the group
  1585. """
  1586. warnings.warn(
  1587. "torch.distributed.all_gather_multigpu will be deprecated. If you must "
  1588. "use it, please revisit our documentation later at "
  1589. "https://pytorch.org/docs/master/distributed.html#multi-gpu-collective-functions"
  1590. )
  1591. if _rank_not_in_group(group):
  1592. _warn_not_in_group("all_gather_multigpu")
  1593. return
  1594. output_tensor_lists = [
  1595. [t if not t.is_complex() else torch.view_as_real(t) for t in l]
  1596. for l in output_tensor_lists
  1597. ]
  1598. input_tensor_list = [
  1599. t if not t.is_complex() else torch.view_as_real(t) for t in input_tensor_list
  1600. ]
  1601. if group is None:
  1602. default_pg = _get_default_group()
  1603. work = default_pg.allgather(output_tensor_lists, input_tensor_list)
  1604. else:
  1605. work = group.allgather(output_tensor_lists, input_tensor_list)
  1606. if async_op:
  1607. return work
  1608. else:
  1609. work.wait()
  1610. def _object_to_tensor(obj, device):
  1611. f = io.BytesIO()
  1612. _pickler(f).dump(obj)
  1613. byte_storage = torch.ByteStorage._from_buffer(f.getvalue()) # type: ignore[attr-defined]
  1614. # Do not replace `torch.ByteTensor` or `torch.LongTensor` with torch.tensor and specifying dtype.
  1615. # Otherwise, it will casue 100X slowdown.
  1616. # See: https://github.com/pytorch/pytorch/issues/65696
  1617. byte_tensor = torch.ByteTensor(byte_storage).to(device)
  1618. local_size = torch.LongTensor([byte_tensor.numel()]).to(device)
  1619. return byte_tensor, local_size
  1620. def _tensor_to_object(tensor, tensor_size):
  1621. tensor = tensor.cpu()
  1622. buf = tensor.numpy().tobytes()[:tensor_size]
  1623. return _unpickler(io.BytesIO(buf)).load()
  1624. def _check_for_nccl_backend(group):
  1625. pg = group or _get_default_group()
  1626. # Gate PG wrapper check on Gloo availability.
  1627. if _GLOO_AVAILABLE:
  1628. # It is not expected for PG to be wrapped many times, but support it just
  1629. # in case
  1630. while isinstance(pg, _ProcessGroupWrapper):
  1631. pg = pg.wrapped_pg
  1632. return (
  1633. is_nccl_available() and
  1634. pg.name() == Backend.NCCL
  1635. )
  1636. @exception_handler
  1637. def all_gather_object(object_list, obj, group=None):
  1638. """
  1639. Gathers picklable objects from the whole group into a list. Similar to
  1640. :func:`all_gather`, but Python objects can be passed in. Note that the object
  1641. must be picklable in order to be gathered.
  1642. Args:
  1643. object_list (list[Any]): Output list. It should be correctly sized as the
  1644. size of the group for this collective and will contain the output.
  1645. obj (Any): Pickable Python object to be broadcast from current process.
  1646. group (ProcessGroup, optional): The process group to work on. If None,
  1647. the default process group will be used. Default is ``None``.
  1648. Returns:
  1649. None. If the calling rank is part of this group, the output of the
  1650. collective will be populated into the input ``object_list``. If the
  1651. calling rank is not part of the group, the passed in ``object_list`` will
  1652. be unmodified.
  1653. .. note:: Note that this API differs slightly from the :func:`all_gather`
  1654. collective since it does not provide an ``async_op`` handle and thus
  1655. will be a blocking call.
  1656. .. note:: For NCCL-based processed groups, internal tensor representations
  1657. of objects must be moved to the GPU device before communication takes
  1658. place. In this case, the device used is given by
  1659. ``torch.cuda.current_device()`` and it is the user's responsiblity to
  1660. ensure that this is set so that each rank has an individual GPU, via
  1661. ``torch.cuda.set_device()``.
  1662. .. warning::
  1663. :func:`all_gather_object` uses ``pickle`` module implicitly, which is
  1664. known to be insecure. It is possible to construct malicious pickle data
  1665. which will execute arbitrary code during unpickling. Only call this
  1666. function with data you trust.
  1667. Example::
  1668. >>> # xdoctest: +SKIP("need process group init")
  1669. >>> # Note: Process group initialization omitted on each rank.
  1670. >>> import torch.distributed as dist
  1671. >>> # Assumes world_size of 3.
  1672. >>> gather_objects = ["foo", 12, {1: 2}] # any picklable object
  1673. >>> output = [None for _ in gather_objects]
  1674. >>> dist.all_gather_object(output, gather_objects[dist.get_rank()])
  1675. >>> output
  1676. ['foo', 12, {1: 2}]
  1677. """
  1678. if _rank_not_in_group(group):
  1679. _warn_not_in_group("all_gather_object")
  1680. return
  1681. current_device = _get_pg_device(group)
  1682. input_tensor, local_size = _object_to_tensor(obj, current_device)
  1683. # Gather all local sizes. This is so that we can find the max size, and index
  1684. # until the correct size when deserializing the tensors.
  1685. group_size = get_world_size(group=group)
  1686. object_sizes_tensor = torch.zeros(
  1687. group_size, dtype=torch.long, device=current_device
  1688. )
  1689. object_size_list = [
  1690. object_sizes_tensor[i].unsqueeze(dim=0) for i in range(group_size)
  1691. ]
  1692. # Allgather tensor sizes
  1693. all_gather(object_size_list, local_size, group=group)
  1694. max_object_size = int(max(object_size_list).item()) # type: ignore[type-var]
  1695. # Resize tensor to max size across all ranks.
  1696. input_tensor.resize_(max_object_size)
  1697. coalesced_output_tensor = torch.empty(
  1698. max_object_size * group_size, dtype=torch.uint8, device=current_device
  1699. )
  1700. # Output tensors are nonoverlapping views of coalesced_output_tensor
  1701. output_tensors = [
  1702. coalesced_output_tensor[max_object_size * i : max_object_size * (i + 1)]
  1703. for i in range(group_size)
  1704. ]
  1705. all_gather(output_tensors, input_tensor, group=group)
  1706. # Deserialize outputs back to object.
  1707. for i, tensor in enumerate(output_tensors):
  1708. tensor = tensor.type(torch.uint8)
  1709. if tensor.device != torch.device("cpu"):
  1710. tensor = tensor.cpu()
  1711. tensor_size = object_size_list[i]
  1712. object_list[i] = _tensor_to_object(tensor, tensor_size)
  1713. @exception_handler
  1714. def gather_object(obj, object_gather_list=None, dst=0, group=None):
  1715. """
  1716. Gathers picklable objects from the whole group in a single process.
  1717. Similar to :func:`gather`, but Python objects can be passed in. Note that the
  1718. object must be picklable in order to be gathered.
  1719. Args:
  1720. obj (Any): Input object. Must be picklable.
  1721. object_gather_list (list[Any]): Output list. On the ``dst`` rank, it
  1722. should be correctly sized as the size of the group for this
  1723. collective and will contain the output. Must be ``None`` on non-dst
  1724. ranks. (default is ``None``)
  1725. dst (int, optional): Destination rank. (default is 0)
  1726. group: (ProcessGroup, optional): The process group to work on. If None,
  1727. the default process group will be used. Default is ``None``.
  1728. Returns:
  1729. None. On the ``dst`` rank, ``object_gather_list`` will contain the
  1730. output of the collective.
  1731. .. note:: Note that this API differs slightly from the gather collective
  1732. since it does not provide an async_op handle and thus will be a blocking
  1733. call.
  1734. .. note:: For NCCL-based processed groups, internal tensor representations
  1735. of objects must be moved to the GPU device before communication takes
  1736. place. In this case, the device used is given by
  1737. ``torch.cuda.current_device()`` and it is the user's responsiblity to
  1738. ensure that this is set so that each rank has an individual GPU, via
  1739. ``torch.cuda.set_device()``.
  1740. .. warning::
  1741. :func:`gather_object` uses ``pickle`` module implicitly, which is
  1742. known to be insecure. It is possible to construct malicious pickle data
  1743. which will execute arbitrary code during unpickling. Only call this
  1744. function with data you trust.
  1745. Example::
  1746. >>> # xdoctest: +SKIP("need process group init")
  1747. >>> # Note: Process group initialization omitted on each rank.
  1748. >>> import torch.distributed as dist
  1749. >>> # Assumes world_size of 3.
  1750. >>> gather_objects = ["foo", 12, {1: 2}] # any picklable object
  1751. >>> output = [None for _ in gather_objects]
  1752. >>> dist.gather_object(
  1753. ... gather_objects[dist.get_rank()],
  1754. ... output if dist.get_rank() == 0 else None,
  1755. ... dst=0
  1756. ... )
  1757. >>> # On rank 0
  1758. >>> output
  1759. ['foo', 12, {1: 2}]
  1760. """
  1761. if _rank_not_in_group(group):
  1762. _warn_not_in_group("gather_object")
  1763. return
  1764. # Ensure object_gather_list is specified appopriately.
  1765. my_rank = get_rank()
  1766. _validate_output_list_for_rank(my_rank, dst, object_gather_list)
  1767. current_device = _get_pg_device(group)
  1768. input_tensor, local_size = _object_to_tensor(obj, current_device)
  1769. # Gather all local sizes. This is so that we can find the max size, and index
  1770. # until the correct size when deserializing the tensors.
  1771. group_size = get_world_size(group=group)
  1772. object_sizes_tensor = torch.zeros(
  1773. group_size, dtype=torch.long, device=current_device
  1774. )
  1775. object_size_list = [
  1776. object_sizes_tensor[i].unsqueeze(dim=0) for i in range(group_size)
  1777. ]
  1778. # Allgather tensor sizes. An all-gather is needed here despite this being a
  1779. # gather, since each rank needs to broadcast a tensor of the same (maximal)
  1780. # size.
  1781. all_gather(object_size_list, local_size, group=group)
  1782. max_object_size = int(max(object_size_list).item()) # type: ignore[type-var]
  1783. # Resize tensor to max size across all ranks.
  1784. input_tensor.resize_(max_object_size)
  1785. # Avoid populating output tensors if the result won't be gathered on this rank.
  1786. if my_rank == dst:
  1787. coalesced_output_tensor = torch.empty(
  1788. max_object_size * group_size, dtype=torch.uint8, device=current_device
  1789. )
  1790. # Output tensors are nonoverlapping views of coalesced_output_tensor
  1791. output_tensors = [
  1792. coalesced_output_tensor[max_object_size * i : max_object_size * (i + 1)]
  1793. for i in range(group_size)
  1794. ]
  1795. # All ranks call gather with equal-sized tensors.
  1796. gather(
  1797. input_tensor,
  1798. gather_list=output_tensors if my_rank == dst else None,
  1799. dst=dst,
  1800. group=group,
  1801. )
  1802. if my_rank != dst:
  1803. return
  1804. for i, tensor in enumerate(output_tensors):
  1805. tensor = tensor.type(torch.uint8)
  1806. tensor_size = object_size_list[i]
  1807. object_gather_list[i] = _tensor_to_object(tensor, tensor_size)
  1808. @exception_handler
  1809. def broadcast_object_list(object_list, src=0, group=None, device=None):
  1810. """
  1811. Broadcasts picklable objects in ``object_list`` to the whole group. Similar
  1812. to :func:`broadcast`, but Python objects can be passed in.
  1813. Note that all objects in ``object_list`` must be picklable in order to be
  1814. broadcasted.
  1815. Args:
  1816. object_list (List[Any]): List of input objects to broadcast.
  1817. Each object must be picklable. Only objects on the ``src`` rank will
  1818. be broadcast, but each rank must provide lists of equal sizes.
  1819. src (int): Source rank from which to broadcast ``object_list``.
  1820. group: (ProcessGroup, optional): The process group to work on. If None,
  1821. the default process group will be used. Default is ``None``.
  1822. device (``torch.device``, optional): If not None, the objects are
  1823. serialized and converted to tensors which are moved to the
  1824. ``device`` before broadcasting. Default is ``None``.
  1825. Returns:
  1826. ``None``. If rank is part of the group, ``object_list`` will contain the
  1827. broadcasted objects from ``src`` rank.
  1828. .. note:: For NCCL-based process groups, internal tensor representations
  1829. of objects must be moved to the GPU device before communication takes
  1830. place. In this case, the device used is given by
  1831. ``torch.cuda.current_device()`` and it is the user's responsibility to
  1832. ensure that this is set so that each rank has an individual GPU, via
  1833. ``torch.cuda.set_device()``.
  1834. .. note:: Note that this API differs slightly from the :func:`all_gather`
  1835. collective since it does not provide an ``async_op`` handle and thus
  1836. will be a blocking call.
  1837. .. warning::
  1838. :func:`broadcast_object_list` uses ``pickle`` module implicitly, which
  1839. is known to be insecure. It is possible to construct malicious pickle
  1840. data which will execute arbitrary code during unpickling. Only call this
  1841. function with data you trust.
  1842. Example::
  1843. >>> # xdoctest: +SKIP("need process group init")
  1844. >>> # Note: Process group initialization omitted on each rank.
  1845. >>> import torch.distributed as dist
  1846. >>> if dist.get_rank() == 0:
  1847. >>> # Assumes world_size of 3.
  1848. >>> objects = ["foo", 12, {1: 2}] # any picklable object
  1849. >>> else:
  1850. >>> objects = [None, None, None]
  1851. >>> # Assumes backend is not NCCL
  1852. >>> device = torch.device("cpu")
  1853. >>> dist.broadcast_object_list(objects, src=0, device=device)
  1854. >>> objects
  1855. ['foo', 12, {1: 2}]
  1856. """
  1857. if _rank_not_in_group(group):
  1858. _warn_not_in_group("broadcast_object_list")
  1859. return
  1860. # Current device selection.
  1861. # To preserve backwards compatibility, ``device`` is default to ``None``
  1862. # in which case we run current logic of device selection, i.e.
  1863. # ``current_device`` is CUDA if backend is NCCL otherwise CPU device. In the
  1864. # case it is not ``None`` we move the size and object tensors to be
  1865. # broadcasted to this device.
  1866. current_device = device or _get_pg_device(group)
  1867. my_rank = get_rank()
  1868. # Serialize object_list elements to tensors on src rank.
  1869. if my_rank == src:
  1870. tensor_list, size_list = zip(*[_object_to_tensor(obj, current_device) for obj in object_list])
  1871. object_sizes_tensor = torch.cat(size_list)
  1872. else:
  1873. object_sizes_tensor = torch.empty(len(object_list), dtype=torch.long, device=current_device)
  1874. # Broadcast object sizes
  1875. broadcast(object_sizes_tensor, src=src, group=group)
  1876. # Concatenate and broadcast serialized object tensors
  1877. if my_rank == src:
  1878. object_tensor = torch.cat(tensor_list)
  1879. else:
  1880. object_tensor = torch.empty( # type: ignore[call-overload]
  1881. torch.sum(object_sizes_tensor).item(), # type: ignore[arg-type]
  1882. dtype=torch.uint8,
  1883. device=current_device
  1884. )
  1885. broadcast(object_tensor, src=src, group=group)
  1886. # Deserialize objects using their stored sizes.
  1887. offset = 0
  1888. if my_rank != src:
  1889. for i, obj_size in enumerate(object_sizes_tensor):
  1890. obj_view = object_tensor[offset : offset + obj_size]
  1891. obj_view = obj_view.type(torch.uint8)
  1892. if obj_view.device != torch.device("cpu"):
  1893. obj_view = obj_view.cpu()
  1894. offset += obj_size
  1895. object_list[i] = _tensor_to_object(obj_view, obj_size)
  1896. @exception_handler
  1897. def scatter_object_list(
  1898. scatter_object_output_list, scatter_object_input_list, src=0, group=None
  1899. ):
  1900. """
  1901. Scatters picklable objects in ``scatter_object_input_list`` to the whole
  1902. group. Similar to :func:`scatter`, but Python objects can be passed in. On
  1903. each rank, the scattered object will be stored as the first element of
  1904. ``scatter_object_output_list``. Note that all objects in
  1905. ``scatter_object_input_list`` must be picklable in order to be scattered.
  1906. Args:
  1907. scatter_object_output_list (List[Any]): Non-empty list whose first
  1908. element will store the object scattered to this rank.
  1909. scatter_object_input_list (List[Any]): List of input objects to scatter.
  1910. Each object must be picklable. Only objects on the ``src`` rank will
  1911. be scattered, and the argument can be ``None`` for non-src ranks.
  1912. src (int): Source rank from which to scatter
  1913. ``scatter_object_input_list``.
  1914. group: (ProcessGroup, optional): The process group to work on. If None,
  1915. the default process group will be used. Default is ``None``.
  1916. Returns:
  1917. ``None``. If rank is part of the group, ``scatter_object_output_list``
  1918. will have its first element set to the scattered object for this rank.
  1919. .. note:: Note that this API differs slightly from the scatter collective
  1920. since it does not provide an ``async_op`` handle and thus will be a
  1921. blocking call.
  1922. .. warning::
  1923. :func:`scatter_object_list` uses ``pickle`` module implicitly, which
  1924. is known to be insecure. It is possible to construct malicious pickle
  1925. data which will execute arbitrary code during unpickling. Only call this
  1926. function with data you trust.
  1927. Example::
  1928. >>> # xdoctest: +SKIP("need process group init")
  1929. >>> # Note: Process group initialization omitted on each rank.
  1930. >>> import torch.distributed as dist
  1931. >>> if dist.get_rank() == 0:
  1932. >>> # Assumes world_size of 3.
  1933. >>> objects = ["foo", 12, {1: 2}] # any picklable object
  1934. >>> else:
  1935. >>> # Can be any list on non-src ranks, elements are not used.
  1936. >>> objects = [None, None, None]
  1937. >>> output_list = [None]
  1938. >>> dist.scatter_object_list(output_list, objects, src=0)
  1939. >>> # Rank i gets objects[i]. For example, on rank 2:
  1940. >>> output_list
  1941. [{1: 2}]
  1942. """
  1943. if _rank_not_in_group(group):
  1944. _warn_not_in_group("scatter_object_list")
  1945. return
  1946. if (
  1947. not isinstance(scatter_object_output_list, list)
  1948. or len(scatter_object_output_list) < 1
  1949. ):
  1950. raise RuntimeError(
  1951. "Expected argument scatter_object_output_list to be a list of size at least 1."
  1952. )
  1953. my_rank = get_rank(group)
  1954. pg_device = _get_pg_device(group)
  1955. if my_rank == src:
  1956. tensor_list, tensor_sizes = zip(
  1957. *[_object_to_tensor(obj, pg_device) for obj in scatter_object_input_list]
  1958. )
  1959. tensor_list, tensor_sizes = list(tensor_list), list(tensor_sizes)
  1960. # Src rank broadcasts the maximum tensor size. This is because all ranks are
  1961. # expected to call into scatter() with equal-sized tensors.
  1962. if my_rank == src:
  1963. max_tensor_size = max(tensor_sizes)
  1964. for tensor in tensor_list:
  1965. tensor.resize_(max_tensor_size)
  1966. else:
  1967. max_tensor_size = torch.tensor([0], dtype=torch.long, device=pg_device)
  1968. broadcast(max_tensor_size, src=src, group=group)
  1969. # Scatter actual serialized objects
  1970. output_tensor = torch.empty(max_tensor_size.item(), dtype=torch.uint8, device=pg_device)
  1971. scatter(
  1972. output_tensor,
  1973. scatter_list=None if my_rank != src else tensor_list,
  1974. src=src,
  1975. group=group,
  1976. )
  1977. # Scatter per-object sizes to trim tensors when deserializing back to object
  1978. obj_tensor_size = torch.tensor([0], dtype=torch.long, device=pg_device)
  1979. scatter(
  1980. obj_tensor_size,
  1981. scatter_list=None if my_rank != src else tensor_sizes,
  1982. src=src,
  1983. group=group,
  1984. )
  1985. # Deserialize back to object
  1986. scatter_object_output_list[0] = _tensor_to_object(output_tensor, obj_tensor_size)
  1987. @exception_handler
  1988. def all_gather(tensor_list, tensor, group=None, async_op=False):
  1989. """
  1990. Gathers tensors from the whole group in a list.
  1991. Complex tensors are supported.
  1992. Args:
  1993. tensor_list (list[Tensor]): Output list. It should contain
  1994. correctly-sized tensors to be used for output of the collective.
  1995. tensor (Tensor): Tensor to be broadcast from current process.
  1996. group (ProcessGroup, optional): The process group to work on. If None,
  1997. the default process group will be used.
  1998. async_op (bool, optional): Whether this op should be an async op
  1999. Returns:
  2000. Async work handle, if async_op is set to True.
  2001. None, if not async_op or if not part of the group
  2002. Examples:
  2003. >>> # xdoctest: +SKIP("need process group init")
  2004. >>> # All tensors below are of torch.int64 dtype.
  2005. >>> # We have 2 process groups, 2 ranks.
  2006. >>> tensor_list = [torch.zeros(2, dtype=torch.int64) for _ in range(2)]
  2007. >>> tensor_list
  2008. [tensor([0, 0]), tensor([0, 0])] # Rank 0 and 1
  2009. >>> tensor = torch.arange(2, dtype=torch.int64) + 1 + 2 * rank
  2010. >>> tensor
  2011. tensor([1, 2]) # Rank 0
  2012. tensor([3, 4]) # Rank 1
  2013. >>> dist.all_gather(tensor_list, tensor)
  2014. >>> tensor_list
  2015. [tensor([1, 2]), tensor([3, 4])] # Rank 0
  2016. [tensor([1, 2]), tensor([3, 4])] # Rank 1
  2017. >>> # All tensors below are of torch.cfloat dtype.
  2018. >>> # We have 2 process groups, 2 ranks.
  2019. >>> tensor_list = [torch.zeros(2, dtype=torch.cfloat) for _ in range(2)]
  2020. >>> tensor_list
  2021. [tensor([0.+0.j, 0.+0.j]), tensor([0.+0.j, 0.+0.j])] # Rank 0 and 1
  2022. >>> tensor = torch.tensor([1+1j, 2+2j], dtype=torch.cfloat) + 2 * rank * (1+1j)
  2023. >>> tensor
  2024. tensor([1.+1.j, 2.+2.j]) # Rank 0
  2025. tensor([3.+3.j, 4.+4.j]) # Rank 1
  2026. >>> dist.all_gather(tensor_list, tensor)
  2027. >>> tensor_list
  2028. [tensor([1.+1.j, 2.+2.j]), tensor([3.+3.j, 4.+4.j])] # Rank 0
  2029. [tensor([1.+1.j, 2.+2.j]), tensor([3.+3.j, 4.+4.j])] # Rank 1
  2030. """
  2031. _check_tensor_list(tensor_list, "tensor_list")
  2032. _check_single_tensor(tensor, "tensor")
  2033. _ensure_all_tensors_same_dtype(tensor_list, tensor)
  2034. if _rank_not_in_group(group):
  2035. _warn_not_in_group("all_gather")
  2036. return
  2037. tensor_list = [
  2038. t if not t.is_complex() else torch.view_as_real(t) for t in tensor_list
  2039. ]
  2040. tensor = tensor if not tensor.is_complex() else torch.view_as_real(tensor)
  2041. if group is None:
  2042. default_pg = _get_default_group()
  2043. work = default_pg.allgather([tensor_list], [tensor])
  2044. else:
  2045. work = group.allgather([tensor_list], [tensor])
  2046. if async_op:
  2047. return work
  2048. else:
  2049. work.wait()
  2050. @exception_handler
  2051. def all_gather_into_tensor(output_tensor, input_tensor, group=None, async_op=False):
  2052. """
  2053. Gather tensors from all ranks and put them in a single output tensor.
  2054. Args:
  2055. output_tensor (Tensor): Output tensor to accommodate tensor elements
  2056. from all ranks. It must be correctly sized to have one of the
  2057. following forms:
  2058. (i) a concatenation of all the input tensors along the primary
  2059. dimension; for definition of "concatenation", see ``torch.cat()``;
  2060. (ii) a stack of all the input tensors along the primary dimension;
  2061. for definition of "stack", see ``torch.stack()``.
  2062. Examples below may better explain the supported output forms.
  2063. input_tensor (Tensor): Tensor to be gathered from current rank.
  2064. Different from the ``all_gather`` API, the input tensors in this
  2065. API must have the same size across all ranks.
  2066. group (ProcessGroup, optional): The process group to work on. If None,
  2067. the default process group will be used.
  2068. async_op (bool, optional): Whether this op should be an async op
  2069. Returns:
  2070. Async work handle, if async_op is set to True.
  2071. None, if not async_op or if not part of the group
  2072. Examples:
  2073. >>> # xdoctest: +SKIP("need process group init")
  2074. >>> # All tensors below are of torch.int64 dtype and on CUDA devices.
  2075. >>> # We have two ranks.
  2076. >>> device = torch.device(f'cuda:{rank}')
  2077. >>> tensor_in = torch.arange(2, dtype=torch.int64, device=device) + 1 + 2 * rank
  2078. >>> tensor_in
  2079. tensor([1, 2], device='cuda:0') # Rank 0
  2080. tensor([3, 4], device='cuda:1') # Rank 1
  2081. >>> # Output in concatenation form
  2082. >>> tensor_out = torch.zeros(world_size * 2, dtype=torch.int64, device=device)
  2083. >>> dist.all_gather_into_tensor(tensor_out, tensor_in)
  2084. >>> tensor_out
  2085. tensor([1, 2, 3, 4], device='cuda:0') # Rank 0
  2086. tensor([1, 2, 3, 4], device='cuda:1') # Rank 1
  2087. >>> # Output in stack form
  2088. >>> tensor_out2 = torch.zeros(world_size, 2, dtype=torch.int64, device=device)
  2089. >>> dist.all_gather_into_tensor(tensor_out2, tensor_in)
  2090. >>> tensor_out2
  2091. tensor([[1, 2],
  2092. [3, 4]], device='cuda:0') # Rank 0
  2093. tensor([[1, 2],
  2094. [3, 4]], device='cuda:1') # Rank 1
  2095. .. warning::
  2096. The Gloo backend does not support this API.
  2097. """
  2098. _check_single_tensor(input_tensor, "input_tensor")
  2099. _check_single_tensor(output_tensor, "output_tensor")
  2100. if _rank_not_in_group(group):
  2101. _warn_not_in_group("all_gather_into_tensor")
  2102. return
  2103. output_tensor = (
  2104. output_tensor
  2105. if not output_tensor.is_complex()
  2106. else torch.view_as_real(output_tensor)
  2107. )
  2108. input_tensor = (
  2109. input_tensor
  2110. if not input_tensor.is_complex()
  2111. else torch.view_as_real(input_tensor)
  2112. )
  2113. if group is None:
  2114. default_pg = _get_default_group()
  2115. work = default_pg._allgather_base(output_tensor, input_tensor)
  2116. else:
  2117. work = group._allgather_base(output_tensor, input_tensor)
  2118. if async_op:
  2119. return work
  2120. else:
  2121. work.wait()
  2122. @exception_handler
  2123. def _all_gather_base(output_tensor, input_tensor, group=None, async_op=False):
  2124. """
  2125. Single tensor all gather. Gathers a single tensor from all ranks, and puts them in a single output tensor.
  2126. Args:
  2127. output_tensor (Tensor): Output tensor. It should contain
  2128. correctly-sized tensors to be used for output of the collective.
  2129. input_tensor (Tensor): Tensor to be broadcast from current process.
  2130. group (ProcessGroup, optional): The process group to work on. If None,
  2131. the default process group will be used.
  2132. async_op (bool, optional): Whether this op should be an async op
  2133. Returns:
  2134. Async work handle, if async_op is set to True.
  2135. None, if not async_op or if not part of the group
  2136. .. warning::
  2137. `_all_gather_base` is a private function. Users should use
  2138. `all_gather_into_tensor` instead.
  2139. """
  2140. warnings.warn(
  2141. "torch.distributed._all_gather_base is a private function and will be "
  2142. "deprecated. Please use torch.distributed.all_gather_into_tensor "
  2143. "instead."
  2144. )
  2145. return all_gather_into_tensor(output_tensor, input_tensor, group, async_op)
  2146. @exception_handler
  2147. def all_gather_coalesced(
  2148. output_tensor_lists, input_tensor_list, group=None, async_op=False
  2149. ):
  2150. """
  2151. Gathers input tensors from the whole group in a list in a coalesced manner.
  2152. Complex tensors are supported.
  2153. Args:
  2154. output_tensor_lists (list[list[Tensor]]): Output list. It should contain
  2155. correctly-sized tensors to be used for output of the collective.
  2156. input_tensor_list (list[Tensor]): Tensors to be broadcast from
  2157. current process. At least one tensor has to be non empty.
  2158. group (ProcessGroup, optional): The process group to work on. If None,
  2159. the default process group will be used.
  2160. async_op (bool, optional): Whether this op should be an async op.
  2161. Returns:
  2162. Async work handle, if async_op is set to True.
  2163. None, if not async_op or if not part of the group
  2164. Example:
  2165. we have 2 process groups, 2 ranks.
  2166. rank 0 passes:
  2167. input_tensor_list = [[[1, 1], [1, 1]], [2], [3, 3]]
  2168. output_tensor_lists =
  2169. [[[[-1, -1], [-1, -1]], [-1], [-1, -1]],
  2170. [[[-1, -1], [-1, -1]], [-1], [-1, -1]]]
  2171. rank 1 passes:
  2172. input_tensor_list = [[[3, 3], [3, 3]], [5], [1, 1]]
  2173. output_tensor_lists =
  2174. [[[[-1, -1], [-1, -1]], [-1], [-1, -1]],
  2175. [[[-1, -1], [-1, -1]], [-1], [-1, -1]]]
  2176. both rank 0 and 1 get:
  2177. output_tensor_lists =
  2178. [[[1, 1], [1, 1]], [2], [3, 3]],
  2179. [[3, 3], [3, 3]], [5], [1, 1]]].
  2180. WARNING: at this time individual shape checking is not implemented across nodes.
  2181. For example, if the rank 0 node passes [torch.rand(4), torch.rand(2)] and the
  2182. rank 1 node passes [torch.rand(2), torch.rand(2), torch.rand(2)], the
  2183. all_gather_coalesced operation will proceed without complaint and return
  2184. erroneous outputs. This lack of shape checking results in significant
  2185. performance improvements but users of this function should take extra care
  2186. to ensure that each node passes in tensors whose shapes match across nodes.
  2187. """
  2188. warnings.warn(
  2189. "torch.distributed.all_gather_coalesced will be deprecated. If you must "
  2190. "use it, please revisit our documentation later at "
  2191. "https://pytorch.org/docs/master/distributed.html#collective-functions"
  2192. )
  2193. # We only check basic compatibility with C++ params here, C++ code will
  2194. # do shape and type checking.
  2195. if _rank_not_in_group(group):
  2196. _warn_not_in_group("all_gather_coalesced")
  2197. return
  2198. _check_tensor_list(input_tensor_list, "input_tensor_list")
  2199. _ensure_all_tensors_same_dtype(input_tensor_list)
  2200. if not isinstance(output_tensor_lists, list):
  2201. raise RuntimeError(
  2202. "Invalid function argument: " "output_tensor_lists should be a list"
  2203. )
  2204. for output_tensor_list in output_tensor_lists:
  2205. _check_tensor_list(output_tensor_list, "output_tensor_lists")
  2206. _ensure_all_tensors_same_dtype(output_tensor_list)
  2207. output_tensor_lists = [
  2208. [t if not t.is_complex() else torch.view_as_real(t) for t in l]
  2209. for l in output_tensor_lists
  2210. ]
  2211. input_tensor_list = [
  2212. t if not t.is_complex() else torch.view_as_real(t) for t in input_tensor_list
  2213. ]
  2214. if group is None:
  2215. default_pg = _get_default_group()
  2216. work = default_pg.allgather_coalesced(output_tensor_lists, input_tensor_list)
  2217. else:
  2218. work = group.allgather_coalesced(output_tensor_lists, input_tensor_list)
  2219. if async_op:
  2220. return work.get_future()
  2221. else:
  2222. work.wait()
  2223. def _validate_output_list_for_rank(my_rank, dst, gather_list):
  2224. if dst == my_rank:
  2225. if not gather_list:
  2226. raise ValueError(
  2227. "Argument ``gather_list`` must be specified on destination rank."
  2228. )
  2229. elif gather_list:
  2230. raise ValueError(
  2231. "Argument ``gather_list`` must NOT be specified "
  2232. "on non-destination ranks."
  2233. )
  2234. @exception_handler
  2235. def gather(tensor, gather_list=None, dst=0, group=None, async_op=False):
  2236. """
  2237. Gathers a list of tensors in a single process.
  2238. Args:
  2239. tensor (Tensor): Input tensor.
  2240. gather_list (list[Tensor], optional): List of appropriately-sized
  2241. tensors to use for gathered data (default is None, must be specified
  2242. on the destination rank)
  2243. dst (int, optional): Destination rank (default is 0)
  2244. group (ProcessGroup, optional): The process group to work on. If None,
  2245. the default process group will be used.
  2246. async_op (bool, optional): Whether this op should be an async op
  2247. Returns:
  2248. Async work handle, if async_op is set to True.
  2249. None, if not async_op or if not part of the group
  2250. """
  2251. _check_single_tensor(tensor, "tensor")
  2252. # Parameter ``gather_list`` may be left unspecified on non-dst ranks.
  2253. if gather_list:
  2254. _check_tensor_list(gather_list, "gather_list")
  2255. else:
  2256. gather_list = []
  2257. _ensure_all_tensors_same_dtype(tensor, gather_list)
  2258. if _rank_not_in_group(group):
  2259. _warn_not_in_group("gather")
  2260. return
  2261. my_rank = get_rank()
  2262. _validate_output_list_for_rank(my_rank, dst, gather_list)
  2263. output_tensors = [gather_list] if dst == my_rank else []
  2264. input_tensors = [tensor]
  2265. opts = GatherOptions()
  2266. opts.rootRank = dst
  2267. if group is None or group is GroupMember.WORLD:
  2268. default_pg = _get_default_group()
  2269. work = default_pg.gather(output_tensors, input_tensors, opts)
  2270. else:
  2271. group_dst_rank = get_group_rank(group, dst)
  2272. opts.rootRank = group_dst_rank
  2273. work = group.gather(output_tensors, input_tensors, opts)
  2274. if async_op:
  2275. return work
  2276. else:
  2277. work.wait()
  2278. @exception_handler
  2279. def scatter(tensor, scatter_list=None, src=0, group=None, async_op=False):
  2280. """
  2281. Scatters a list of tensors to all processes in a group.
  2282. Each process will receive exactly one tensor and store its data in the
  2283. ``tensor`` argument.
  2284. Complex tensors are supported.
  2285. Args:
  2286. tensor (Tensor): Output tensor.
  2287. scatter_list (list[Tensor]): List of tensors to scatter (default is
  2288. None, must be specified on the source rank)
  2289. src (int): Source rank (default is 0)
  2290. group (ProcessGroup, optional): The process group to work on. If None,
  2291. the default process group will be used.
  2292. async_op (bool, optional): Whether this op should be an async op
  2293. Returns:
  2294. Async work handle, if async_op is set to True.
  2295. None, if not async_op or if not part of the group
  2296. .. note:: Note that all Tensors in scatter_list must have the same size.
  2297. Example::
  2298. >>> # xdoctest: +SKIP("need process group init")
  2299. >>> # Note: Process group initialization omitted on each rank.
  2300. >>> import torch.distributed as dist
  2301. >>> tensor_size = 2
  2302. >>> t_ones = torch.ones(tensor_size)
  2303. >>> t_fives = torch.ones(tensor_size) * 5
  2304. >>> output_tensor = torch.zeros(tensor_size)
  2305. >>> if dist.get_rank() == 0:
  2306. >>> # Assumes world_size of 2.
  2307. >>> # Only tensors, all of which must be the same size.
  2308. >>> scatter_list = [t_ones, t_fives]
  2309. >>> else:
  2310. >>> scatter_list = None
  2311. >>> dist.scatter(output_tensor, scatter_list, src=0)
  2312. >>> # Rank i gets scatter_list[i]. For example, on rank 1:
  2313. >>> output_tensor
  2314. tensor([5., 5.])
  2315. """
  2316. _check_single_tensor(tensor, "tensor")
  2317. # Parameter ``scatter_list`` may be left unspecified on non-src ranks.
  2318. if scatter_list:
  2319. _check_tensor_list(scatter_list, "scatter_list")
  2320. else:
  2321. scatter_list = []
  2322. _ensure_all_tensors_same_dtype(tensor, scatter_list)
  2323. if _rank_not_in_group(group):
  2324. _warn_not_in_group("scatter")
  2325. return
  2326. scatter_list = [
  2327. t if not t.is_complex() else torch.view_as_real(t) for t in scatter_list
  2328. ]
  2329. tensor = tensor if not tensor.is_complex() else torch.view_as_real(tensor)
  2330. my_rank = get_rank()
  2331. if src == my_rank:
  2332. if not scatter_list:
  2333. raise ValueError(
  2334. "Argument ``scatter_list`` must be specified " "on source rank."
  2335. )
  2336. input_tensors = [scatter_list]
  2337. output_tensors = [tensor]
  2338. else:
  2339. if scatter_list:
  2340. raise ValueError(
  2341. "Argument ``scatter_list`` must NOT be specified "
  2342. "on non-source ranks."
  2343. )
  2344. input_tensors = []
  2345. output_tensors = [tensor]
  2346. opts = ScatterOptions()
  2347. opts.rootRank = src
  2348. if group is None or group is GroupMember.WORLD:
  2349. default_pg = _get_default_group()
  2350. work = default_pg.scatter(output_tensors, input_tensors, opts)
  2351. else:
  2352. group_src_rank = get_group_rank(group, src)
  2353. opts.rootRank = group_src_rank
  2354. work = group.scatter(output_tensors, input_tensors, opts)
  2355. if async_op:
  2356. return work
  2357. else:
  2358. work.wait()
  2359. @exception_handler
  2360. def reduce_scatter_multigpu(
  2361. output_tensor_list, input_tensor_lists, op=ReduceOp.SUM, group=None, async_op=False
  2362. ):
  2363. """
  2364. Reduce and scatter a list of tensors to the whole group. Only nccl backend
  2365. is currently supported.
  2366. Each tensor in ``output_tensor_list`` should reside on a separate GPU, as
  2367. should each list of tensors in ``input_tensor_lists``.
  2368. Args:
  2369. output_tensor_list (List[Tensor]): Output tensors (on different GPUs)
  2370. to receive the result of the operation.
  2371. Note that ``len(output_tensor_list)`` needs to be the same for all
  2372. the distributed processes calling this function.
  2373. input_tensor_lists (List[List[Tensor]]): Input lists. It should
  2374. contain correctly-sized tensors on each GPU to be used for input of
  2375. the collective, e.g. ``input_tensor_lists[i]`` contains the
  2376. reduce_scatter input that resides on the GPU of
  2377. ``output_tensor_list[i]``.
  2378. Note that each element of ``input_tensor_lists`` has the size of
  2379. ``world_size * len(output_tensor_list)``, since the function
  2380. scatters the result from every single GPU in the group. To
  2381. interpret each element of ``input_tensor_lists[i]``, note that
  2382. ``output_tensor_list[j]`` of rank k receives the reduce-scattered
  2383. result from ``input_tensor_lists[i][k * world_size + j]``
  2384. Also note that ``len(input_tensor_lists)``, and the size of each
  2385. element in ``input_tensor_lists`` (each element is a list,
  2386. therefore ``len(input_tensor_lists[i])``) need to be the same for
  2387. all the distributed processes calling this function.
  2388. group (ProcessGroup, optional): The process group to work on. If None,
  2389. the default process group will be used.
  2390. async_op (bool, optional): Whether this op should be an async op.
  2391. Returns:
  2392. Async work handle, if async_op is set to True.
  2393. None, if not async_op or if not part of the group.
  2394. """
  2395. warnings.warn(
  2396. "torch.distributed.reduce_scatter_multigpu will be deprecated. If you must "
  2397. "use it, please revisit our documentation later at "
  2398. "https://pytorch.org/docs/master/distributed.html#multi-gpu-collective-functions"
  2399. )
  2400. if _rank_not_in_group(group):
  2401. _warn_not_in_group("reduce_scatter_multigpu")
  2402. return
  2403. opts = ReduceScatterOptions()
  2404. opts.reduceOp = op
  2405. if group is None:
  2406. default_pg = _get_default_group()
  2407. work = default_pg.reduce_scatter(output_tensor_list, input_tensor_lists, opts)
  2408. else:
  2409. work = group.reduce_scatter(output_tensor_list, input_tensor_lists, opts)
  2410. if async_op:
  2411. return work
  2412. else:
  2413. work.wait()
  2414. @exception_handler
  2415. def reduce_scatter(output, input_list, op=ReduceOp.SUM, group=None, async_op=False):
  2416. """
  2417. Reduces, then scatters a list of tensors to all processes in a group.
  2418. Args:
  2419. output (Tensor): Output tensor.
  2420. input_list (list[Tensor]): List of tensors to reduce and scatter.
  2421. op (optional): One of the values from
  2422. ``torch.distributed.ReduceOp``
  2423. enum. Specifies an operation used for element-wise reductions.
  2424. group (ProcessGroup, optional): The process group to work on. If None,
  2425. the default process group will be used.
  2426. async_op (bool, optional): Whether this op should be an async op.
  2427. Returns:
  2428. Async work handle, if async_op is set to True.
  2429. None, if not async_op or if not part of the group.
  2430. """
  2431. _check_single_tensor(output, "output")
  2432. _check_tensor_list(input_list, "input_list")
  2433. _ensure_all_tensors_same_dtype(output, input_list)
  2434. if _rank_not_in_group(group):
  2435. _warn_not_in_group("reduce_scatter")
  2436. return
  2437. opts = ReduceScatterOptions()
  2438. opts.reduceOp = op
  2439. if group is None:
  2440. default_pg = _get_default_group()
  2441. work = default_pg.reduce_scatter([output], [input_list], opts)
  2442. else:
  2443. work = group.reduce_scatter([output], [input_list], opts)
  2444. if async_op:
  2445. return work
  2446. else:
  2447. work.wait()
  2448. @exception_handler
  2449. def reduce_scatter_tensor(output, input, op=ReduceOp.SUM, group=None, async_op=False):
  2450. """
  2451. Reduces, then scatters a tensor to all ranks in a group.
  2452. Args:
  2453. output (Tensor): Output tensor. It should have the same size across all
  2454. ranks.
  2455. input (Tensor): Input tensor to be reduced and scattered. Its size
  2456. should be output tensor size times the world size. The input tensor
  2457. can have one of the following shapes:
  2458. (i) a concatenation of the output tensors along the primary
  2459. dimension, or
  2460. (ii) a stack of the output tensors along the primary dimension.
  2461. For definition of "concatenation", see ``torch.cat()``.
  2462. For definition of "stack", see ``torch.stack()``.
  2463. group (ProcessGroup, optional): The process group to work on. If None,
  2464. the default process group will be used.
  2465. async_op (bool, optional): Whether this op should be an async op.
  2466. Returns:
  2467. Async work handle, if async_op is set to True.
  2468. None, if not async_op or if not part of the group.
  2469. Examples:
  2470. >>> # xdoctest: +SKIP("need process group init")
  2471. >>> # All tensors below are of torch.int64 dtype and on CUDA devices.
  2472. >>> # We have two ranks.
  2473. >>> device = torch.device(f'cuda:{rank}')
  2474. >>> tensor_out = torch.zeros(2, dtype=torch.int64, device=device)
  2475. >>> # Input in concatenation form
  2476. >>> tensor_in = torch.arange(world_size * 2, dtype=torch.int64, device=device)
  2477. >>> tensor_in
  2478. tensor([0, 1, 2, 3], device='cuda:0') # Rank 0
  2479. tensor([0, 1, 2, 3], device='cuda:1') # Rank 1
  2480. >>> dist.reduce_scatter_tensor(tensor_out, tensor_in)
  2481. >>> tensor_out
  2482. tensor([0, 2], device='cuda:0') # Rank 0
  2483. tensor([4, 6], device='cuda:1') # Rank 1
  2484. >>> # Input in stack form
  2485. >>> tensor_in = torch.reshape(tensor_in, (world_size, 2))
  2486. >>> tensor_in
  2487. tensor([[0, 1],
  2488. [2, 3]], device='cuda:0') # Rank 0
  2489. tensor([[0, 1],
  2490. [2, 3]], device='cuda:1') # Rank 1
  2491. >>> dist.reduce_scatter_tensor(tensor_out, tensor_in)
  2492. >>> tensor_out
  2493. tensor([0, 2], device='cuda:0') # Rank 0
  2494. tensor([4, 6], device='cuda:1') # Rank 1
  2495. .. warning::
  2496. The Gloo backend does not support this API.
  2497. """
  2498. _check_single_tensor(output, "output")
  2499. _check_single_tensor(input, "input")
  2500. if _rank_not_in_group(group):
  2501. _warn_not_in_group("reduce_scatter_tensor")
  2502. return
  2503. opts = ReduceScatterOptions()
  2504. opts.reduceOp = op
  2505. if group is None:
  2506. default_pg = _get_default_group()
  2507. work = default_pg._reduce_scatter_base(output, input, opts)
  2508. else:
  2509. work = group._reduce_scatter_base(output, input, opts)
  2510. if async_op:
  2511. return work
  2512. else:
  2513. work.wait()
  2514. def _reduce_scatter_base(output, input, op=ReduceOp.SUM, group=None, async_op=False):
  2515. """
  2516. Reduces, then scatters a flattened tensor to all processes in a group.
  2517. Args:
  2518. output (Tensor): Output tensor.
  2519. input (Tensor): Input tensor that is of size output tensor size times world size
  2520. group (ProcessGroup, optional): The process group to work on. If None,
  2521. the default process group will be used.
  2522. async_op (bool, optional): Whether this op should be an async op.
  2523. Returns:
  2524. Async work handle, if async_op is set to True.
  2525. None, if not async_op or if not part of the group.
  2526. .. warning::
  2527. `_reduce_scatter_base` is a private function. Users should use
  2528. `reduce_scatter_tensor` instead.
  2529. """
  2530. warnings.warn(
  2531. "torch.distributed._reduce_scatter_base is a private function and will "
  2532. "be deprecated. Please use torch.distributed.reduce_scatter_tensor "
  2533. "instead."
  2534. )
  2535. return reduce_scatter_tensor(output, input, op, group, async_op)
  2536. @exception_handler
  2537. def all_to_all_single(
  2538. output,
  2539. input,
  2540. output_split_sizes=None,
  2541. input_split_sizes=None,
  2542. group=None,
  2543. async_op=False,
  2544. ):
  2545. """
  2546. Each process splits input tensor and then scatters the split list
  2547. to all processes in a group. Then concatenate the received tensors from all
  2548. the processes in the group and return single output tensor.
  2549. Complex tensors are supported.
  2550. Args:
  2551. output (Tensor): Gathered cancatenated output tensor.
  2552. input (Tensor): Input tensor to scatter.
  2553. output_split_sizes: (list[Int], optional): Output split sizes for dim 0
  2554. if specified None or empty, dim 0 of ``output`` tensor must divide
  2555. equally by ``world_size``.
  2556. input_split_sizes: (list[Int], optional): Input split sizes for dim 0
  2557. if specified None or empty, dim 0 of ``input`` tensor must divide
  2558. equally by ``world_size``.
  2559. group (ProcessGroup, optional): The process group to work on. If None,
  2560. the default process group will be used.
  2561. async_op (bool, optional): Whether this op should be an async op.
  2562. Returns:
  2563. Async work handle, if async_op is set to True.
  2564. None, if not async_op or if not part of the group.
  2565. .. warning::
  2566. `all_to_all_single` is experimental and subject to change.
  2567. Examples:
  2568. >>> # xdoctest: +SKIP("Undefined rank")
  2569. >>> input = torch.arange(4) + rank * 4
  2570. >>> input
  2571. tensor([0, 1, 2, 3]) # Rank 0
  2572. tensor([4, 5, 6, 7]) # Rank 1
  2573. tensor([8, 9, 10, 11]) # Rank 2
  2574. tensor([12, 13, 14, 15]) # Rank 3
  2575. >>> output = torch.empty([4], dtype=torch.int64)
  2576. >>> dist.all_to_all_single(output, input)
  2577. >>> output
  2578. tensor([0, 4, 8, 12]) # Rank 0
  2579. tensor([1, 5, 9, 13]) # Rank 1
  2580. tensor([2, 6, 10, 14]) # Rank 2
  2581. tensor([3, 7, 11, 15]) # Rank 3
  2582. >>> # Essentially, it is similar to following operation:
  2583. >>> scatter_list = list(input.chunk(world_size))
  2584. >>> gather_list = list(output.chunk(world_size))
  2585. >>> for i in range(world_size):
  2586. >>> dist.scatter(gather_list[i], scatter_list if i == rank else [], src = i)
  2587. >>> # Another example with uneven split
  2588. >>> input
  2589. tensor([0, 1, 2, 3, 4, 5]) # Rank 0
  2590. tensor([10, 11, 12, 13, 14, 15, 16, 17, 18]) # Rank 1
  2591. tensor([20, 21, 22, 23, 24]) # Rank 2
  2592. tensor([30, 31, 32, 33, 34, 35, 36]) # Rank 3
  2593. >>> input_splits
  2594. [2, 2, 1, 1] # Rank 0
  2595. [3, 2, 2, 2] # Rank 1
  2596. [2, 1, 1, 1] # Rank 2
  2597. [2, 2, 2, 1] # Rank 3
  2598. >>> output_splits
  2599. [2, 3, 2, 2] # Rank 0
  2600. [2, 2, 1, 2] # Rank 1
  2601. [1, 2, 1, 2] # Rank 2
  2602. [1, 2, 1, 1] # Rank 3
  2603. >>> output = ...
  2604. >>> dist.all_to_all_single(output, input, output_splits, input_splits)
  2605. >>> output
  2606. tensor([ 0, 1, 10, 11, 12, 20, 21, 30, 31]) # Rank 0
  2607. tensor([ 2, 3, 13, 14, 22, 32, 33]) # Rank 1
  2608. tensor([ 4, 15, 16, 23, 34, 35]) # Rank 2
  2609. tensor([ 5, 17, 18, 24, 36]) # Rank 3
  2610. >>> # Another example with tensors of torch.cfloat type.
  2611. >>> input = torch.tensor([1+1j, 2+2j, 3+3j, 4+4j], dtype=torch.cfloat) + 4 * rank * (1+1j)
  2612. >>> input
  2613. tensor([1+1j, 2+2j, 3+3j, 4+4j]) # Rank 0
  2614. tensor([5+5j, 6+6j, 7+7j, 8+8j]) # Rank 1
  2615. tensor([9+9j, 10+10j, 11+11j, 12+12j]) # Rank 2
  2616. tensor([13+13j, 14+14j, 15+15j, 16+16j]) # Rank 3
  2617. >>> output = torch.empty([4], dtype=torch.int64)
  2618. >>> dist.all_to_all_single(output, input)
  2619. >>> output
  2620. tensor([1+1j, 5+5j, 9+9j, 13+13j]) # Rank 0
  2621. tensor([2+2j, 6+6j, 10+10j, 14+14j]) # Rank 1
  2622. tensor([3+3j, 7+7j, 11+11j, 15+15j]) # Rank 2
  2623. tensor([4+4j, 8+8j, 12+12j, 16+16j]) # Rank 3
  2624. """
  2625. if _rank_not_in_group(group):
  2626. _warn_not_in_group("all_to_all_single")
  2627. return
  2628. opts = AllToAllOptions()
  2629. _check_single_tensor(output, "output")
  2630. _check_single_tensor(input, "input")
  2631. _ensure_all_tensors_same_dtype(output, input)
  2632. if input.is_complex():
  2633. input = torch.view_as_real(input)
  2634. if output.is_complex():
  2635. output = torch.view_as_real(output)
  2636. output_split_sizes = [] if output_split_sizes is None else output_split_sizes
  2637. input_split_sizes = [] if input_split_sizes is None else input_split_sizes
  2638. if group is None:
  2639. default_pg = _get_default_group()
  2640. work = default_pg.alltoall_base(
  2641. output, input, output_split_sizes, input_split_sizes, opts
  2642. )
  2643. else:
  2644. work = group.alltoall_base(
  2645. output, input, output_split_sizes, input_split_sizes, opts
  2646. )
  2647. if async_op:
  2648. return work
  2649. else:
  2650. work.wait()
  2651. @exception_handler
  2652. def all_to_all(output_tensor_list, input_tensor_list, group=None, async_op=False):
  2653. """
  2654. Each process scatters list of input tensors to all processes in a group and
  2655. return gathered list of tensors in output list.
  2656. Complex tensors are supported.
  2657. Args:
  2658. output_tensor_list (list[Tensor]): List of tensors to be gathered one
  2659. per rank.
  2660. input_tensor_list (list[Tensor]): List of tensors to scatter one per rank.
  2661. group (ProcessGroup, optional): The process group to work on. If None,
  2662. the default process group will be used.
  2663. async_op (bool, optional): Whether this op should be an async op.
  2664. Returns:
  2665. Async work handle, if async_op is set to True.
  2666. None, if not async_op or if not part of the group.
  2667. .. warning::
  2668. `all_to_all` is experimental and subject to change.
  2669. Examples:
  2670. >>> # xdoctest: +SKIP("Undefined rank")
  2671. >>> input = torch.arange(4) + rank * 4
  2672. >>> input = list(input.chunk(4))
  2673. >>> input
  2674. [tensor([0]), tensor([1]), tensor([2]), tensor([3])] # Rank 0
  2675. [tensor([4]), tensor([5]), tensor([6]), tensor([7])] # Rank 1
  2676. [tensor([8]), tensor([9]), tensor([10]), tensor([11])] # Rank 2
  2677. [tensor([12]), tensor([13]), tensor([14]), tensor([15])] # Rank 3
  2678. >>> output = list(torch.empty([4], dtype=torch.int64).chunk(4))
  2679. >>> dist.all_to_all(output, input)
  2680. >>> output
  2681. [tensor([0]), tensor([4]), tensor([8]), tensor([12])] # Rank 0
  2682. [tensor([1]), tensor([5]), tensor([9]), tensor([13])] # Rank 1
  2683. [tensor([2]), tensor([6]), tensor([10]), tensor([14])] # Rank 2
  2684. [tensor([3]), tensor([7]), tensor([11]), tensor([15])] # Rank 3
  2685. >>> # Essentially, it is similar to following operation:
  2686. >>> scatter_list = input
  2687. >>> gather_list = output
  2688. >>> for i in range(world_size):
  2689. >>> dist.scatter(gather_list[i], scatter_list if i == rank else [], src=i)
  2690. >>> input
  2691. tensor([0, 1, 2, 3, 4, 5]) # Rank 0
  2692. tensor([10, 11, 12, 13, 14, 15, 16, 17, 18]) # Rank 1
  2693. tensor([20, 21, 22, 23, 24]) # Rank 2
  2694. tensor([30, 31, 32, 33, 34, 35, 36]) # Rank 3
  2695. >>> input_splits
  2696. [2, 2, 1, 1] # Rank 0
  2697. [3, 2, 2, 2] # Rank 1
  2698. [2, 1, 1, 1] # Rank 2
  2699. [2, 2, 2, 1] # Rank 3
  2700. >>> output_splits
  2701. [2, 3, 2, 2] # Rank 0
  2702. [2, 2, 1, 2] # Rank 1
  2703. [1, 2, 1, 2] # Rank 2
  2704. [1, 2, 1, 1] # Rank 3
  2705. >>> input = list(input.split(input_splits))
  2706. >>> input
  2707. [tensor([0, 1]), tensor([2, 3]), tensor([4]), tensor([5])] # Rank 0
  2708. [tensor([10, 11, 12]), tensor([13, 14]), tensor([15, 16]), tensor([17, 18])] # Rank 1
  2709. [tensor([20, 21]), tensor([22]), tensor([23]), tensor([24])] # Rank 2
  2710. [tensor([30, 31]), tensor([32, 33]), tensor([34, 35]), tensor([36])] # Rank 3
  2711. >>> output = ...
  2712. >>> dist.all_to_all(output, input)
  2713. >>> output
  2714. [tensor([0, 1]), tensor([10, 11, 12]), tensor([20, 21]), tensor([30, 31])] # Rank 0
  2715. [tensor([2, 3]), tensor([13, 14]), tensor([22]), tensor([32, 33])] # Rank 1
  2716. [tensor([4]), tensor([15, 16]), tensor([23]), tensor([34, 35])] # Rank 2
  2717. [tensor([5]), tensor([17, 18]), tensor([24]), tensor([36])] # Rank 3
  2718. >>> # Another example with tensors of torch.cfloat type.
  2719. >>> input = torch.tensor([1+1j, 2+2j, 3+3j, 4+4j], dtype=torch.cfloat) + 4 * rank * (1+1j)
  2720. >>> input = list(input.chunk(4))
  2721. >>> input
  2722. [tensor([1+1j]), tensor([2+2j]), tensor([3+3j]), tensor([4+4j])] # Rank 0
  2723. [tensor([5+5j]), tensor([6+6j]), tensor([7+7j]), tensor([8+8j])] # Rank 1
  2724. [tensor([9+9j]), tensor([10+10j]), tensor([11+11j]), tensor([12+12j])] # Rank 2
  2725. [tensor([13+13j]), tensor([14+14j]), tensor([15+15j]), tensor([16+16j])] # Rank 3
  2726. >>> output = list(torch.empty([4], dtype=torch.int64).chunk(4))
  2727. >>> dist.all_to_all(output, input)
  2728. >>> output
  2729. [tensor([1+1j]), tensor([5+5j]), tensor([9+9j]), tensor([13+13j])] # Rank 0
  2730. [tensor([2+2j]), tensor([6+6j]), tensor([10+10j]), tensor([14+14j])] # Rank 1
  2731. [tensor([3+3j]), tensor([7+7j]), tensor([11+11j]), tensor([15+15j])] # Rank 2
  2732. [tensor([4+4j]), tensor([8+8j]), tensor([12+12j]), tensor([16+16j])] # Rank 3
  2733. """
  2734. if _rank_not_in_group(group):
  2735. _warn_not_in_group("all_to_all")
  2736. return
  2737. opts = AllToAllOptions()
  2738. _check_tensor_list(output_tensor_list, "output_tensor_list")
  2739. _check_tensor_list(input_tensor_list, "input_tensor_list")
  2740. _ensure_all_tensors_same_dtype(output_tensor_list, input_tensor_list)
  2741. input_tensor_list = [
  2742. t if not t.is_complex() else torch.view_as_real(t) for t in input_tensor_list
  2743. ]
  2744. output_tensor_list = [
  2745. t if not t.is_complex() else torch.view_as_real(t) for t in output_tensor_list
  2746. ]
  2747. if group is None:
  2748. default_pg = _get_default_group()
  2749. work = default_pg.alltoall(output_tensor_list, input_tensor_list, opts)
  2750. else:
  2751. work = group.alltoall(output_tensor_list, input_tensor_list, opts)
  2752. if async_op:
  2753. return work
  2754. else:
  2755. work.wait()
  2756. def barrier(group=GroupMember.WORLD, async_op=False, device_ids=None):
  2757. """
  2758. Synchronizes all processes.
  2759. This collective blocks processes until the whole group enters this function,
  2760. if async_op is False, or if async work handle is called on wait().
  2761. Args:
  2762. group (ProcessGroup, optional): The process group to work on. If None,
  2763. the default process group will be used.
  2764. async_op (bool, optional): Whether this op should be an async op
  2765. device_ids ([int], optional): List of device/GPU ids.
  2766. Valid only for NCCL backend.
  2767. Returns:
  2768. Async work handle, if async_op is set to True.
  2769. None, if not async_op or if not part of the group
  2770. """
  2771. if _rank_not_in_group(group):
  2772. _warn_not_in_group("barrier")
  2773. return
  2774. opts = BarrierOptions()
  2775. if device_ids is not None:
  2776. if get_backend(group) != Backend.NCCL:
  2777. raise RuntimeError(
  2778. "Function argument device_ids not supported "
  2779. "for the selected backend {}".format(get_backend(group))
  2780. )
  2781. if isinstance(device_ids, list):
  2782. opts.device_ids = device_ids
  2783. else:
  2784. raise RuntimeError(
  2785. "Invalid function argument: " "device_ids type should be List[int]"
  2786. )
  2787. if group is None:
  2788. default_pg = _get_default_group()
  2789. work = default_pg.barrier(opts=opts)
  2790. else:
  2791. work = group.barrier(opts=opts)
  2792. if async_op:
  2793. return work
  2794. else:
  2795. work.wait()
  2796. def monitored_barrier(group=GroupMember.WORLD, timeout=None, wait_all_ranks=False):
  2797. """
  2798. Synchronizes all processes similar to ``torch.distributed.barrier``, but takes
  2799. a configurable timeout and is able to report ranks that did not pass this
  2800. barrier within that timeout. Specifically, for non-zero ranks, will block
  2801. until a send/recv is processed from rank 0. Rank 0 will block until all send
  2802. /recv from other ranks are processed, and will report failures for ranks
  2803. that failed to respond in time. Note that if one rank does not reach the
  2804. monitored_barrier (for example due to a hang), all other ranks would fail
  2805. in monitored_barrier.
  2806. This collective will block all processes/ranks in the group, until the
  2807. whole group exits the function successfully, making it useful for debugging
  2808. and synchronizing. However, it can have a performance impact and should only
  2809. be used for debugging or scenarios that require full synchronization points
  2810. on the host-side. For debugging purposes, this barrier can be inserted
  2811. before the application's collective calls to check if any ranks are
  2812. desynchronized.
  2813. .. note:: Note that this collective is only supported with the GLOO backend.
  2814. Args:
  2815. group (ProcessGroup, optional): The process group to work on. If
  2816. ``None``, the default process group will be used.
  2817. timeout (datetime.timedelta, optional): Timeout for monitored_barrier.
  2818. If ``None``, the default process group timeout will be used.
  2819. wait_all_ranks (bool, optional): Whether to collect all failed ranks or
  2820. not. By default, this is ``False`` and ``monitored_barrier`` on rank 0
  2821. will throw on the first failed rank it encounters in order to fail
  2822. fast. By setting ``wait_all_ranks=True`` ``monitored_barrier`` will
  2823. collect all failed ranks and throw an error containing information
  2824. about all failed ranks.
  2825. Returns:
  2826. ``None``.
  2827. Example::
  2828. >>> # xdoctest: +SKIP("need process group init")
  2829. >>> # Note: Process group initialization omitted on each rank.
  2830. >>> import torch.distributed as dist
  2831. >>> if dist.get_rank() != 1:
  2832. >>> dist.monitored_barrier() # Raises exception indicating that
  2833. >>> # rank 1 did not call into monitored_barrier.
  2834. >>> # Example with wait_all_ranks=True
  2835. >>> if dist.get_rank() == 0:
  2836. >>> dist.monitored_barrier(wait_all_ranks=True) # Raises exception
  2837. >>> # indicating that ranks 1, 2, ... world_size - 1 did not call into
  2838. >>> # monitored_barrier.
  2839. """
  2840. # Need to call rank not in group before using the group, otherwise
  2841. # "Invalid process group" error is raised.
  2842. if _rank_not_in_group(group):
  2843. _warn_not_in_group("monitored_barrier")
  2844. return
  2845. if get_backend(group) != Backend.GLOO:
  2846. raise RuntimeError("monitored_barrier is only implemented for GLOO backend.")
  2847. if timeout is None:
  2848. timeout = default_pg_timeout
  2849. group_to_use = _get_default_group() if group is None else group
  2850. return group_to_use.monitored_barrier(timeout, wait_all_ranks=wait_all_ranks)
  2851. def _create_process_group_wrapper(
  2852. wrapped_pg: ProcessGroup,
  2853. store_prefix: str,
  2854. store: Store,
  2855. rank: int,
  2856. world_size: int,
  2857. timeout: timedelta = default_pg_timeout,
  2858. ):
  2859. # Create a separate prefix store for the helper process group.
  2860. prefix = f"{PG_WRAPPER_STORE_PREFIX}:{store_prefix}"
  2861. store = PrefixStore(prefix, store)
  2862. helper_pg = ProcessGroupGloo(store, rank, world_size, timeout=timeout)
  2863. # Wrap the underlying pg with ProcessGroupWrapper.
  2864. wrapped_pg = _ProcessGroupWrapper(wrapped_pg, helper_pg)
  2865. return wrapped_pg
  2866. def new_group(ranks=None, timeout=default_pg_timeout, backend=None, pg_options=None):
  2867. """
  2868. Creates a new distributed group.
  2869. This function requires that all processes in the main group (i.e. all
  2870. processes that are part of the distributed job) enter this function, even
  2871. if they are not going to be members of the group. Additionally, groups
  2872. should be created in the same order in all processes.
  2873. .. warning::
  2874. Using multiple process groups with the ``NCCL`` backend concurrently
  2875. is not safe and the user should perform explicit synchronization in
  2876. their application to ensure only one process group is used at a time.
  2877. This means collectives from one process group should have completed
  2878. execution on the device (not just enqueued since CUDA execution is
  2879. async) before collectives from another process group are enqueued.
  2880. See `Using multiple NCCL communicators concurrently <https://docs.nvid
  2881. ia.com/deeplearning/nccl/user-guide/docs/usage/communicators.html#using
  2882. -multiple-nccl-communicators-concurrently>`_ for more details.
  2883. Args:
  2884. ranks (list[int]): List of ranks of group members. If ``None``, will be
  2885. set to all ranks. Default is ``None``.
  2886. timeout (timedelta, optional): Timeout for operations executed against
  2887. the process group. Default value equals 30 minutes.
  2888. This is applicable for the ``gloo`` backend. For ``nccl``, this is
  2889. applicable only if the environment variable ``NCCL_BLOCKING_WAIT``
  2890. or ``NCCL_ASYNC_ERROR_HANDLING`` is set to 1. When
  2891. ``NCCL_BLOCKING_WAIT`` is set, this is the duration for which the
  2892. process will block and wait for collectives to complete before
  2893. throwing an exception. When ``NCCL_ASYNC_ERROR_HANDLING`` is set,
  2894. this is the duration after which collectives will be aborted
  2895. asynchronously and the process will crash. ``NCCL_BLOCKING_WAIT``
  2896. will provide errors to the user which can be caught and handled,
  2897. but due to its blocking nature, it has a performance overhead. On
  2898. the other hand, ``NCCL_ASYNC_ERROR_HANDLING`` has very little
  2899. performance overhead, but crashes the process on errors. This is
  2900. done since CUDA execution is async and it is no longer safe to
  2901. continue executing user code since failed async NCCL operations
  2902. might result in subsequent CUDA operations running on corrupted
  2903. data. Only one of these two environment variables should be set.
  2904. backend (str or Backend, optional): The backend to use. Depending on
  2905. build-time configurations, valid values are ``gloo`` and ``nccl``.
  2906. By default uses the same backend as the global group. This field
  2907. should be given as a lowercase string (e.g., ``"gloo"``), which can
  2908. also be accessed via :class:`Backend` attributes (e.g.,
  2909. ``Backend.GLOO``). If ``None`` is passed in, the backend
  2910. corresponding to the default process group will be used. Default is
  2911. ``None``.
  2912. pg_options (ProcessGroupOptions, optional): process group options
  2913. specifying what additional options need to be passed in during
  2914. the construction of specific process groups. i.e. for the ``nccl``
  2915. backend, ``is_high_priority_stream`` can be specified so that
  2916. process group can pick up high priority cuda streams.
  2917. Returns:
  2918. A handle of distributed group that can be given to collective calls.
  2919. """
  2920. global _world
  2921. default_pg = _get_default_group()
  2922. default_backend, default_store = _world.pg_map[default_pg]
  2923. global_rank = default_pg.rank()
  2924. global_world_size = default_pg.size()
  2925. # Default to the same backend as the global process group
  2926. # if the backend is not specified.
  2927. if not backend:
  2928. backend = default_backend
  2929. # checks the input ranks
  2930. if ranks is not None:
  2931. ranks = sorted(ranks)
  2932. group_world_size = len(ranks)
  2933. if group_world_size > global_world_size:
  2934. raise RuntimeError(
  2935. "the new group's world size should be less or "
  2936. "equal to the world size set by "
  2937. "init_process_group"
  2938. )
  2939. # check ranks' sanity
  2940. for rank in ranks:
  2941. if rank < 0 or rank >= global_world_size:
  2942. raise RuntimeError(
  2943. "The new group's rank should be within the "
  2944. "the world_size set by init_process_group"
  2945. )
  2946. if global_rank in ranks:
  2947. group_rank = ranks.index(global_rank)
  2948. else:
  2949. group_rank = None
  2950. else:
  2951. ranks = list(range(global_world_size))
  2952. group_world_size = global_world_size
  2953. group_rank = global_rank
  2954. backend = Backend(backend)
  2955. with record_function(f"## process_group:init with ranks: {ranks}"):
  2956. pg = _new_process_group_helper(
  2957. group_world_size,
  2958. group_rank,
  2959. ranks,
  2960. backend,
  2961. default_store,
  2962. pg_options=pg_options,
  2963. timeout=timeout,
  2964. )
  2965. # Create the global rank to group rank mapping
  2966. _world.pg_group_ranks[pg] = {
  2967. global_rank: group_rank for group_rank, global_rank in enumerate(ranks)
  2968. }
  2969. # barrier at the end to ensure that once we return from this method, all
  2970. # process groups including global variables are updated correctly on all
  2971. # ranks.
  2972. if backend == Backend.MPI:
  2973. # MPI doesn't have store.
  2974. barrier()
  2975. else:
  2976. # Use store based barrier here since barrier() used a bunch of
  2977. # default devices and messes up NCCL internal state.
  2978. _store_based_barrier(global_rank, default_store, timeout)
  2979. return pg
  2980. def new_subgroups(
  2981. group_size=None,
  2982. group=None,
  2983. timeout=default_pg_timeout,
  2984. backend=None,
  2985. pg_options=None,
  2986. ):
  2987. """
  2988. Creates GPU subgroups of equal size. By default, it creates intra-machine subgroups,
  2989. where each of which contains all the ranks of a machine, based on the assumption
  2990. that each machine has the same number of CUDA devices.
  2991. This is a convenience API that calls ``new_group`` to generate multiple subgroups.
  2992. It requires that all processes in the main group (i.e. all
  2993. processes that are part of the distributed job) enter this function, even
  2994. if they are not going to be members of the group.
  2995. .. warning::
  2996. This API only works when CUDA is available.
  2997. .. warning::
  2998. If ``group_size`` is passed in, the world size must be divisible by ``group_size``.
  2999. If no ``group_size`` is passed in, and not all the machines have the same number
  3000. of devices, the subgroup division will be different across nodes and can cause
  3001. unexpected behaviors.
  3002. .. warning::
  3003. Using multiple process groups with the ``NCCL`` backend concurrently
  3004. is not safe and the user should perform explicit synchronization in
  3005. their application to ensure only one process group is used at a time.
  3006. This means collectives from one process group should have completed
  3007. execution on the device (not just enqueued since CUDA execution is
  3008. async) before collectives from another process group are enqueued.
  3009. See `Using multiple NCCL communicators concurrently <https://docs.nvid
  3010. ia.com/deeplearning/nccl/user-guide/docs/usage/communicators.html#using
  3011. -multiple-nccl-communicators-concurrently>`_ for more details.
  3012. Args:
  3013. group_size (int, optional): The size of each subgroup. If ``None``,
  3014. the default subgroup size is equal to the number of devices on each machine,
  3015. based on the assumption that each machine has exactly the same
  3016. number of devices. Default is ``None``.
  3017. timeout (timedelta, optional): Timeout for operations executed against
  3018. the process group. Default value equals 30 minutes.
  3019. This is applicable for the ``gloo`` backend. For ``nccl``, this is
  3020. applicable only if the environment variable ``NCCL_BLOCKING_WAIT``
  3021. or ``NCCL_ASYNC_ERROR_HANDLING`` is set to 1. When
  3022. ``NCCL_BLOCKING_WAIT`` is set, this is the duration for which the
  3023. process will block and wait for collectives to complete before
  3024. throwing an exception. When ``NCCL_ASYNC_ERROR_HANDLING`` is set,
  3025. this is the duration after which collectives will be aborted
  3026. asynchronously and the process will crash. ``NCCL_BLOCKING_WAIT``
  3027. will provide errors to the user which can be caught and handled,
  3028. but due to its blocking nature, it has a performance overhead. On
  3029. the other hand, ``NCCL_ASYNC_ERROR_HANDLING`` has very little
  3030. performance overhead, but crashes the process on errors. This is
  3031. done since CUDA execution is async and it is no longer safe to
  3032. continue executing user code since failed async NCCL operations
  3033. might result in subsequent CUDA operations running on corrupted
  3034. data. Only one of these two environment variables should be set.
  3035. backend (str or Backend, optional): The backend to use. Depending on
  3036. build-time configurations, valid values are ``gloo`` and ``nccl``.
  3037. By default uses the same backend as the global group. This field
  3038. should be given as a lowercase string (e.g., ``"gloo"``), which can
  3039. also be accessed via :class:`Backend` attributes (e.g.,
  3040. ``Backend.GLOO``). If ``None`` is passed in, the backend
  3041. corresponding to the default process group will be used. Default is
  3042. ``None``.
  3043. pg_options (ProcessGroupOptions, optional): process group options
  3044. specifying what additional options need to be passed in during
  3045. the construction of specific process groups. i.e. for the ``nccl``
  3046. backend, ``is_high_priority_stream`` can be specified so that
  3047. process group can pick up high priority cuda streams.
  3048. Returns:
  3049. The subgroup containing the current rank, and all the subgroups used for cleanup.
  3050. Examples:
  3051. >>> # Create intra-machine subgroups.
  3052. >>> # xdoctest: +SKIP("need process group init")
  3053. >>> cur_subgroup, subgroups = dist.new_subgroups()
  3054. >>> # Allreduce within the machine.
  3055. >>> rank = dist.get_rank()
  3056. >>> tensor = torch.ones(1, device=rank) * rank
  3057. >>> dist.all_reduce(tensor, group=cur_subgroup)
  3058. >>> tensor
  3059. tensor([8]) # Assume 8 is the number of CUDA devices per machine.
  3060. >>> # Cleanup.
  3061. >>> for subgroup in subgroups:
  3062. >>> dist.destroy_process_group(subgroup)
  3063. """
  3064. if not torch.cuda.is_available():
  3065. raise ValueError("Subgroups can only be created when CUDA is available")
  3066. if group_size is None:
  3067. group_size = torch.cuda.device_count()
  3068. world_size = get_world_size()
  3069. if world_size < group_size:
  3070. raise ValueError(f"The arg 'group_size' ({group_size}) must not exceed the world size ({world_size})")
  3071. if world_size % group_size != 0:
  3072. raise ValueError("The world size must be divisible by 'group_size'")
  3073. subgroups = []
  3074. cur_subgroup = None
  3075. for subgroup_id in range(world_size // group_size):
  3076. start_rank = subgroup_id * group_size
  3077. end_rank = start_rank + group_size
  3078. ranks_in_subgroup = list(range(start_rank, end_rank))
  3079. subgroup = new_group(
  3080. ranks=ranks_in_subgroup,
  3081. timeout=timeout,
  3082. backend=backend,
  3083. pg_options=pg_options,
  3084. )
  3085. subgroups.append(subgroup)
  3086. rank = get_rank()
  3087. if rank in ranks_in_subgroup:
  3088. cur_subgroup = subgroup
  3089. logger.info(
  3090. "Rank {} is assigned to subgroup {}".format(rank, ranks_in_subgroup)
  3091. )
  3092. return cur_subgroup, subgroups
  3093. def new_subgroups_by_enumeration(
  3094. ranks_per_subgroup_list,
  3095. timeout=default_pg_timeout,
  3096. backend=None,
  3097. pg_options=None,
  3098. ):
  3099. """
  3100. Creates GPU subgroups by dividing the global world, where the division is specified by
  3101. a nested list of ranks. The subgroups cannot have overlap, and some ranks may not have
  3102. to be in any subgroup.
  3103. This is a convenience API that calls ``new_group`` to generate multiple subgroups.
  3104. It requires that all processes in the main group (i.e. all
  3105. processes that are part of the distributed job) enter this function, even
  3106. if they are not going to be members of the group.
  3107. .. warning::
  3108. Using multiple process groups with the ``NCCL`` backend concurrently
  3109. is not safe and the user should perform explicit synchronization in
  3110. their application to ensure only one process group is used at a time.
  3111. This means collectives from one process group should have completed
  3112. execution on the device (not just enqueued since CUDA execution is
  3113. async) before collectives from another process group are enqueued.
  3114. See `Using multiple NCCL communicators concurrently <https://docs.nvid
  3115. ia.com/deeplearning/nccl/user-guide/docs/usage/communicators.html#using
  3116. -multiple-nccl-communicators-concurrently>`_ for more details.
  3117. Args:
  3118. ranks_per_subgroup_list (list[list[int]]): A nested list of ranks of
  3119. group members.
  3120. timeout (timedelta, optional): Timeout for operations executed against
  3121. the process group. Default value equals 30 minutes.
  3122. This is applicable for the ``gloo`` backend. For ``nccl``, this is
  3123. applicable only if the environment variable ``NCCL_BLOCKING_WAIT``
  3124. or ``NCCL_ASYNC_ERROR_HANDLING`` is set to 1. When
  3125. ``NCCL_BLOCKING_WAIT`` is set, this is the duration for which the
  3126. process will block and wait for collectives to complete before
  3127. throwing an exception. When ``NCCL_ASYNC_ERROR_HANDLING`` is set,
  3128. this is the duration after which collectives will be aborted
  3129. asynchronously and the process will crash. ``NCCL_BLOCKING_WAIT``
  3130. will provide errors to the user which can be caught and handled,
  3131. but due to its blocking nature, it has a performance overhead. On
  3132. the other hand, ``NCCL_ASYNC_ERROR_HANDLING`` has very little
  3133. performance overhead, but crashes the process on errors. This is
  3134. done since CUDA execution is async and it is no longer safe to
  3135. continue executing user code since failed async NCCL operations
  3136. might result in subsequent CUDA operations running on corrupted
  3137. data. Only one of these two environment variables should be set.
  3138. backend (str or Backend, optional): The backend to use. Depending on
  3139. build-time configurations, valid values are ``gloo`` and ``nccl``.
  3140. By default uses the same backend as the global group. This field
  3141. should be given as a lowercase string (e.g., ``"gloo"``), which can
  3142. also be accessed via :class:`Backend` attributes (e.g.,
  3143. ``Backend.GLOO``). If ``None`` is passed in, the backend
  3144. corresponding to the default process group will be used. Default is
  3145. ``None``.
  3146. pg_options (ProcessGroupOptions, optional): process group options
  3147. specifying what additional options need to be passed in during
  3148. the construction of specific process groups. i.e. for the ``nccl``
  3149. backend, ``is_high_priority_stream`` can be specified so that
  3150. process group can pick up high priority cuda streams.
  3151. Returns:
  3152. The subgroup containing the current rank, and all the subgroups used for cleanup.
  3153. Examples:
  3154. >>> # Create two subgroups, where each has 2 processes.
  3155. >>> # xdoctest: +SKIP("need process group init")
  3156. >>> cur_subgroup, subgroups = dist.new_subgroups(ranks=[[0, 2], [1, 3]])
  3157. >>> rank = dist.get_rank()
  3158. >>> tensor = torch.ones(1, device=rank) * rank
  3159. >>> dist.all_reduce(tensor, group=cur_subgroup)
  3160. >>> tensor
  3161. tensor([2]) # Subgroup 0: ranks 0 and 2
  3162. tensor([4]) # Subgroup 1: ranks 1 and 3
  3163. """
  3164. if not torch.cuda.is_available():
  3165. raise ValueError("Subgroups can only be created when CUDA is available")
  3166. if ranks_per_subgroup_list is None or len(ranks_per_subgroup_list) == 0:
  3167. raise ValueError("The arg 'ranks_per_subgroup_list' cannot be empty")
  3168. world_size = get_world_size()
  3169. subgroups = []
  3170. cur_subgroup = None
  3171. # Create a mapping from rank to subgroup to check if there is any subgroup overlap.
  3172. rank_to_ranks_dict = {} # type: ignore[var-annotated]
  3173. for ranks in ranks_per_subgroup_list:
  3174. subgroup = new_group(
  3175. ranks=ranks,
  3176. timeout=timeout,
  3177. backend=backend,
  3178. pg_options=pg_options,
  3179. )
  3180. subgroups.append(subgroup)
  3181. my_rank = get_rank()
  3182. for rank in ranks:
  3183. if rank in rank_to_ranks_dict:
  3184. raise ValueError(
  3185. "Rank {} has appeared in both subgroup {} and {}".format(
  3186. rank, rank_to_ranks_dict[rank], ranks
  3187. )
  3188. )
  3189. rank_to_ranks_dict[rank] = ranks
  3190. if my_rank == rank:
  3191. cur_subgroup = subgroup
  3192. logger.info("Rank {} is assigned to subgroup {}".format(rank, ranks))
  3193. return cur_subgroup, subgroups