123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172 |
- from typing import Dict, List, Optional, Union
- import torch
- from torch._C._distributed_rpc import _TensorPipeRpcBackendOptionsBase
- from . import constants as rpc_contants
- DeviceType = Union[int, str, torch.device]
- __all__ = ["TensorPipeRpcBackendOptions"]
- def _to_device(device: DeviceType) -> torch.device:
- device = torch.device(device)
- if device.type != "cuda":
- raise ValueError(
- "`set_devices` expect a list of CUDA devices, but got "
- f"device type {device.type}."
- )
- return device
- def _to_device_map(
- device_map: Dict[DeviceType, DeviceType]
- ) -> Dict[torch.device, torch.device]:
- full_device_map: Dict[torch.device, torch.device] = {}
- reverse_map: Dict[torch.device, torch.device] = {}
- for k, v in device_map.items():
- k, v = torch.device(k), torch.device(v)
- if v in reverse_map:
- raise ValueError(
- "`device_map` only supports 1-to-1 mapping, "
- f"trying to map {k} and {reverse_map[v]} to {v}"
- )
- full_device_map[k] = v
- reverse_map[v] = k
- return full_device_map
- def _to_device_list(devices: List[DeviceType]) -> List[torch.device]:
- return list(map(_to_device, devices))
- class TensorPipeRpcBackendOptions(_TensorPipeRpcBackendOptionsBase):
- r"""
- The backend options for
- :class:`~torch.distributed.rpc.TensorPipeAgent`, derived from
- :class:`~torch.distributed.rpc.RpcBackendOptions`.
- Args:
- num_worker_threads (int, optional): The number of threads in the
- thread-pool used by
- :class:`~torch.distributed.rpc.TensorPipeAgent` to execute
- requests (default: 16).
- rpc_timeout (float, optional): The default timeout, in seconds,
- for RPC requests (default: 60 seconds). If the RPC has not
- completed in this timeframe, an exception indicating so will
- be raised. Callers can override this timeout for individual
- RPCs in :meth:`~torch.distributed.rpc.rpc_sync` and
- :meth:`~torch.distributed.rpc.rpc_async` if necessary.
- init_method (str, optional): The URL to initialize the distributed
- store used for rendezvous. It takes any value accepted for the
- same argument of :meth:`~torch.distributed.init_process_group`
- (default: ``env://``).
- device_maps (Dict[str, Dict], optional): Device placement mappings from
- this worker to the callee. Key is the callee worker name and value
- the dictionary (``Dict`` of ``int``, ``str``, or ``torch.device``)
- that maps this worker's devices to the callee worker's devices.
- (default: ``None``)
- devices (List[int, str, or ``torch.device``], optional): all local
- CUDA devices used by RPC agent. By Default, it will be initialized
- to all local devices from its own ``device_maps`` and corresponding
- devices from its peers' ``device_maps``. When processing CUDA RPC
- requests, the agent will properly synchronize CUDA streams for
- all devices in this ``List``.
- """
- def __init__(
- self,
- *,
- num_worker_threads: int = rpc_contants.DEFAULT_NUM_WORKER_THREADS,
- rpc_timeout: float = rpc_contants.DEFAULT_RPC_TIMEOUT_SEC,
- init_method: str = rpc_contants.DEFAULT_INIT_METHOD,
- device_maps: Optional[Dict[str, Dict[DeviceType, DeviceType]]] = None,
- devices: Optional[List[DeviceType]] = None,
- _transports: Optional[List] = None,
- _channels: Optional[List] = None,
- ):
- full_device_maps = (
- {}
- if device_maps is None
- else {k: _to_device_map(v) for k, v in device_maps.items()}
- )
- full_device_list = [] if devices is None else _to_device_list(devices)
- super().__init__(
- num_worker_threads,
- _transports,
- _channels,
- rpc_timeout,
- init_method,
- full_device_maps,
- full_device_list,
- )
- def set_device_map(self, to: str, device_map: Dict[DeviceType, DeviceType]):
- r"""
- Set device mapping between each RPC caller and callee pair. This
- function can be called multiple times to incrementally add
- device placement configurations.
- Args:
- to (str): Callee name.
- device_map (Dict of int, str, or torch.device): Device placement
- mappings from this worker to the callee. This map must be
- invertible.
- Example:
- >>> # xdoctest: +SKIP("distributed")
- >>> # both workers
- >>> def add(x, y):
- >>> print(x) # tensor([1., 1.], device='cuda:1')
- >>> return x + y, (x + y).to(2)
- >>>
- >>> # on worker 0
- >>> options = TensorPipeRpcBackendOptions(
- >>> num_worker_threads=8,
- >>> device_maps={"worker1": {0: 1}}
- >>> # maps worker0's cuda:0 to worker1's cuda:1
- >>> )
- >>> options.set_device_map("worker1", {1: 2})
- >>> # maps worker0's cuda:1 to worker1's cuda:2
- >>>
- >>> rpc.init_rpc(
- >>> "worker0",
- >>> rank=0,
- >>> world_size=2,
- >>> backend=rpc.BackendType.TENSORPIPE,
- >>> rpc_backend_options=options
- >>> )
- >>>
- >>> x = torch.ones(2)
- >>> rets = rpc.rpc_sync("worker1", add, args=(x.to(0), 1))
- >>> # The first argument will be moved to cuda:1 on worker1. When
- >>> # sending the return value back, it will follow the invert of
- >>> # the device map, and hence will be moved back to cuda:0 and
- >>> # cuda:1 on worker0
- >>> print(rets[0]) # tensor([2., 2.], device='cuda:0')
- >>> print(rets[1]) # tensor([2., 2.], device='cuda:1')
- """
- full_device_map = _to_device_map(device_map)
- curr_device_maps = super().device_maps
- if to in curr_device_maps:
- for k, v in full_device_map.items():
- if k in curr_device_maps[to] and v != curr_device_maps[to][k]:
- raise ValueError(
- "`set_device_map` only supports 1-to-1 mapping, trying"
- f" to map {k} to {v} and {curr_device_maps[to][k]}"
- )
- super()._set_device_map(to, full_device_map)
- def set_devices(self, devices: List[DeviceType]):
- r"""
- Set local devices used by the TensorPipe RPC agent. When processing
- CUDA RPC requests, the TensorPipe RPC agent will properly synchronize
- CUDA streams for all devices in this ``List``.
- Args:
- devices (List of int, str, or torch.device): local devices used by
- the TensorPipe RPC agent.
- """
- self.devices = _to_device_list(devices)
|