c10d_rendezvous_backend.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264
  1. # Copyright (c) Facebook, Inc. and its affiliates.
  2. # All rights reserved.
  3. #
  4. # This source code is licensed under the BSD-style license found in the
  5. # LICENSE file in the root directory of this source tree.
  6. import binascii
  7. import logging
  8. import os
  9. import tempfile
  10. from base64 import b64decode, b64encode
  11. from datetime import timedelta
  12. from typing import Any, Optional, Tuple, cast
  13. from torch.distributed import FileStore, Store, TCPStore
  14. from torch.distributed.elastic.events import (
  15. NodeState,
  16. construct_and_record_rdzv_event,
  17. )
  18. from .api import (
  19. RendezvousConnectionError,
  20. RendezvousError,
  21. RendezvousParameters,
  22. RendezvousStateError,
  23. )
  24. from .dynamic_rendezvous import RendezvousBackend, Token
  25. from .utils import _matches_machine_hostname, parse_rendezvous_endpoint
  26. log = logging.getLogger(__name__)
  27. class C10dRendezvousBackend(RendezvousBackend):
  28. """Represents a C10d-backed rendezvous backend.
  29. Args:
  30. store:
  31. The :py:class:`torch.distributed.Store` instance to use to
  32. communicate with the C10d store.
  33. run_id:
  34. The run id of the rendezvous.
  35. """
  36. # See the explanation in the __init__ method.
  37. _NULL_SENTINEL = "Y2FuaW1hZGFt"
  38. _store: Store
  39. _key: str
  40. def __init__(self, store: Store, run_id: str) -> None:
  41. if not run_id:
  42. raise ValueError("The run id must be a non-empty string.")
  43. self._store = store
  44. self._key = "torch.rendezvous." + run_id
  45. # The read operation of a store blocks the caller until the specified
  46. # key becomes available. This behavior makes it tricky to use a store
  47. # as a regular key-value dictionary.
  48. #
  49. # As a workaround we initially set a sentinel value as the rendezvous
  50. # state. Whenever this value gets returned we treat it as a None.
  51. self._call_store("compare_set", self._key, "", self._NULL_SENTINEL)
  52. @property
  53. def name(self) -> str:
  54. """See base class."""
  55. return "c10d"
  56. def get_state(self) -> Optional[Tuple[bytes, Token]]:
  57. """See base class."""
  58. base64_state: bytes = self._call_store("get", self._key)
  59. return self._decode_state(base64_state)
  60. def set_state(
  61. self, state: bytes, token: Optional[Token] = None
  62. ) -> Optional[Tuple[bytes, Token, bool]]:
  63. """See base class."""
  64. base64_state_str: str = b64encode(state).decode()
  65. if token:
  66. # Shortcut if we know for sure that the token is not valid.
  67. if not isinstance(token, bytes):
  68. result = self.get_state()
  69. if result is not None:
  70. tmp = *result, False
  71. # Python 3.6 does not support tuple unpacking in return
  72. # statements.
  73. return tmp
  74. return None
  75. token = token.decode()
  76. else:
  77. token = self._NULL_SENTINEL
  78. base64_state: bytes = self._call_store("compare_set", self._key, token, base64_state_str)
  79. state_token_pair = self._decode_state(base64_state)
  80. if state_token_pair is None:
  81. return None
  82. new_state, new_token = state_token_pair
  83. # C10d Store's compare_set method does not offer an easy way to find out
  84. # whether our write attempt was successful. As a brute-force solution we
  85. # perform a bitwise comparison of our local state and the remote state.
  86. return new_state, new_token, new_state == state
  87. def _call_store(self, store_op: str, *args, **kwargs) -> Any:
  88. try:
  89. return getattr(self._store, store_op)(*args, **kwargs)
  90. except (ValueError, RuntimeError, TimeoutError) as exc:
  91. raise RendezvousConnectionError(
  92. "The connection to the C10d store has failed. See inner exception for details."
  93. ) from exc
  94. def _decode_state(self, base64_state: bytes) -> Optional[Tuple[bytes, Token]]:
  95. if base64_state == self._NULL_SENTINEL.encode():
  96. return None
  97. try:
  98. state = b64decode(base64_state)
  99. except binascii.Error as exc:
  100. raise RendezvousStateError(
  101. "The state object is corrupt. See inner exception for details."
  102. ) from exc
  103. return state, base64_state
  104. def _create_tcp_store(params: RendezvousParameters) -> TCPStore:
  105. host, port = parse_rendezvous_endpoint(params.endpoint, default_port=29400)
  106. cfg_is_host = params.get_as_bool("is_host")
  107. # If the user has explicitly specified whether our process should host the
  108. # the store, respect it.
  109. if cfg_is_host is not None:
  110. is_host = cfg_is_host
  111. # Otherwise try to determine whether we are the host based on our hostname
  112. # and IP address.
  113. else:
  114. is_host = _matches_machine_hostname(host)
  115. # The timeout
  116. read_timeout = cast(int, params.get_as_int("read_timeout", 60))
  117. if read_timeout <= 0:
  118. raise ValueError("The read timeout must be a positive integer.")
  119. # In specific cases we attempt to instantiate the store twice. For details
  120. # see the explanation in the except clause below.
  121. for is_server in [is_host, False]:
  122. try:
  123. store = TCPStore(
  124. host, port, is_master=is_server, timeout=timedelta(seconds=read_timeout)
  125. )
  126. if is_server:
  127. msg = f"Process {os.getpid()} hosts the TCP store for the C10d rendezvous backend."
  128. construct_and_record_rdzv_event(
  129. run_id=params.run_id, message=msg, node_state=NodeState.INIT
  130. )
  131. log.info(msg)
  132. break
  133. except (ValueError, RuntimeError, TimeoutError) as exc:
  134. # If we heuristically inferred the value of is_host as True and our
  135. # first attempt to instantiate the TCP store has failed, try it one
  136. # more time with is_host set to False. As an edge case there can be
  137. # more than one process that is part of the same rendezvous on this
  138. # machine and only one of them will eventually host the store.
  139. if not is_server or cfg_is_host is not None:
  140. raise RendezvousConnectionError(
  141. "The connection to the C10d store has failed. See inner exception for details."
  142. ) from exc
  143. return store
  144. def _create_file_store(params: RendezvousParameters) -> FileStore:
  145. # If a user specifies an endpoint, we treat it as a path to a file.
  146. if params.endpoint:
  147. path = params.endpoint
  148. else:
  149. try:
  150. # The temporary file is readable and writable only by the user of
  151. # this process.
  152. _, path = tempfile.mkstemp()
  153. except OSError as exc:
  154. raise RendezvousError(
  155. "The file creation for C10d store has failed. See inner exception for details."
  156. ) from exc
  157. try:
  158. store = FileStore(path)
  159. except (ValueError, RuntimeError) as exc:
  160. raise RendezvousConnectionError(
  161. "The connection to the C10d store has failed. See inner exception for details."
  162. ) from exc
  163. return store
  164. def create_backend(params: RendezvousParameters) -> Tuple[C10dRendezvousBackend, Store]:
  165. """Creates a new :py:class:`C10dRendezvousBackend` from the specified
  166. parameters.
  167. +--------------+-----------------------------------------------------------+
  168. | Parameter | Description |
  169. +==============+===========================================================+
  170. | store_type | The type of the C10d store. The currently supported types |
  171. | | are "tcp" and "file" which correspond to |
  172. | | :py:class:`torch.distributed.TCPStore` and |
  173. | | :py:class:`torch.distributed.FileStore`, respectively. |
  174. | | Defaults to "tcp". |
  175. +--------------+-----------------------------------------------------------+
  176. | read_timeout | The read timeout, in seconds, for store operations. |
  177. | | Defaults to 60 seconds. |
  178. | | |
  179. | | Note this only applies to |
  180. | | :py:class:`torch.distributed.TCPStore`. It is not relevant|
  181. | | to :py:class:`torch.distributed.FileStore` which does not |
  182. | | take in timeout as a parameter. |
  183. +--------------+-----------------------------------------------------------+
  184. | is_host | A boolean value indicating whether this backend instance |
  185. | | will host the C10d store. If not specified it will be |
  186. | | inferred heuristically by matching the hostname or the IP |
  187. | | address of this machine against the specified rendezvous |
  188. | | endpoint. Defaults to ``None``. |
  189. | | |
  190. | | Note that this configuration option only applies to |
  191. | | :py:class:`torch.distributed.TCPStore`. In normal |
  192. | | circumstances you can safely skip it; the only time when |
  193. | | it is needed is if its value cannot be correctly |
  194. | | determined (e.g. the rendezvous endpoint has a CNAME as |
  195. | | the hostname or does not match the FQDN of the machine). |
  196. +--------------+-----------------------------------------------------------+
  197. """
  198. # As of today we only support TCPStore and FileStore. Other store types do
  199. # not have the required functionality (e.g. compare_set) yet.
  200. store_type = params.get("store_type", "tcp").strip().lower()
  201. store: Store
  202. try:
  203. if store_type == "file":
  204. store = _create_file_store(params)
  205. elif store_type == "tcp":
  206. store = _create_tcp_store(params)
  207. else:
  208. raise ValueError("Invalid store type given. Currently only supports file and tcp.")
  209. backend = C10dRendezvousBackend(store, params.run_id)
  210. except Exception as e:
  211. construct_and_record_rdzv_event(
  212. message=f"{type(e).__name__}: {str(e)}",
  213. run_id=params.run_id,
  214. node_state=NodeState.FAILED,
  215. )
  216. raise
  217. return backend, store