123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251 |
- from dataclasses import dataclass
- from functools import partial
- from typing import Any, List, Optional, Tuple
- import torch
- from torch._C import _disabled_torch_function_impl
- from torch.fx.experimental.proxy_tensor import (
- _ProxyTensor,
- get_innermost_proxy_mode,
- fetch_tensor_proxy,
- get_proxy_slot,
- set_proxy_slot,
- track_tensor_tree,
- )
- from torch.utils._mode_utils import no_dispatch
- from torch.utils._pytree import (
- tree_flatten,
- tree_map,
- tree_map_only,
- )
- @dataclass
- class _CommResult:
- # a custom type wrapping both inplace output tensor and work handle
- _tensor: torch.Tensor
- _work: torch.distributed._Work
- def _wait_comm(comm_result: _CommResult):
- # This function is only used by tracing mode as a call_function node right
- # before consuming a collective result tensor.
- comm_result._work.wait()
- return comm_result._tensor
- def _wrap_comm_result(result: Tuple[Any, Any]) -> Tuple[Any, Any]:
- def wrap(work, e):
- assert isinstance(e, torch.Tensor), (
- "Excepting collection of tensors as the first element in the "
- "return value of communication operations."
- )
- return _CommResult(e, work)
- # E.g.,
- # allreduce_ returns ([tensor], work)
- # allgather_ returns ([[tensor1, tensor2]], work)
- work = result[1]
- return (tree_map(partial(wrap, work), result[0]), work)
- def _get_tracer() -> Optional[torch.fx.Tracer]:
- mode = get_innermost_proxy_mode()
- if mode is None:
- return None
- return mode.tracer
- class CommTensor(torch.Tensor):
- r"""
- A Tensor subclass to wrap input tensors for collective communications. This
- Tensor subclass works for both eager and tracing mode.
- In eager mode, it will record whether the inplace collective communication
- has been launched using this Tensor and remember the corresponding work
- handle. If yes, it will expliclty call wait() in the ``__torch_dispatch__``
- function before subsequent operations consuming the value of the Tensor.
- In tracing mode, ``CommTensor`` inserts two node into the graph using the
- ``__torch_dispatch__`` function.
- 1. The first node is inserted right after the
- communication, wrapping both the inplace output tensor and the returned
- work handle into a custom ``_CommResult`` type. We have to do this because
- ``ProxyTorchDispatchMode`` only handles ``torch.Tensor``, ``_ProxyTensor``,
- and ``torch.nn.Parameter`` objects and will treat the work handle
- as a constant and embed that into the graph. As a result, during execution,
- it will use the work handle created during tracing and will lead to wrong
- result. The solution in this test is to manually create a proxy on the
- return value of ``allreduce_`` which is ``([tensor], work)``, and wrap that
- to ``[(_CommResult(tensor, work)), work]``. In this way, subsequent nodes can
- directly consume ``_CommResult``.
- 2. The second node is inserted right before any subsequent node reads from
- ``_CommResult``. It will call ``wait()`` on the stashed work handle to ensure
- that computation waits for communication.
- """
- _supported_comms: List[str] = [
- "_allgather_base_",
- "_reduce_scatter_base_",
- "allreduce_",
- "allgather_",
- "alltoall_",
- "broadcast_",
- "reduce_scatter_",
- "scatter_",
- ]
- _tensor: torch.Tensor
- _work: Optional[torch.distributed._Work]
- @staticmethod
- def __new__(cls, tensor: torch.Tensor):
- t = tensor._tensor if isinstance(tensor, CommTensor) else tensor
- if get_innermost_proxy_mode() is None:
- # noop for eager mode
- return tensor
- # Use non-CommTensor to avoid nested CommTensor Wrapping
- r = torch.Tensor._make_subclass(cls, t, require_grad=t.requires_grad)
- # The tensor object wrapped by this CommTensor
- # NB: THIS CAN BE A CommTensor; see test_nested_comm_tensor_wrapping
- r._tensor = tensor # type: ignore[attr-defined]
- # Record the LAST `work` object returned by collective communication
- # operations. If this is None, it means no collectives have called
- # since last time a tensor is wrapped by CommTensor
- r._work = None # type: ignore[attr-defined]
- return r
- def __repr__(self):
- return f"CommTensor({self._tensor}, work={self._work})"
- # disable __torch_function__ so that CommTensor can recursively dispatch
- # with ProxyTorchDispatchMode in make_fx
- __torch_function__ = _disabled_torch_function_impl
- @classmethod
- def _is_supported(cls, op_name):
- return any([comm in op_name for comm in cls._supported_comms])
- @classmethod
- def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
- # shared states when unwrapping args
- tracer: Optional[torch.fx.Tracer] = None
- work: Optional[torch.distributed._Work] = None
- # wrapped ._tensor if this is a CommTensor, and insert/call wait()
- # if communication has been launched on this tensor.
- def unwrap(e: Any):
- if isinstance(e, CommTensor):
- nonlocal tracer, work
- work = e._work
- # TODO(ezyang): I don't really understand what's going on
- # here, but it seems that tracer doesn't reflect whether or
- # not there is ambient tracing going on, but rather, whether
- # or not we will trace THIS particular invocation. If we
- # have a nested CommTensor, the outer layer doesn't actually
- # trace and we only trace the inner layer
- if not isinstance(e._tensor, CommTensor):
- tracer = _get_tracer()
- if work is not None:
- if tracer is not None:
- # insert a node to the traced graph.
- proxy_res = tracer.create_proxy( # type: ignore[union-attr]
- 'call_function',
- _wait_comm,
- (get_proxy_slot(e._tensor, tracer).proxy,),
- {},
- name="wait_comm"
- )
- # HACK: update the proxy for the inplace output
- set_proxy_slot(e._tensor, tracer, proxy_res)
- # For eager mode, simply wait.
- # During tracing, still need to wait here, to make sure the
- # execution during tracing is correct.
- work.wait()
- # communication has been waited, stop propagating CommTensor
- return e._tensor
- else:
- return e
- def wrap(e: Any):
- return CommTensor(e) if isinstance(e, torch.Tensor) else e
- def set_work(work: torch.distributed._Work, e: Any):
- if isinstance(e, CommTensor):
- e._work = work # type: ignore[attr-defined]
- elif isinstance(e, torch.Tensor):
- raise RuntimeError(
- "Type of output tensors from collective communication during "
- "tracing should always be CommTensor instead of torch.Tensor"
- )
- return e
- unwrapped_args = tree_map(unwrap, args)
- unwrapped_kwargs = tree_map(unwrap, kwargs)
- if cls._is_supported(func.__name__):
- if tracer is not None:
- # in tracing mode, get proxies for args
- proxy_args, proxy_kwargs = tree_map_only(
- _ProxyTensor,
- lambda e: e.proxy,
- tree_map_only(
- torch.Tensor,
- fetch_tensor_proxy(tracer),
- (unwrapped_args, unwrapped_kwargs)
- ),
- )
- # get proxy for output tuple
- proxy_res = func(*proxy_args, **proxy_kwargs)
- assert isinstance(proxy_res, torch.fx.Proxy)
- # insert a node that wraps the output tuple into
- # _CommResult(tensor, work)
- comm_result_proxy = tracer.create_proxy( # type: ignore[union-attr]
- 'call_function',
- _wrap_comm_result,
- (proxy_res, ),
- {},
- name="comm_result"
- )
- with no_dispatch():
- # disable dispatch to avoid trigger ProxyTorchDispatchMode logic
- out = func(*unwrapped_args, **unwrapped_kwargs)
- # wrap output with the proxy of _CommResult, so that subsequent
- # ops and link to it.
- track_tensor_tree(out, comm_result_proxy, constant=None, tracer=tracer)
- # N.B.: we still need to remember the work handle here, and wait
- # for it later to make sure the execution during tracing is
- # correct. Also, remember comm is already launched
- # args[0] is always the collection of output tensors
- tree_map(partial(set_work, out[1]), args[0])
- # HACK: update the proxy on the input argument as this is an
- # inplace collective communication.
- flat_args, args_spec = tree_flatten(unwrapped_args[0])
- flat_out, out_spec = tree_flatten(out[0])
- for a, o in zip(flat_args, flat_out):
- set_proxy_slot(a, tracer, get_proxy_slot(o, tracer))
- return out
- else:
- # in eager mode, simply remember work handle as an attribute
- out = func(*unwrapped_args, **unwrapped_kwargs)
- tree_map(partial(set_work, out[1]), args[0])
- return out
- else:
- if work is not None:
- return func(*unwrapped_args, **unwrapped_kwargs)
- else:
- # we need to propagate CommTensor wrapping until the first
- # subsequent operation has waited for it.
- return tree_map(wrap, func(*unwrapped_args, **unwrapped_kwargs))
|