api.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425
  1. # Copyright (c) Meta Platforms, Inc. and affiliates
  2. from typing import Dict, Union
  3. import torch
  4. import torch.nn as nn
  5. from torch.distributed._tensor import (
  6. DeviceMesh,
  7. DTensor,
  8. distribute_module,
  9. distribute_tensor,
  10. Replicate,
  11. Shard,
  12. )
  13. from torch.distributed._tensor.sharding_prop import _CachingPropagator
  14. from torch.distributed.tensor.parallel._utils import _create_1d_device_mesh
  15. from torch.distributed.tensor.parallel.multihead_attention_tp import (
  16. TensorParallelMultiheadAttention,
  17. )
  18. from torch.distributed.tensor.parallel.style import (
  19. ColwiseParallel,
  20. PairwiseParallel,
  21. ParallelStyle,
  22. RowwiseParallel,
  23. )
  24. __all__ = [
  25. "parallelize_module",
  26. ]
  27. # switch the DTensor propagator to use the caching propagator to speed up
  28. # the TP eager execution time.
  29. DTensor._propagator = _CachingPropagator(DTensor._propagator.op_to_rules)
  30. def parallelize_module( # type: ignore[return]
  31. module: nn.Module,
  32. device_mesh: DeviceMesh,
  33. parallelize_plan: Union[ParallelStyle, Dict[str, ParallelStyle]],
  34. tp_mesh_dim: int = 0,
  35. ) -> nn.Module:
  36. """
  37. The API to apply Tensor Parallelism (TP) in PyTorch. We parallelize module
  38. or sub_modules based on a parallelize_plan. The parallelize_plan contains
  39. :class:`ParallelStyle`, which indicates how user wants the module or sub_module
  40. to be parallelized.
  41. User can also specify different parallel style per module fully qualifed name (FQN).
  42. The API supports 2D parallelism natively by accepting an n-dimension device_mesh
  43. and users just need to specify the dimension where we perform tensor parallelism on.
  44. Args:
  45. module (:class:`nn.Module`):
  46. Module to be parallelized.
  47. device_mesh (:class:`DeviceMesh`):
  48. Object which describes the mesh topology
  49. of devices for the DTensor.
  50. parallelize_plan (Union[:class:`ParallelStyle`, Dict[str, :class:`ParallelStyle`]]):
  51. The plan used to parallelize the module. It can be either a
  52. :class:`ParallelStyle` object which contains how
  53. we prepare input/output for Tensor Parallelism or it can be a
  54. dict of module FQN and its corresponding :class:`ParallelStyle` object.
  55. tp_mesh_dim (int):
  56. The dimension of ``device_mesh`` where we perform
  57. Tensor Parallelism on.
  58. Return:
  59. A :class:`nn.Module` object parallelized.
  60. Example::
  61. >>> # xdoctest: +SKIP("distributed")
  62. >>> from torch.distributed._tensor.parallel import parallelize_module, PairwiseParallel
  63. >>>
  64. >>> # Define the module.
  65. >>> m = Model(...)
  66. >>> m = parallelize_module(m, PairwiseParallel())
  67. >>>
  68. .. warning::
  69. ``PairwiseParallel`` comes with constraints for now. If you need finer
  70. granularity, you need to pass in a dict of module FQN and parallel style instead.
  71. """
  72. if device_mesh.ndim > 1:
  73. device_mesh = _create_1d_device_mesh(device_mesh, tp_mesh_dim)
  74. if isinstance(parallelize_plan, ParallelStyle):
  75. # RowwiseParallel or ColwiseParallel
  76. if isinstance(parallelize_plan, (ColwiseParallel, RowwiseParallel)):
  77. return _parallelize_linear(module, device_mesh, parallelize_plan)
  78. # PairwiseParallel
  79. if _is_mha_for_pairwise_parallel(module):
  80. return _parallelize_multihead_attn(module, device_mesh)
  81. elif _is_mlp_for_pairwise_parallel(module):
  82. return _parallelize_mlp(module, device_mesh)
  83. else:
  84. for n, m in module.named_children():
  85. module.register_module(
  86. n, parallelize_module(m, device_mesh, parallelize_plan)
  87. )
  88. return module
  89. elif isinstance(parallelize_plan, dict):
  90. for module_path, parallelize_style in parallelize_plan.items():
  91. sub_module = module.get_submodule(module_path)
  92. parent_module = module
  93. if "." in module_path:
  94. parent_module_path = ".".join(module_path.split(".")[:-1])
  95. parent_module = module.get_submodule(parent_module_path)
  96. module_path = module_path.split(".")[-1]
  97. parent_module.register_module( # type: ignore[call-arg] # pyre-ignore[20]
  98. module_path,
  99. parallelize_module( # type: ignore[arg-type]
  100. sub_module, device_mesh, parallelize_style # type: ignore[arg-type] # pyre-ignore[6]
  101. ),
  102. )
  103. return module
  104. else:
  105. raise RuntimeError( # pyre-ignore[7]
  106. "Expect Union[ParallelStyle, Dict[str, ParallelStyle]] for"
  107. f" parallelize_plan, {type(parallelize_plan)} found!"
  108. )
  109. def _is_mha_for_pairwise_parallel(module: nn.Module) -> bool:
  110. """
  111. Check whether the mha module is the one can be handled for Pairwise parallel.
  112. Args:
  113. module (:class:`nn.Module`):
  114. Module to be checked.
  115. Return:
  116. A boolean object which specifies whether the module is MHA supported by Pairwise parallel or not.
  117. """
  118. return isinstance(module, (TensorParallelMultiheadAttention, nn.MultiheadAttention))
  119. def _is_mlp_for_pairwise_parallel(module: nn.Module) -> bool:
  120. """
  121. Traverse through all the immediate children of the given module and count the
  122. number of Linear module. If the number is more than one, we return True.
  123. Args:
  124. module (:class:`nn.Module`):
  125. Module to be traversed and counted.
  126. Return:
  127. A bool which specifies whether the module is MLP supported or not.
  128. .. warning::
  129. The traversal is not recursive for now.
  130. """
  131. linear_submodules = list(
  132. filter(lambda x: isinstance(x, nn.Linear), module.children())
  133. )
  134. return len(linear_submodules) > 1
  135. def _rowwise_parallelize_linear_fn(
  136. name: str,
  137. module: nn.Module,
  138. device_mesh: DeviceMesh,
  139. ) -> None:
  140. """
  141. This function parallelizes the input :class:`nn.Linear` module in
  142. :class:`RowwiseParallel` style.
  143. Args:
  144. name (str):
  145. Name of the input module.
  146. module (:class:`nn.Module`):
  147. The :class:`nn.Linear` module to be parallelized.
  148. device_mesh (:class:`DeviceMesh`):
  149. Object which describes the mesh topology of devices.
  150. Returns:
  151. None
  152. """
  153. for name, param in module.named_parameters():
  154. dist_spec = (
  155. [Shard(1)] if name == "weight" else [Replicate()] # type: ignore[list-item]
  156. )
  157. dist_param = torch.nn.Parameter(
  158. distribute_tensor(param, device_mesh, dist_spec)
  159. )
  160. module.register_parameter(name, dist_param)
  161. def _colwise_parallelize_linear_fn(
  162. name: str,
  163. module: nn.Module,
  164. device_mesh: DeviceMesh,
  165. ) -> None:
  166. """
  167. This function parallelizes the input :class:`nn.Linear` module in
  168. :class:`ColwiseParallel` style.
  169. Args:
  170. name (str):
  171. Name of the input module.
  172. module (:class:`nn.Module`):
  173. The :class:`nn.Linear` module to be parallelized.
  174. device_mesh (:class:`DeviceMesh`):
  175. Object which describes the mesh topology of devices.
  176. Returns:
  177. None
  178. """
  179. for name, param in module.named_parameters():
  180. dist_param = torch.nn.Parameter(
  181. distribute_tensor(param, device_mesh, [Shard(0)])
  182. )
  183. module.register_parameter(name, dist_param)
  184. def _parallelize_linear(
  185. module: nn.Module,
  186. device_mesh: DeviceMesh,
  187. parallel_style: ParallelStyle = ColwiseParallel(),
  188. tp_mesh_dim: int = 0,
  189. ) -> nn.Module:
  190. """
  191. This function requires that the input module be an object
  192. of :class:`nn.Linear`.
  193. The module will be parallelized over a 1-d :class:`DeviceMesh`
  194. based on the :class:`ParallelStyle`.
  195. Args:
  196. module (:class:`nn.Module`):
  197. The module to be parallelized.
  198. device_mesh (:class:`DeviceMesh`):
  199. Object which describes the mesh topology of devices for the :class:`DTensor`.
  200. If the mesh is more than 1-dimensional, we will use the mesh dim of
  201. `device_mesh` specified by `tp_mesh_dim`.
  202. parallel_style (:class:`ParallelStyle`, optional):
  203. The object which describes how the :class:`nn.Linear` module
  204. should be distributed over :class:`DeviceMesh` and how the input
  205. and output should be prepared for Tensor Parallelism.
  206. :class:`RowwiseStyle`: weight is sharded on dim 1 and bias is
  207. replicate.
  208. :class:`ColwiseStyle`: weight and bias are both sharded on dim 0.
  209. Default: :class:`ColwiseParallel`
  210. tp_mesh_dim (int):
  211. The dimension of :class:`DeviceMesh` on which we
  212. perform Tensor Parallelism.
  213. Default: 0
  214. Return:
  215. A :class:`nn.Module` object parallelized.
  216. """
  217. if not isinstance(module, nn.Linear):
  218. raise RuntimeError(
  219. f"Expect a torch.nn.Linear module but received {type(module)}!"
  220. )
  221. if not isinstance(parallel_style, ParallelStyle):
  222. raise RuntimeError(
  223. "Expect a ParallelStyle object but received" f" {type(parallel_style)}!"
  224. )
  225. if device_mesh.ndim > 1:
  226. device_mesh = _create_1d_device_mesh(device_mesh, tp_mesh_dim)
  227. if isinstance(parallel_style, RowwiseParallel):
  228. distribute_module(
  229. module,
  230. device_mesh,
  231. _rowwise_parallelize_linear_fn,
  232. input_fn=parallel_style._prepare_input, # type: ignore[arg-type, misc] # pyre-ignore[6]
  233. output_fn=parallel_style._prepare_output, # type: ignore[arg-type, misc] # pyre-ignore[6]
  234. )
  235. elif isinstance(parallel_style, ColwiseParallel):
  236. distribute_module(
  237. module,
  238. device_mesh,
  239. _colwise_parallelize_linear_fn,
  240. input_fn=parallel_style._prepare_input, # type: ignore[arg-type, misc] # pyre-ignore[6]
  241. output_fn=parallel_style._prepare_output, # type: ignore[arg-type, misc] # pyre-ignore[6]
  242. )
  243. else:
  244. raise RuntimeError(f"{type(parallel_style)} is not supported!")
  245. return module
  246. def _parallelize_multihead_attn(
  247. module: nn.Module,
  248. device_mesh: DeviceMesh,
  249. parallel_style: ParallelStyle = PairwiseParallel(),
  250. tp_mesh_dim: int = 0,
  251. ) -> nn.Module:
  252. """
  253. This function assumes the input module is a sequence of nn.Linear
  254. and we parallelize the module based on the given parallel style.
  255. We don't change the FQN of each sub-module and replace each parameter
  256. in place.
  257. Args:
  258. module (:class:`nn.Module`):
  259. Module to be parallelized.
  260. device_mesh (:class:`DeviceMesh`):
  261. Object which describes the mesh topology of devices.
  262. parallel_style (:class:`ParallelStyle`):
  263. Object which contains how we prepare input/output
  264. for Tensor Parallelism.
  265. tp_mesh_dim (int):
  266. The dimension of `device_mesh` where we perform
  267. Tensor Parallelism on.
  268. Return:
  269. A :class:`nn.Module` object parallelized.
  270. .. warning::
  271. We only support ``PairwiseParallel`` right now.
  272. """
  273. if not isinstance(parallel_style, PairwiseParallel):
  274. raise NotImplementedError(
  275. "Only support PairwiseParallel for Multihead Attention" " parallelization."
  276. )
  277. if device_mesh.ndim > 1:
  278. device_mesh = _create_1d_device_mesh(device_mesh, tp_mesh_dim)
  279. if isinstance(module, nn.MultiheadAttention):
  280. tp_multi_head_attention = TensorParallelMultiheadAttention(
  281. module.embed_dim,
  282. module.num_heads,
  283. device=torch.device(device_mesh.device_type),
  284. tp_size=device_mesh.size(tp_mesh_dim),
  285. add_bias_kv=module.bias_k is not None,
  286. )
  287. tp_multi_head_attention.copy(module)
  288. module = tp_multi_head_attention
  289. if isinstance(module, TensorParallelMultiheadAttention): # shard TPMA
  290. for n, m in module.named_children():
  291. if n == "qkv":
  292. # Col-wise Parallelize the qkv layer.
  293. distribute_module(
  294. m,
  295. device_mesh,
  296. _colwise_parallelize_linear_fn,
  297. input_fn=parallel_style._prepare_input, # type: ignore[arg-type, misc] # pyre-ignore[6]
  298. )
  299. elif n == "proj":
  300. # Row-wise Parallelize the proj layer
  301. distribute_module(
  302. m,
  303. device_mesh,
  304. _rowwise_parallelize_linear_fn,
  305. output_fn=parallel_style._prepare_output, # type: ignore[arg-type, misc] # pyre-ignore[6]
  306. )
  307. return module
  308. def _parallelize_mlp(
  309. module: nn.Module,
  310. device_mesh: DeviceMesh,
  311. parallel_style: ParallelStyle = PairwiseParallel(),
  312. tp_mesh_dim: int = 0,
  313. ) -> nn.Module:
  314. """
  315. This function assumes the input module is a sequence of nn.Linear
  316. and we parallelize the module based on the given parallel style.
  317. We don't change the FQN of each sub-module and replace each parameter
  318. in place.
  319. Args:
  320. module (:class:`nn.Module`):
  321. Module to be parallelized.
  322. device_mesh (:class:`DeviceMesh`):
  323. Object which describes the mesh topology of devices.
  324. parallel_style (:class:`ParallelStyle`):
  325. Object which contains how we prepare input/output
  326. for Tensor Parallelism.
  327. tp_mesh_dim (int):
  328. The dimension of `device_mesh` where we perform
  329. Tensor Parallelism on.
  330. Return:
  331. A :class:`nn.Module` object parallelized.
  332. .. warning::
  333. We only support ``PairwiseParallel`` right now.
  334. """
  335. if not isinstance(parallel_style, PairwiseParallel):
  336. raise NotImplementedError(
  337. "Only support PairwiseParallel for MLP parallelization."
  338. )
  339. if not _is_mlp_for_pairwise_parallel(module):
  340. raise RuntimeError("More than one nn.Linear needed for a MLP.")
  341. if device_mesh.ndim > 1:
  342. device_mesh = _create_1d_device_mesh(device_mesh, tp_mesh_dim)
  343. linear_submodules = list(
  344. filter(lambda x: isinstance(x, nn.Linear), module.children())
  345. )
  346. mlp_last_even_layer = (len(linear_submodules) // 2) * 2
  347. for i in range(mlp_last_even_layer):
  348. m = linear_submodules[i]
  349. if i % 2 == 0:
  350. # Col-wise Parallelize the linear layer
  351. distribute_module(
  352. m,
  353. device_mesh,
  354. _colwise_parallelize_linear_fn,
  355. input_fn=parallel_style._prepare_input # type: ignore[arg-type, misc] # pyre-ignore[6]
  356. if i == 0
  357. else None,
  358. )
  359. else:
  360. # Row-wise Parallelize the linear layer
  361. distribute_module(
  362. m,
  363. device_mesh,
  364. _rowwise_parallelize_linear_fn,
  365. output_fn=parallel_style._prepare_output # type: ignore[arg-type, misc] # pyre-ignore[6]
  366. if i == (mlp_last_even_layer - 1)
  367. else None,
  368. )
  369. return module