123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264 |
- # Copyright (c) Facebook, Inc. and its affiliates.
- # All rights reserved.
- #
- # This source code is licensed under the BSD-style license found in the
- # LICENSE file in the root directory of this source tree.
- import binascii
- import logging
- import os
- import tempfile
- from base64 import b64decode, b64encode
- from datetime import timedelta
- from typing import Any, Optional, Tuple, cast
- from torch.distributed import FileStore, Store, TCPStore
- from torch.distributed.elastic.events import (
- NodeState,
- construct_and_record_rdzv_event,
- )
- from .api import (
- RendezvousConnectionError,
- RendezvousError,
- RendezvousParameters,
- RendezvousStateError,
- )
- from .dynamic_rendezvous import RendezvousBackend, Token
- from .utils import _matches_machine_hostname, parse_rendezvous_endpoint
- log = logging.getLogger(__name__)
- class C10dRendezvousBackend(RendezvousBackend):
- """Represents a C10d-backed rendezvous backend.
- Args:
- store:
- The :py:class:`torch.distributed.Store` instance to use to
- communicate with the C10d store.
- run_id:
- The run id of the rendezvous.
- """
- # See the explanation in the __init__ method.
- _NULL_SENTINEL = "Y2FuaW1hZGFt"
- _store: Store
- _key: str
- def __init__(self, store: Store, run_id: str) -> None:
- if not run_id:
- raise ValueError("The run id must be a non-empty string.")
- self._store = store
- self._key = "torch.rendezvous." + run_id
- # The read operation of a store blocks the caller until the specified
- # key becomes available. This behavior makes it tricky to use a store
- # as a regular key-value dictionary.
- #
- # As a workaround we initially set a sentinel value as the rendezvous
- # state. Whenever this value gets returned we treat it as a None.
- self._call_store("compare_set", self._key, "", self._NULL_SENTINEL)
- @property
- def name(self) -> str:
- """See base class."""
- return "c10d"
- def get_state(self) -> Optional[Tuple[bytes, Token]]:
- """See base class."""
- base64_state: bytes = self._call_store("get", self._key)
- return self._decode_state(base64_state)
- def set_state(
- self, state: bytes, token: Optional[Token] = None
- ) -> Optional[Tuple[bytes, Token, bool]]:
- """See base class."""
- base64_state_str: str = b64encode(state).decode()
- if token:
- # Shortcut if we know for sure that the token is not valid.
- if not isinstance(token, bytes):
- result = self.get_state()
- if result is not None:
- tmp = *result, False
- # Python 3.6 does not support tuple unpacking in return
- # statements.
- return tmp
- return None
- token = token.decode()
- else:
- token = self._NULL_SENTINEL
- base64_state: bytes = self._call_store("compare_set", self._key, token, base64_state_str)
- state_token_pair = self._decode_state(base64_state)
- if state_token_pair is None:
- return None
- new_state, new_token = state_token_pair
- # C10d Store's compare_set method does not offer an easy way to find out
- # whether our write attempt was successful. As a brute-force solution we
- # perform a bitwise comparison of our local state and the remote state.
- return new_state, new_token, new_state == state
- def _call_store(self, store_op: str, *args, **kwargs) -> Any:
- try:
- return getattr(self._store, store_op)(*args, **kwargs)
- except (ValueError, RuntimeError, TimeoutError) as exc:
- raise RendezvousConnectionError(
- "The connection to the C10d store has failed. See inner exception for details."
- ) from exc
- def _decode_state(self, base64_state: bytes) -> Optional[Tuple[bytes, Token]]:
- if base64_state == self._NULL_SENTINEL.encode():
- return None
- try:
- state = b64decode(base64_state)
- except binascii.Error as exc:
- raise RendezvousStateError(
- "The state object is corrupt. See inner exception for details."
- ) from exc
- return state, base64_state
- def _create_tcp_store(params: RendezvousParameters) -> TCPStore:
- host, port = parse_rendezvous_endpoint(params.endpoint, default_port=29400)
- cfg_is_host = params.get_as_bool("is_host")
- # If the user has explicitly specified whether our process should host the
- # the store, respect it.
- if cfg_is_host is not None:
- is_host = cfg_is_host
- # Otherwise try to determine whether we are the host based on our hostname
- # and IP address.
- else:
- is_host = _matches_machine_hostname(host)
- # The timeout
- read_timeout = cast(int, params.get_as_int("read_timeout", 60))
- if read_timeout <= 0:
- raise ValueError("The read timeout must be a positive integer.")
- # In specific cases we attempt to instantiate the store twice. For details
- # see the explanation in the except clause below.
- for is_server in [is_host, False]:
- try:
- store = TCPStore(
- host, port, is_master=is_server, timeout=timedelta(seconds=read_timeout)
- )
- if is_server:
- msg = f"Process {os.getpid()} hosts the TCP store for the C10d rendezvous backend."
- construct_and_record_rdzv_event(
- run_id=params.run_id, message=msg, node_state=NodeState.INIT
- )
- log.info(msg)
- break
- except (ValueError, RuntimeError, TimeoutError) as exc:
- # If we heuristically inferred the value of is_host as True and our
- # first attempt to instantiate the TCP store has failed, try it one
- # more time with is_host set to False. As an edge case there can be
- # more than one process that is part of the same rendezvous on this
- # machine and only one of them will eventually host the store.
- if not is_server or cfg_is_host is not None:
- raise RendezvousConnectionError(
- "The connection to the C10d store has failed. See inner exception for details."
- ) from exc
- return store
- def _create_file_store(params: RendezvousParameters) -> FileStore:
- # If a user specifies an endpoint, we treat it as a path to a file.
- if params.endpoint:
- path = params.endpoint
- else:
- try:
- # The temporary file is readable and writable only by the user of
- # this process.
- _, path = tempfile.mkstemp()
- except OSError as exc:
- raise RendezvousError(
- "The file creation for C10d store has failed. See inner exception for details."
- ) from exc
- try:
- store = FileStore(path)
- except (ValueError, RuntimeError) as exc:
- raise RendezvousConnectionError(
- "The connection to the C10d store has failed. See inner exception for details."
- ) from exc
- return store
- def create_backend(params: RendezvousParameters) -> Tuple[C10dRendezvousBackend, Store]:
- """Creates a new :py:class:`C10dRendezvousBackend` from the specified
- parameters.
- +--------------+-----------------------------------------------------------+
- | Parameter | Description |
- +==============+===========================================================+
- | store_type | The type of the C10d store. The currently supported types |
- | | are "tcp" and "file" which correspond to |
- | | :py:class:`torch.distributed.TCPStore` and |
- | | :py:class:`torch.distributed.FileStore`, respectively. |
- | | Defaults to "tcp". |
- +--------------+-----------------------------------------------------------+
- | read_timeout | The read timeout, in seconds, for store operations. |
- | | Defaults to 60 seconds. |
- | | |
- | | Note this only applies to |
- | | :py:class:`torch.distributed.TCPStore`. It is not relevant|
- | | to :py:class:`torch.distributed.FileStore` which does not |
- | | take in timeout as a parameter. |
- +--------------+-----------------------------------------------------------+
- | is_host | A boolean value indicating whether this backend instance |
- | | will host the C10d store. If not specified it will be |
- | | inferred heuristically by matching the hostname or the IP |
- | | address of this machine against the specified rendezvous |
- | | endpoint. Defaults to ``None``. |
- | | |
- | | Note that this configuration option only applies to |
- | | :py:class:`torch.distributed.TCPStore`. In normal |
- | | circumstances you can safely skip it; the only time when |
- | | it is needed is if its value cannot be correctly |
- | | determined (e.g. the rendezvous endpoint has a CNAME as |
- | | the hostname or does not match the FQDN of the machine). |
- +--------------+-----------------------------------------------------------+
- """
- # As of today we only support TCPStore and FileStore. Other store types do
- # not have the required functionality (e.g. compare_set) yet.
- store_type = params.get("store_type", "tcp").strip().lower()
- store: Store
- try:
- if store_type == "file":
- store = _create_file_store(params)
- elif store_type == "tcp":
- store = _create_tcp_store(params)
- else:
- raise ValueError("Invalid store type given. Currently only supports file and tcp.")
- backend = C10dRendezvousBackend(store, params.run_id)
- except Exception as e:
- construct_and_record_rdzv_event(
- message=f"{type(e).__name__}: {str(e)}",
- run_id=params.run_id,
- node_state=NodeState.FAILED,
- )
- raise
- return backend, store
|