123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318 |
- from __future__ import annotations
- from typing import cast, Callable, Generic, List, Optional, Type, TypeVar, Union
- import torch
- __all__ = ['Future', 'collect_all', 'wait_all']
- T = TypeVar("T")
- S = TypeVar("S")
- class _PyFutureMeta(type(torch._C.Future), type(Generic)): # type: ignore[misc, no-redef]
- pass
- class Future(torch._C.Future, Generic[T], metaclass=_PyFutureMeta):
- r"""
- Wrapper around a ``torch._C.Future`` which encapsulates an asynchronous
- execution of a callable, e.g. :meth:`~torch.distributed.rpc.rpc_async`. It
- also exposes a set of APIs to add callback functions and set results.
- .. warning:: GPU support is a beta feature, subject to changes.
- """
- def __init__(self, *, devices: Optional[List[Union[int, str, torch.device]]] = None):
- r"""
- Create an empty unset ``Future``. If the future is intended to hold
- values containing CUDA tensors, (a superset of) their CUDA devices must
- be specified at construction. (This is only supported if
- ``torch.cuda.is_available()`` returns ``True``). This is needed to
- ensure proper CUDA stream synchronization. The child futures, returned
- by the ``then`` method, will inherit these devices.
- Args:
- devices(``List[Union[int, str, torch.device]]``, optional): the set
- of devices on which tensors contained in this future's value are
- allowed to reside and on which callbacks are allowed to operate.
- """
- if devices is None:
- devices = []
- super().__init__([torch.device(d) for d in devices])
- def done(self) -> bool:
- r"""
- Return ``True`` if this ``Future`` is done. A ``Future`` is done if it
- has a result or an exception.
- If the value contains tensors that reside on GPUs, ``Future.done()``
- will return ``True`` even if the asynchronous kernels that are
- populating those tensors haven't yet completed running on the device,
- because at such stage the result is already usable, provided one
- performs the appropriate synchronizations (see :meth:`wait`).
- """
- return super().done()
- def wait(self) -> T:
- r"""
- Block until the value of this ``Future`` is ready.
- If the value contains tensors that reside on GPUs, then an additional
- synchronization is performed with the kernels (executing on the device)
- which may be asynchronously populating those tensors. Such sync is
- non-blocking, which means that ``wait()`` will insert the necessary
- instructions in the current streams to ensure that further operations
- enqueued on those streams will be properly scheduled after the async
- kernels but, once that is done, ``wait()`` will return, even if those
- kernels are still running. No further synchronization is required when
- accessing and using the values, as long as one doesn't change streams.
- Returns:
- The value held by this ``Future``. If the function (callback or RPC)
- creating the value has thrown an error, this ``wait`` method will
- also throw an error.
- """
- return super().wait()
- def value(self) -> T:
- r"""
- Obtain the value of an already-completed future.
- This method should only be called after a call to :meth:`wait` has
- completed, or inside a callback function passed to :meth:`then`. In
- other cases this ``Future`` may not yet hold a value and calling
- ``value()`` could fail.
- If the value contains tensors that reside on GPUs, then this method will
- *not* perform any additional synchronization. This should be done
- beforehand, separately, through a call to :meth:`wait` (except within
- callbacks, for which it's already being taken care of by :meth:`then`).
- Returns:
- The value held by this ``Future``. If the function (callback or RPC)
- creating the value has thrown an error, this ``value()`` method will
- also throw an error.
- """
- return super().value()
- def then(self, callback: Callable[[Future[T]], S]) -> Future[S]:
- r"""
- Append the given callback function to this ``Future``, which will be run
- when the ``Future`` is completed. Multiple callbacks can be added to
- the same ``Future``, but the order in which they will be executed cannot
- be guaranteed (to enforce a certain order consider chaining:
- ``fut.then(cb1).then(cb2)``). The callback must take one argument, which
- is the reference to this ``Future``. The callback function can use the
- :meth:`value` method to get the value. Note that if this ``Future`` is
- already completed, the given callback will be run immediately inline.
- If the ``Future``'s value contains tensors that reside on GPUs, the
- callback might be invoked while the async kernels that are populating
- those tensors haven't yet finished executing on the device. However, the
- callback will be invoked with some dedicated streams set as current
- (fetched from a global pool) which will be synchronized with those
- kernels. Hence any operation performed by the callback on these tensors
- will be scheduled on the device after the kernels complete. In other
- words, as long as the callback doesn't switch streams, it can safely
- manipulate the result without any additional synchronization. This is
- similar to the non-blocking behavior of :meth:`wait`.
- Similarly, if the callback returns a value that contains tensors that
- reside on a GPU, it can do so even if the kernels that are producing
- these tensors are still running on the device, as long as the callback
- didn't change streams during its execution. If one wants to change
- streams, one must be careful to re-synchronize them with the original
- streams, that is, those that were current when the callback was invoked.
- Args:
- callback(``Callable``): a ``Callable`` that takes this ``Future`` as
- the only argument.
- Returns:
- A new ``Future`` object that holds the return value of the
- ``callback`` and will be marked as completed when the given
- ``callback`` finishes.
- .. note:: Note that if the callback function throws, either
- through the original future being completed with an exception and
- calling ``fut.wait()``, or through other code in the callback, the
- future returned by ``then`` will be marked appropriately with the
- encountered error. However, if this callback later completes
- additional futures, those futures are not marked as completed with
- an error and the user is responsible for handling completion/waiting
- on those futures independently.
- Example::
- >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_FUTURES)
- >>> def callback(fut):
- ... print(f"RPC return value is {fut.wait()}.")
- >>> fut = torch.futures.Future()
- >>> # The inserted callback will print the return value when
- >>> # receiving the response from "worker1"
- >>> cb_fut = fut.then(callback)
- >>> chain_cb_fut = cb_fut.then(
- ... lambda x : print(f"Chained cb done. {x.wait()}")
- ... )
- >>> fut.set_result(5)
- RPC return value is 5.
- Chained cb done. None
- """
- return cast(Future[S], super().then(callback))
- def add_done_callback(self, callback: Callable[[Future[T]], None]) -> None:
- r"""
- Append the given callback function to this ``Future``, which will be run
- when the ``Future`` is completed. Multiple callbacks can be added to
- the same ``Future``, but the order in which they will be executed cannot
- be guaranteed. The callback must take one argument, which is the
- reference to this ``Future``. The callback function can use the
- :meth:`value` method to get the value. Note that if this ``Future`` is
- already completed, the given callback will be run inline.
- We recommend that you use the :meth:`then` method as it provides a way
- to synchronize after your callback has completed. ``add_done_callback``
- can be cheaper if your callback does not return anything. But both
- :meth:`then` and ``add_done_callback`` use the same callback
- registration API under the hood.
- With respect to GPU tensors, this method behaves in the same way as
- :meth:`then`.
- Args:
- callback(``Future``): a ``Callable`` that takes in one argument,
- which is the reference to this ``Future``.
- .. note:: Note that if the callback function throws, either
- through the original future being completed with an exception and
- calling ``fut.wait()``, or through other code in the callback,
- error handling must be carefully taken care of. For example, if
- this callback later completes additional futures, those futures are
- not marked as completed with an error and the user is responsible
- for handling completion/waiting on those futures independently.
- Example::
- >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_FUTURES)
- >>> def callback(fut):
- ... print("This will run after the future has finished.")
- ... print(fut.wait())
- >>> fut = torch.futures.Future()
- >>> fut.add_done_callback(callback)
- >>> fut.set_result(5)
- This will run after the future has finished.
- 5
- """
- super().add_done_callback(callback)
- def set_result(self, result: T) -> None:
- r"""
- Set the result for this ``Future``, which will mark this ``Future`` as
- completed and trigger all attached callbacks. Note that a ``Future``
- cannot be marked completed twice.
- If the result contains tensors that reside on GPUs, this method can be
- called even if the asynchronous kernels that are populating those
- tensors haven't yet completed running on the device, provided that the
- streams on which those kernels were enqueued are set as the current ones
- when this method is called. Put simply, it's safe to call this method
- immediately after launching those kernels, without any additional
- synchronization, as long as one doesn't change streams in between. This
- method will record events on all the relevant current streams and will
- use them to ensure proper scheduling for all the consumers of this
- ``Future``.
- Args:
- result (object): the result object of this ``Future``.
- Example::
- >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_FUTURES)
- >>> import threading
- >>> import time
- >>> def slow_set_future(fut, value):
- ... time.sleep(0.5)
- ... fut.set_result(value)
- >>> fut = torch.futures.Future()
- >>> t = threading.Thread(
- ... target=slow_set_future,
- ... args=(fut, torch.ones(2) * 3)
- ... )
- >>> t.start()
- >>> print(fut.wait())
- tensor([3., 3.])
- >>> t.join()
- """
- super().set_result(result)
- def set_exception(self, result: T) -> None:
- r"""
- Set an exception for this ``Future``, which will mark this ``Future`` as
- completed with an error and trigger all attached callbacks. Note that
- when calling wait()/value() on this ``Future``, the exception set here
- will be raised inline.
- Args:
- result (BaseException): the exception for this ``Future``.
- Example::
- >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_FUTURES)
- >>> fut = torch.futures.Future()
- >>> fut.set_exception(ValueError("foo"))
- >>> fut.wait()
- Traceback (most recent call last):
- ...
- ValueError: foo
- """
- assert isinstance(result, Exception), f"{result} is of type {type(result)}, not an Exception."
- def raise_error(fut_result):
- raise fut_result
- super()._set_unwrap_func(raise_error)
- self.set_result(result) # type: ignore[arg-type]
- def collect_all(futures: List[Future]) -> Future[List[Future]]:
- r"""
- Collects the provided :class:`~torch.futures.Future` objects into a single
- combined :class:`~torch.futures.Future` that is completed when all of the
- sub-futures are completed.
- Args:
- futures (list): a list of :class:`~torch.futures.Future` objects.
- Returns:
- Returns a :class:`~torch.futures.Future` object to a list of the passed
- in Futures.
- Example::
- >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_FUTURES)
- >>> fut0 = torch.futures.Future()
- >>> fut1 = torch.futures.Future()
- >>> fut = torch.futures.collect_all([fut0, fut1])
- >>> fut0.set_result(0)
- >>> fut1.set_result(1)
- >>> fut_list = fut.wait()
- >>> print(f"fut0 result = {fut_list[0].wait()}")
- fut0 result = 0
- >>> print(f"fut1 result = {fut_list[1].wait()}")
- fut1 result = 1
- """
- return cast(Future[List[Future]], torch._C._collect_all(cast(List[torch._C.Future], futures)))
- def wait_all(futures: List[Future]) -> List:
- r"""
- Waits for all provided futures to be complete, and returns
- the list of completed values. If any of the futures encounters an error,
- the method will exit early and report the error not waiting for other
- futures to complete.
- Args:
- futures (list): a list of :class:`~torch.futures.Future` object.
- Returns:
- A list of the completed :class:`~torch.futures.Future` results. This
- method will throw an error if ``wait`` on any
- :class:`~torch.futures.Future` throws.
- """
- return [fut.wait() for fut in torch._C._collect_all(cast(List[torch._C.Future], futures)).wait()]
|