123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943 |
- __all__ = ["shutdown", "get_worker_info", "remote", "rpc_sync",
- "rpc_async", "RRef", "AllGatherStates", "method_factory", "new_method"]
- import collections
- import contextlib
- import functools
- import inspect
- import logging
- import threading
- from typing import Dict, Generic, TypeVar, Set, Any
- import torch
- from torch.futures import Future
- from torch._C._distributed_rpc import (
- PyRRef,
- RemoteProfilerManager,
- WorkerInfo,
- TensorPipeAgent,
- get_rpc_timeout,
- _cleanup_python_rpc_handler,
- _delete_all_user_and_unforked_owner_rrefs,
- _destroy_rref_context,
- _get_current_rpc_agent,
- _invoke_remote_builtin,
- _invoke_remote_python_udf,
- _invoke_remote_torchscript,
- _invoke_rpc_builtin,
- _invoke_rpc_python_udf,
- _invoke_rpc_torchscript,
- _is_current_rpc_agent_set,
- _reset_current_rpc_agent,
- _set_and_start_rpc_agent,
- )
- from .internal import (
- PythonUDF,
- RPCExecMode,
- _internal_rpc_pickler,
- _build_rpc_profiling_key,
- )
- from .constants import DEFAULT_SHUTDOWN_TIMEOUT, UNSET_RPC_TIMEOUT
- from ._utils import _group_membership_management, _update_group_membership
- logger = logging.getLogger(__name__)
- # NB: Ignoring RRef leaks during shutdown. Without this, applications have to
- # make sure there is no references to any RRef in the application code and
- # Python GC has done its job to delete those RRefs. This is could result in bad
- # debugging experiences especially when for large applications. Therefore, by
- # default, we are going to ignore RRef leaks during shutdown. This is usually
- # fine as shutdown means applications have done training and no longer care
- # about states.
- #
- # To enable RRef leak checking, set this _ignore_rref_leak to False
- _ignore_rref_leak = True
- _default_pickler = _internal_rpc_pickler
- @contextlib.contextmanager
- def _use_rpc_pickler(rpc_pickler):
- r"""
- rpc_pickler: (.internal._InternalRPCPickler) Overrides the default RPC pickler
- """
- global _default_pickler
- _default_pickler = rpc_pickler
- try:
- yield
- finally:
- _default_pickler = _internal_rpc_pickler
- def _require_initialized(func):
- @functools.wraps(func)
- def wrapper(*args, **kwargs):
- if not _is_current_rpc_agent_set():
- raise RuntimeError(
- "RPC has not been initialized. Call "
- "torch.distributed.rpc.init_rpc first."
- )
- return func(*args, **kwargs)
- return wrapper
- class AllGatherStates:
- def __init__(self):
- # Each `gathered_objects` is an empty dict at beginning.
- # The leader worker is elected as the first worker in a sorted worker
- # name list. Whenever there is a worker entering `_all_gather()`, it
- # runs `_gather_to_leader()` on the leader to add its own name and
- # data obj to this dict. The leader also adds itself's name to the dict
- # on calling `_all_gather()`.
- # Once `set(gathered_objects.keys()) == _ALL_WORKER_NAMES`, the leader
- # will broadcast the gathered dict to all follower workers and set their
- # `gathered_objects` field and the `proceed_signal` field.
- self.gathered_objects = {}
- # All workers wait on this signal until it receives all gathered
- # objects.
- self.proceed_signal = threading.Event()
- # States used by `def _all_gather()`.
- # `_ALL_WORKER_NAMES` is initialized on initiaizing RPC layer.
- _ALL_WORKER_NAMES: Set[Any] = set()
- _all_gather_dict_lock = threading.RLock()
- _all_gather_sequence_id: Dict[str, int] = {}
- _all_gather_sequence_id_to_states: collections.defaultdict = collections.defaultdict(AllGatherStates)
- def _init_rpc_states(agent):
- worker_infos = agent.get_worker_infos()
- global _ALL_WORKER_NAMES
- _ALL_WORKER_NAMES = {worker_info.name for worker_info in worker_infos}
- # NB: backend implementation might have already set the rpc_agent.
- if not _is_current_rpc_agent_set():
- _set_and_start_rpc_agent(agent)
- def _gather_to_leader(sequence_id, worker_name, obj, worker_names=None):
- with _all_gather_dict_lock:
- if not worker_names:
- worker_names = _ALL_WORKER_NAMES
- assert (
- worker_name in worker_names
- ), f"{worker_name} is not expected by leader."
- states = _all_gather_sequence_id_to_states[sequence_id]
- assert (
- worker_name not in states.gathered_objects
- ), f"{worker_name} reported intent sequence id {sequence_id} twice. "
- states.gathered_objects[worker_name] = obj
- if worker_names == set(states.gathered_objects.keys()):
- states.proceed_signal.set()
- def _broadcast_to_followers(sequence_id, objects_map):
- with _all_gather_dict_lock:
- states = _all_gather_sequence_id_to_states[sequence_id]
- assert (
- not states.proceed_signal.is_set()
- ), "Termination signal sequence id {} got set twice.".format(sequence_id)
- states.gathered_objects = objects_map
- states.proceed_signal.set()
- _thread_local_var = threading.local()
- @contextlib.contextmanager
- def _wait_all():
- r"""
- A context manager that collects all futures returned by ``rpc_async`` and
- waits them on the context manager's exit; relieving the user of needing
- to explicitly call wait.
- Example::
- >>> # xdoctest: +SKIP("distributed")
- >>> # On worker 0:
- >>> import torch
- >>> import torch.distributed.rpc as rpc
- >>> rpc.init_rpc("worker0", rank=0, world_size=2)
- >>> with rpc._wait_all():
- >>> fut_1 = rpc.rpc_async(dst, torch.add, (torch.ones(2, 2), 1))
- >>> fut_2 = rpc.rpc_async(dst, torch.add, (torch.ones(2, 2), 1))
- >>> #fut_1 and fut_2 are waited on
- """
- _thread_local_var.future_list = []
- try:
- yield
- finally:
- try:
- torch.futures.wait_all(_thread_local_var.future_list)
- finally:
- del _thread_local_var.future_list
- @_require_initialized
- def _all_gather(obj, worker_names=None, timeout=UNSET_RPC_TIMEOUT):
- r"""
- This is similar to torch.distributed.all_gather(), but is using RPC. It
- picks the worker with the smallest name (alphabetic order) as the leader.
- Then all followers send their data ``obj`` to the leader. After the leader
- has received all, it will broadcast the results back to all followers. This
- function blocks until all workers have received the gathered results.
- """
- if not worker_names:
- assert (
- _ALL_WORKER_NAMES is not None
- ), "`_ALL_WORKER_NAMES` is not initialized for `def _all_gather`."
- worker_names = _ALL_WORKER_NAMES
- leader_name = min(worker_names)
- self_name = _get_current_rpc_agent().get_worker_info().name
- with _all_gather_dict_lock:
- concat_names = "".join(sorted(worker_names))
- sequence_num = _all_gather_sequence_id.get(concat_names, 0)
- _all_gather_sequence_id[concat_names] = sequence_num + 1
- sequence_id = concat_names + str(sequence_num)
- is_leader = leader_name == self_name
- if timeout == UNSET_RPC_TIMEOUT:
- # Timeout is specified by agent for RPC calls
- rpc_timeout = get_rpc_timeout()
- # No timeout for signal
- signal_timeout = None
- elif timeout == DEFAULT_SHUTDOWN_TIMEOUT:
- # No timeout for RPC
- rpc_timeout = timeout
- # No timeout for signal
- signal_timeout = None
- else:
- # Signal and RPC timeout use the same timeout
- signal_timeout = rpc_timeout = timeout
- # Phase 1: Followers send it's object to the leader
- if is_leader:
- _gather_to_leader(sequence_id, self_name, obj, worker_names)
- else:
- rpc_sync(
- leader_name,
- _gather_to_leader,
- args=(sequence_id, self_name, obj, worker_names),
- timeout=rpc_timeout,
- )
- with _all_gather_dict_lock:
- states = _all_gather_sequence_id_to_states[sequence_id]
- # Timeout is either set by function parameter or None (which is indefinite)
- states.proceed_signal.wait(timeout=signal_timeout)
- # Phase 2: Leader broadcast gathered results to all followers
- # Leader's signal is the first to be unblocked, after receiving all
- # followers' data objects.
- if is_leader:
- worker_name_to_response_future_dict = {}
- for follower_name in worker_names - {leader_name}:
- fut = rpc_async(
- follower_name,
- _broadcast_to_followers,
- args=(sequence_id, states.gathered_objects),
- timeout=rpc_timeout
- )
- worker_name_to_response_future_dict[follower_name] = fut
- errors = []
- for follower_name, fut in worker_name_to_response_future_dict.items():
- try:
- fut.wait()
- except RuntimeError as ex:
- errors.append((follower_name, ex))
- if errors:
- raise RuntimeError(
- f"Followers {[e[0] for e in errors]} timed out in _all_gather "
- f"after {rpc_timeout:.2f} seconds. The first exception is {errors[0][1]}"
- )
- # Clean up for the states using the sequence_id
- with _all_gather_dict_lock:
- states = _all_gather_sequence_id_to_states.pop(sequence_id)
- return states.gathered_objects
- @_require_initialized
- def _barrier(worker_names):
- r"""
- Synchronizes local and remote RPC processes.
- This will block until all local and remote RPC processes specified under worker_names
- reach this method to wait for all outstanding work to complete.
- Args:
- worker_names (List[str]): The set of workers to synchronize.
- """
- try:
- _all_gather(None, set(worker_names))
- except RuntimeError as ex:
- logger.error(
- f"Failed to complete barrier, got error {ex}"
- )
- @_require_initialized
- def _wait_all_workers(timeout=DEFAULT_SHUTDOWN_TIMEOUT):
- r"""
- Block until all local and remote RPC processes reach this method and wait
- for all outstanding work to complete. Every RPC process must call this
- method before exit to perform a graceful shutdown. This should be used to
- terminate the RPC framework, and there is no guarantee that the RPC
- framework will work after this method returns.
- """
- try:
- _all_gather(None, timeout=timeout)
- except RuntimeError as ex:
- logger.error(
- f"Failed to respond to 'Shutdown Proceed' in time, got error {ex}"
- )
- raise ex
- @_require_initialized
- def shutdown(graceful=True, timeout=DEFAULT_SHUTDOWN_TIMEOUT):
- r"""
- Perform a shutdown of the RPC agent, and then destroy the RPC agent. This
- stops the local agent from accepting outstanding requests, and shuts
- down the RPC framework by terminating all RPC threads. If ``graceful=True``,
- this will block until all local and remote RPC processes reach this method
- and wait for all outstanding work to complete. Otherwise, if
- ``graceful=False``, this is a local shutdown, and it does not wait for other
- RPC processes to reach this method.
- .. warning::
- For :class:`~torch.futures.Future` objects returned by
- :meth:`~torch.distributed.rpc.rpc_async`, ``future.wait()`` should not
- be called after ``shutdown()``.
- Args:
- graceful (bool): Whether to do a graceful shutdown or not. If True,
- this will 1) wait until there is no pending system
- messages for ``UserRRefs`` and delete them; 2) block
- until all local and remote RPC processes have reached
- this method and wait for all outstanding work to
- complete.
- Example::
- Make sure that ``MASTER_ADDR`` and ``MASTER_PORT`` are set properly
- on both workers. Refer to :meth:`~torch.distributed.init_process_group`
- API for more details. For example,
- export MASTER_ADDR=localhost
- export MASTER_PORT=5678
- Then run the following code in two different processes:
- >>> # xdoctest: +SKIP
- >>> # On worker 0:
- >>> import torch
- >>> import torch.distributed.rpc as rpc
- >>> rpc.init_rpc("worker0", rank=0, world_size=2)
- >>> # do some work
- >>> result = rpc.rpc_sync("worker1", torch.add, args=(torch.ones(1), 1))
- >>> # ready to shutdown
- >>> rpc.shutdown()
- >>> # On worker 1:
- >>> import torch.distributed.rpc as rpc
- >>> rpc.init_rpc("worker1", rank=1, world_size=2)
- >>> # wait for worker 0 to finish work, and then shutdown.
- >>> rpc.shutdown()
- """
- if graceful:
- try:
- agent = _get_current_rpc_agent()
- if not isinstance(agent, TensorPipeAgent) or agent.is_static_group:
- _wait_all_workers(timeout)
- _delete_all_user_and_unforked_owner_rrefs()
- agent.join(shutdown=True, timeout=timeout)
- else:
- # This is a dynamic group so we need to grab the token for the operation
- my_worker_info = agent.get_worker_info()
- my_name = my_worker_info.name
- with _group_membership_management(agent.store, my_name, False):
- all_worker_infos = agent.get_worker_infos()
- for worker in all_worker_infos:
- if worker.name != my_name:
- rpc_sync(worker.name, _update_group_membership, args=(my_worker_info, [], {}, False))
- agent.join(shutdown=True, timeout=timeout)
- finally:
- # In case of errors, continue to complete the local shutdown.
- _finalize_shutdown()
- else:
- _finalize_shutdown()
- def _finalize_shutdown():
- try:
- # This raises a `TORCH_CHECK()` exception on RRef leak detected.
- _destroy_rref_context(_ignore_rref_leak)
- finally:
- _get_current_rpc_agent().shutdown()
- # clean up python rpc handler in shutdown(), see comments in
- # PythonRpcHandler::cleanup(), call it in python API because the
- # cleanup() function has python dependency, it assumes python
- # interpreter exists.
- # No matter if RRef leak exception is raised, this clean-up code
- # must run to avoid destruction segfault in Python 3.5.
- #
- # future.wait() should not be called after shutdown().
- # pythonRpcHandler is cleaned up in shutdown(), after
- # shutdown(), python objects returned from rpc python call can not be
- # resolved.
- _cleanup_python_rpc_handler()
- _reset_current_rpc_agent()
- @_require_initialized
- def get_worker_info(worker_name=None):
- r"""
- Get :class:`~torch.distributed.rpc.WorkerInfo` of a given worker name.
- Use this :class:`~torch.distributed.rpc.WorkerInfo` to avoid passing an
- expensive string on every invocation.
- Args:
- worker_name (str): the string name of a worker. If ``None``, return the
- the id of the current worker. (default ``None``)
- Returns:
- :class:`~torch.distributed.rpc.WorkerInfo` instance for the given
- ``worker_name`` or :class:`~torch.distributed.rpc.WorkerInfo` of the
- current worker if ``worker_name`` is ``None``.
- """
- if worker_name is not None:
- return _get_current_rpc_agent().get_worker_info(worker_name)
- else:
- return _get_current_rpc_agent().get_worker_info()
- def _to_worker_info(to):
- if isinstance(to, WorkerInfo):
- return to
- elif isinstance(to, (str, int)):
- return get_worker_info(to)
- else:
- raise ValueError("Cannot get WorkerInfo from name {}".format(to))
- def _rref_typeof_on_owner(rref, blocking=True):
- rref_type = type(rref.local_value())
- if blocking:
- return rref_type
- else:
- # Wrap result into a completed Future. This is so that if blocking=`False`
- # is specified, we return a future regardless of if this call is on user
- # or owner.
- future = Future[type]()
- future.set_result(rref_type)
- return future
- def _rref_typeof_on_user(rref, timeout=UNSET_RPC_TIMEOUT, blocking=True):
- fut = rpc_async(
- rref.owner(),
- _rref_typeof_on_owner,
- args=(rref,),
- timeout=timeout
- )
- if blocking:
- return fut.wait()
- else:
- return fut
- T = TypeVar("T")
- GenericWithOneTypeVar = Generic[T]
- try:
- # Combine the implementation class and the type class.
- class RRef(PyRRef, Generic[T]):
- pass
- except TypeError:
- # TypeError: metaclass conflict: the metaclass of a derived class
- # must be a (non-strict) subclass of the metaclasses of all its bases
- # Mypy doesn't understand __class__ (mypy bug #4177)
- class RRefMeta(PyRRef.__class__, GenericWithOneTypeVar.__class__): # type: ignore[name-defined, misc, valid-type]
- pass
- # Combine the implementation class and the type class.
- # Types for classes expecting a certain generic parameter (mypy bug #7791)
- class RRef(PyRRef, GenericWithOneTypeVar, metaclass=RRefMeta): # type: ignore[misc, no-redef, valid-type]
- pass
- # Install docstrings from `PyRRef` to `RRef`.
- #
- # This is for the fact that pybind11 generates the parameter
- # `self` as type `rpc.PyRRef`, so a `:inherited-members:`
- # under `.. autoclass:: RRef` does not work.
- # we have to do the following process to replacee `rpc.PyRRef` with `rpc.RRef`.
- #
- def method_factory(method_name, docstring):
- def method(self, *args, **kwargs):
- return getattr(super(RRef, self), method_name)(*args, **kwargs)
- if method.__doc__:
- method.__doc__ = docstring
- return method
- for method_name, method in inspect.getmembers(PyRRef):
- # Ignore magic methods, except "__str__".
- if method_name.startswith("_") and method_name != "__str__":
- continue
- # Get pybind11 generated docstring.
- # It's like,
- """
- to_here(self: torch.distributed.rpc.PyRRef, timeout: float=-1.0) -> object
- Blocking call that copies the value of the RRef from the owner
- to the local node and returns it. If the current node is the
- owner, returns a reference to the local value.
- """
- docstring = getattr(method, "__doc__", None)
- assert docstring is not None, "RRef user-facing methods should all have docstrings."
- # Do surgery on pybind11 generated docstrings.
- docstring = docstring.replace("torch.distributed.rpc.PyRRef", "torch.distributed.rpc.RRef")
- # Attach user-facing RRef method with modified docstring.
- new_method = method_factory(method_name, docstring)
- setattr(RRef, method_name, new_method)
- @_require_initialized
- def remote(to, func, args=None, kwargs=None, timeout=UNSET_RPC_TIMEOUT):
- r"""
- Make a remote call to run ``func`` on worker ``to`` and return an
- :class:`~torch.distributed.rpc.RRef` to the result value immediately.
- Worker ``to`` will be the owner of the returned
- :class:`~torch.distributed.rpc.RRef`, and the worker calling ``remote`` is
- a user. The owner manages the global reference count of its
- :class:`~torch.distributed.rpc.RRef`, and the owner
- :class:`~torch.distributed.rpc.RRef` is only destructed when globally there
- are no living references to it.
- Args:
- to (str or WorkerInfo or int): name/rank/``WorkerInfo`` of the destination worker.
- func (Callable): a callable function, such as Python callables, builtin
- operators (e.g. :meth:`~torch.add`) and annotated
- TorchScript functions.
- args (tuple): the argument tuple for the ``func`` invocation.
- kwargs (dict): is a dictionary of keyword arguments for the ``func``
- invocation.
- timeout (float, optional): timeout in seconds for this remote call. If the
- creation of this
- :class:`~torch.distributed.rpc.RRef` on worker
- ``to`` is not successfully processed on this
- worker within this timeout, then the next time
- there is an attempt to use the RRef (such as
- ``to_here()``), a timeout will be raised
- indicating this failure. A value of 0 indicates
- an infinite timeout, i.e. a timeout error will
- never be raised. If not provided, the default
- value set during initialization or with
- ``_set_rpc_timeout`` is used.
- Returns:
- A user :class:`~torch.distributed.rpc.RRef` instance to the result
- value. Use the blocking API :meth:`torch.distributed.rpc.RRef.to_here`
- to retrieve the result value locally.
- .. warning ::
- The ``remote`` API does not copy storages of argument tensors until
- sending them over the wire, which could be done by a different thread
- depending on the RPC backend type. The caller should make sure that the
- contents of those tensors stay intact until the returned RRef is
- confirmed by the owner, which can be checked using the
- :meth:`torch.distributed.rpc.RRef.confirmed_by_owner` API.
- .. warning ::
- Errors such as timeouts for the ``remote`` API are handled on a
- best-effort basis. This means that when remote calls initiated by
- ``remote`` fail, such as with a timeout error, we take a best-effort
- approach to error handling. This means that errors are handled and set
- on the resulting RRef on an asynchronous basis. If the RRef has not been
- used by the application before this handling (such as ``to_here`` or
- fork call), then future uses of the ``RRef`` will appropriately raise
- errors. However, it is possible that the user application will use the
- ``RRef`` before the errors are handled. In this case, errors may not be
- raised as they have not yet been handled.
- Example::
- Make sure that ``MASTER_ADDR`` and ``MASTER_PORT`` are set properly
- on both workers. Refer to :meth:`~torch.distributed.init_process_group`
- API for more details. For example,
- export MASTER_ADDR=localhost
- export MASTER_PORT=5678
- Then run the following code in two different processes:
- >>> # xdoctest: +SKIP
- >>> # On worker 0:
- >>> import torch
- >>> import torch.distributed.rpc as rpc
- >>> rpc.init_rpc("worker0", rank=0, world_size=2)
- >>> rref1 = rpc.remote("worker1", torch.add, args=(torch.ones(2), 3))
- >>> rref2 = rpc.remote("worker1", torch.add, args=(torch.ones(2), 1))
- >>> x = rref1.to_here() + rref2.to_here()
- >>> rpc.shutdown()
- >>> # On worker 1:
- >>> import torch.distributed.rpc as rpc
- >>> rpc.init_rpc("worker1", rank=1, world_size=2)
- >>> rpc.shutdown()
- Below is an example of running a TorchScript function using RPC.
- >>> # On both workers:
- >>> @torch.jit.script
- >>> def my_script_add(t1, t2):
- >>> return torch.add(t1, t2)
- >>> # On worker 0:
- >>> import torch.distributed.rpc as rpc
- >>> rpc.init_rpc("worker0", rank=0, world_size=2)
- >>> rref = rpc.remote("worker1", my_script_add, args=(torch.ones(2), 3))
- >>> rref.to_here()
- >>> rpc.shutdown()
- >>> # On worker 1:
- >>> import torch.distributed.rpc as rpc
- >>> rpc.init_rpc("worker1", rank=1, world_size=2)
- >>> rpc.shutdown()
- """
- torch._C._log_api_usage_once("torch.distributed.rpc_remote")
- qualified_name = torch.jit._builtins._find_builtin(func)
- dst_worker_info = _to_worker_info(to)
- should_profile = _get_should_profile()
- ctx_manager = _enable_rpc_profiler(should_profile, qualified_name, func, RPCExecMode.REMOTE, dst_worker_info)
- with ctx_manager as rf:
- args = args if args else ()
- kwargs = kwargs if kwargs else {}
- is_async_exec = hasattr(func, "_wrapped_async_rpc_function")
- if is_async_exec:
- wrapped = func._wrapped_async_rpc_function
- if isinstance(wrapped, torch.jit.ScriptFunction):
- func = wrapped
- if qualified_name is not None:
- rref = _invoke_remote_builtin(dst_worker_info, qualified_name, timeout, *args, **kwargs)
- elif isinstance(func, torch.jit.ScriptFunction):
- rref = _invoke_remote_torchscript(
- dst_worker_info.name,
- torch._jit_internal._qualified_name(func),
- timeout,
- is_async_exec,
- *args,
- **kwargs,
- )
- else:
- (pickled_python_udf, tensors) = _default_pickler.serialize(
- PythonUDF(func, args, kwargs)
- )
- rref = _invoke_remote_python_udf(
- dst_worker_info,
- pickled_python_udf,
- tensors,
- timeout,
- is_async_exec
- )
- # attach profiling information
- if should_profile:
- assert torch.autograd._profiler_enabled()
- assert rf is not None
- fut = rf._call_end_callbacks_on_future(rref._get_future())
- rref._set_profiling_future(fut)
- return rref
- def _invoke_rpc(to, func, rpc_type, args=None, kwargs=None, rpc_timeout=UNSET_RPC_TIMEOUT):
- if not callable(func):
- raise TypeError("function should be callable.")
- qualified_name = torch.jit._builtins._find_builtin(func)
- dst_worker_info = _to_worker_info(to)
- should_profile = _get_should_profile()
- ctx_manager = _enable_rpc_profiler(should_profile, qualified_name, func, rpc_type, dst_worker_info)
- with ctx_manager as rf:
- args = args if args else ()
- kwargs = kwargs if kwargs else {}
- is_async_exec = hasattr(func, "_wrapped_async_rpc_function")
- if is_async_exec:
- wrapped = func._wrapped_async_rpc_function
- if isinstance(wrapped, torch.jit.ScriptFunction):
- func = wrapped
- if qualified_name is not None:
- fut = _invoke_rpc_builtin(
- dst_worker_info,
- qualified_name,
- rpc_timeout,
- *args,
- **kwargs
- )
- elif isinstance(func, torch.jit.ScriptFunction):
- fut = _invoke_rpc_torchscript(
- dst_worker_info.name,
- torch._jit_internal._qualified_name(func),
- args,
- kwargs,
- rpc_timeout,
- is_async_exec
- )
- else:
- (pickled_python_udf, tensors) = _default_pickler.serialize(
- PythonUDF(func, args, kwargs)
- )
- fut = _invoke_rpc_python_udf(
- dst_worker_info,
- pickled_python_udf,
- tensors,
- rpc_timeout,
- is_async_exec
- )
- if should_profile:
- assert torch.autograd._profiler_enabled()
- assert rf is not None
- # Schedule profiling callbacks to run when the future completes.
- # This returns a future that is completed when the original future
- # completes and the profiling callbacks have been completed as well,
- # to guarantee that fut.wait() completes the profiling. This new
- # future will contain the same value as the original future.
- fut = rf._call_end_callbacks_on_future(fut)
- return fut
- @_require_initialized
- def rpc_sync(to, func, args=None, kwargs=None, timeout=UNSET_RPC_TIMEOUT):
- r"""
- Make a blocking RPC call to run function ``func`` on worker ``to``. RPC
- messages are sent and received in parallel to execution of Python code. This
- method is thread-safe.
- Args:
- to (str or WorkerInfo or int): name/rank/``WorkerInfo`` of the destination worker.
- func (Callable): a callable function, such as Python callables, builtin
- operators (e.g. :meth:`~torch.add`) and annotated
- TorchScript functions.
- args (tuple): the argument tuple for the ``func`` invocation.
- kwargs (dict): is a dictionary of keyword arguments for the ``func``
- invocation.
- timeout (float, optional): timeout in seconds to use for this RPC. If
- the RPC does not complete in this amount of
- time, an exception indicating it has
- timed out will be raised. A value of 0
- indicates an infinite timeout, i.e. a timeout
- error will never be raised. If not provided,
- the default value set during initialization
- or with ``_set_rpc_timeout`` is used.
- Returns:
- Returns the result of running ``func`` with ``args`` and ``kwargs``.
- Example::
- Make sure that ``MASTER_ADDR`` and ``MASTER_PORT`` are set properly
- on both workers. Refer to :meth:`~torch.distributed.init_process_group`
- API for more details. For example,
- export MASTER_ADDR=localhost
- export MASTER_PORT=5678
- Then run the following code in two different processes:
- >>> # xdoctest: +SKIP
- >>> # On worker 0:
- >>> import torch
- >>> import torch.distributed.rpc as rpc
- >>> rpc.init_rpc("worker0", rank=0, world_size=2)
- >>> ret = rpc.rpc_sync("worker1", torch.add, args=(torch.ones(2), 3))
- >>> rpc.shutdown()
- >>> # On worker 1:
- >>> import torch.distributed.rpc as rpc
- >>> rpc.init_rpc("worker1", rank=1, world_size=2)
- >>> rpc.shutdown()
- Below is an example of running a TorchScript function using RPC.
- >>> # On both workers:
- >>> @torch.jit.script
- >>> def my_script_add(t1, t2):
- >>> return torch.add(t1, t2)
- >>> # On worker 0:
- >>> import torch.distributed.rpc as rpc
- >>> rpc.init_rpc("worker0", rank=0, world_size=2)
- >>> ret = rpc.rpc_sync("worker1", my_script_add, args=(torch.ones(2), 3))
- >>> rpc.shutdown()
- >>> # On worker 1:
- >>> import torch.distributed.rpc as rpc
- >>> rpc.init_rpc("worker1", rank=1, world_size=2)
- >>> rpc.shutdown()
- """
- torch._C._log_api_usage_once("torch.distributed.rpc_sync")
- fut = _invoke_rpc(to, func, RPCExecMode.SYNC, args, kwargs, timeout)
- return fut.wait()
- @_require_initialized
- def rpc_async(to, func, args=None, kwargs=None, timeout=UNSET_RPC_TIMEOUT):
- r"""
- Make a non-blocking RPC call to run function ``func`` on worker ``to``. RPC
- messages are sent and received in parallel to execution of Python code. This
- method is thread-safe. This method will immediately return a
- :class:`~torch.futures.Future` that can be awaited on.
- Args:
- to (str or WorkerInfo or int): name/rank/``WorkerInfo`` of the destination worker.
- func (Callable): a callable function, such as Python callables, builtin
- operators (e.g. :meth:`~torch.add`) and annotated
- TorchScript functions.
- args (tuple): the argument tuple for the ``func`` invocation.
- kwargs (dict): is a dictionary of keyword arguments for the ``func``
- invocation.
- timeout (float, optional): timeout in seconds to use for this RPC. If
- the RPC does not complete in this amount of
- time, an exception indicating it has
- timed out will be raised. A value of 0
- indicates an infinite timeout, i.e. a timeout
- error will never be raised. If not provided,
- the default value set during initialization
- or with ``_set_rpc_timeout`` is used.
- Returns:
- Returns a :class:`~torch.futures.Future` object that can be waited
- on. When completed, the return value of ``func`` on ``args`` and
- ``kwargs`` can be retrieved from the :class:`~torch.futures.Future`
- object.
- .. warning ::
- Using GPU tensors as arguments or return values of ``func`` is not
- supported since we don't support sending GPU tensors over the wire. You
- need to explicitly copy GPU tensors to CPU before using them as
- arguments or return values of ``func``.
- .. warning ::
- The ``rpc_async`` API does not copy storages of argument tensors until
- sending them over the wire, which could be done by a different thread
- depending on the RPC backend type. The caller should make sure that the
- contents of those tensors stay intact until the returned
- :class:`~torch.futures.Future` completes.
- Example::
- Make sure that ``MASTER_ADDR`` and ``MASTER_PORT`` are set properly
- on both workers. Refer to :meth:`~torch.distributed.init_process_group`
- API for more details. For example,
- export MASTER_ADDR=localhost
- export MASTER_PORT=5678
- Then run the following code in two different processes:
- >>> # xdoctest: +SKIP
- >>> # On worker 0:
- >>> import torch
- >>> import torch.distributed.rpc as rpc
- >>> rpc.init_rpc("worker0", rank=0, world_size=2)
- >>> fut1 = rpc.rpc_async("worker1", torch.add, args=(torch.ones(2), 3))
- >>> fut2 = rpc.rpc_async("worker1", min, args=(1, 2))
- >>> result = fut1.wait() + fut2.wait()
- >>> rpc.shutdown()
- >>> # On worker 1:
- >>> import torch.distributed.rpc as rpc
- >>> rpc.init_rpc("worker1", rank=1, world_size=2)
- >>> rpc.shutdown()
- Below is an example of running a TorchScript function using RPC.
- >>> # On both workers:
- >>> @torch.jit.script
- >>> def my_script_add(t1, t2):
- >>> return torch.add(t1, t2)
- >>> # On worker 0:
- >>> import torch.distributed.rpc as rpc
- >>> rpc.init_rpc("worker0", rank=0, world_size=2)
- >>> fut = rpc.rpc_async("worker1", my_script_add, args=(torch.ones(2), 3))
- >>> ret = fut.wait()
- >>> rpc.shutdown()
- >>> # On worker 1:
- >>> import torch.distributed.rpc as rpc
- >>> rpc.init_rpc("worker1", rank=1, world_size=2)
- >>> rpc.shutdown()
- """
- torch._C._log_api_usage_once("torch.distributed.rpc_async")
- fut = _invoke_rpc(to, func, RPCExecMode.ASYNC, args, kwargs, timeout)
- if hasattr(_thread_local_var, "future_list"):
- _thread_local_var.future_list.append(fut)
- return fut
- def _get_should_profile():
- # Legacy profiler should be enabled. RPC profiling is not supported with
- # Kineto profiler.
- ActiveProfilerType = torch._C._profiler.ActiveProfilerType
- return (
- torch.autograd._profiler_enabled() and
- torch._C._autograd._profiler_type() == ActiveProfilerType.LEGACY # type: ignore[attr-defined]
- )
- def _enable_rpc_profiler(should_profile, qualified_name, func, rpc_type, dst_worker_info):
- ctx_manager = contextlib.suppress()
- if should_profile:
- # Create appropriate string representation based on type of func
- # (builtin, script, python)
- if qualified_name is None:
- func_name = (
- torch._jit_internal._qualified_name(func)
- if isinstance(func, torch.jit.ScriptFunction)
- else func.__qualname__
- )
- else:
- func_name = qualified_name
- # Build RPC profiling key.
- rpc_profiling_key = _build_rpc_profiling_key(
- rpc_type,
- func_name,
- get_worker_info().name,
- dst_worker_info.name,
- )
- RemoteProfilerManager.set_current_profiling_key(rpc_profiling_key)
- # Mypy doesn't support re-def of a variable not in the same block (#1174)
- ctx_manager = torch.autograd.profiler.record_function(rpc_profiling_key) # type: ignore[assignment]
- return ctx_manager
|