api.py 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271
  1. # Copyright (c) Facebook, Inc. and its affiliates.
  2. # All rights reserved.
  3. #
  4. # This source code is licensed under the BSD-style license found in the
  5. # LICENSE file in the root directory of this source tree.
  6. from abc import ABC, abstractmethod
  7. from typing import Any, Callable, Dict, Optional, Tuple
  8. from torch.distributed import Store
  9. class RendezvousError(Exception):
  10. """Represents the base type for rendezvous errors."""
  11. class RendezvousClosedError(RendezvousError):
  12. """Raised when a rendezvous is closed."""
  13. class RendezvousTimeoutError(RendezvousError):
  14. """Raised when a rendezvous did not complete on time."""
  15. class RendezvousConnectionError(RendezvousError):
  16. """Raised when the connection to a rendezvous backend has failed."""
  17. class RendezvousStateError(RendezvousError):
  18. """Raised when the state of a rendezvous is corrupt."""
  19. class RendezvousHandler(ABC):
  20. """Main rendezvous interface.
  21. Note:
  22. Distributed Torch users normally **do not** need to implement their own
  23. ``RendezvousHandler``. An implementation based on C10d Store is already
  24. provided, and is recommended for most users.
  25. """
  26. @abstractmethod
  27. def get_backend(self) -> str:
  28. """Returns the name of the rendezvous backend."""
  29. @abstractmethod
  30. def next_rendezvous(
  31. self,
  32. ) -> Tuple[Store, int, int]:
  33. """Main entry-point into the rendezvous barrier.
  34. Blocks until the rendezvous is complete and the current process is
  35. included in the formed worker group, or a timeout occurs, or the
  36. rendezvous was marked closed.
  37. Returns:
  38. A tuple of :py:class:`torch.distributed.Store`, ``rank``, and
  39. ``world size``.
  40. Raises:
  41. RendezvousClosedError:
  42. The rendezvous is closed.
  43. RendezvousConnectionError:
  44. The connection to the rendezvous backend has failed.
  45. RendezvousStateError:
  46. The rendezvous state is corrupt.
  47. RendezvousTimeoutError:
  48. The rendezvous did not complete on time.
  49. """
  50. @abstractmethod
  51. def is_closed(self) -> bool:
  52. """Checks whether the rendezvous has been closed.
  53. A closed rendezvous means all future attempts to re-rendezvous within
  54. same job will fail.
  55. ``is_closed()`` and :py:meth:`set_closed` have semantics of eventual
  56. propagation and should not be used for synchronization. The intention is
  57. that if at least one node decides the job is finished, it will close the
  58. rendezvous, and other nodes will soon observe this and stop running as
  59. well.
  60. """
  61. @abstractmethod
  62. def set_closed(self):
  63. """Marks the rendezvous as closed."""
  64. @abstractmethod
  65. def num_nodes_waiting(self) -> int:
  66. """Returns the number of nodes who arrived late at the rendezvous
  67. barrier, hence were not included in the current worker group.
  68. Callers should periodically call this method to check whether new
  69. nodes are waiting to join the job and if so admit them by calling
  70. :py:meth:`next_rendezvous()` (re-rendezvous).
  71. """
  72. @abstractmethod
  73. def get_run_id(self) -> str:
  74. """Returns the run id of the rendezvous.
  75. The run id is a user-defined id that uniquely identifies an instance of
  76. a distributed application. It typically maps to a job id and is used to
  77. allow nodes to join the correct distributed application.
  78. """
  79. def shutdown(self) -> bool:
  80. """Closes all resources that were open for the rendezvous.
  81. Example::
  82. rdzv_handler = ...
  83. try:
  84. store, rank, world_size = rdzv_handler.next_rendezvous()
  85. finally:
  86. rdzv_handler.shutdown()
  87. """
  88. class RendezvousParameters:
  89. """Holds the parameters to construct a :py:class:`RendezvousHandler`.
  90. Args:
  91. backend:
  92. The name of the backend to use to handle the rendezvous.
  93. endpoint:
  94. The endpoint of the rendezvous, usually in form <hostname>[:<port>].
  95. run_id:
  96. The id of the rendezvous.
  97. min_nodes:
  98. The minimum number of nodes to admit to the rendezvous.
  99. max_nodes:
  100. The maximum number of nodes to admit to the rendezvous.
  101. local_addr:
  102. The address of the local node.
  103. **kwargs:
  104. Additional parameters for the specified backend.
  105. """
  106. def __init__(
  107. self,
  108. backend: str,
  109. endpoint: str,
  110. run_id: str,
  111. min_nodes: int,
  112. max_nodes: int,
  113. local_addr: Optional[str] = None,
  114. **kwargs,
  115. ):
  116. if not backend:
  117. raise ValueError("The rendezvous backend name must be a non-empty string.")
  118. if min_nodes < 1:
  119. raise ValueError(
  120. f"The minimum number of rendezvous nodes ({min_nodes}) must be greater than zero."
  121. )
  122. if max_nodes < min_nodes:
  123. raise ValueError(
  124. f"The maximum number of rendezvous nodes ({max_nodes}) must be greater than or "
  125. f"equal to the minimum number of rendezvous nodes ({min_nodes})."
  126. )
  127. self.backend = backend
  128. self.endpoint = endpoint
  129. self.run_id = run_id
  130. self.min_nodes = min_nodes
  131. self.max_nodes = max_nodes
  132. self.config = kwargs
  133. self.local_addr = local_addr
  134. def get(self, key: str, default: Any = None) -> Any:
  135. """Returns the value for ``key`` if ``key`` exists, else ``default``."""
  136. return self.config.get(key, default)
  137. def get_as_bool(self, key: str, default: Optional[bool] = None) -> Optional[bool]:
  138. """Returns the value for ``key`` as a ``bool``."""
  139. value = self.get(key, default)
  140. if value is None or isinstance(value, bool):
  141. return value
  142. if isinstance(value, int):
  143. if value == 1:
  144. return True
  145. if value == 0:
  146. return False
  147. elif isinstance(value, str):
  148. if value.lower() in ["1", "true", "t", "yes", "y"]:
  149. return True
  150. if value.lower() in ["0", "false", "f", "no", "n"]:
  151. return False
  152. raise ValueError(
  153. f"The rendezvous configuration option '{key}' does not represent a valid boolean value."
  154. )
  155. def get_as_int(self, key: str, default: Optional[int] = None) -> Optional[int]:
  156. """Returns the value for ``key`` as an ``int``."""
  157. value = self.get(key, default)
  158. if value is None:
  159. return value
  160. try:
  161. return int(value)
  162. except ValueError as e:
  163. raise ValueError(
  164. f"The rendezvous configuration option '{key}' does not represent a valid integer "
  165. "value."
  166. ) from e
  167. RendezvousHandlerCreator = Callable[[RendezvousParameters], RendezvousHandler]
  168. class RendezvousHandlerRegistry:
  169. """Represents a registry of :py:class:`RendezvousHandler` backends."""
  170. _registry: Dict[str, RendezvousHandlerCreator]
  171. def __init__(self) -> None:
  172. self._registry = {}
  173. def register(self, backend: str, creator: RendezvousHandlerCreator) -> None:
  174. """Registers a new rendezvous backend.
  175. Args:
  176. backend:
  177. The name of the backend.
  178. creator:
  179. The callback to invoke to construct the
  180. :py:class:`RendezvousHandler`.
  181. """
  182. if not backend:
  183. raise ValueError("The rendezvous backend name must be a non-empty string.")
  184. current_creator: Optional[RendezvousHandlerCreator]
  185. try:
  186. current_creator = self._registry[backend]
  187. except KeyError:
  188. current_creator = None
  189. if current_creator is not None and current_creator != creator:
  190. raise ValueError(
  191. f"The rendezvous backend '{backend}' cannot be registered with '{creator}' as it "
  192. f"is already registered with '{current_creator}'."
  193. )
  194. self._registry[backend] = creator
  195. def create_handler(self, params: RendezvousParameters) -> RendezvousHandler:
  196. """Creates a new :py:class:`RendezvousHandler`."""
  197. try:
  198. creator = self._registry[params.backend]
  199. except KeyError as e:
  200. raise ValueError(
  201. f"The rendezvous backend '{params.backend}' is not registered. Did you forget "
  202. f"to call `{self.register.__name__}`?"
  203. ) from e
  204. handler = creator(params)
  205. # Do some sanity check.
  206. if handler.get_backend() != params.backend:
  207. raise RuntimeError(
  208. f"The rendezvous backend '{handler.get_backend()}' does not match the requested "
  209. f"backend '{params.backend}'."
  210. )
  211. return handler
  212. # The default global registry instance used by launcher scripts to instantiate
  213. # rendezvous handlers.
  214. rendezvous_handler_registry = RendezvousHandlerRegistry()