_utils.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152
  1. import functools
  2. from typing import Callable, Optional, Union
  3. import torch
  4. from torch.distributed._tensor import DeviceMesh, DTensor
  5. _PrepareInputType = Callable[
  6. [Union[torch.Tensor, DTensor], Optional[DeviceMesh], Optional[int]], DTensor
  7. ]
  8. _PrepareOutputType = Callable[
  9. [DTensor, Optional[DeviceMesh], Optional[int]], Union[torch.Tensor, DTensor]
  10. ]
  11. def _prepare_input_validate(
  12. _prepare_input_func: _PrepareInputType,
  13. ) -> _PrepareInputType:
  14. """
  15. Inject common validation logics for `_prepare_input` funcs via this
  16. decorator, including verifying that input needs to be either
  17. a :class:`Tensor` or :class:`DTensor` and only 1D :class:`DeviceMesh`
  18. is passed in.
  19. Args:
  20. _prepare_input_func (Callable): The func we want to inject the
  21. validation into.
  22. Returns:
  23. func (Callable): Same input function with validation logic added.
  24. Example::
  25. >>> # xdoctest: +SKIP(failing)
  26. >>> @_prepare_input_validate
  27. >>> def make_input_shard_1d(args, kwargs):
  28. >>> ...
  29. >>>
  30. >>> # xdoctest: +SKIP(failing)
  31. >>> input = torch.rand(...)
  32. >>> dtensor = make_input_shard_1d(input, device_mesh, 1)
  33. >>> # This will call '_prepare_input_validate' first
  34. """
  35. @functools.wraps(_prepare_input_func)
  36. def wrapper(*args, **kwargs): # pyre-ignore[2, 3]
  37. assert len(args) >= 1, "_prepare_input needs at least one arg."
  38. input = args[0]
  39. if isinstance(input, (list, tuple)):
  40. input = input[0]
  41. args = (input, *args[1:])
  42. device_mesh = None if len(args) < 2 else args[1]
  43. if device_mesh is None:
  44. if isinstance(input, DTensor):
  45. device_mesh = input.device_mesh
  46. args = (*args[:1], device_mesh, *args[2:]) # pyre-ignore[60]
  47. else:
  48. raise RuntimeError("device_mesh is not passed nor can be inferred")
  49. if device_mesh.ndim != 1:
  50. raise RuntimeError(
  51. f"device_mesh has dims {device_mesh.ndim} but expcted to be 1"
  52. " for input."
  53. )
  54. return _prepare_input_func(*args, **kwargs)
  55. return wrapper
  56. def _prepare_output_validate(
  57. _prepare_output_func: _PrepareOutputType,
  58. ) -> _PrepareOutputType:
  59. """
  60. Inject common validation logics for _prepare_output funcs via this
  61. decorator, including verifying that output needs to be a DTensor
  62. and only 1D Device Mesh is passed in.
  63. Example::
  64. >>> # xdoctest: +SKIP(failing)
  65. >>> @_prepare_output_validate
  66. >>> def make_output_shard_1d(args, kwargs):
  67. >>> ...
  68. >>>
  69. >>> # xdoctest: +SKIP(failing)
  70. >>> dt = distribute(tensor, device_mesh, [Shard(0)])
  71. >>> make_output_shard_1d(dt, device_mesh, 1)
  72. >>> # This will call '_prepare_output_validate' first
  73. Args:
  74. _prepare_output_func (Callable): The func we want to inject the
  75. validation into.
  76. Return:
  77. func (Callable): Same input func with validation logic added.
  78. """
  79. @functools.wraps(_prepare_output_func)
  80. def wrapper(*args, **kwargs): # pyre-ignore[2, 3]
  81. assert len(args) >= 1, "_prepare_output needs at least one arg."
  82. output = args[0]
  83. assert isinstance(output, DTensor), (
  84. "Expect output of Tensor Parallel to be a DTensor, but found"
  85. f" {type(output)}."
  86. )
  87. if len(args) < 2 or args[1] is None:
  88. device_mesh = output.device_mesh
  89. args = (*args[:1], device_mesh, *args[2:]) # pyre-ignore[60]
  90. else:
  91. device_mesh = args[1]
  92. assert device_mesh.ndim == 1, (
  93. f"device_mesh has dims {device_mesh.ndim} but expcted to be 1 for"
  94. " output."
  95. )
  96. return _prepare_output_func(*args, **kwargs)
  97. return wrapper
  98. def _create_1d_device_mesh(device_mesh: DeviceMesh, tp_mesh_dim: int = 0) -> DeviceMesh:
  99. """
  100. This function converts a N-D ``device_mesh`` into a 1D ``device_mesh``
  101. for 1D Tensor Parallelism.
  102. Args:
  103. device_mesh (DeviceMesh):
  104. :class:``DeviceMesh`` object which describes the mesh topology
  105. of devices for the DTensor.
  106. tp_mesh_dim (int):
  107. the dimension of ``device_mesh`` where we perform
  108. Tensor Parallelism on.
  109. Return:
  110. device_mesh (DeviceMesh): 1-D :class:``DeviceMesh`` object that
  111. Tensor Parallelism operates on.
  112. """
  113. assert tp_mesh_dim < device_mesh.ndim and tp_mesh_dim >= -device_mesh.ndim, (
  114. f"Expect tp_mesh_dim within range [{-device_mesh.ndim},"
  115. f" {device_mesh.ndim}), but found {tp_mesh_dim}."
  116. )
  117. if device_mesh.ndim == 1:
  118. return device_mesh
  119. # swap the current dim to the last dim then reshape to flatten out other
  120. # dims, so we can just extract the list of ranks which contains cur_rank.
  121. cur_rank = device_mesh.get_rank()
  122. pg_ranks_by_dim = device_mesh.mesh.swapdims(-1, tp_mesh_dim).reshape(
  123. -1, device_mesh.mesh.size(tp_mesh_dim)
  124. )
  125. dim_mesh_1d = pg_ranks_by_dim[torch.any(pg_ranks_by_dim == cur_rank, 1), :]
  126. sub_pg = device_mesh.get_dim_groups()[tp_mesh_dim]
  127. return DeviceMesh(device_mesh.device_type, dim_mesh_1d.squeeze(), [sub_pg])