123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166 |
- import functools
- def async_execution(fn):
- r"""
- A decorator for a function indicating that the return value of the function
- is guaranteed to be a :class:`~torch.futures.Future` object and this
- function can run asynchronously on the RPC callee. More specifically, the
- callee extracts the :class:`~torch.futures.Future` returned by the wrapped
- function and installs subsequent processing steps as a callback to that
- :class:`~torch.futures.Future`. The installed callback will read the value
- from the :class:`~torch.futures.Future` when completed and send the
- value back as the RPC response. That also means the returned
- :class:`~torch.futures.Future` only exists on the callee side and is never
- sent through RPC. This decorator is useful when the wrapped function's
- (``fn``) execution needs to pause and resume due to, e.g., containing
- :meth:`~torch.distributed.rpc.rpc_async` or waiting for other signals.
- .. note:: To enable asynchronous execution, applications must pass the
- function object returned by this decorator to RPC APIs. If RPC detected
- attributes installed by this decorator, it knows that this function
- returns a ``Future`` object and will handle that accordingly.
- However, this does not mean this decorator has to be outmost one when
- defining a function. For example, when combined with ``@staticmethod``
- or ``@classmethod``, ``@rpc.functions.async_execution`` needs to be the
- inner decorator to allow the target function be recognized as a static
- or class function. This target function can still execute asynchronously
- because, when accessed, the static or class method preserves attributes
- installed by ``@rpc.functions.async_execution``.
- Example::
- The returned :class:`~torch.futures.Future` object can come from
- :meth:`~torch.distributed.rpc.rpc_async`,
- :meth:`~torch.futures.Future.then`, or :class:`~torch.futures.Future`
- constructor. The example below shows directly using the
- :class:`~torch.futures.Future` returned by
- :meth:`~torch.futures.Future.then`.
- >>> from torch.distributed import rpc
- >>>
- >>> # omitting setup and shutdown RPC
- >>>
- >>> # On all workers
- >>> @rpc.functions.async_execution
- >>> def async_add_chained(to, x, y, z):
- >>> # This function runs on "worker1" and returns immediately when
- >>> # the callback is installed through the `then(cb)` API. In the
- >>> # mean time, the `rpc_async` to "worker2" can run concurrently.
- >>> # When the return value of that `rpc_async` arrives at
- >>> # "worker1", "worker1" will run the lambda function accordingly
- >>> # and set the value for the previously returned `Future`, which
- >>> # will then trigger RPC to send the result back to "worker0".
- >>> return rpc.rpc_async(to, torch.add, args=(x, y)).then(
- >>> lambda fut: fut.wait() + z
- >>> )
- >>>
- >>> # On worker0
- >>> # xdoctest: +SKIP
- >>> ret = rpc.rpc_sync(
- >>> "worker1",
- >>> async_add_chained,
- >>> args=("worker2", torch.ones(2), 1, 1)
- >>> )
- >>> print(ret) # prints tensor([3., 3.])
- When combined with TorchScript decorators, this decorator must be the
- outmost one.
- >>> from torch import Tensor
- >>> from torch.futures import Future
- >>> from torch.distributed import rpc
- >>>
- >>> # omitting setup and shutdown RPC
- >>>
- >>> # On all workers
- >>> @torch.jit.script
- >>> def script_add(x: Tensor, y: Tensor) -> Tensor:
- >>> return x + y
- >>>
- >>> @rpc.functions.async_execution
- >>> @torch.jit.script
- >>> def async_add(to: str, x: Tensor, y: Tensor) -> Future[Tensor]:
- >>> return rpc.rpc_async(to, script_add, (x, y))
- >>>
- >>> # On worker0
- >>> ret = rpc.rpc_sync(
- >>> "worker1",
- >>> async_add,
- >>> args=("worker2", torch.ones(2), 1)
- >>> )
- >>> print(ret) # prints tensor([2., 2.])
- When combined with static or class method, this decorator must be the
- inner one.
- >>> from torch.distributed import rpc
- >>>
- >>> # omitting setup and shutdown RPC
- >>>
- >>> # On all workers
- >>> class AsyncExecutionClass:
- >>>
- >>> @staticmethod
- >>> @rpc.functions.async_execution
- >>> def static_async_add(to, x, y, z):
- >>> return rpc.rpc_async(to, torch.add, args=(x, y)).then(
- >>> lambda fut: fut.wait() + z
- >>> )
- >>>
- >>> @classmethod
- >>> @rpc.functions.async_execution
- >>> def class_async_add(cls, to, x, y, z):
- >>> ret_fut = torch.futures.Future()
- >>> rpc.rpc_async(to, torch.add, args=(x, y)).then(
- >>> lambda fut: ret_fut.set_result(fut.wait() + z)
- >>> )
- >>> return ret_fut
- >>>
- >>> @rpc.functions.async_execution
- >>> def bound_async_add(self, to, x, y, z):
- >>> return rpc.rpc_async(to, torch.add, args=(x, y)).then(
- >>> lambda fut: fut.wait() + z
- >>> )
- >>>
- >>> # On worker0
- >>> ret = rpc.rpc_sync(
- >>> "worker1",
- >>> AsyncExecutionClass.static_async_add,
- >>> args=("worker2", torch.ones(2), 1, 2)
- >>> )
- >>> print(ret) # prints tensor([4., 4.])
- >>>
- >>> ret = rpc.rpc_sync(
- >>> "worker1",
- >>> AsyncExecutionClass.class_async_add,
- >>> args=("worker2", torch.ones(2), 1, 2)
- >>> )
- >>> print(ret) # prints tensor([4., 4.])
- This decorator also works with RRef helpers, i.e., .
- :meth:`torch.distributed.rpc.RRef.rpc_sync`,
- :meth:`torch.distributed.rpc.RRef.rpc_async`, and
- :meth:`torch.distributed.rpc.RRef.remote`.
- >>> from torch.distributed import rpc
- >>>
- >>> # reuse the AsyncExecutionClass class above
- >>> rref = rpc.remote("worker1", AsyncExecutionClass)
- >>> ret = rref.rpc_sync().static_async_add("worker2", torch.ones(2), 1, 2)
- >>> print(ret) # prints tensor([4., 4.])
- >>>
- >>> rref = rpc.remote("worker1", AsyncExecutionClass)
- >>> ret = rref.rpc_async().static_async_add("worker2", torch.ones(2), 1, 2).wait()
- >>> print(ret) # prints tensor([4., 4.])
- >>>
- >>> rref = rpc.remote("worker1", AsyncExecutionClass)
- >>> ret = rref.remote().static_async_add("worker2", torch.ones(2), 1, 2).to_here()
- >>> print(ret) # prints tensor([4., 4.])
- """
- @functools.wraps(fn)
- def wrapper(*args, **kwargs):
- return fn(*args, **kwargs)
- # Can't declare and use attributes of function objects (mypy#2087)
- wrapper._wrapped_async_rpc_function = fn # type: ignore[attr-defined]
- return wrapper
|