etcd_rendezvous_backend.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214
  1. # Copyright (c) Facebook, Inc. and its affiliates.
  2. # All rights reserved.
  3. #
  4. # This source code is licensed under the BSD-style license found in the
  5. # LICENSE file in the root directory of this source tree.
  6. import binascii
  7. from base64 import b64decode, b64encode
  8. from typing import Optional, Tuple, cast
  9. import urllib3.exceptions # type: ignore[import]
  10. from etcd import Client as EtcdClient # type: ignore[import]
  11. from etcd import (
  12. EtcdAlreadyExist,
  13. EtcdCompareFailed,
  14. EtcdException,
  15. EtcdKeyNotFound,
  16. EtcdResult,
  17. )
  18. from torch.distributed import Store
  19. from .api import RendezvousConnectionError, RendezvousParameters, RendezvousStateError
  20. from .dynamic_rendezvous import RendezvousBackend, Token
  21. from .etcd_store import EtcdStore
  22. from .utils import parse_rendezvous_endpoint
  23. class EtcdRendezvousBackend(RendezvousBackend):
  24. """Represents an etcd-based rendezvous backend.
  25. Args:
  26. client:
  27. The ``etcd.Client`` instance to use to communicate with etcd.
  28. run_id:
  29. The run id of the rendezvous.
  30. key_prefix:
  31. The path under which to store the rendezvous state in etcd.
  32. ttl:
  33. The TTL of the rendezvous state. If not specified, defaults to two hours.
  34. """
  35. _DEFAULT_TTL = 7200 # 2 hours
  36. _client: EtcdClient
  37. _key: str
  38. _ttl: int
  39. def __init__(
  40. self,
  41. client: EtcdClient,
  42. run_id: str,
  43. key_prefix: Optional[str] = None,
  44. ttl: Optional[int] = None,
  45. ) -> None:
  46. if not run_id:
  47. raise ValueError("The run id must be a non-empty string.")
  48. self._client = client
  49. if key_prefix:
  50. self._key = key_prefix + "/" + run_id
  51. else:
  52. self._key = run_id
  53. if ttl and ttl > 0:
  54. self._ttl = ttl
  55. else:
  56. self._ttl = self._DEFAULT_TTL
  57. @property
  58. def name(self) -> str:
  59. """See base class."""
  60. return "etcd-v2"
  61. def get_state(self) -> Optional[Tuple[bytes, Token]]:
  62. """See base class."""
  63. try:
  64. result = self._client.read(self._key)
  65. except EtcdKeyNotFound:
  66. return None
  67. except (EtcdException, urllib3.exceptions.TimeoutError) as exc:
  68. raise RendezvousConnectionError(
  69. "The connection to etcd has failed. See inner exception for details."
  70. ) from exc
  71. return self._decode_state(result)
  72. def set_state(
  73. self, state: bytes, token: Optional[Token] = None
  74. ) -> Optional[Tuple[bytes, Token, bool]]:
  75. """See base class."""
  76. base64_state = b64encode(state).decode()
  77. kwargs = {}
  78. def get_state():
  79. result = self.get_state()
  80. if result is not None:
  81. tmp = *result, False
  82. # Python 3.6 does not support tuple unpacking in return
  83. # statements.
  84. return tmp
  85. return None
  86. if token:
  87. try:
  88. token = int(token)
  89. except ValueError:
  90. return get_state()
  91. if token:
  92. kwargs["prevIndex"] = token
  93. else:
  94. kwargs["prevExist"] = False
  95. try:
  96. result = self._client.write(self._key, base64_state, self._ttl, **kwargs)
  97. except (EtcdAlreadyExist, EtcdCompareFailed):
  98. result = None
  99. except (EtcdException, urllib3.exceptions.TimeoutError) as exc:
  100. raise RendezvousConnectionError(
  101. "The connection to etcd has failed. See inner exception for details."
  102. ) from exc
  103. if result is None:
  104. return get_state()
  105. tmp = *self._decode_state(result), True
  106. return tmp
  107. def _decode_state(self, result: EtcdResult) -> Tuple[bytes, Token]:
  108. base64_state = result.value.encode()
  109. try:
  110. state = b64decode(base64_state)
  111. except binascii.Error as exc:
  112. raise RendezvousStateError(
  113. "The state object is corrupt. See inner exception for details."
  114. ) from exc
  115. return state, result.modifiedIndex
  116. def _create_etcd_client(params: RendezvousParameters) -> EtcdClient:
  117. host, port = parse_rendezvous_endpoint(params.endpoint, default_port=2379)
  118. # The timeout
  119. read_timeout = cast(int, params.get_as_int("read_timeout", 60))
  120. if read_timeout <= 0:
  121. raise ValueError("The read timeout must be a positive integer.")
  122. # The communication protocol
  123. protocol = params.get("protocol", "http").strip().lower()
  124. if protocol != "http" and protocol != "https":
  125. raise ValueError("The protocol must be HTTP or HTTPS.")
  126. # The SSL client certificate
  127. ssl_cert = params.get("ssl_cert")
  128. if ssl_cert:
  129. ssl_cert_key = params.get("ssl_cert_key")
  130. if ssl_cert_key:
  131. # The etcd client expects the certificate key as the second element
  132. # of the `cert` tuple.
  133. ssl_cert = (ssl_cert, ssl_cert_key)
  134. # The root certificate
  135. ca_cert = params.get("ca_cert")
  136. try:
  137. return EtcdClient(
  138. host,
  139. port,
  140. read_timeout=read_timeout,
  141. protocol=protocol,
  142. cert=ssl_cert,
  143. ca_cert=ca_cert,
  144. allow_reconnect=True,
  145. )
  146. except (EtcdException, urllib3.exceptions.TimeoutError) as exc:
  147. raise RendezvousConnectionError(
  148. "The connection to etcd has failed. See inner exception for details."
  149. ) from exc
  150. def create_backend(params: RendezvousParameters) -> Tuple[EtcdRendezvousBackend, Store]:
  151. """Creates a new :py:class:`EtcdRendezvousBackend` from the specified
  152. parameters.
  153. +--------------+-----------------------------------------------------------+
  154. | Parameter | Description |
  155. +==============+===========================================================+
  156. | read_timeout | The read timeout, in seconds, for etcd operations. |
  157. | | Defaults to 60 seconds. |
  158. +--------------+-----------------------------------------------------------+
  159. | protocol | The protocol to use to communicate with etcd. Valid |
  160. | | values are "http" and "https". Defaults to "http". |
  161. +--------------+-----------------------------------------------------------+
  162. | ssl_cert | The path to the SSL client certificate to use along with |
  163. | | HTTPS. Defaults to ``None``. |
  164. +--------------+-----------------------------------------------------------+
  165. | ssl_cert_key | The path to the private key of the SSL client certificate |
  166. | | to use along with HTTPS. Defaults to ``None``. |
  167. +--------------+-----------------------------------------------------------+
  168. | ca_cert | The path to the rool SSL authority certificate. Defaults |
  169. | | to ``None``. |
  170. +--------------+-----------------------------------------------------------+
  171. """
  172. client = _create_etcd_client(params)
  173. backend = EtcdRendezvousBackend(client, params.run_id, key_prefix="/torch/elastic/rendezvous")
  174. store = EtcdStore(client, "/torch/elastic/store")
  175. return backend, store