api.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307
  1. from contextlib import contextmanager
  2. import torch
  3. import torch.distributed as dist
  4. import torch.nn as nn
  5. from torch.distributed import distributed_c10d
  6. from torch.distributed._shard.sharded_tensor import (
  7. ShardedTensor,
  8. _PartialTensor
  9. )
  10. from .replicated_tensor import ReplicatedTensor
  11. from .sharding_spec import (
  12. ShardingSpec,
  13. ChunkShardingSpec
  14. )
  15. from .sharding_plan import (
  16. ShardingPlan
  17. )
  18. from .sharder import Sharder
  19. def _shard_tensor(
  20. tensor: torch.Tensor, sharding_spec: ShardingSpec, src_rank=0, process_group=None
  21. ) -> ShardedTensor:
  22. """
  23. Given a :class:`torch.Tensor`, it shards that tensor according to the provided
  24. ``sharding_spec``. ``src_rank`` denotes the source rank which would be
  25. used as the ground truth of the data which would be scattered as shards
  26. across the rest of the ranks.
  27. Args:
  28. tensor (:class:`torch.Tensor`): Tensor needs to be sharded.
  29. sharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The specification
  30. describing how to shard the Tensor.
  31. Keyword args:
  32. src_rank (int, optional): The source rank which is used as the ground truth of
  33. the data for the parameter that would be sharded and scattered
  34. across the rest of the ranks.
  35. Default: 0.
  36. process_group (ProcessGroup, optional): The process group to work on. If None,
  37. the default process group will be used.
  38. Returns:
  39. A :class:`ShardedTensor` sharded from the given tensor.
  40. .. warning::
  41. Only :class:`torch.distributed._shard.sharding_spec.ChunkShardingSpec` is
  42. currently supported as the ``sharding_spec``.
  43. """
  44. if not tensor.is_contiguous():
  45. raise ValueError('input tensor is not a contiguous Tensor')
  46. pg = process_group if process_group is not None else distributed_c10d._get_default_group()
  47. world_size = dist.get_world_size(pg)
  48. current_rank = dist.get_rank(pg)
  49. # Validate src_rank and sharding_spec are same across all ranks.
  50. gathered_list = [None] * world_size
  51. dist.all_gather_object(gathered_list, (src_rank, sharding_spec), group=pg)
  52. for idx, entry in enumerate(gathered_list):
  53. if src_rank != entry[0]: # type: ignore[index]
  54. raise ValueError(
  55. f'src_rank={src_rank} on rank: {current_rank} does not ' # type: ignore[index]
  56. f'match with src_rank={entry[0]} on rank: {idx}')
  57. if sharding_spec != entry[1]: # type: ignore[index]
  58. raise ValueError(
  59. f'sharding_spec={sharding_spec} on rank: {current_rank} does not ' # type: ignore[index]
  60. f'match with sharding_spec={entry[1]} on rank: {idx}')
  61. st = sharding_spec.shard(tensor, src_rank=src_rank, process_group=process_group)
  62. return st
  63. def shard_parameter(
  64. module: torch.nn.Module,
  65. param_name: str,
  66. sharding_spec: ShardingSpec,
  67. src_rank=0,
  68. process_group=None):
  69. """
  70. Given a :class:`torch.nn.Module`, a ``param_name`` for a parameter in that
  71. module, it shards that parameter according to the provided
  72. ``sharding_spec``. ``src_rank`` denotes the source rank which would be
  73. used as the ground truth of the data which would be scattered as shards
  74. across the rest of the ranks.
  75. This method replaces ``module.param_name`` with a
  76. :class:`torch.distributed._sharded_tensor.ShardedTensor`
  77. Args:
  78. module (:class:`torch.nn.Module`): Module whose parameter needs to be sharded.
  79. param_name (str): Name of the parameter of ``module`` that needs to be sharded.
  80. sharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The specification
  81. describing how to shard the Tensor.
  82. Keyword args:
  83. src_rank (int, optional): The source rank which is used as the ground truth of
  84. the data for the parameter that would be sharded and scattered
  85. across the rest of the ranks.
  86. Default: 0.
  87. process_group (ProcessGroup, optional): The process group to work on. If None,
  88. the default process group will be used.
  89. .. warning::
  90. Only :class:`torch.distributed._shard.sharding_spec.ChunkShardingSpec` is
  91. currently supported as the ``sharding_spec``.
  92. """
  93. # Perform some validation first.
  94. if not hasattr(module, param_name):
  95. raise AttributeError(f'{module._get_name()} has no attribute `{param_name}`')
  96. tensor = getattr(module, param_name)
  97. if not isinstance(tensor, torch.Tensor):
  98. raise ValueError(f'Expected {type(module).__name__}.{param_name} to be a Tensor, but found {type(tensor).__name__}')
  99. if not tensor.is_contiguous():
  100. raise ValueError(f'param: {param_name} is not a contiguous Tensor')
  101. st = _shard_tensor(tensor, sharding_spec, src_rank, process_group)
  102. # Replace param with ShardedTensor.
  103. module.register_parameter(param_name, nn.Parameter(st))
  104. def _replicate_tensor(tensor: torch.Tensor, process_group=None) -> ReplicatedTensor:
  105. """
  106. Given a :class:`torch.Tensor`, mark it as a ReplicatedTensor where all
  107. ranks have the same value.
  108. Args:
  109. tensor (:class:`torch.Tensor`): the tensor to be marked as replicated.
  110. Keyword args:
  111. process_group (ProcessGroup, optional): The process group to replicate on.
  112. If None, the default process group will be used.
  113. Returns:
  114. A :class:`ReplicatedTensor` from the given tensor.
  115. """
  116. return ReplicatedTensor(tensor, process_group=process_group)
  117. # Tracks the current process group in the load context manager.
  118. _CURRENT_PROCESS_GROUP = None
  119. @contextmanager
  120. def load_with_process_group(process_group):
  121. """
  122. Context manager to set the process group with which to load a ShardedTensor/ReplicatedTensor.
  123. """
  124. global _CURRENT_PROCESS_GROUP
  125. if _CURRENT_PROCESS_GROUP is not None:
  126. raise RuntimeError(
  127. 'ProcessGroup already set by previous "load_with_process_group" '
  128. 'context manager')
  129. _CURRENT_PROCESS_GROUP = process_group
  130. try:
  131. yield process_group
  132. finally:
  133. _CURRENT_PROCESS_GROUP = None
  134. def _get_current_process_group():
  135. """
  136. Retrieves the current process group set by ``load_with_process_group``.
  137. If not set, it just returns the default group.
  138. """
  139. global _CURRENT_PROCESS_GROUP
  140. if _CURRENT_PROCESS_GROUP is None:
  141. return distributed_c10d._get_default_group()
  142. else:
  143. return _CURRENT_PROCESS_GROUP
  144. def _reshard_output(
  145. module: torch.nn.Module,
  146. resharding_spec: ShardingSpec) -> torch.nn.Module:
  147. """
  148. Hook a module with output resharding in the forward pass according
  149. to the given ``resharding_spec``.
  150. Args:
  151. module (:class:`torch.nn.Module`): Module whose output needs to be resharded.
  152. resharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`):
  153. The specification describing how the output of the module will be resharded.
  154. Returns:
  155. A :class:`torch.nn.Module` object with reshard API hooked.
  156. """
  157. def hook_func(_module, _input, output):
  158. if isinstance(output, (ShardedTensor, _PartialTensor)):
  159. return output.reshard(resharding_spec)
  160. return output
  161. module.register_forward_hook(hook_func)
  162. return module
  163. def _collect_local_shard(module: torch.nn.Module) -> torch.nn.Module:
  164. """
  165. Hook a module with local shards collection in the forward pass.
  166. This API is typically used to convert a sharded representation back to data parallel
  167. representation. In particular, it returns the local tensor for this Shard. If the
  168. size along the sharding dimension for the local tensor is 1, this dimension is removed
  169. from the final result. For example a [4, 16] ShardedTensor across 4 ranks is typically
  170. a local Tensor of size [16] across each rank and not [1, 16] across each rank.
  171. Args:
  172. module (:class:`torch.nn.Module`): Module whose output is ShardedTensor and the
  173. local tensor value needs to be returned.
  174. Returns:
  175. A :class:`torch.nn.Module` object with collection API hooked.
  176. """
  177. def hook_func(_module, _input, output):
  178. if isinstance(output, ShardedTensor):
  179. local_tensor = output.local_tensor()
  180. # Squeeze the # of dimensions manually, only applicable to ChunkShardingSpec
  181. sharding_spec = output._sharding_spec
  182. if isinstance(sharding_spec, ChunkShardingSpec) \
  183. and local_tensor.size(sharding_spec.dim) == 1: # type: ignore[attr-defined, arg-type]
  184. local_tensor = local_tensor.squeeze(
  185. output._sharding_spec.dim # type: ignore[attr-defined]
  186. )
  187. return local_tensor
  188. module.register_forward_hook(hook_func)
  189. return module
  190. def shard_module(
  191. module: nn.Module,
  192. plan: ShardingPlan,
  193. src_rank=0,
  194. process_group=None
  195. ):
  196. """
  197. Shards a given module according to the provided sharding `plan`. This method
  198. first shards all the parameters according to the given sharding `plan`. Then if
  199. `output_plan` and `return_local_tensor` are specified in the sharding `plan`, it
  200. will tag the output of modules according `output_plan`, convert the module's
  201. output back to data parallel according to `return_local_tensor`.
  202. Needs to be called on all ranks in an SPMD fashion.
  203. Args:
  204. module (:class:`torch.nn.Module`): The module to apply sharding to
  205. plan (:class:`torch.distributed._shard.sharding_plan.ShardingPlan`):
  206. The ShardingPlan which specified param name to ShardingSpec to apply to
  207. each parameter.
  208. Keyword args:
  209. src_rank (int, optional): The source rank which is used as the ground truth of
  210. the data for the module that would be sharded and scattered across the rest
  211. of the ranks.
  212. Default: 0.
  213. process_group (ProcessGroup, optional): The process group to work on. If None,
  214. the default process group will be used.
  215. """
  216. # record Sharder paths for sanity check on the plan to ensure items in the plan
  217. # does not conflict with the submodule tree that the Sharder is working with
  218. sharder_paths = []
  219. for name, spec in plan.plan.items():
  220. if isinstance(spec, Sharder):
  221. sharder_paths.append(name)
  222. # shard the parameter according to the ShardingPlan
  223. for name, spec in plan.plan.items():
  224. if isinstance(spec, ShardingSpec):
  225. # if found a sharding spec, try to shard the parameter
  226. module_path, _, param_name = name.rpartition(".")
  227. for sharder_path in sharder_paths:
  228. if module_path.startswith(sharder_path):
  229. raise RuntimeError(f"ShardingPlan is in-valid, trying to shard a parameter: {name},"
  230. f" but there's already a Sharder entry for module {sharder_path},"
  231. f" parameter sharding should not conflict with the submodule tree"
  232. f" that a Sharder is working with!")
  233. mod = module.get_submodule(module_path)
  234. shard_parameter(
  235. mod,
  236. param_name,
  237. spec,
  238. src_rank=src_rank,
  239. process_group=process_group
  240. )
  241. elif isinstance(spec, Sharder):
  242. parent_mod_path, _, mod_name = name.rpartition(".")
  243. if name == "":
  244. raise KeyError("Module path must not be empty for custom sharder!")
  245. mod = module.get_submodule(name)
  246. parent_mod = module.get_submodule(parent_mod_path)
  247. sharded_mod = spec.shard(mod)
  248. # swap this submodule with the sharded module
  249. parent_mod.mod_name = sharded_mod
  250. else:
  251. raise TypeError(f"Only `ShardingSpec` and `Sharder` are supported to shard '{name}'")
  252. # reshard output if there's an entry in `reshard_output` for this module
  253. if plan.output_plan is not None:
  254. for module_path, output_spec in plan.output_plan.items():
  255. if isinstance(output_spec, ShardingSpec):
  256. mod = module.get_submodule(module_path)
  257. _reshard_output(mod, output_spec)
  258. else:
  259. raise TypeError(f"Only `ShardingSpec` is supported as output_plan for '{module_path}'")
  260. # convert the output back to data parallel for the modules appears in
  261. # `return_local_tensor` of the plan, we will call `_collect_local_shard`
  262. # to collect the local tensor for output of modules
  263. if plan.return_local_tensor is not None:
  264. for module_path in plan.return_local_tensor:
  265. mod = module.get_submodule(module_path)
  266. _collect_local_shard(mod)