etcd_server.py 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263
  1. #!/usr/bin/env python3
  2. # Copyright (c) Facebook, Inc. and its affiliates.
  3. # All rights reserved.
  4. #
  5. # This source code is licensed under the BSD-style license found in the
  6. # LICENSE file in the root directory of this source tree.
  7. import atexit
  8. import logging
  9. import os
  10. import shlex
  11. import shutil
  12. import socket
  13. import subprocess
  14. import tempfile
  15. import time
  16. from typing import Optional, TextIO, Union
  17. try:
  18. import etcd # type: ignore[import]
  19. except ModuleNotFoundError:
  20. pass
  21. log = logging.getLogger(__name__)
  22. def find_free_port():
  23. """
  24. Finds a free port and binds a temporary socket to it so that
  25. the port can be "reserved" until used.
  26. .. note:: the returned socket must be closed before using the port,
  27. otherwise a ``address already in use`` error will happen.
  28. The socket should be held and closed as close to the
  29. consumer of the port as possible since otherwise, there
  30. is a greater chance of race-condition where a different
  31. process may see the port as being free and take it.
  32. Returns: a socket binded to the reserved free port
  33. Usage::
  34. sock = find_free_port()
  35. port = sock.getsockname()[1]
  36. sock.close()
  37. use_port(port)
  38. """
  39. addrs = socket.getaddrinfo(
  40. host="localhost", port=None, family=socket.AF_UNSPEC, type=socket.SOCK_STREAM
  41. )
  42. for addr in addrs:
  43. family, type, proto, _, _ = addr
  44. try:
  45. s = socket.socket(family, type, proto)
  46. s.bind(("localhost", 0))
  47. s.listen(0)
  48. return s
  49. except OSError as e:
  50. s.close()
  51. print(f"Socket creation attempt failed: {e}")
  52. raise RuntimeError("Failed to create a socket")
  53. def stop_etcd(subprocess, data_dir: Optional[str] = None):
  54. if subprocess and subprocess.poll() is None:
  55. log.info("stopping etcd server")
  56. subprocess.terminate()
  57. subprocess.wait()
  58. if data_dir:
  59. log.info(f"deleting etcd data dir: {data_dir}")
  60. shutil.rmtree(data_dir, ignore_errors=True)
  61. class EtcdServer:
  62. """
  63. .. note:: tested on etcd server v3.4.3
  64. Starts and stops a local standalone etcd server on a random free
  65. port. Useful for single node, multi-worker launches or testing,
  66. where a sidecar etcd server is more convenient than having to
  67. separately setup an etcd server.
  68. This class registers a termination handler to shutdown the etcd
  69. subprocess on exit. This termination handler is NOT a substitute for
  70. calling the ``stop()`` method.
  71. The following fallback mechanism is used to find the etcd binary:
  72. 1. Uses env var TORCHELASTIC_ETCD_BINARY_PATH
  73. 2. Uses ``<this file root>/bin/etcd`` if one exists
  74. 3. Uses ``etcd`` from ``PATH``
  75. Usage
  76. ::
  77. server = EtcdServer("/usr/bin/etcd", 2379, "/tmp/default.etcd")
  78. server.start()
  79. client = server.get_client()
  80. # use client
  81. server.stop()
  82. Args:
  83. etcd_binary_path: path of etcd server binary (see above for fallback path)
  84. """
  85. def __init__(self, data_dir: Optional[str] = None):
  86. self._port = -1
  87. self._host = "localhost"
  88. root = os.path.dirname(__file__)
  89. default_etcd_bin = os.path.join(root, "bin/etcd")
  90. self._etcd_binary_path = os.environ.get(
  91. "TORCHELASTIC_ETCD_BINARY_PATH", default_etcd_bin
  92. )
  93. if not os.path.isfile(self._etcd_binary_path):
  94. self._etcd_binary_path = "etcd"
  95. self._base_data_dir = (
  96. data_dir if data_dir else tempfile.mkdtemp(prefix="torchelastic_etcd_data")
  97. )
  98. self._etcd_cmd = None
  99. self._etcd_proc: Optional[subprocess.Popen] = None
  100. def _get_etcd_server_process(self) -> subprocess.Popen:
  101. if not self._etcd_proc:
  102. raise RuntimeError(
  103. "No etcd server process started. Call etcd_server.start() first"
  104. )
  105. else:
  106. return self._etcd_proc
  107. def get_port(self) -> int:
  108. """
  109. Returns:
  110. the port the server is running on.
  111. """
  112. return self._port
  113. def get_host(self) -> str:
  114. """
  115. Returns:
  116. the host the server is running on.
  117. """
  118. return self._host
  119. def get_endpoint(self) -> str:
  120. """
  121. Returns:
  122. the etcd server endpoint (host:port)
  123. """
  124. return f"{self._host}:{self._port}"
  125. def start(
  126. self,
  127. timeout: int = 60,
  128. num_retries: int = 3,
  129. stderr: Union[int, TextIO, None] = None,
  130. ) -> None:
  131. """
  132. Starts the server, and waits for it to be ready. When this function
  133. returns the sever is ready to take requests.
  134. Args:
  135. timeout: time (in seconds) to wait for the server to be ready
  136. before giving up.
  137. num_retries: number of retries to start the server. Each retry
  138. will wait for max ``timeout`` before considering it as failed.
  139. stderr: the standard error file handle. Valid values are
  140. `subprocess.PIPE`, `subprocess.DEVNULL`, an existing file
  141. descriptor (a positive integer), an existing file object, and
  142. `None`.
  143. Raises:
  144. TimeoutError: if the server is not ready within the specified timeout
  145. """
  146. curr_retries = 0
  147. while True:
  148. try:
  149. data_dir = os.path.join(self._base_data_dir, str(curr_retries))
  150. os.makedirs(data_dir, exist_ok=True)
  151. return self._start(data_dir, timeout, stderr)
  152. except Exception as e:
  153. curr_retries += 1
  154. stop_etcd(self._etcd_proc)
  155. log.warning(
  156. f"Failed to start etcd server, got error: {str(e)}, retrying"
  157. )
  158. if curr_retries >= num_retries:
  159. shutil.rmtree(self._base_data_dir, ignore_errors=True)
  160. raise
  161. atexit.register(stop_etcd, self._etcd_proc, self._base_data_dir)
  162. def _start(
  163. self, data_dir: str, timeout: int = 60, stderr: Union[int, TextIO, None] = None
  164. ) -> None:
  165. sock = find_free_port()
  166. sock_peer = find_free_port()
  167. self._port = sock.getsockname()[1]
  168. peer_port = sock_peer.getsockname()[1]
  169. etcd_cmd = shlex.split(
  170. " ".join(
  171. [
  172. self._etcd_binary_path,
  173. "--enable-v2",
  174. "--data-dir",
  175. data_dir,
  176. "--listen-client-urls",
  177. f"http://{self._host}:{self._port}",
  178. "--advertise-client-urls",
  179. f"http://{self._host}:{self._port}",
  180. "--listen-peer-urls",
  181. f"http://{self._host}:{peer_port}",
  182. ]
  183. )
  184. )
  185. log.info(f"Starting etcd server: [{etcd_cmd}]")
  186. sock.close()
  187. sock_peer.close()
  188. self._etcd_proc = subprocess.Popen(etcd_cmd, close_fds=True, stderr=stderr)
  189. self._wait_for_ready(timeout)
  190. def get_client(self):
  191. """
  192. Returns:
  193. An etcd client object that can be used to make requests to
  194. this server.
  195. """
  196. return etcd.Client(
  197. host=self._host, port=self._port, version_prefix="/v2", read_timeout=10
  198. )
  199. def _wait_for_ready(self, timeout: int = 60) -> None:
  200. client = etcd.Client(
  201. host=f"{self._host}", port=self._port, version_prefix="/v2", read_timeout=5
  202. )
  203. max_time = time.time() + timeout
  204. while time.time() < max_time:
  205. if self._get_etcd_server_process().poll() is not None:
  206. # etcd server process finished
  207. exitcode = self._get_etcd_server_process().returncode
  208. raise RuntimeError(
  209. f"Etcd server process exited with the code: {exitcode}"
  210. )
  211. try:
  212. log.info(f"etcd server ready. version: {client.version}")
  213. return
  214. except Exception:
  215. time.sleep(1)
  216. raise TimeoutError("Timed out waiting for etcd server to be ready!")
  217. def stop(self) -> None:
  218. """
  219. Stops the server and cleans up auto generated resources (e.g. data dir)
  220. """
  221. log.info("EtcdServer stop method called")
  222. stop_etcd(self._etcd_proc, self._base_data_dir)