backend_registry.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399
  1. __all__ = ["init_backend", "backend_registered", "construct_rpc_backend_options", "register_backend", "BackendType", "BackendValue"]
  2. import collections
  3. import enum
  4. from typing import cast, Dict, List, Set, Tuple
  5. import torch
  6. import torch.distributed as dist
  7. from ._utils import _group_membership_management, _update_group_membership
  8. from . import api
  9. from . import constants as rpc_constants
  10. __all__ = ["backend_registered", "register_backend", "construct_rpc_backend_options", "init_backend",
  11. "BackendValue", "BackendType"]
  12. BackendValue = collections.namedtuple(
  13. "BackendValue", ["construct_rpc_backend_options_handler", "init_backend_handler"]
  14. )
  15. def _backend_type_repr(self):
  16. return "BackendType." + self.name
  17. _backend_type_doc = """
  18. An enum class of available backends.
  19. PyTorch ships with a builtin ``BackendType.TENSORPIPE`` backend.
  20. Additional ones can be registered using the
  21. :func:`~torch.distributed.rpc.backend_registry.register_backend` function.
  22. """
  23. # Create an enum type, `BackendType`, with empty members.
  24. # Can't handle Function Enum API (mypy bug #9079)
  25. BackendType = enum.Enum(value="BackendType", names=dict()) # type: ignore[misc]
  26. # Unable to assign a function a method (mypy bug #2427)
  27. BackendType.__repr__ = _backend_type_repr # type: ignore[assignment]
  28. if BackendType.__doc__:
  29. BackendType.__doc__ = _backend_type_doc
  30. def backend_registered(backend_name):
  31. """
  32. Checks if backend_name is registered as an RPC backend.
  33. Args:
  34. backend_name (str): string to identify the RPC backend.
  35. Returns:
  36. True if the backend has been registered with ``register_backend``, else
  37. False.
  38. """
  39. return backend_name in BackendType.__members__.keys()
  40. def register_backend(
  41. backend_name, construct_rpc_backend_options_handler, init_backend_handler
  42. ):
  43. """Registers a new RPC backend.
  44. Args:
  45. backend_name (str): backend string to identify the handler.
  46. construct_rpc_backend_options_handler (function):
  47. Handler that is invoked when
  48. rpc_backend.construct_rpc_backend_options(**dict) is called.
  49. init_backend_handler (function): Handler that is invoked when the
  50. `_init_rpc_backend()` function is called with a backend.
  51. This returns the agent.
  52. """
  53. global BackendType
  54. if backend_registered(backend_name):
  55. raise RuntimeError("RPC backend {}: already registered".format(backend_name))
  56. # Create a new enum type, `BackendType`, with extended members.
  57. existing_enum_dict = {member.name: member.value for member in BackendType}
  58. extended_enum_dict = dict(
  59. {
  60. backend_name: BackendValue(
  61. construct_rpc_backend_options_handler=construct_rpc_backend_options_handler,
  62. init_backend_handler=init_backend_handler,
  63. )
  64. },
  65. **existing_enum_dict
  66. )
  67. # Can't handle Function Enum API (mypy bug #9079)
  68. BackendType = enum.Enum(value="BackendType", names=extended_enum_dict) # type: ignore[misc]
  69. # Unable to assign a function a method (mypy bug #2427)
  70. BackendType.__repr__ = _backend_type_repr # type: ignore[assignment]
  71. if BackendType.__doc__:
  72. BackendType.__doc__ = _backend_type_doc
  73. return BackendType[backend_name]
  74. def construct_rpc_backend_options(
  75. backend,
  76. rpc_timeout=rpc_constants.DEFAULT_RPC_TIMEOUT_SEC,
  77. init_method=rpc_constants.DEFAULT_INIT_METHOD,
  78. **kwargs
  79. ):
  80. return backend.value.construct_rpc_backend_options_handler(
  81. rpc_timeout, init_method, **kwargs
  82. )
  83. def init_backend(backend, *args, **kwargs):
  84. return backend.value.init_backend_handler(*args, **kwargs)
  85. def _init_process_group(store, rank, world_size):
  86. # Initialize ProcessGroup.
  87. process_group_timeout = rpc_constants.DEFAULT_PROCESS_GROUP_TIMEOUT
  88. # We're using a bunch of private APIs here since `new_group` requires the
  89. # default group to be initialized.
  90. group = dist.ProcessGroupGloo(store, rank, world_size, process_group_timeout)
  91. assert group is not None, "Failed to initialize default ProcessGroup."
  92. if (rank != -1) and (rank != group.rank()):
  93. raise RuntimeError(
  94. "rank argument {} doesn't match pg rank {}".format(rank, group.rank())
  95. )
  96. if (world_size != -1) and (world_size != group.size()):
  97. raise RuntimeError(
  98. "world_size argument {} doesn't match pg size {}".format(
  99. world_size, group.size()
  100. )
  101. )
  102. return group
  103. def _tensorpipe_construct_rpc_backend_options_handler(
  104. rpc_timeout,
  105. init_method,
  106. num_worker_threads=rpc_constants.DEFAULT_NUM_WORKER_THREADS,
  107. _transports=None,
  108. _channels=None,
  109. **kwargs
  110. ):
  111. from . import TensorPipeRpcBackendOptions
  112. return TensorPipeRpcBackendOptions(
  113. rpc_timeout=rpc_timeout,
  114. init_method=init_method,
  115. num_worker_threads=num_worker_threads,
  116. _transports=_transports,
  117. _channels=_channels,
  118. )
  119. def _tensorpipe_validate_devices(devices, device_count):
  120. return all(
  121. d.type == "cpu" or (d.type == "cuda" and 0 <= d.index < device_count)
  122. for d in devices
  123. )
  124. # detect if any worker has invalid device_map configurations, and return
  125. # reverse device maps
  126. def _tensorpipe_exchange_and_check_all_device_maps(
  127. my_name, my_device_count, my_device_maps, my_devices, group
  128. ):
  129. gathered: List[Tuple[
  130. str, int, Dict[str, Dict[torch.device, torch.device]], List[torch.device]
  131. ]] = [("", 0, {}, []) for _ in range(group.size())]
  132. dist.all_gather_object(
  133. gathered, (my_name, my_device_count, my_device_maps, my_devices), group
  134. )
  135. all_names = [name for name, _, _, _ in gathered]
  136. all_device_counts = {name: count for name, count, _, _ in gathered}
  137. all_device_maps = {name: map_ for name, _, map_, _ in gathered}
  138. all_devices = {name: devices for name, _, _, devices in gathered}
  139. _validate_device_maps(all_names, all_device_counts, all_device_maps, all_devices)
  140. # passed all checked, construct reverse mapping and get list of devices handled by this agent
  141. reverse_device_maps = _create_reverse_mapping(my_name, all_names, all_device_maps)
  142. my_devices = _create_device_list(my_devices, my_device_maps, reverse_device_maps)
  143. return reverse_device_maps, my_devices
  144. def _validate_device_maps(all_names, all_device_counts, all_device_maps, all_devices, is_static_group=True):
  145. for node in all_names:
  146. devices = all_devices[node]
  147. if len(set(devices)) != len(devices):
  148. raise ValueError(
  149. f"Node {node} has duplicated devices\n"
  150. f"devices = {devices}"
  151. )
  152. if not _tensorpipe_validate_devices(devices, all_device_counts[node]):
  153. raise ValueError(
  154. f"Node {node} has devices with invalid indices\n"
  155. f"devices = {devices}\n"
  156. f"device count = {all_device_counts[node]}"
  157. )
  158. for source_node in all_names:
  159. # For dynamic group (non-static) do not check the target node name since it may not have joined yet
  160. if is_static_group and not set(all_device_maps[source_node].keys()).issubset(all_names):
  161. raise ValueError(
  162. f"Node {source_node} has invalid target node names in its device maps\n"
  163. f"device maps = {all_device_maps[source_node].keys()}\n"
  164. f"node names = {all_names}"
  165. )
  166. for target_node, map_ in all_device_maps[source_node].items():
  167. if len(set(map_.values())) != len(map_):
  168. raise ValueError(
  169. f"Node {source_node} has duplicated target devices "
  170. f"in its device map for {target_node}\n"
  171. f"device map = {map_}"
  172. )
  173. if all_devices[source_node]:
  174. if not set(map_.keys()).issubset(all_devices[source_node]):
  175. raise ValueError(
  176. f"Node {source_node} has unexpected source devices "
  177. f"in its device map for {target_node}\n"
  178. f"device map = {map_}\n"
  179. f"devices = {all_devices[source_node]}"
  180. )
  181. elif not _tensorpipe_validate_devices(
  182. map_.keys(), all_device_counts[source_node]
  183. ):
  184. raise ValueError(
  185. f"Node {source_node} has source devices with invalid indices "
  186. f"in its device map for {target_node}\n"
  187. f"device map = {map_}\n"
  188. f"device count = {all_device_counts[source_node]}"
  189. )
  190. if all_devices.get(target_node, []):
  191. if not set(map_.values()).issubset(all_devices[target_node]):
  192. raise ValueError(
  193. f"Node {source_node} has unexpected target devices "
  194. f"in its device map for {target_node}\n"
  195. f"device map = {map_}\n"
  196. f"devices = {all_devices[target_node]}"
  197. )
  198. elif target_node in all_device_counts and not _tensorpipe_validate_devices(
  199. map_.values(), all_device_counts[target_node]
  200. ):
  201. raise ValueError(
  202. f"Node {source_node} has target devices with invalid indices "
  203. f"in its device map for {target_node}\n"
  204. f"device map = {map_}\n"
  205. f"device count = {all_device_counts[target_node]}"
  206. )
  207. def _create_device_list(my_devices, my_device_maps, reverse_device_maps):
  208. if not my_devices:
  209. devices_set: Set[torch.device] = set()
  210. for _, map_ in my_device_maps.items():
  211. devices_set.update(map_.keys())
  212. for _, map_ in reverse_device_maps.items():
  213. devices_set.update(map_.keys())
  214. devices_set.discard(torch.device("cpu"))
  215. my_devices = list(devices_set)
  216. my_devices = sorted(my_devices, key=lambda d: d.index)
  217. return my_devices
  218. def _create_reverse_mapping(my_name, all_names, all_device_maps):
  219. reverse_device_maps: Dict[str, Dict[torch.device, torch.device]] = {}
  220. for node in all_names:
  221. if my_name in all_device_maps[node]:
  222. reverse_device_maps[node] = {
  223. v: k for k, v in all_device_maps[node][my_name].items()
  224. }
  225. return reverse_device_maps
  226. def _get_device_infos():
  227. from . import TensorPipeAgent
  228. agent = cast(TensorPipeAgent, api._get_current_rpc_agent())
  229. opts = agent._get_backend_options()
  230. device_count = torch.cuda.device_count()
  231. if torch.cuda.is_available() and opts.devices:
  232. torch.cuda.init()
  233. return device_count, opts.device_maps, opts.devices
  234. def _set_devices_and_reverse_device_map(agent):
  235. from . import TensorPipeAgent
  236. agent = cast(TensorPipeAgent, agent)
  237. # Group state is retrieved from local agent
  238. # On initialization, tensorpipe agent retrieves information from all existing workers, so group state is valid
  239. my_worker_info = agent.get_worker_info()
  240. my_name = my_worker_info.name
  241. all_worker_infos = agent.get_worker_infos()
  242. # One round to get device_maps of all workers and construct reverse device maps
  243. all_device_counts, all_device_maps, all_devices, all_names = {}, {}, {}, []
  244. for worker_info in all_worker_infos:
  245. worker_name = worker_info.name
  246. if worker_name != my_name:
  247. # TODO: make async?
  248. device_count, device_map, devices = api.rpc_sync(worker_name, _get_device_infos)
  249. else:
  250. opts = agent._get_backend_options()
  251. device_count, device_map, devices = torch.cuda.device_count(), opts.device_maps, opts.devices
  252. all_device_counts[worker_name] = device_count
  253. all_device_maps[worker_name] = device_map
  254. all_devices[worker_name] = devices
  255. all_names.append(worker_name)
  256. _validate_device_maps(all_names, all_device_counts, all_device_maps, all_devices, is_static_group=False)
  257. reverse_device_maps = _create_reverse_mapping(my_name, all_names, all_device_maps)
  258. # Perform RPC call to all workers, including itself, to include newly joined worker information and device maps
  259. for worker_name in all_names:
  260. # Set device list for each worker
  261. all_devices[worker_name] = _create_device_list(all_devices[worker_name], all_device_maps[worker_name], reverse_device_maps)
  262. api.rpc_sync(worker_name, _update_group_membership,
  263. args=(my_worker_info, all_devices[worker_name], reverse_device_maps, True))
  264. def _tensorpipe_init_backend_handler(store, name, rank, world_size, rpc_backend_options):
  265. from . import TensorPipeAgent
  266. from . import TensorPipeRpcBackendOptions
  267. if not isinstance(store, dist.Store):
  268. raise TypeError("`store` must be a c10d::Store. {}".format(store))
  269. if not isinstance(
  270. rpc_backend_options, TensorPipeRpcBackendOptions
  271. ):
  272. raise TypeError(
  273. "`rpc_backend_options` must be a `TensorPipeRpcBackendOptions`. {}".format(
  274. rpc_backend_options
  275. )
  276. )
  277. device_count = torch.cuda.device_count()
  278. is_static_group = True if world_size else False
  279. # world_size is specified so this is a static group (ranks cannot join and leave)
  280. if is_static_group:
  281. # The agent's join method is required to behave like a barrier and perform
  282. # collective operations, for which it relies on a process group, instead of
  283. # re-implementing this on top of RPCs.
  284. group = _init_process_group(store, rank, world_size)
  285. reverse_device_maps, devices = _tensorpipe_exchange_and_check_all_device_maps(
  286. name,
  287. device_count,
  288. rpc_backend_options.device_maps,
  289. rpc_backend_options.devices,
  290. group,
  291. )
  292. if torch.cuda.is_available() and devices:
  293. # It's necessary to initialize PyTorch CUDA states here (e.g.,
  294. # CUDACachingAllocator). If this is missing, we could hit errors like
  295. # "allocator not initialized", because other processes might send
  296. # CUDA-related RPC request to this process before user code in this
  297. # process initializes its PyTorch CUDA states.
  298. torch.cuda.init()
  299. # TODO: add try-except and destroy _agent in all processes if any fails.
  300. agent = TensorPipeAgent(
  301. store,
  302. name,
  303. rank,
  304. world_size,
  305. rpc_backend_options,
  306. reverse_device_maps,
  307. devices,
  308. )
  309. api._init_rpc_states(agent)
  310. # Run one dummy round of RPC to initialize channels/transports. Without
  311. # this, it's easy to hit timeout in rpc.shutdown() if there is no other RPC
  312. # on that process before rpc.shutdown(), as the agent initialization can
  313. # take longer than 5s.
  314. api._all_gather(None, timeout=rpc_backend_options.rpc_timeout)
  315. # Need a barrier here to make sure no peers leave before the rank0 finishes
  316. # _all_gather
  317. group.barrier().wait()
  318. return agent
  319. # initialization for dynamic rpc (ranks can join and leave)
  320. else:
  321. with _group_membership_management(store, name, True):
  322. # Construct TPAgent with empty reverse_device_map and devices
  323. # these properties will be updated after initialization
  324. agent = TensorPipeAgent(
  325. store,
  326. name,
  327. rank,
  328. world_size,
  329. rpc_backend_options,
  330. {},
  331. [],
  332. )
  333. api._init_rpc_states(agent)
  334. try:
  335. # Notify all workers in group this rank has joined and set devices and reverse_device_map
  336. # This is a synchronous operation that completes once all existing ranks are updated
  337. _set_devices_and_reverse_device_map(agent)
  338. pass
  339. except Exception:
  340. api.shutdown()
  341. raise
  342. return agent
  343. register_backend(
  344. "TENSORPIPE",
  345. _tensorpipe_construct_rpc_backend_options_handler,
  346. _tensorpipe_init_backend_handler,
  347. )