comm_tensor.py 10.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251
  1. from dataclasses import dataclass
  2. from functools import partial
  3. from typing import Any, List, Optional, Tuple
  4. import torch
  5. from torch._C import _disabled_torch_function_impl
  6. from torch.fx.experimental.proxy_tensor import (
  7. _ProxyTensor,
  8. get_innermost_proxy_mode,
  9. fetch_tensor_proxy,
  10. get_proxy_slot,
  11. set_proxy_slot,
  12. track_tensor_tree,
  13. )
  14. from torch.utils._mode_utils import no_dispatch
  15. from torch.utils._pytree import (
  16. tree_flatten,
  17. tree_map,
  18. tree_map_only,
  19. )
  20. @dataclass
  21. class _CommResult:
  22. # a custom type wrapping both inplace output tensor and work handle
  23. _tensor: torch.Tensor
  24. _work: torch.distributed._Work
  25. def _wait_comm(comm_result: _CommResult):
  26. # This function is only used by tracing mode as a call_function node right
  27. # before consuming a collective result tensor.
  28. comm_result._work.wait()
  29. return comm_result._tensor
  30. def _wrap_comm_result(result: Tuple[Any, Any]) -> Tuple[Any, Any]:
  31. def wrap(work, e):
  32. assert isinstance(e, torch.Tensor), (
  33. "Excepting collection of tensors as the first element in the "
  34. "return value of communication operations."
  35. )
  36. return _CommResult(e, work)
  37. # E.g.,
  38. # allreduce_ returns ([tensor], work)
  39. # allgather_ returns ([[tensor1, tensor2]], work)
  40. work = result[1]
  41. return (tree_map(partial(wrap, work), result[0]), work)
  42. def _get_tracer() -> Optional[torch.fx.Tracer]:
  43. mode = get_innermost_proxy_mode()
  44. if mode is None:
  45. return None
  46. return mode.tracer
  47. class CommTensor(torch.Tensor):
  48. r"""
  49. A Tensor subclass to wrap input tensors for collective communications. This
  50. Tensor subclass works for both eager and tracing mode.
  51. In eager mode, it will record whether the inplace collective communication
  52. has been launched using this Tensor and remember the corresponding work
  53. handle. If yes, it will expliclty call wait() in the ``__torch_dispatch__``
  54. function before subsequent operations consuming the value of the Tensor.
  55. In tracing mode, ``CommTensor`` inserts two node into the graph using the
  56. ``__torch_dispatch__`` function.
  57. 1. The first node is inserted right after the
  58. communication, wrapping both the inplace output tensor and the returned
  59. work handle into a custom ``_CommResult`` type. We have to do this because
  60. ``ProxyTorchDispatchMode`` only handles ``torch.Tensor``, ``_ProxyTensor``,
  61. and ``torch.nn.Parameter`` objects and will treat the work handle
  62. as a constant and embed that into the graph. As a result, during execution,
  63. it will use the work handle created during tracing and will lead to wrong
  64. result. The solution in this test is to manually create a proxy on the
  65. return value of ``allreduce_`` which is ``([tensor], work)``, and wrap that
  66. to ``[(_CommResult(tensor, work)), work]``. In this way, subsequent nodes can
  67. directly consume ``_CommResult``.
  68. 2. The second node is inserted right before any subsequent node reads from
  69. ``_CommResult``. It will call ``wait()`` on the stashed work handle to ensure
  70. that computation waits for communication.
  71. """
  72. _supported_comms: List[str] = [
  73. "_allgather_base_",
  74. "_reduce_scatter_base_",
  75. "allreduce_",
  76. "allgather_",
  77. "alltoall_",
  78. "broadcast_",
  79. "reduce_scatter_",
  80. "scatter_",
  81. ]
  82. _tensor: torch.Tensor
  83. _work: Optional[torch.distributed._Work]
  84. @staticmethod
  85. def __new__(cls, tensor: torch.Tensor):
  86. t = tensor._tensor if isinstance(tensor, CommTensor) else tensor
  87. if get_innermost_proxy_mode() is None:
  88. # noop for eager mode
  89. return tensor
  90. # Use non-CommTensor to avoid nested CommTensor Wrapping
  91. r = torch.Tensor._make_subclass(cls, t, require_grad=t.requires_grad)
  92. # The tensor object wrapped by this CommTensor
  93. # NB: THIS CAN BE A CommTensor; see test_nested_comm_tensor_wrapping
  94. r._tensor = tensor # type: ignore[attr-defined]
  95. # Record the LAST `work` object returned by collective communication
  96. # operations. If this is None, it means no collectives have called
  97. # since last time a tensor is wrapped by CommTensor
  98. r._work = None # type: ignore[attr-defined]
  99. return r
  100. def __repr__(self):
  101. return f"CommTensor({self._tensor}, work={self._work})"
  102. # disable __torch_function__ so that CommTensor can recursively dispatch
  103. # with ProxyTorchDispatchMode in make_fx
  104. __torch_function__ = _disabled_torch_function_impl
  105. @classmethod
  106. def _is_supported(cls, op_name):
  107. return any([comm in op_name for comm in cls._supported_comms])
  108. @classmethod
  109. def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
  110. # shared states when unwrapping args
  111. tracer: Optional[torch.fx.Tracer] = None
  112. work: Optional[torch.distributed._Work] = None
  113. # wrapped ._tensor if this is a CommTensor, and insert/call wait()
  114. # if communication has been launched on this tensor.
  115. def unwrap(e: Any):
  116. if isinstance(e, CommTensor):
  117. nonlocal tracer, work
  118. work = e._work
  119. # TODO(ezyang): I don't really understand what's going on
  120. # here, but it seems that tracer doesn't reflect whether or
  121. # not there is ambient tracing going on, but rather, whether
  122. # or not we will trace THIS particular invocation. If we
  123. # have a nested CommTensor, the outer layer doesn't actually
  124. # trace and we only trace the inner layer
  125. if not isinstance(e._tensor, CommTensor):
  126. tracer = _get_tracer()
  127. if work is not None:
  128. if tracer is not None:
  129. # insert a node to the traced graph.
  130. proxy_res = tracer.create_proxy( # type: ignore[union-attr]
  131. 'call_function',
  132. _wait_comm,
  133. (get_proxy_slot(e._tensor, tracer).proxy,),
  134. {},
  135. name="wait_comm"
  136. )
  137. # HACK: update the proxy for the inplace output
  138. set_proxy_slot(e._tensor, tracer, proxy_res)
  139. # For eager mode, simply wait.
  140. # During tracing, still need to wait here, to make sure the
  141. # execution during tracing is correct.
  142. work.wait()
  143. # communication has been waited, stop propagating CommTensor
  144. return e._tensor
  145. else:
  146. return e
  147. def wrap(e: Any):
  148. return CommTensor(e) if isinstance(e, torch.Tensor) else e
  149. def set_work(work: torch.distributed._Work, e: Any):
  150. if isinstance(e, CommTensor):
  151. e._work = work # type: ignore[attr-defined]
  152. elif isinstance(e, torch.Tensor):
  153. raise RuntimeError(
  154. "Type of output tensors from collective communication during "
  155. "tracing should always be CommTensor instead of torch.Tensor"
  156. )
  157. return e
  158. unwrapped_args = tree_map(unwrap, args)
  159. unwrapped_kwargs = tree_map(unwrap, kwargs)
  160. if cls._is_supported(func.__name__):
  161. if tracer is not None:
  162. # in tracing mode, get proxies for args
  163. proxy_args, proxy_kwargs = tree_map_only(
  164. _ProxyTensor,
  165. lambda e: e.proxy,
  166. tree_map_only(
  167. torch.Tensor,
  168. fetch_tensor_proxy(tracer),
  169. (unwrapped_args, unwrapped_kwargs)
  170. ),
  171. )
  172. # get proxy for output tuple
  173. proxy_res = func(*proxy_args, **proxy_kwargs)
  174. assert isinstance(proxy_res, torch.fx.Proxy)
  175. # insert a node that wraps the output tuple into
  176. # _CommResult(tensor, work)
  177. comm_result_proxy = tracer.create_proxy( # type: ignore[union-attr]
  178. 'call_function',
  179. _wrap_comm_result,
  180. (proxy_res, ),
  181. {},
  182. name="comm_result"
  183. )
  184. with no_dispatch():
  185. # disable dispatch to avoid trigger ProxyTorchDispatchMode logic
  186. out = func(*unwrapped_args, **unwrapped_kwargs)
  187. # wrap output with the proxy of _CommResult, so that subsequent
  188. # ops and link to it.
  189. track_tensor_tree(out, comm_result_proxy, constant=None, tracer=tracer)
  190. # N.B.: we still need to remember the work handle here, and wait
  191. # for it later to make sure the execution during tracing is
  192. # correct. Also, remember comm is already launched
  193. # args[0] is always the collection of output tensors
  194. tree_map(partial(set_work, out[1]), args[0])
  195. # HACK: update the proxy on the input argument as this is an
  196. # inplace collective communication.
  197. flat_args, args_spec = tree_flatten(unwrapped_args[0])
  198. flat_out, out_spec = tree_flatten(out[0])
  199. for a, o in zip(flat_args, flat_out):
  200. set_proxy_slot(a, tracer, get_proxy_slot(o, tracer))
  201. return out
  202. else:
  203. # in eager mode, simply remember work handle as an attribute
  204. out = func(*unwrapped_args, **unwrapped_kwargs)
  205. tree_map(partial(set_work, out[1]), args[0])
  206. return out
  207. else:
  208. if work is not None:
  209. return func(*unwrapped_args, **unwrapped_kwargs)
  210. else:
  211. # we need to propagate CommTensor wrapping until the first
  212. # subsequent operation has waited for it.
  213. return tree_map(wrap, func(*unwrapped_args, **unwrapped_kwargs))