123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279 |
- # Copyright (c) Facebook, Inc. and its affiliates.
- # All rights reserved.
- #
- # This source code is licensed under the BSD-style license found in the
- # LICENSE file in the root directory of this source tree.
- import ipaddress
- import random
- import re
- import socket
- import time
- import weakref
- from datetime import timedelta
- from threading import Event, Thread
- from typing import Any, Callable, Dict, Optional, Tuple, Union
- __all__ = ['parse_rendezvous_endpoint']
- def _parse_rendezvous_config(config_str: str) -> Dict[str, str]:
- """Extracts key-value pairs from a rendezvous configuration string.
- Args:
- config_str:
- A string in format <key1>=<value1>,...,<keyN>=<valueN>.
- """
- config: Dict[str, str] = {}
- config_str = config_str.strip()
- if not config_str:
- return config
- key_values = config_str.split(",")
- for kv in key_values:
- key, *values = kv.split("=", 1)
- key = key.strip()
- if not key:
- raise ValueError(
- "The rendezvous configuration string must be in format "
- "<key1>=<value1>,...,<keyN>=<valueN>."
- )
- value: Optional[str]
- if values:
- value = values[0].strip()
- else:
- value = None
- if not value:
- raise ValueError(
- f"The rendezvous configuration option '{key}' must have a value specified."
- )
- config[key] = value
- return config
- def _try_parse_port(port_str: str) -> Optional[int]:
- """Tries to extract the port number from ``port_str``."""
- if port_str and re.match(r"^[0-9]{1,5}$", port_str):
- return int(port_str)
- return None
- def parse_rendezvous_endpoint(endpoint: Optional[str], default_port: int) -> Tuple[str, int]:
- """Extracts the hostname and the port number from a rendezvous endpoint.
- Args:
- endpoint:
- A string in format <hostname>[:<port>].
- default_port:
- The port number to use if the endpoint does not include one.
- Returns:
- A tuple of hostname and port number.
- """
- if endpoint is not None:
- endpoint = endpoint.strip()
- if not endpoint:
- return ("localhost", default_port)
- # An endpoint that starts and ends with brackets represents an IPv6 address.
- if endpoint[0] == "[" and endpoint[-1] == "]":
- host, *rest = endpoint, *[]
- else:
- host, *rest = endpoint.rsplit(":", 1)
- # Sanitize the IPv6 address.
- if len(host) > 1 and host[0] == "[" and host[-1] == "]":
- host = host[1:-1]
- if len(rest) == 1:
- port = _try_parse_port(rest[0])
- if port is None or port >= 2 ** 16:
- raise ValueError(
- f"The port number of the rendezvous endpoint '{endpoint}' must be an integer "
- "between 0 and 65536."
- )
- else:
- port = default_port
- if not re.match(r"^[\w\.:-]+$", host):
- raise ValueError(
- f"The hostname of the rendezvous endpoint '{endpoint}' must be a dot-separated list of "
- "labels, an IPv4 address, or an IPv6 address."
- )
- return host, port
- def _matches_machine_hostname(host: str) -> bool:
- """Indicates whether ``host`` matches the hostname of this machine.
- This function compares ``host`` to the hostname as well as to the IP
- addresses of this machine. Note that it may return a false negative if this
- machine has CNAME records beyond its FQDN or IP addresses assigned to
- secondary NICs.
- """
- if host == "localhost":
- return True
- try:
- addr = ipaddress.ip_address(host)
- except ValueError:
- addr = None
- if addr and addr.is_loopback:
- return True
- try:
- host_addr_list = socket.getaddrinfo(
- host, None, proto=socket.IPPROTO_TCP, flags=socket.AI_CANONNAME
- )
- except (ValueError, socket.gaierror) as _:
- host_addr_list = []
- host_ip_list = [
- host_addr_info[4][0]
- for host_addr_info in host_addr_list
- ]
- this_host = socket.gethostname()
- if host == this_host:
- return True
- addr_list = socket.getaddrinfo(
- this_host, None, proto=socket.IPPROTO_TCP, flags=socket.AI_CANONNAME
- )
- for addr_info in addr_list:
- # If we have an FQDN in the addr_info, compare it to `host`.
- if addr_info[3] and addr_info[3] == host:
- return True
- # Otherwise if `host` represents an IP address, compare it to our IP
- # address.
- if addr and addr_info[4][0] == str(addr):
- return True
- # If the IP address matches one of the provided host's IP addresses
- if addr_info[4][0] in host_ip_list:
- return True
- return False
- def _delay(seconds: Union[float, Tuple[float, float]]) -> None:
- """Suspends the current thread for ``seconds``.
- Args:
- seconds:
- Either the delay, in seconds, or a tuple of a lower and an upper
- bound within which a random delay will be picked.
- """
- if isinstance(seconds, tuple):
- seconds = random.uniform(*seconds)
- # Ignore delay requests that are less than 10 milliseconds.
- if seconds >= 0.01:
- time.sleep(seconds)
- class _PeriodicTimer:
- """Represents a timer that periodically runs a specified function.
- Args:
- interval:
- The interval, in seconds, between each run.
- function:
- The function to run.
- """
- # The state of the timer is hold in a separate context object to avoid a
- # reference cycle between the timer and the background thread.
- class _Context:
- interval: float
- function: Callable[..., None]
- args: Tuple[Any, ...]
- kwargs: Dict[str, Any]
- stop_event: Event
- _name: Optional[str]
- _thread: Optional[Thread]
- _finalizer: Optional[weakref.finalize]
- # The context that is shared between the timer and the background thread.
- _ctx: _Context
- def __init__(
- self,
- interval: timedelta,
- function: Callable[..., None],
- *args: Any,
- **kwargs: Any,
- ) -> None:
- self._name = None
- self._ctx = self._Context()
- self._ctx.interval = interval.total_seconds()
- self._ctx.function = function # type: ignore[assignment]
- self._ctx.args = args or ()
- self._ctx.kwargs = kwargs or {}
- self._ctx.stop_event = Event()
- self._thread = None
- self._finalizer = None
- @property
- def name(self) -> Optional[str]:
- """Gets the name of the timer."""
- return self._name
- def set_name(self, name: str) -> None:
- """Sets the name of the timer.
- The specified name will be assigned to the background thread and serves
- for debugging and troubleshooting purposes.
- """
- if self._thread:
- raise RuntimeError("The timer has already started.")
- self._name = name
- def start(self) -> None:
- """Start the timer."""
- if self._thread:
- raise RuntimeError("The timer has already started.")
- self._thread = Thread(
- target=self._run, name=self._name or "PeriodicTimer", args=(self._ctx,), daemon=True
- )
- # We avoid using a regular finalizer (a.k.a. __del__) for stopping the
- # timer as joining a daemon thread during the interpreter shutdown can
- # cause deadlocks. The weakref.finalize is a superior alternative that
- # provides a consistent behavior regardless of the GC implementation.
- self._finalizer = weakref.finalize(
- self, self._stop_thread, self._thread, self._ctx.stop_event
- )
- # We do not attempt to stop our background thread during the interpreter
- # shutdown. At that point we do not even know whether it still exists.
- self._finalizer.atexit = False
- self._thread.start()
- def cancel(self) -> None:
- """Stop the timer at the next opportunity."""
- if self._finalizer:
- self._finalizer()
- @staticmethod
- def _run(ctx) -> None:
- while not ctx.stop_event.wait(ctx.interval):
- ctx.function(*ctx.args, **ctx.kwargs)
- @staticmethod
- def _stop_thread(thread, stop_event):
- stop_event.set()
- thread.join()
|