functions.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166
  1. import functools
  2. def async_execution(fn):
  3. r"""
  4. A decorator for a function indicating that the return value of the function
  5. is guaranteed to be a :class:`~torch.futures.Future` object and this
  6. function can run asynchronously on the RPC callee. More specifically, the
  7. callee extracts the :class:`~torch.futures.Future` returned by the wrapped
  8. function and installs subsequent processing steps as a callback to that
  9. :class:`~torch.futures.Future`. The installed callback will read the value
  10. from the :class:`~torch.futures.Future` when completed and send the
  11. value back as the RPC response. That also means the returned
  12. :class:`~torch.futures.Future` only exists on the callee side and is never
  13. sent through RPC. This decorator is useful when the wrapped function's
  14. (``fn``) execution needs to pause and resume due to, e.g., containing
  15. :meth:`~torch.distributed.rpc.rpc_async` or waiting for other signals.
  16. .. note:: To enable asynchronous execution, applications must pass the
  17. function object returned by this decorator to RPC APIs. If RPC detected
  18. attributes installed by this decorator, it knows that this function
  19. returns a ``Future`` object and will handle that accordingly.
  20. However, this does not mean this decorator has to be outmost one when
  21. defining a function. For example, when combined with ``@staticmethod``
  22. or ``@classmethod``, ``@rpc.functions.async_execution`` needs to be the
  23. inner decorator to allow the target function be recognized as a static
  24. or class function. This target function can still execute asynchronously
  25. because, when accessed, the static or class method preserves attributes
  26. installed by ``@rpc.functions.async_execution``.
  27. Example::
  28. The returned :class:`~torch.futures.Future` object can come from
  29. :meth:`~torch.distributed.rpc.rpc_async`,
  30. :meth:`~torch.futures.Future.then`, or :class:`~torch.futures.Future`
  31. constructor. The example below shows directly using the
  32. :class:`~torch.futures.Future` returned by
  33. :meth:`~torch.futures.Future.then`.
  34. >>> from torch.distributed import rpc
  35. >>>
  36. >>> # omitting setup and shutdown RPC
  37. >>>
  38. >>> # On all workers
  39. >>> @rpc.functions.async_execution
  40. >>> def async_add_chained(to, x, y, z):
  41. >>> # This function runs on "worker1" and returns immediately when
  42. >>> # the callback is installed through the `then(cb)` API. In the
  43. >>> # mean time, the `rpc_async` to "worker2" can run concurrently.
  44. >>> # When the return value of that `rpc_async` arrives at
  45. >>> # "worker1", "worker1" will run the lambda function accordingly
  46. >>> # and set the value for the previously returned `Future`, which
  47. >>> # will then trigger RPC to send the result back to "worker0".
  48. >>> return rpc.rpc_async(to, torch.add, args=(x, y)).then(
  49. >>> lambda fut: fut.wait() + z
  50. >>> )
  51. >>>
  52. >>> # On worker0
  53. >>> # xdoctest: +SKIP
  54. >>> ret = rpc.rpc_sync(
  55. >>> "worker1",
  56. >>> async_add_chained,
  57. >>> args=("worker2", torch.ones(2), 1, 1)
  58. >>> )
  59. >>> print(ret) # prints tensor([3., 3.])
  60. When combined with TorchScript decorators, this decorator must be the
  61. outmost one.
  62. >>> from torch import Tensor
  63. >>> from torch.futures import Future
  64. >>> from torch.distributed import rpc
  65. >>>
  66. >>> # omitting setup and shutdown RPC
  67. >>>
  68. >>> # On all workers
  69. >>> @torch.jit.script
  70. >>> def script_add(x: Tensor, y: Tensor) -> Tensor:
  71. >>> return x + y
  72. >>>
  73. >>> @rpc.functions.async_execution
  74. >>> @torch.jit.script
  75. >>> def async_add(to: str, x: Tensor, y: Tensor) -> Future[Tensor]:
  76. >>> return rpc.rpc_async(to, script_add, (x, y))
  77. >>>
  78. >>> # On worker0
  79. >>> ret = rpc.rpc_sync(
  80. >>> "worker1",
  81. >>> async_add,
  82. >>> args=("worker2", torch.ones(2), 1)
  83. >>> )
  84. >>> print(ret) # prints tensor([2., 2.])
  85. When combined with static or class method, this decorator must be the
  86. inner one.
  87. >>> from torch.distributed import rpc
  88. >>>
  89. >>> # omitting setup and shutdown RPC
  90. >>>
  91. >>> # On all workers
  92. >>> class AsyncExecutionClass:
  93. >>>
  94. >>> @staticmethod
  95. >>> @rpc.functions.async_execution
  96. >>> def static_async_add(to, x, y, z):
  97. >>> return rpc.rpc_async(to, torch.add, args=(x, y)).then(
  98. >>> lambda fut: fut.wait() + z
  99. >>> )
  100. >>>
  101. >>> @classmethod
  102. >>> @rpc.functions.async_execution
  103. >>> def class_async_add(cls, to, x, y, z):
  104. >>> ret_fut = torch.futures.Future()
  105. >>> rpc.rpc_async(to, torch.add, args=(x, y)).then(
  106. >>> lambda fut: ret_fut.set_result(fut.wait() + z)
  107. >>> )
  108. >>> return ret_fut
  109. >>>
  110. >>> @rpc.functions.async_execution
  111. >>> def bound_async_add(self, to, x, y, z):
  112. >>> return rpc.rpc_async(to, torch.add, args=(x, y)).then(
  113. >>> lambda fut: fut.wait() + z
  114. >>> )
  115. >>>
  116. >>> # On worker0
  117. >>> ret = rpc.rpc_sync(
  118. >>> "worker1",
  119. >>> AsyncExecutionClass.static_async_add,
  120. >>> args=("worker2", torch.ones(2), 1, 2)
  121. >>> )
  122. >>> print(ret) # prints tensor([4., 4.])
  123. >>>
  124. >>> ret = rpc.rpc_sync(
  125. >>> "worker1",
  126. >>> AsyncExecutionClass.class_async_add,
  127. >>> args=("worker2", torch.ones(2), 1, 2)
  128. >>> )
  129. >>> print(ret) # prints tensor([4., 4.])
  130. This decorator also works with RRef helpers, i.e., .
  131. :meth:`torch.distributed.rpc.RRef.rpc_sync`,
  132. :meth:`torch.distributed.rpc.RRef.rpc_async`, and
  133. :meth:`torch.distributed.rpc.RRef.remote`.
  134. >>> from torch.distributed import rpc
  135. >>>
  136. >>> # reuse the AsyncExecutionClass class above
  137. >>> rref = rpc.remote("worker1", AsyncExecutionClass)
  138. >>> ret = rref.rpc_sync().static_async_add("worker2", torch.ones(2), 1, 2)
  139. >>> print(ret) # prints tensor([4., 4.])
  140. >>>
  141. >>> rref = rpc.remote("worker1", AsyncExecutionClass)
  142. >>> ret = rref.rpc_async().static_async_add("worker2", torch.ones(2), 1, 2).wait()
  143. >>> print(ret) # prints tensor([4., 4.])
  144. >>>
  145. >>> rref = rpc.remote("worker1", AsyncExecutionClass)
  146. >>> ret = rref.remote().static_async_add("worker2", torch.ones(2), 1, 2).to_here()
  147. >>> print(ret) # prints tensor([4., 4.])
  148. """
  149. @functools.wraps(fn)
  150. def wrapper(*args, **kwargs):
  151. return fn(*args, **kwargs)
  152. # Can't declare and use attributes of function objects (mypy#2087)
  153. wrapper._wrapped_async_rpc_function = fn # type: ignore[attr-defined]
  154. return wrapper