123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399 |
- __all__ = ["init_backend", "backend_registered", "construct_rpc_backend_options", "register_backend", "BackendType", "BackendValue"]
- import collections
- import enum
- from typing import cast, Dict, List, Set, Tuple
- import torch
- import torch.distributed as dist
- from ._utils import _group_membership_management, _update_group_membership
- from . import api
- from . import constants as rpc_constants
- __all__ = ["backend_registered", "register_backend", "construct_rpc_backend_options", "init_backend",
- "BackendValue", "BackendType"]
- BackendValue = collections.namedtuple(
- "BackendValue", ["construct_rpc_backend_options_handler", "init_backend_handler"]
- )
- def _backend_type_repr(self):
- return "BackendType." + self.name
- _backend_type_doc = """
- An enum class of available backends.
- PyTorch ships with a builtin ``BackendType.TENSORPIPE`` backend.
- Additional ones can be registered using the
- :func:`~torch.distributed.rpc.backend_registry.register_backend` function.
- """
- # Create an enum type, `BackendType`, with empty members.
- # Can't handle Function Enum API (mypy bug #9079)
- BackendType = enum.Enum(value="BackendType", names=dict()) # type: ignore[misc]
- # Unable to assign a function a method (mypy bug #2427)
- BackendType.__repr__ = _backend_type_repr # type: ignore[assignment]
- if BackendType.__doc__:
- BackendType.__doc__ = _backend_type_doc
- def backend_registered(backend_name):
- """
- Checks if backend_name is registered as an RPC backend.
- Args:
- backend_name (str): string to identify the RPC backend.
- Returns:
- True if the backend has been registered with ``register_backend``, else
- False.
- """
- return backend_name in BackendType.__members__.keys()
- def register_backend(
- backend_name, construct_rpc_backend_options_handler, init_backend_handler
- ):
- """Registers a new RPC backend.
- Args:
- backend_name (str): backend string to identify the handler.
- construct_rpc_backend_options_handler (function):
- Handler that is invoked when
- rpc_backend.construct_rpc_backend_options(**dict) is called.
- init_backend_handler (function): Handler that is invoked when the
- `_init_rpc_backend()` function is called with a backend.
- This returns the agent.
- """
- global BackendType
- if backend_registered(backend_name):
- raise RuntimeError("RPC backend {}: already registered".format(backend_name))
- # Create a new enum type, `BackendType`, with extended members.
- existing_enum_dict = {member.name: member.value for member in BackendType}
- extended_enum_dict = dict(
- {
- backend_name: BackendValue(
- construct_rpc_backend_options_handler=construct_rpc_backend_options_handler,
- init_backend_handler=init_backend_handler,
- )
- },
- **existing_enum_dict
- )
- # Can't handle Function Enum API (mypy bug #9079)
- BackendType = enum.Enum(value="BackendType", names=extended_enum_dict) # type: ignore[misc]
- # Unable to assign a function a method (mypy bug #2427)
- BackendType.__repr__ = _backend_type_repr # type: ignore[assignment]
- if BackendType.__doc__:
- BackendType.__doc__ = _backend_type_doc
- return BackendType[backend_name]
- def construct_rpc_backend_options(
- backend,
- rpc_timeout=rpc_constants.DEFAULT_RPC_TIMEOUT_SEC,
- init_method=rpc_constants.DEFAULT_INIT_METHOD,
- **kwargs
- ):
- return backend.value.construct_rpc_backend_options_handler(
- rpc_timeout, init_method, **kwargs
- )
- def init_backend(backend, *args, **kwargs):
- return backend.value.init_backend_handler(*args, **kwargs)
- def _init_process_group(store, rank, world_size):
- # Initialize ProcessGroup.
- process_group_timeout = rpc_constants.DEFAULT_PROCESS_GROUP_TIMEOUT
- # We're using a bunch of private APIs here since `new_group` requires the
- # default group to be initialized.
- group = dist.ProcessGroupGloo(store, rank, world_size, process_group_timeout)
- assert group is not None, "Failed to initialize default ProcessGroup."
- if (rank != -1) and (rank != group.rank()):
- raise RuntimeError(
- "rank argument {} doesn't match pg rank {}".format(rank, group.rank())
- )
- if (world_size != -1) and (world_size != group.size()):
- raise RuntimeError(
- "world_size argument {} doesn't match pg size {}".format(
- world_size, group.size()
- )
- )
- return group
- def _tensorpipe_construct_rpc_backend_options_handler(
- rpc_timeout,
- init_method,
- num_worker_threads=rpc_constants.DEFAULT_NUM_WORKER_THREADS,
- _transports=None,
- _channels=None,
- **kwargs
- ):
- from . import TensorPipeRpcBackendOptions
- return TensorPipeRpcBackendOptions(
- rpc_timeout=rpc_timeout,
- init_method=init_method,
- num_worker_threads=num_worker_threads,
- _transports=_transports,
- _channels=_channels,
- )
- def _tensorpipe_validate_devices(devices, device_count):
- return all(
- d.type == "cpu" or (d.type == "cuda" and 0 <= d.index < device_count)
- for d in devices
- )
- # detect if any worker has invalid device_map configurations, and return
- # reverse device maps
- def _tensorpipe_exchange_and_check_all_device_maps(
- my_name, my_device_count, my_device_maps, my_devices, group
- ):
- gathered: List[Tuple[
- str, int, Dict[str, Dict[torch.device, torch.device]], List[torch.device]
- ]] = [("", 0, {}, []) for _ in range(group.size())]
- dist.all_gather_object(
- gathered, (my_name, my_device_count, my_device_maps, my_devices), group
- )
- all_names = [name for name, _, _, _ in gathered]
- all_device_counts = {name: count for name, count, _, _ in gathered}
- all_device_maps = {name: map_ for name, _, map_, _ in gathered}
- all_devices = {name: devices for name, _, _, devices in gathered}
- _validate_device_maps(all_names, all_device_counts, all_device_maps, all_devices)
- # passed all checked, construct reverse mapping and get list of devices handled by this agent
- reverse_device_maps = _create_reverse_mapping(my_name, all_names, all_device_maps)
- my_devices = _create_device_list(my_devices, my_device_maps, reverse_device_maps)
- return reverse_device_maps, my_devices
- def _validate_device_maps(all_names, all_device_counts, all_device_maps, all_devices, is_static_group=True):
- for node in all_names:
- devices = all_devices[node]
- if len(set(devices)) != len(devices):
- raise ValueError(
- f"Node {node} has duplicated devices\n"
- f"devices = {devices}"
- )
- if not _tensorpipe_validate_devices(devices, all_device_counts[node]):
- raise ValueError(
- f"Node {node} has devices with invalid indices\n"
- f"devices = {devices}\n"
- f"device count = {all_device_counts[node]}"
- )
- for source_node in all_names:
- # For dynamic group (non-static) do not check the target node name since it may not have joined yet
- if is_static_group and not set(all_device_maps[source_node].keys()).issubset(all_names):
- raise ValueError(
- f"Node {source_node} has invalid target node names in its device maps\n"
- f"device maps = {all_device_maps[source_node].keys()}\n"
- f"node names = {all_names}"
- )
- for target_node, map_ in all_device_maps[source_node].items():
- if len(set(map_.values())) != len(map_):
- raise ValueError(
- f"Node {source_node} has duplicated target devices "
- f"in its device map for {target_node}\n"
- f"device map = {map_}"
- )
- if all_devices[source_node]:
- if not set(map_.keys()).issubset(all_devices[source_node]):
- raise ValueError(
- f"Node {source_node} has unexpected source devices "
- f"in its device map for {target_node}\n"
- f"device map = {map_}\n"
- f"devices = {all_devices[source_node]}"
- )
- elif not _tensorpipe_validate_devices(
- map_.keys(), all_device_counts[source_node]
- ):
- raise ValueError(
- f"Node {source_node} has source devices with invalid indices "
- f"in its device map for {target_node}\n"
- f"device map = {map_}\n"
- f"device count = {all_device_counts[source_node]}"
- )
- if all_devices.get(target_node, []):
- if not set(map_.values()).issubset(all_devices[target_node]):
- raise ValueError(
- f"Node {source_node} has unexpected target devices "
- f"in its device map for {target_node}\n"
- f"device map = {map_}\n"
- f"devices = {all_devices[target_node]}"
- )
- elif target_node in all_device_counts and not _tensorpipe_validate_devices(
- map_.values(), all_device_counts[target_node]
- ):
- raise ValueError(
- f"Node {source_node} has target devices with invalid indices "
- f"in its device map for {target_node}\n"
- f"device map = {map_}\n"
- f"device count = {all_device_counts[target_node]}"
- )
- def _create_device_list(my_devices, my_device_maps, reverse_device_maps):
- if not my_devices:
- devices_set: Set[torch.device] = set()
- for _, map_ in my_device_maps.items():
- devices_set.update(map_.keys())
- for _, map_ in reverse_device_maps.items():
- devices_set.update(map_.keys())
- devices_set.discard(torch.device("cpu"))
- my_devices = list(devices_set)
- my_devices = sorted(my_devices, key=lambda d: d.index)
- return my_devices
- def _create_reverse_mapping(my_name, all_names, all_device_maps):
- reverse_device_maps: Dict[str, Dict[torch.device, torch.device]] = {}
- for node in all_names:
- if my_name in all_device_maps[node]:
- reverse_device_maps[node] = {
- v: k for k, v in all_device_maps[node][my_name].items()
- }
- return reverse_device_maps
- def _get_device_infos():
- from . import TensorPipeAgent
- agent = cast(TensorPipeAgent, api._get_current_rpc_agent())
- opts = agent._get_backend_options()
- device_count = torch.cuda.device_count()
- if torch.cuda.is_available() and opts.devices:
- torch.cuda.init()
- return device_count, opts.device_maps, opts.devices
- def _set_devices_and_reverse_device_map(agent):
- from . import TensorPipeAgent
- agent = cast(TensorPipeAgent, agent)
- # Group state is retrieved from local agent
- # On initialization, tensorpipe agent retrieves information from all existing workers, so group state is valid
- my_worker_info = agent.get_worker_info()
- my_name = my_worker_info.name
- all_worker_infos = agent.get_worker_infos()
- # One round to get device_maps of all workers and construct reverse device maps
- all_device_counts, all_device_maps, all_devices, all_names = {}, {}, {}, []
- for worker_info in all_worker_infos:
- worker_name = worker_info.name
- if worker_name != my_name:
- # TODO: make async?
- device_count, device_map, devices = api.rpc_sync(worker_name, _get_device_infos)
- else:
- opts = agent._get_backend_options()
- device_count, device_map, devices = torch.cuda.device_count(), opts.device_maps, opts.devices
- all_device_counts[worker_name] = device_count
- all_device_maps[worker_name] = device_map
- all_devices[worker_name] = devices
- all_names.append(worker_name)
- _validate_device_maps(all_names, all_device_counts, all_device_maps, all_devices, is_static_group=False)
- reverse_device_maps = _create_reverse_mapping(my_name, all_names, all_device_maps)
- # Perform RPC call to all workers, including itself, to include newly joined worker information and device maps
- for worker_name in all_names:
- # Set device list for each worker
- all_devices[worker_name] = _create_device_list(all_devices[worker_name], all_device_maps[worker_name], reverse_device_maps)
- api.rpc_sync(worker_name, _update_group_membership,
- args=(my_worker_info, all_devices[worker_name], reverse_device_maps, True))
- def _tensorpipe_init_backend_handler(store, name, rank, world_size, rpc_backend_options):
- from . import TensorPipeAgent
- from . import TensorPipeRpcBackendOptions
- if not isinstance(store, dist.Store):
- raise TypeError("`store` must be a c10d::Store. {}".format(store))
- if not isinstance(
- rpc_backend_options, TensorPipeRpcBackendOptions
- ):
- raise TypeError(
- "`rpc_backend_options` must be a `TensorPipeRpcBackendOptions`. {}".format(
- rpc_backend_options
- )
- )
- device_count = torch.cuda.device_count()
- is_static_group = True if world_size else False
- # world_size is specified so this is a static group (ranks cannot join and leave)
- if is_static_group:
- # The agent's join method is required to behave like a barrier and perform
- # collective operations, for which it relies on a process group, instead of
- # re-implementing this on top of RPCs.
- group = _init_process_group(store, rank, world_size)
- reverse_device_maps, devices = _tensorpipe_exchange_and_check_all_device_maps(
- name,
- device_count,
- rpc_backend_options.device_maps,
- rpc_backend_options.devices,
- group,
- )
- if torch.cuda.is_available() and devices:
- # It's necessary to initialize PyTorch CUDA states here (e.g.,
- # CUDACachingAllocator). If this is missing, we could hit errors like
- # "allocator not initialized", because other processes might send
- # CUDA-related RPC request to this process before user code in this
- # process initializes its PyTorch CUDA states.
- torch.cuda.init()
- # TODO: add try-except and destroy _agent in all processes if any fails.
- agent = TensorPipeAgent(
- store,
- name,
- rank,
- world_size,
- rpc_backend_options,
- reverse_device_maps,
- devices,
- )
- api._init_rpc_states(agent)
- # Run one dummy round of RPC to initialize channels/transports. Without
- # this, it's easy to hit timeout in rpc.shutdown() if there is no other RPC
- # on that process before rpc.shutdown(), as the agent initialization can
- # take longer than 5s.
- api._all_gather(None, timeout=rpc_backend_options.rpc_timeout)
- # Need a barrier here to make sure no peers leave before the rank0 finishes
- # _all_gather
- group.barrier().wait()
- return agent
- # initialization for dynamic rpc (ranks can join and leave)
- else:
- with _group_membership_management(store, name, True):
- # Construct TPAgent with empty reverse_device_map and devices
- # these properties will be updated after initialization
- agent = TensorPipeAgent(
- store,
- name,
- rank,
- world_size,
- rpc_backend_options,
- {},
- [],
- )
- api._init_rpc_states(agent)
- try:
- # Notify all workers in group this rank has joined and set devices and reverse_device_map
- # This is a synchronous operation that completes once all existing ranks are updated
- _set_devices_and_reverse_device_map(agent)
- pass
- except Exception:
- api.shutdown()
- raise
- return agent
- register_backend(
- "TENSORPIPE",
- _tensorpipe_construct_rpc_backend_options_handler,
- _tensorpipe_init_backend_handler,
- )
|