fsdp.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351
  1. import copy
  2. import warnings
  3. from typing import cast, List, NamedTuple, Optional, Tuple
  4. import torch
  5. import torch.distributed as dist
  6. import torch.distributed._shard.sharding_spec as shard_spec
  7. import torch.distributed.distributed_c10d as c10d
  8. from torch.distributed.fsdp._common_utils import _set_fsdp_flattened
  9. from torch.distributed._shard.sharded_tensor import (
  10. Shard,
  11. ShardedTensor,
  12. ShardedTensorMetadata,
  13. TensorProperties,
  14. )
  15. from torch.distributed._shard.sharding_spec import ShardMetadata
  16. from torch.distributed._shard.sharding_spec.chunk_sharding_spec import ChunkShardingSpec
  17. from torch.distributed._tensor import (
  18. DeviceMesh,
  19. DTensor as DistributedTensor,
  20. Shard as DShard,
  21. )
  22. from torch.distributed._tensor.placement_types import Placement
  23. from torch.distributed.fsdp._shard_utils import _create_chunk_sharded_tensor
  24. from torch.distributed.remote_device import _remote_device
  25. __all__ = ["enable_2d_with_fsdp"]
  26. def enable_2d_with_fsdp() -> bool:
  27. """
  28. The API registers the extension which is needed for Tensor Parallelism (TP)
  29. to work with FullyShardedDataParallel (FSDP). We first parallelize parameters
  30. within one module or sub_modules based on a parallelize_plan and will let FSDP
  31. reshard the local tensor of distributed parameter which is essentially a DTensor.
  32. Return:
  33. A `bool` indicated whether extension registration succeeds or not.
  34. """
  35. try:
  36. from torch.distributed.fsdp._fsdp_extensions import (
  37. _set_fsdp_extensions,
  38. FSDPExtensions,
  39. )
  40. class DTensorExtensions(FSDPExtensions):
  41. def pre_flatten_transform(
  42. self,
  43. tensor: torch.Tensor,
  44. ) -> Tuple[torch.Tensor, Optional[_STShardingInfo]]:
  45. return _flatten_tensor(tensor)
  46. def post_unflatten_transform(
  47. self, tensor: torch.Tensor, param_extension: _STShardingInfo
  48. ) -> torch.Tensor:
  49. return _unflatten_tensor(tensor, param_extension)
  50. def chunk_tensor(
  51. self,
  52. tensor: torch.Tensor,
  53. rank: int,
  54. world_size: int,
  55. num_devices_per_node: int,
  56. pg: dist.ProcessGroup,
  57. ) -> torch.Tensor:
  58. return _chunk_tensor(tensor, rank, world_size, num_devices_per_node, pg)
  59. def pre_load_state_dict_transform(
  60. self,
  61. tensor: torch.Tensor,
  62. ) -> Tuple[torch.Tensor, List[Shard]]:
  63. return _pre_load_state_dict(tensor)
  64. _set_fsdp_extensions(DTensorExtensions())
  65. return True
  66. except BaseException as e:
  67. warnings.warn(
  68. "PyTorch doesn't have TensorFlattener extension point available"
  69. "2D parallelism won't work with FSDP"
  70. f"exception: {e}"
  71. )
  72. return False
  73. class _STShardingInfo(NamedTuple):
  74. """:class:`ShardedTensor` sharding information."""
  75. sharding_spec: Optional[shard_spec.ShardingSpec]
  76. global_size: Optional[torch.Size]
  77. process_group: Optional[c10d.ProcessGroup]
  78. device_mesh: Optional[DeviceMesh]
  79. placements: Optional[List[Placement]]
  80. def _get_box(tensor: DistributedTensor) -> Tuple[torch.Size, torch.Size]:
  81. device_mesh = tensor.device_mesh
  82. assert device_mesh.ndim == 1, "Only 1D DeviceMeshes currently handled"
  83. placement = tensor.placements[0]
  84. offsets = [0] * len(tensor.size())
  85. num_chunks = device_mesh.size(dim=0)
  86. if tensor.placements[0].is_shard():
  87. shard_dim = cast(DShard, placement).dim
  88. chunk_size = tensor.size(shard_dim) // num_chunks
  89. offsets[shard_dim] = chunk_size
  90. return (torch.Size(offsets), tensor._local_tensor.size())
  91. def _get_box_for(tensor: DistributedTensor, idx: int) -> Tuple[torch.Size, torch.Size]:
  92. offsets, size = _get_box(tensor)
  93. return (torch.Size([val * idx for val in offsets]), size)
  94. def _get_local_box(tensor: DistributedTensor) -> Tuple[torch.Size, torch.Size]:
  95. device_mesh = tensor.device_mesh
  96. dim_0_coord = device_mesh.get_coordinate_on_dim(0)
  97. assert dim_0_coord is not None
  98. return _get_box_for(tensor, dim_0_coord)
  99. def _create_shard_md_from_dt(dt: DistributedTensor, current_rank: int) -> ShardMetadata:
  100. mesh = dt.device_mesh
  101. assert mesh.ndim == 1, "Only 1D DeviceMeshes currently handled"
  102. offsets, sizes = _get_local_box(dt)
  103. return ShardMetadata(
  104. shard_offsets=list(offsets),
  105. shard_sizes=list(sizes),
  106. placement=f"rank:{current_rank}/{dt._local_tensor.device}",
  107. )
  108. def _create_sharded_tensor_md_from_dt(
  109. dt: DistributedTensor, dt_pg: c10d.ProcessGroup
  110. ) -> ShardedTensorMetadata:
  111. # This is where it gets tricky, we have to produce a ShardedTensor that has full coverage
  112. # and yet has only one valid shard for the current rank.
  113. shards_md = []
  114. my_rank = dist.get_rank(dt_pg)
  115. scapegoat_rank = 0 if my_rank > 0 else 1
  116. if dt.placements[0].is_shard():
  117. shard_count = dt_pg.size()
  118. else:
  119. shard_count = 1
  120. for i in range(shard_count):
  121. offsets, sizes = _get_box_for(dt, i)
  122. shards_md.append(
  123. ShardMetadata(
  124. shard_offsets=list(offsets),
  125. shard_sizes=list(sizes),
  126. placement=(
  127. f"rank:{scapegoat_rank if i > 0 else my_rank}/{dt._local_tensor.device}"
  128. ),
  129. )
  130. )
  131. return ShardedTensorMetadata(
  132. shards_metadata=shards_md,
  133. size=dt.size(),
  134. tensor_properties=TensorProperties(
  135. dtype=dt.dtype,
  136. layout=dt.layout,
  137. requires_grad=dt.requires_grad,
  138. # ignore memory_format and pin_memory as those are not supported by DT
  139. ),
  140. )
  141. def _get_dt_pg(dt: DistributedTensor) -> c10d.ProcessGroup:
  142. mesh = dt.device_mesh
  143. assert mesh.ndim == 1, "Only 1D DeviceMeshes currently handled"
  144. return mesh.get_dim_groups()[0]
  145. def _rewrite_spec_if_needed(
  146. spec: shard_spec.ShardingSpec, tensor: torch.Tensor, rank: int
  147. ) -> shard_spec.ShardingSpec:
  148. """
  149. Rewrite ``spec`` to match the device of ``tensor``.
  150. FSDP.sharded_optim_state_dict sneakly ships optimizer state to CPU so if the original ShardingSpec
  151. produces CUDA metadata, ST construction bombs.
  152. """
  153. if not isinstance(spec, ChunkShardingSpec):
  154. return spec
  155. # let's see if we need
  156. rewrite = False
  157. for p in spec.placements:
  158. p = cast(_remote_device, p)
  159. if p.rank() == rank and p.device() != tensor.device:
  160. rewrite = True
  161. break
  162. if rewrite:
  163. spec = copy.deepcopy(spec)
  164. for i, placement in enumerate(spec.placements):
  165. placement = cast(_remote_device, placement)
  166. if placement.rank() == rank and placement.device() != tensor.device:
  167. spec.placements[i] = _remote_device(f"rank:{rank}/{tensor.device}")
  168. return spec
  169. def _flatten_tensor(
  170. tensor: torch.Tensor,
  171. ) -> Tuple[torch.Tensor, Optional[_STShardingInfo]]:
  172. if type(tensor) is ShardedTensor:
  173. return tensor.local_tensor(), _STShardingInfo(
  174. tensor.sharding_spec(),
  175. tensor.size(),
  176. tensor._process_group,
  177. None,
  178. None,
  179. )
  180. elif type(tensor) is DistributedTensor:
  181. tensor._local_tensor.requires_grad_()
  182. return tensor._local_tensor, _STShardingInfo(
  183. None,
  184. None,
  185. None,
  186. tensor.device_mesh,
  187. list(tensor.placements),
  188. )
  189. return tensor, None
  190. def _unflatten_tensor(
  191. tensor: torch.Tensor, sharding_info: _STShardingInfo
  192. ) -> torch.Tensor:
  193. result: torch.Tensor
  194. if sharding_info.sharding_spec is not None:
  195. assert sharding_info.global_size is not None
  196. result = ShardedTensor._init_from_local_tensor(
  197. tensor,
  198. _rewrite_spec_if_needed(
  199. sharding_info.sharding_spec,
  200. tensor,
  201. dist.get_rank(sharding_info.process_group),
  202. ),
  203. sharding_info.global_size,
  204. process_group=cast(dist.ProcessGroup, sharding_info.process_group),
  205. )
  206. else:
  207. result = DistributedTensor.from_local(
  208. tensor,
  209. device_mesh=sharding_info.device_mesh,
  210. placements=sharding_info.placements,
  211. run_check=False,
  212. )
  213. _set_fsdp_flattened(result)
  214. return result
  215. def _chunk_tensor(
  216. tensor: torch.Tensor,
  217. rank: int,
  218. world_size: int,
  219. num_devices_per_node: int,
  220. pg: dist.ProcessGroup,
  221. ) -> torch.Tensor:
  222. if type(tensor) is ShardedTensor:
  223. assert len(tensor.local_shards()) == 1
  224. inner_param = tensor.local_tensor()
  225. inner_st = _create_chunk_sharded_tensor(
  226. inner_param,
  227. rank,
  228. world_size,
  229. num_devices_per_node,
  230. pg,
  231. )
  232. outer_local_shard = tensor.local_shards()[0]
  233. shards: List[Shard] = [
  234. Shard(inner_st, copy.deepcopy(outer_local_shard.metadata))
  235. ]
  236. st_meta = copy.deepcopy(tensor.metadata())
  237. st_meta.tensor_properties.requires_grad = False
  238. st_outer = ShardedTensor._init_from_local_shards_and_global_metadata(
  239. shards,
  240. sharded_tensor_metadata=st_meta,
  241. process_group=tensor._process_group,
  242. init_rrefs=False,
  243. )
  244. return st_outer
  245. elif type(tensor) is DistributedTensor:
  246. device_mesh = tensor.device_mesh
  247. assert device_mesh.ndim == 1, "Only 1D DeviceMeshes currently handled"
  248. inner_param = tensor._local_tensor
  249. inner_st = _create_chunk_sharded_tensor(
  250. inner_param,
  251. rank,
  252. world_size,
  253. torch.cuda.device_count(),
  254. pg,
  255. )
  256. dt_pg = _get_dt_pg(tensor)
  257. # We do this differently here, we create a ST with no local shards then patch it
  258. shards = [
  259. Shard(inner_st, _create_shard_md_from_dt(tensor, dist.get_rank(dt_pg)))
  260. ]
  261. st_meta = _create_sharded_tensor_md_from_dt(tensor, dt_pg)
  262. st_meta.tensor_properties.requires_grad = False
  263. st_outer = ShardedTensor._init_from_local_shards_and_global_metadata(
  264. shards,
  265. sharded_tensor_metadata=st_meta,
  266. process_group=dt_pg,
  267. init_rrefs=False,
  268. )
  269. return st_outer
  270. else:
  271. return _create_chunk_sharded_tensor(
  272. tensor,
  273. rank,
  274. world_size,
  275. num_devices_per_node,
  276. pg,
  277. )
  278. def _pre_load_state_dict(
  279. tensor: torch.Tensor,
  280. ) -> Tuple[torch.Tensor, List[Shard]]:
  281. shards = cast(ShardedTensor, tensor).local_shards()
  282. if len(shards) == 1 and type(shards[0].tensor) is ShardedTensor:
  283. inner_tensor = shards[0].tensor
  284. shards = inner_tensor.local_shards() # pyre-ignore[16]
  285. tensor = inner_tensor
  286. return (tensor, shards if len(shards) > 0 else [])