utils.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279
  1. # Copyright (c) Facebook, Inc. and its affiliates.
  2. # All rights reserved.
  3. #
  4. # This source code is licensed under the BSD-style license found in the
  5. # LICENSE file in the root directory of this source tree.
  6. import ipaddress
  7. import random
  8. import re
  9. import socket
  10. import time
  11. import weakref
  12. from datetime import timedelta
  13. from threading import Event, Thread
  14. from typing import Any, Callable, Dict, Optional, Tuple, Union
  15. __all__ = ['parse_rendezvous_endpoint']
  16. def _parse_rendezvous_config(config_str: str) -> Dict[str, str]:
  17. """Extracts key-value pairs from a rendezvous configuration string.
  18. Args:
  19. config_str:
  20. A string in format <key1>=<value1>,...,<keyN>=<valueN>.
  21. """
  22. config: Dict[str, str] = {}
  23. config_str = config_str.strip()
  24. if not config_str:
  25. return config
  26. key_values = config_str.split(",")
  27. for kv in key_values:
  28. key, *values = kv.split("=", 1)
  29. key = key.strip()
  30. if not key:
  31. raise ValueError(
  32. "The rendezvous configuration string must be in format "
  33. "<key1>=<value1>,...,<keyN>=<valueN>."
  34. )
  35. value: Optional[str]
  36. if values:
  37. value = values[0].strip()
  38. else:
  39. value = None
  40. if not value:
  41. raise ValueError(
  42. f"The rendezvous configuration option '{key}' must have a value specified."
  43. )
  44. config[key] = value
  45. return config
  46. def _try_parse_port(port_str: str) -> Optional[int]:
  47. """Tries to extract the port number from ``port_str``."""
  48. if port_str and re.match(r"^[0-9]{1,5}$", port_str):
  49. return int(port_str)
  50. return None
  51. def parse_rendezvous_endpoint(endpoint: Optional[str], default_port: int) -> Tuple[str, int]:
  52. """Extracts the hostname and the port number from a rendezvous endpoint.
  53. Args:
  54. endpoint:
  55. A string in format <hostname>[:<port>].
  56. default_port:
  57. The port number to use if the endpoint does not include one.
  58. Returns:
  59. A tuple of hostname and port number.
  60. """
  61. if endpoint is not None:
  62. endpoint = endpoint.strip()
  63. if not endpoint:
  64. return ("localhost", default_port)
  65. # An endpoint that starts and ends with brackets represents an IPv6 address.
  66. if endpoint[0] == "[" and endpoint[-1] == "]":
  67. host, *rest = endpoint, *[]
  68. else:
  69. host, *rest = endpoint.rsplit(":", 1)
  70. # Sanitize the IPv6 address.
  71. if len(host) > 1 and host[0] == "[" and host[-1] == "]":
  72. host = host[1:-1]
  73. if len(rest) == 1:
  74. port = _try_parse_port(rest[0])
  75. if port is None or port >= 2 ** 16:
  76. raise ValueError(
  77. f"The port number of the rendezvous endpoint '{endpoint}' must be an integer "
  78. "between 0 and 65536."
  79. )
  80. else:
  81. port = default_port
  82. if not re.match(r"^[\w\.:-]+$", host):
  83. raise ValueError(
  84. f"The hostname of the rendezvous endpoint '{endpoint}' must be a dot-separated list of "
  85. "labels, an IPv4 address, or an IPv6 address."
  86. )
  87. return host, port
  88. def _matches_machine_hostname(host: str) -> bool:
  89. """Indicates whether ``host`` matches the hostname of this machine.
  90. This function compares ``host`` to the hostname as well as to the IP
  91. addresses of this machine. Note that it may return a false negative if this
  92. machine has CNAME records beyond its FQDN or IP addresses assigned to
  93. secondary NICs.
  94. """
  95. if host == "localhost":
  96. return True
  97. try:
  98. addr = ipaddress.ip_address(host)
  99. except ValueError:
  100. addr = None
  101. if addr and addr.is_loopback:
  102. return True
  103. try:
  104. host_addr_list = socket.getaddrinfo(
  105. host, None, proto=socket.IPPROTO_TCP, flags=socket.AI_CANONNAME
  106. )
  107. except (ValueError, socket.gaierror) as _:
  108. host_addr_list = []
  109. host_ip_list = [
  110. host_addr_info[4][0]
  111. for host_addr_info in host_addr_list
  112. ]
  113. this_host = socket.gethostname()
  114. if host == this_host:
  115. return True
  116. addr_list = socket.getaddrinfo(
  117. this_host, None, proto=socket.IPPROTO_TCP, flags=socket.AI_CANONNAME
  118. )
  119. for addr_info in addr_list:
  120. # If we have an FQDN in the addr_info, compare it to `host`.
  121. if addr_info[3] and addr_info[3] == host:
  122. return True
  123. # Otherwise if `host` represents an IP address, compare it to our IP
  124. # address.
  125. if addr and addr_info[4][0] == str(addr):
  126. return True
  127. # If the IP address matches one of the provided host's IP addresses
  128. if addr_info[4][0] in host_ip_list:
  129. return True
  130. return False
  131. def _delay(seconds: Union[float, Tuple[float, float]]) -> None:
  132. """Suspends the current thread for ``seconds``.
  133. Args:
  134. seconds:
  135. Either the delay, in seconds, or a tuple of a lower and an upper
  136. bound within which a random delay will be picked.
  137. """
  138. if isinstance(seconds, tuple):
  139. seconds = random.uniform(*seconds)
  140. # Ignore delay requests that are less than 10 milliseconds.
  141. if seconds >= 0.01:
  142. time.sleep(seconds)
  143. class _PeriodicTimer:
  144. """Represents a timer that periodically runs a specified function.
  145. Args:
  146. interval:
  147. The interval, in seconds, between each run.
  148. function:
  149. The function to run.
  150. """
  151. # The state of the timer is hold in a separate context object to avoid a
  152. # reference cycle between the timer and the background thread.
  153. class _Context:
  154. interval: float
  155. function: Callable[..., None]
  156. args: Tuple[Any, ...]
  157. kwargs: Dict[str, Any]
  158. stop_event: Event
  159. _name: Optional[str]
  160. _thread: Optional[Thread]
  161. _finalizer: Optional[weakref.finalize]
  162. # The context that is shared between the timer and the background thread.
  163. _ctx: _Context
  164. def __init__(
  165. self,
  166. interval: timedelta,
  167. function: Callable[..., None],
  168. *args: Any,
  169. **kwargs: Any,
  170. ) -> None:
  171. self._name = None
  172. self._ctx = self._Context()
  173. self._ctx.interval = interval.total_seconds()
  174. self._ctx.function = function # type: ignore[assignment]
  175. self._ctx.args = args or ()
  176. self._ctx.kwargs = kwargs or {}
  177. self._ctx.stop_event = Event()
  178. self._thread = None
  179. self._finalizer = None
  180. @property
  181. def name(self) -> Optional[str]:
  182. """Gets the name of the timer."""
  183. return self._name
  184. def set_name(self, name: str) -> None:
  185. """Sets the name of the timer.
  186. The specified name will be assigned to the background thread and serves
  187. for debugging and troubleshooting purposes.
  188. """
  189. if self._thread:
  190. raise RuntimeError("The timer has already started.")
  191. self._name = name
  192. def start(self) -> None:
  193. """Start the timer."""
  194. if self._thread:
  195. raise RuntimeError("The timer has already started.")
  196. self._thread = Thread(
  197. target=self._run, name=self._name or "PeriodicTimer", args=(self._ctx,), daemon=True
  198. )
  199. # We avoid using a regular finalizer (a.k.a. __del__) for stopping the
  200. # timer as joining a daemon thread during the interpreter shutdown can
  201. # cause deadlocks. The weakref.finalize is a superior alternative that
  202. # provides a consistent behavior regardless of the GC implementation.
  203. self._finalizer = weakref.finalize(
  204. self, self._stop_thread, self._thread, self._ctx.stop_event
  205. )
  206. # We do not attempt to stop our background thread during the interpreter
  207. # shutdown. At that point we do not even know whether it still exists.
  208. self._finalizer.atexit = False
  209. self._thread.start()
  210. def cancel(self) -> None:
  211. """Stop the timer at the next opportunity."""
  212. if self._finalizer:
  213. self._finalizer()
  214. @staticmethod
  215. def _run(ctx) -> None:
  216. while not ctx.stop_event.wait(ctx.interval):
  217. ctx.function(*ctx.args, **ctx.kwargs)
  218. @staticmethod
  219. def _stop_thread(thread, stop_event):
  220. stop_event.set()
  221. thread.join()