device_mesh.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552
  1. # Copyright (c) Meta Platforms, Inc. and affiliates
  2. import os
  3. import warnings
  4. from typing import List, Optional, Sequence, TypeVar, Union
  5. import torch
  6. from torch.distributed.distributed_c10d import (
  7. _get_default_group,
  8. all_gather,
  9. all_reduce,
  10. all_to_all,
  11. broadcast,
  12. get_global_rank,
  13. get_rank,
  14. get_world_size,
  15. GroupMember,
  16. init_process_group,
  17. is_initialized,
  18. new_group,
  19. ProcessGroup,
  20. reduce_scatter,
  21. ReduceOp,
  22. scatter,
  23. Work,
  24. )
  25. _global_device_mesh: Optional["DeviceMesh"] = None
  26. def get_global_device_mesh() -> "DeviceMesh":
  27. global _global_device_mesh
  28. assert _global_device_mesh is not None, "Could not get a default device mesh!"
  29. return _global_device_mesh
  30. def set_global_device_mesh(mesh: Optional["DeviceMesh"]) -> None:
  31. global _global_device_mesh
  32. _global_device_mesh = mesh
  33. # We want a type for "can be passed to torch.as_tensor()";
  34. # this is a recursive sequence type, which isn't fully supported
  35. # yet in python. This construct simulates that up to depth 7.
  36. T = TypeVar("T")
  37. _L = Union[T, Sequence[T]]
  38. NDIntList = _L[_L[_L[_L[_L[_L[_L[int]]]]]]]
  39. MeshExprT = Union[
  40. torch.Tensor,
  41. NDIntList,
  42. ]
  43. class DeviceMesh:
  44. """
  45. DeviceMesh represents a mesh of devices, where layout of devices could be
  46. represented as a n-d dimension array, and each value of the n-d dimensional
  47. array is the global id of the default process group ranks.
  48. DeviceMesh could be used to describe the layout of devices across the cluster,
  49. and serves as a proxy for communication among the device lists within the cluster.
  50. We use the default ProcessGroup in this DeviceMesh class to implement proper
  51. communications. Note that we also add collective wrappers in this class. This is
  52. used to decouple detailed communication backend with the underlying
  53. DTensor implementation.
  54. DeviceMesh can be used as a context manager.
  55. Args:
  56. device_type (str): device type of the mesh. Currently supports: cpu, cuda.
  57. mesh (ndarray): could be a multi-dimension array or an integer tensor that
  58. describes the layout of devices, the ids are global ids of the
  59. default process group.
  60. dim_groups (List[ProcessGroup], optional): The ProcessGroup used per mesh
  61. dimension.
  62. Returns:
  63. A :class:`DeviceMesh` object
  64. Example (2 host with 4 GPUs each):
  65. ```
  66. # The following program runs on each process/rank in SPMD manner.
  67. # initialized default world
  68. torch.distributed.init_process_group(backend="nccl", world_size=8)
  69. # initialize device mesh as (2, 4) to represent the topology
  70. # of cross-host(dim 0), and within-host (dim 1)
  71. mesh = DeviceMesh(device_type="cuda",
  72. mesh=[
  73. [0, 1, 2, 3],
  74. [4, 5, 6, 7]
  75. ])
  76. ```
  77. A reduction over the first dimension of mesh will reduce across
  78. columns (0, 4), .. and (3, 7), a reduction over the second dimension
  79. of mesh reduces across rows (0, 1, 2, 3) and (4, 5, 6, 7)
  80. """
  81. device_type: str
  82. mesh: torch.Tensor
  83. _backend: str
  84. def __init__(
  85. self,
  86. device_type: str,
  87. mesh: MeshExprT,
  88. dim_groups: Optional[List[ProcessGroup]] = None,
  89. ) -> None:
  90. self.device_type = device_type
  91. self.mesh = (
  92. mesh.detach()
  93. if isinstance(mesh, torch.Tensor)
  94. else torch.tensor(mesh, dtype=torch.int)
  95. )
  96. default_pg = self._get_or_create_default_group()
  97. self._backend = default_pg._get_backend_name()
  98. # TODO: if user want to pass pg_options, offer a way to do it
  99. # check default pg backend, should support device_type
  100. if device_type == "cpu":
  101. assert (
  102. self._backend == "gloo" or self._backend == "threaded"
  103. ), f"ProcessGroup backend: {self._backend} not supporting CPU!"
  104. elif device_type == "cuda":
  105. if self._backend == "gloo":
  106. warnings.warn(
  107. "We recommend using nccl backend for cuda device type, gloo backend might only have partial support!"
  108. )
  109. assert self._backend == "gloo" or self._backend == "nccl" or self._backend == "threaded"
  110. else:
  111. raise RuntimeError(
  112. f"DeviceMesh only support cpu or cuda device type, but got {device_type}"
  113. )
  114. world_size = get_world_size()
  115. if self.mesh.numel() > world_size:
  116. raise RuntimeError(
  117. f"Mesh should not be bigger than default world size, but found {self.mesh.numel()} ranks!"
  118. )
  119. unique_mesh_values = self.mesh.unique(sorted=True)
  120. if unique_mesh_values.numel() != self.mesh.numel():
  121. raise RuntimeError(
  122. f"DeviceMesh cannot have duplicate values, but found {self.mesh.tolist()}"
  123. )
  124. # coordinates of this rank on the mesh
  125. rank_coords = (self.mesh == get_rank()).nonzero()
  126. assert rank_coords.size(0) in (0, 1)
  127. self._coordinate_on_dim: Optional[List[int]] = (
  128. rank_coords[0].tolist() if rank_coords.size(0) > 0 else None
  129. )
  130. # groups created by dimension, each dimension should have exact
  131. # one valid process group per rank
  132. self._dim_groups: List[ProcessGroup] = []
  133. if dim_groups is not None:
  134. # if user hand creating dimension based groups
  135. # we just take it and use it for communication
  136. if not isinstance(dim_groups, list):
  137. raise RuntimeError(
  138. "dim_groups expected to be Optional[List[ProcessGroup]]"
  139. )
  140. for group in dim_groups:
  141. if not isinstance(group, ProcessGroup):
  142. raise RuntimeError(
  143. f"found object in dim_groups that is not a ProcessGroup: {group}"
  144. )
  145. if self.get_rank() in self.mesh:
  146. if len(dim_groups) != self.mesh.ndim:
  147. raise RuntimeError(
  148. f"length of dim_groups ({len(dim_groups)}) expected to be equal to mesh.ndim ({self.mesh.ndim})"
  149. )
  150. else:
  151. if len(dim_groups) != 0:
  152. raise RuntimeError(
  153. f"length of dim_groups ({len(dim_groups)}) expected to be equal to 0 on rank {self.get_rank()} "
  154. f"for mesh {self.mesh}"
  155. )
  156. self._dim_groups = dim_groups
  157. return
  158. if self.mesh.ndim == 1 and unique_mesh_values[-1] == world_size - 1:
  159. # if the mesh is the same as world_pg, we just append the default
  160. # pg to the first dim goups, as new_group cannot have the exact
  161. # same ranks as world
  162. self._dim_groups.append(default_pg)
  163. else:
  164. # create sub pgs base on the mesh argument specified
  165. # handle multi-dim mesh, create subgroups by
  166. # looping over the pg_ranks_by_dim for each dim
  167. for dim in range(self.mesh.ndim):
  168. # swap the current dim to the last dim
  169. # then reshape to flatten out other dims
  170. pg_ranks_by_dim = self.mesh.swapdims(-1, dim).reshape(
  171. -1, self.mesh.size(dim)
  172. )
  173. # multi-dim mesh, create subgroups by
  174. # looping over the pg_ranks for each dim
  175. # and append the groups
  176. for dim_mesh in pg_ranks_by_dim:
  177. subgroup_ranks = dim_mesh.tolist()
  178. # call new_group regardless of the current rank in the
  179. # pg or not, it's required that all ranks participate
  180. # in subgroup construction
  181. new_subgroup = new_group(
  182. ranks=subgroup_ranks, backend=self._backend
  183. )
  184. # only add to dim_groups if the current rank in the subgroup
  185. if self.get_rank() in subgroup_ranks:
  186. if len(self._dim_groups) > dim:
  187. raise RuntimeError(
  188. f"Each device mesh dimension should get only one process group, but got {self.get_rank} "
  189. f"in {subgroup_ranks}!"
  190. )
  191. self._dim_groups.append(new_subgroup)
  192. def _get_or_create_default_group(self):
  193. if not is_initialized():
  194. # TODO: we will support mesh on a subset of WORLD in future
  195. world_size = int(os.getenv("WORLD_SIZE", 1))
  196. if self.mesh.numel() < world_size:
  197. raise RuntimeError(
  198. "DeviceMesh must include every process in WORLD, "
  199. f"but WORLD_SIZE({world_size}) != mesh size({self.mesh.numel()})"
  200. )
  201. unique_mesh_values = self.mesh.unique(sorted=True)
  202. if unique_mesh_values.numel() != self.mesh.numel():
  203. raise RuntimeError(
  204. f"DeviceMesh cannot have duplicate values, but found {self.mesh.tolist()}"
  205. )
  206. # ranks in mesh must start from 0
  207. if unique_mesh_values[0] != 0:
  208. raise RuntimeError(
  209. "DeviceMesh ranks must start from 0, "
  210. f"but found min rank = {unique_mesh_values[0]}"
  211. )
  212. # mesh must be contiguous (i.e. from 0 to N-1)
  213. if 2 * unique_mesh_values.sum().item() != world_size * (world_size - 1):
  214. raise RuntimeError(
  215. f"DeviceMesh should have all ranks of WORLD, but found {self.mesh.tolist()}"
  216. )
  217. _backend = "gloo" if self.device_type == "cpu" else "nccl"
  218. init_process_group(backend=_backend)
  219. return _get_default_group()
  220. def __enter__(self) -> "DeviceMesh":
  221. # set global device_mesh to this instance
  222. set_global_device_mesh(self)
  223. return self
  224. # pyre-fixme[2]: Parameter must be annotated.
  225. def __exit__(self, exc_type, exc_value, exc_traceback) -> None:
  226. # unset global device mesh
  227. set_global_device_mesh(None)
  228. def __repr__(self) -> str:
  229. return f"DeviceMesh:({self.mesh.tolist()})"
  230. def __hash__(self):
  231. return hash((self.mesh, id(self)))
  232. def __eq__(self, other: object) -> bool:
  233. if not isinstance(other, DeviceMesh):
  234. return False
  235. if id(self) == id(other):
  236. return True
  237. return self.mesh.equal(other.mesh)
  238. def get_dim_groups(self) -> List[ProcessGroup]:
  239. return self._dim_groups
  240. # pyre-fixme[3]: Return type must be annotated.
  241. def size(self, dim: int = 0):
  242. return self.mesh.size(dim)
  243. @property
  244. def ndim(self) -> int:
  245. return self.mesh.ndim
  246. def backend(self) -> str:
  247. return self._backend
  248. def get_rank(self) -> int:
  249. return get_rank()
  250. def get_coordinate_on_dim(self, dim: int) -> Optional[int]:
  251. """
  252. Return the relative index of this rank relative to a given
  253. dimension of the mesh. If this rank is not part of the mesh, return None.
  254. """
  255. return self._coordinate_on_dim[dim] if self._coordinate_on_dim else None
  256. def scatter(
  257. self,
  258. output: torch.Tensor,
  259. scatter_list: List[torch.Tensor],
  260. mesh_dim: int = 0,
  261. async_op: bool = False,
  262. ) -> Optional[Work]:
  263. """
  264. scatter a list of tensors to a device mesh dimension. We by default
  265. use the first rank of the mesh dimension as the source of truth, i.e
  266. for a 2d mesh [[0, 1], [2, 3]], if we scatter on mesh_dim = 1, we will
  267. scatter the tensor list on rank 0 to rank 0/1, and tensor list on rank
  268. 2 to rank 2/3.
  269. Args:
  270. output (torch.Tensor): the tensor to receive the scattered list.
  271. scatter_list (List[torch.Tensor]): the tensor list to be scattered.
  272. mesh_dim (int, optional): indicate which mesh dimension we want
  273. to scatter on, we by default choose the first rank on the
  274. mesh dimension as source of truth.
  275. Returns:
  276. A :class:`Work` object
  277. """
  278. # TODO: Ideally we should use the meta tensor way
  279. # (to register a meta kernel for the collective op)
  280. # so that it would avoid the communication. Need to
  281. # remove the check below once that is done.
  282. if output.is_meta:
  283. return None
  284. dim_group = self._dim_groups[mesh_dim]
  285. # src need to be global rank
  286. src_for_dim = 0
  287. if dim_group is not GroupMember.WORLD:
  288. src_for_dim = get_global_rank(dim_group, 0)
  289. if src_for_dim == get_rank():
  290. fut = scatter(
  291. output,
  292. scatter_list=scatter_list,
  293. src=src_for_dim,
  294. group=dim_group,
  295. async_op=async_op,
  296. )
  297. else:
  298. fut = scatter(
  299. output,
  300. scatter_list=None,
  301. src=src_for_dim,
  302. group=dim_group,
  303. async_op=async_op,
  304. )
  305. return fut
  306. def broadcast(
  307. self,
  308. tensor: torch.Tensor,
  309. mesh_dim: int = 0,
  310. async_op: bool = False,
  311. ) -> Optional[Work]:
  312. """
  313. broadcast the tensor to a device mesh dimension. We by default
  314. use the first rank of the mesh dimension as the source of truth, i.e
  315. for a 2d mesh [[0, 1], [2, 3]], if we broadcast on mesh_dim = 1, we will
  316. broadcast the tensor on rank 0 to rank 0/1, and tensor on rank 2
  317. to rank 2/3.
  318. Args:
  319. tensor (torch.Tensor): tensor to broadcast.
  320. mesh_dim (int, optional): indicate which mesh dimension we want
  321. to scatter on, we by default choose the first rank on the
  322. mesh dimension as source of truth.
  323. Returns:
  324. A :class:`Work` object
  325. """
  326. # TODO: Ideally we should use the meta tensor way
  327. # (to register a meta kernel for the collective op)
  328. # so that it would avoid the communication. Need to
  329. # remove the check below once that is done.
  330. if tensor.is_meta:
  331. return None
  332. dim_group = self._dim_groups[mesh_dim]
  333. # src need to be global rank
  334. src_for_dim = 0
  335. if dim_group is not GroupMember.WORLD:
  336. src_for_dim = get_global_rank(dim_group, 0)
  337. return broadcast(tensor, src=src_for_dim, group=dim_group, async_op=async_op)
  338. def all_gather(
  339. self,
  340. tensor_list: List[torch.Tensor],
  341. tensor: torch.Tensor,
  342. mesh_dim: int = 0,
  343. async_op: bool = False,
  344. ) -> Optional[Work]:
  345. """
  346. all_gather the tensor on each rank to the tensor_list on a
  347. device mesh dimension.
  348. Args:
  349. tensor_list (List[torch.Tensor]): The gathered tensor list.
  350. tensor (torch.Tensor): tensor to be gathered on each rank.
  351. mesh_dim (int, optional): indicate which mesh dimension we want
  352. to scatter on, we by default choose the first rank on the
  353. mesh dimension as source of truth.
  354. Returns:
  355. A :class:`Work` object
  356. """
  357. dim_group = self._dim_groups[mesh_dim]
  358. return all_gather(tensor_list, tensor, group=dim_group, async_op=async_op)
  359. def all_reduce(
  360. self,
  361. tensor: torch.Tensor,
  362. op: ReduceOp = ReduceOp.SUM, # type: ignore[assignment]
  363. mesh_dim: int = 0,
  364. async_op: bool = False,
  365. ) -> Optional[Work]:
  366. """
  367. all_reduce the tensor on each rank on a device mesh dimension, and
  368. return an output tensor on each rank after all_reduce.
  369. Args:
  370. tensor (torch.Tensor): tensor to be all_reduced on each rank.
  371. op (:class:`torch.distributed.distributed_c10d.ReduceOp, optional):
  372. the reduction op of all_reduce (i.e. ReduceOp.SUM)
  373. mesh_dim (int, optional): indicate which mesh dimension we want
  374. to reduce on.
  375. Returns:
  376. A :class:`Work` object
  377. """
  378. dim_group = self._dim_groups[mesh_dim]
  379. return all_reduce(tensor, op=op, group=dim_group, async_op=async_op)
  380. def reduce_scatter(
  381. self,
  382. output: torch.Tensor,
  383. input_list: List[torch.Tensor],
  384. op: ReduceOp = ReduceOp.SUM, # type: ignore[assignment]
  385. mesh_dim: int = 0,
  386. async_op: bool = False,
  387. ) -> Optional[Work]:
  388. """
  389. reduce the input_list on each rank on a device mesh dimension, and scatter
  390. the results to the output tensor on each rank.
  391. Args:
  392. output (torch.Tensor): tensor to receive the scattered result.
  393. input_list (List[torch.Tensor]): tensor list to be reduced and scattered
  394. and scattered on each rank.
  395. op (:class:`torch.distributed.distributed_c10d.ReduceOp, optional):
  396. the reduction op of reduce_scatter (i.e. ReduceOp.SUM)
  397. mesh_dim (int, optional): indicate which mesh dimension we want
  398. to scatter on.
  399. Returns:
  400. A :class:`Work` object
  401. """
  402. if self._backend == "nccl":
  403. dim_group = self._dim_groups[mesh_dim]
  404. fut = reduce_scatter(
  405. output, input_list, op=op, group=dim_group, async_op=async_op
  406. )
  407. elif self._backend == "gloo":
  408. # it's gloo, which does not have reduce_scatter
  409. # we have to do all_reduce + scatter
  410. warnings.warn(
  411. "ProcessGroupGloo does not support reduce_scatter, falling back with all reduce!"
  412. )
  413. my_coordinate = self.get_coordinate_on_dim(mesh_dim)
  414. # TODO: what should happen if rank is not in the mesh?
  415. # see issue https://github.com/pytorch/tau/pull/492
  416. assert (
  417. my_coordinate is not None
  418. ), "Rank if not part of mesh" # TODO: figure out behavior here
  419. fut = None
  420. flattened_list = []
  421. offset_list = []
  422. offset = 0
  423. for input in input_list:
  424. offset_list.append(offset)
  425. offset += input.numel()
  426. flattened_list.append(input.flatten())
  427. # all reduce since gloo does not support reduce_scatter
  428. flat_tensor = torch.cat(flattened_list).clone(
  429. memory_format=torch.contiguous_format
  430. )
  431. fut = self.all_reduce(
  432. flat_tensor, op=op, mesh_dim=mesh_dim, async_op=async_op
  433. )
  434. # scatter the tensor
  435. output_offset = offset_list[my_coordinate]
  436. output.copy_(
  437. flat_tensor[output_offset : output_offset + output.numel()].view(
  438. output.shape
  439. )
  440. )
  441. else:
  442. raise RuntimeError(
  443. f"backend {self._backend} does not support reduce_scatter!"
  444. )
  445. return fut
  446. # TODO: test uneven split on GLOO and NCCL
  447. def all_to_all(
  448. self,
  449. output_tensor_list: List[torch.Tensor],
  450. input_tensor_list: List[torch.Tensor],
  451. mesh_dim: int = 0,
  452. async_op: bool = False,
  453. ) -> Optional[Work]:
  454. dim_group = self._dim_groups[mesh_dim]
  455. work = None
  456. # no direct dist.all_to_all support on 'gloo' so we manually do scatters
  457. if self.backend() == "gloo":
  458. # TODO: pull the handle of uneven case in #492
  459. dim_group_size = get_world_size(dim_group)
  460. for i in range(dim_group_size):
  461. # src need to be global rank
  462. src_for_dim = i
  463. if dim_group is not GroupMember.WORLD:
  464. src_for_dim = get_global_rank(dim_group, i)
  465. work = scatter(
  466. output_tensor_list[i],
  467. input_tensor_list if self.get_rank() == src_for_dim else [],
  468. group=dim_group,
  469. src=src_for_dim,
  470. async_op=async_op,
  471. )
  472. elif self.backend() == "nccl":
  473. work = all_to_all(
  474. output_tensor_list,
  475. input_tensor_list,
  476. dim_group,
  477. async_op=async_op,
  478. )
  479. else:
  480. raise RuntimeError(
  481. f"DeviceMesh does not support all-to-all collective operations on {self.backend()} backend."
  482. )
  483. return work