etcd_store.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203
  1. # Copyright (c) Facebook, Inc. and its affiliates.
  2. # All rights reserved.
  3. #
  4. # This source code is licensed under the BSD-style license found in the
  5. # LICENSE file in the root directory of this source tree.
  6. import datetime
  7. import random
  8. import time
  9. from base64 import b64decode, b64encode
  10. from typing import Optional
  11. import etcd # type: ignore[import]
  12. # pyre-ignore[21]: Could not find name `Store` in `torch.distributed`.
  13. from torch.distributed import Store
  14. # Delay (sleep) for a small random amount to reduce CAS failures.
  15. # This does not affect correctness, but will reduce requests to etcd server.
  16. def cas_delay():
  17. time.sleep(random.uniform(0, 0.1))
  18. # pyre-fixme[11]: Annotation `Store` is not defined as a type.
  19. class EtcdStore(Store):
  20. """
  21. Implements a c10 Store interface by piggybacking on the rendezvous etcd
  22. instance. This is the store object returned by ``EtcdRendezvous``
  23. """
  24. def __init__(
  25. self,
  26. etcd_client,
  27. etcd_store_prefix,
  28. # Default timeout same as in c10d/Store.hpp
  29. timeout: Optional[datetime.timedelta] = None,
  30. ):
  31. super().__init__() # required for pybind trampoline.
  32. self.client = etcd_client
  33. self.prefix = etcd_store_prefix
  34. if timeout is not None:
  35. self.set_timeout(timeout)
  36. if not self.prefix.endswith("/"):
  37. self.prefix += "/"
  38. def set(self, key, value):
  39. """
  40. Write a key/value pair into ``EtcdStore``.
  41. Both key and value may be either Python ``str`` or ``bytes``.
  42. """
  43. self.client.set(key=self.prefix + self._encode(key), value=self._encode(value))
  44. def get(self, key) -> bytes:
  45. """
  46. Get a value by key, possibly doing a blocking wait.
  47. If key is not immediately present, will do a blocking wait
  48. for at most ``timeout`` duration or until the key is published.
  49. Returns:
  50. value ``(bytes)``
  51. Raises:
  52. LookupError - If key still not published after timeout
  53. """
  54. b64_key = self.prefix + self._encode(key)
  55. kvs = self._try_wait_get([b64_key])
  56. if kvs is None:
  57. raise LookupError(f"Key {key} not found in EtcdStore")
  58. return self._decode(kvs[b64_key])
  59. def add(self, key, num: int) -> int:
  60. """
  61. Atomically increment a value by an integer amount. The integer is
  62. represented as a string using base 10. If key is not present,
  63. a default value of ``0`` will be assumed.
  64. Returns:
  65. the new (incremented) value
  66. """
  67. b64_key = self._encode(key)
  68. # c10d Store assumes value is an integer represented as a decimal string
  69. try:
  70. # Assume default value "0", if this key didn't yet:
  71. node = self.client.write(
  72. key=self.prefix + b64_key,
  73. value=self._encode(str(num)), # i.e. 0 + num
  74. prevExist=False,
  75. )
  76. return int(self._decode(node.value))
  77. except etcd.EtcdAlreadyExist:
  78. pass
  79. while True:
  80. # Note: c10d Store does not have a method to delete keys, so we
  81. # can be sure it's still there.
  82. node = self.client.get(key=self.prefix + b64_key)
  83. new_value = self._encode(str(int(self._decode(node.value)) + num))
  84. try:
  85. node = self.client.test_and_set(
  86. key=node.key, value=new_value, prev_value=node.value
  87. )
  88. return int(self._decode(node.value))
  89. except etcd.EtcdCompareFailed:
  90. cas_delay()
  91. def wait(self, keys, override_timeout: Optional[datetime.timedelta] = None):
  92. """
  93. Waits until all of the keys are published, or until timeout.
  94. Raises:
  95. LookupError - if timeout occurs
  96. """
  97. b64_keys = [self.prefix + self._encode(key) for key in keys]
  98. kvs = self._try_wait_get(b64_keys, override_timeout)
  99. if kvs is None:
  100. raise LookupError("Timeout while waiting for keys in EtcdStore")
  101. # No return value on success
  102. def check(self, keys) -> bool:
  103. """
  104. Check if all of the keys are immediately present (without waiting).
  105. """
  106. b64_keys = [self.prefix + self._encode(key) for key in keys]
  107. kvs = self._try_wait_get(
  108. b64_keys,
  109. override_timeout=datetime.timedelta(microseconds=1), # as if no wait
  110. )
  111. return kvs is not None
  112. #
  113. # Encode key/value data in base64, so we can store arbitrary binary data
  114. # in EtcdStore. Input can be `str` or `bytes`.
  115. # In case of `str`, utf-8 encoding is assumed.
  116. #
  117. def _encode(self, value) -> str:
  118. if type(value) == bytes:
  119. return b64encode(value).decode()
  120. elif type(value) == str:
  121. return b64encode(value.encode()).decode()
  122. raise ValueError("Value must be of type str or bytes")
  123. #
  124. # Decode a base64 string (of type `str` or `bytes`).
  125. # Return type is `bytes`, which is more convenient with the Store interface.
  126. #
  127. def _decode(self, value) -> bytes:
  128. if type(value) == bytes:
  129. return b64decode(value)
  130. elif type(value) == str:
  131. return b64decode(value.encode())
  132. raise ValueError("Value must be of type str or bytes")
  133. #
  134. # Get all of the (base64-encoded) etcd keys at once, or wait until all the keys
  135. # are published or timeout occurs.
  136. # This is a helper method for the public interface methods.
  137. #
  138. # On success, a dictionary of {etcd key -> etcd value} is returned.
  139. # On timeout, None is returned.
  140. #
  141. def _try_wait_get(self, b64_keys, override_timeout=None):
  142. timeout = self.timeout if override_timeout is None else override_timeout # type: ignore[attr-defined]
  143. deadline = time.time() + timeout.total_seconds()
  144. while True:
  145. # Read whole directory (of keys), filter only the ones waited for
  146. all_nodes = self.client.get(key=self.prefix)
  147. req_nodes = {
  148. node.key: node.value for node in all_nodes.children if node.key in b64_keys
  149. }
  150. if len(req_nodes) == len(b64_keys):
  151. # All keys are available
  152. return req_nodes
  153. watch_timeout = deadline - time.time()
  154. if watch_timeout <= 0:
  155. return None
  156. try:
  157. self.client.watch(
  158. key=self.prefix,
  159. recursive=True,
  160. timeout=watch_timeout,
  161. index=all_nodes.etcd_index + 1,
  162. )
  163. except etcd.EtcdWatchTimedOut:
  164. if time.time() >= deadline:
  165. return None
  166. else:
  167. continue
  168. except etcd.EtcdEventIndexCleared:
  169. continue