style.py 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242
  1. # Copyright (c) Meta Platforms, Inc. and affiliates
  2. from abc import ABC, abstractmethod
  3. from typing import Optional, Union
  4. import torch
  5. from torch.distributed._tensor import DeviceMesh, DTensor, Replicate, Shard
  6. from torch.distributed.tensor.parallel._utils import (
  7. _prepare_input_validate,
  8. _prepare_output_validate,
  9. _PrepareInputType,
  10. _PrepareOutputType,
  11. )
  12. __all__ = [
  13. "ParallelStyle",
  14. "RowwiseParallel",
  15. "ColwiseParallel",
  16. "PairwiseParallel",
  17. "make_input_replicate_1d",
  18. "make_input_shard_1d",
  19. "make_input_shard_1d_last_dim",
  20. "make_output_replicate_1d",
  21. "make_output_tensor",
  22. "make_output_shard_1d",
  23. ]
  24. class ParallelStyle(ABC):
  25. """
  26. The parallel style user wants the module or submodule to be parallelized.
  27. Users can extend this class to build their own parallel style with customized input/output preparations.
  28. """
  29. _prepare_input: _PrepareInputType
  30. _prepare_output: _PrepareOutputType
  31. @abstractmethod
  32. def __init__(self, _prepare_input, _prepare_output) -> None:
  33. self._prepare_input = _prepare_input # type: ignore[assignment, misc]
  34. self._prepare_output = _prepare_output # type: ignore[assignment, misc]
  35. class PairwiseParallel(ParallelStyle):
  36. """
  37. PairwiseParallel concatenate colwise and rowwise styles as a fixed
  38. pair like what Megatron-LM(https://arxiv.org/abs/1909.08053) is doing.
  39. We assume both input and output needs to a replicate DTensor.
  40. .. warning::
  41. PairwiseParallel only supports ``nn.Multihead Attention``,
  42. ``nn.Transformer`` or even-number-layer MLP for now.
  43. """
  44. def __init__(self) -> None:
  45. super().__init__(make_input_replicate_1d, make_output_tensor)
  46. class RowwiseParallel(ParallelStyle):
  47. """
  48. Partitioning the row of a module.
  49. We assume the input to be a sharded :class:`DTensor` and output to be a replicated :class:`DTensor`.
  50. """
  51. def __init__(self) -> None:
  52. super().__init__(make_input_shard_1d_last_dim, make_output_replicate_1d)
  53. class ColwiseParallel(ParallelStyle):
  54. """
  55. Partitioning the column of a tensor or module.
  56. We assume the input to be a replicated :class:`DTensor` and output to be a sharded :class:`DTensor`.
  57. """
  58. def __init__(self) -> None:
  59. super().__init__(make_input_replicate_1d, make_output_replicate_1d)
  60. @_prepare_input_validate # type: ignore[arg-type] # pyre-ignore[56]
  61. def make_input_shard_1d(
  62. input: Union[torch.Tensor, DTensor],
  63. device_mesh: Optional[DeviceMesh] = None,
  64. dim: int = 0,
  65. ) -> DTensor:
  66. """
  67. Shard input tensor on ``dim`` over an 1-D device mesh. This function will be used in ParallelStyle.
  68. Args:
  69. input (Union[:class:`torch.Tensor`, :class:`DTensor`]):
  70. Single tensor will be sharded on dimension ``dim``
  71. over the 1-D :class:`DeviceMesh`.
  72. device_mesh (:class:`DeviceMesh`, optional):
  73. The 1-D device mesh where ``input`` will be sharded.
  74. If no :class:`DeviceMesh` is passed and ``input`` is a :class:`DTensor`,
  75. `input.device_mesh` will be used.
  76. If :class:`DeviceMesh` is not 1-D, an exception will be thrown.
  77. Default: ``None``
  78. dim (int, optional): The sharding dimension of ``input`` tensor.
  79. Default: 0
  80. Returns:
  81. A :class:`DTensor` sharded on dimension ``dim`` over ``device_mesh``.
  82. """
  83. shard_spec = [Shard(dim)]
  84. if isinstance(input, DTensor):
  85. return input.redistribute(device_mesh, shard_spec)
  86. elif isinstance(input, torch.Tensor):
  87. return DTensor.from_local(input, device_mesh, shard_spec, run_check=False)
  88. else:
  89. raise RuntimeError(
  90. "Tensor parallel module expects torch.Tensor or DTensor input but"
  91. f" received {type(input)}!"
  92. )
  93. def make_input_shard_1d_last_dim(
  94. input: Union[torch.Tensor, DTensor],
  95. device_mesh: Optional[DeviceMesh] = None,
  96. ) -> DTensor:
  97. """
  98. Wrapper func of ``make_input_shard_1d`` with ``dim`` = -1.
  99. Args:
  100. input (Union[:class:`torch.Tensor`, :class:`DTensor`]):
  101. This single tensor will be sharded on dimension ``dim``
  102. over the 1-D :class:`DeviceMesh`.
  103. device_mesh (:class:`DeviceMesh`, optional):
  104. The 1-D device mesh where ``input`` will be sharded.
  105. If no :class:`DeviceMesh` is passed and ``input`` is a :class:`DTensor`,
  106. `input.device_mesh` will be used.
  107. If :class:`DeviceMesh` is not 1-D, an exception will be thrown.
  108. Default: ``None``
  109. Returns:
  110. A :class:`DTensor` sharded on dimension ``dim`` over ``device_mesh``.
  111. """
  112. return make_input_shard_1d(input, device_mesh, dim=-1) # type: ignore[call-arg]
  113. @_prepare_input_validate # type: ignore[arg-type] # pyre-ignore[56]
  114. def make_input_replicate_1d(
  115. input: Union[torch.Tensor, DTensor],
  116. device_mesh: Optional[DeviceMesh] = None,
  117. ) -> DTensor:
  118. """
  119. Replicate input tensor over an 1-D device mesh. This function will be used in ParallelStyle.
  120. Args:
  121. input (Union[:class:`torch.Tensor`, :class:`DTensor`]):
  122. This input tensor will be replicated over the 1-D :class:`DeviceMesh`.
  123. device_mesh (:class:`DeviceMesh`, optional):
  124. The 1-D device mesh where ``input`` will be replicated.
  125. If no :class:`DeviceMesh` is passed and ``input`` is a :class:`DTensor`,
  126. ``input.device_mesh`` will be used.
  127. If :class:`DeviceMesh` is not 1-D, an exception will be thrown.
  128. Default: ``None``
  129. Returns:
  130. A :class:`DTensor` replicated over ``device_mesh``.
  131. """
  132. replicate = [Replicate()]
  133. if isinstance(input, DTensor):
  134. return input.redistribute(device_mesh, replicate)
  135. elif isinstance(input, torch.Tensor):
  136. return DTensor.from_local(input, device_mesh, replicate, run_check=False)
  137. else:
  138. raise RuntimeError(
  139. "Tensor parallel module expects torch.Tensor or DTensor input but"
  140. f" received {type(input)}!"
  141. )
  142. @_prepare_output_validate # type: ignore[arg-type] # pyre-ignore[56]
  143. def make_output_shard_1d(
  144. output: DTensor, device_mesh: Optional[DeviceMesh] = None, dim: int = 0
  145. ) -> DTensor:
  146. """
  147. Convert Output DTensor to a sharded DTensor. This will be used in ParallelStyle.
  148. Args:
  149. output (:class:`DTensor`):
  150. Output of module to be converted.
  151. device_mesh (:class:`DeviceMesh`, optional):
  152. Object needed to shard the output and it needs to be a 1D ``device_mesh``
  153. and we will throw exceptions if a non-1D ``device_mesh`` is passed in.
  154. If no ``device_mesh`` is passed in, we will reuse the one from output.
  155. Default: ``None``
  156. dim (int): Sharding dim for output. Default: 0
  157. Return:
  158. A :class:`DTensor` object sharded on the given dim.
  159. """
  160. return output.redistribute(device_mesh, [Shard(dim)])
  161. @_prepare_output_validate # type: ignore[arg-type] # pyre-ignore[56]
  162. def make_output_replicate_1d(
  163. output: DTensor, device_mesh: Optional[DeviceMesh] = None
  164. ) -> DTensor:
  165. """
  166. Convert Output DTensor to a replicated DTensor. This will be used in ParallelStyle.
  167. Args:
  168. output (:class:`DTensor`):
  169. Output of module to be converted.
  170. device_mesh (:class:`DeviceMesh`, optional):
  171. Object needed to replicate the output and it needs to be a 1D ``device_mesh``
  172. and we will throw exceptions if a non-1D ``device_mesh`` is passed in.
  173. If no ``device_mesh`` is passed in, we will reuse the one from output.
  174. Default: ``None``
  175. Return:
  176. A :class:`DTensor` object made replicate.
  177. """
  178. return output.redistribute(device_mesh, [Replicate()])
  179. @_prepare_output_validate # type: ignore[arg-type] # pyre-ignore[56]
  180. def make_output_tensor(
  181. output: DTensor, device_mesh: Optional[DeviceMesh] = None
  182. ) -> torch.Tensor:
  183. """
  184. Convert Output DTensor to a replicated DTensor first and then convert it to Tensor.
  185. Args:
  186. output (:class:`DTensor`):
  187. Output of module to be converted.
  188. device_mesh (:class:`DeviceMesh`, optional):
  189. Object which is needed to replicate the output and it needs to be
  190. a 1D ``device_mesh`` and we will throw exceptions if a non-1D
  191. ``device_mesh`` is passed in. If no ``device_mesh`` is passed in,
  192. we will reuse the one from output.
  193. Default: ``None``
  194. Return:
  195. A :class:`torch.Tensor` object converted from output DTensor.
  196. """
  197. return make_output_replicate_1d( # type: ignore[attr-defined]
  198. output, device_mesh
  199. ).to_local() # type: ignore[call-arg]