api.py 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279
  1. # Copyright (c) Facebook, Inc. and its affiliates.
  2. # All rights reserved.
  3. #
  4. # This source code is licensed under the BSD-style license found in the
  5. # LICENSE file in the root directory of this source tree.
  6. import abc
  7. import logging
  8. import threading
  9. import time
  10. from contextlib import contextmanager
  11. from inspect import getframeinfo, stack
  12. from typing import Any, Dict, List, Optional, Set
  13. __all__ = ['TimerRequest', 'TimerClient', 'RequestQueue', 'TimerServer', 'configure', 'expires']
  14. log = logging.getLogger(__name__)
  15. class TimerRequest:
  16. """
  17. Data object representing a countdown timer acquisition and release
  18. that is used between the ``TimerClient`` and ``TimerServer``.
  19. A negative ``expiration_time`` should be interpreted as a "release"
  20. request.
  21. .. note:: the type of ``worker_id`` is implementation specific.
  22. It is whatever the TimerServer and TimerClient implementations
  23. have on to uniquely identify a worker.
  24. """
  25. __slots__ = ["worker_id", "scope_id", "expiration_time"]
  26. def __init__(self, worker_id: Any, scope_id: str, expiration_time: float):
  27. self.worker_id = worker_id
  28. self.scope_id = scope_id
  29. self.expiration_time = expiration_time
  30. def __eq__(self, other):
  31. if isinstance(other, TimerRequest):
  32. return (
  33. self.worker_id == other.worker_id
  34. and self.scope_id == other.scope_id
  35. and self.expiration_time == other.expiration_time
  36. )
  37. return False
  38. class TimerClient(abc.ABC):
  39. """
  40. Client library to acquire and release countdown timers by communicating
  41. with the TimerServer.
  42. """
  43. @abc.abstractmethod
  44. def acquire(self, scope_id: str, expiration_time: float) -> None:
  45. """
  46. Acquires a timer for the worker that holds this client object
  47. given the scope_id and expiration_time. Typically registers
  48. the timer with the TimerServer.
  49. """
  50. pass
  51. @abc.abstractmethod
  52. def release(self, scope_id: str):
  53. """
  54. Releases the timer for the ``scope_id`` on the worker this
  55. client represents. After this method is
  56. called, the countdown timer on the scope is no longer in effect.
  57. """
  58. pass
  59. class RequestQueue(abc.ABC):
  60. """
  61. Consumer queue holding timer acquisition/release requests
  62. """
  63. @abc.abstractmethod
  64. def size(self) -> int:
  65. """
  66. Returns the size of the queue at the time this method is called.
  67. Note that by the time ``get`` is called the size of the queue
  68. may have increased. The size of the queue should not decrease
  69. until the ``get`` method is called. That is, the following assertion
  70. should hold:
  71. size = q.size()
  72. res = q.get(size, timeout=0)
  73. assert size == len(res)
  74. -- or --
  75. size = q.size()
  76. res = q.get(size * 2, timeout=1)
  77. assert size <= len(res) <= size * 2
  78. """
  79. pass
  80. @abc.abstractmethod
  81. def get(self, size: int, timeout: float) -> List[TimerRequest]:
  82. """
  83. Gets up to ``size`` number of timer requests in a blocking fashion
  84. (no more than ``timeout`` seconds).
  85. """
  86. pass
  87. class TimerServer(abc.ABC):
  88. """
  89. Entity that monitors active timers and expires them
  90. in a timely fashion. This server is responsible for
  91. reaping workers that have expired timers.
  92. """
  93. def __init__(
  94. self, request_queue: RequestQueue, max_interval: float, daemon: bool = True
  95. ):
  96. """
  97. :param request_queue: Consumer ``RequestQueue``
  98. :param max_interval: max time (in seconds) to wait
  99. for an item in the request_queue
  100. :param daemon: whether to run the watchdog thread as a daemon
  101. """
  102. super().__init__()
  103. self._request_queue = request_queue
  104. self._max_interval = max_interval
  105. self._daemon = daemon
  106. self._watchdog_thread: Optional[threading.Thread] = None
  107. self._stop_signaled = False
  108. @abc.abstractmethod
  109. def register_timers(self, timer_requests: List[TimerRequest]) -> None:
  110. """
  111. Processes the incoming timer requests and registers them with the server.
  112. The timer request can either be a acquire-timer or release-timer request.
  113. Timer requests with a negative expiration_time should be interpreted
  114. as a release-timer request.
  115. """
  116. pass
  117. @abc.abstractmethod
  118. def clear_timers(self, worker_ids: Set[Any]) -> None:
  119. """
  120. Clears all timers for the given ``worker_ids``.
  121. """
  122. pass
  123. @abc.abstractmethod
  124. def get_expired_timers(self, deadline: float) -> Dict[str, List[TimerRequest]]:
  125. """
  126. Returns all expired timers for each worker_id. An expired timer
  127. is a timer for which the expiration_time is less than or equal to
  128. the provided deadline.
  129. """
  130. pass
  131. @abc.abstractmethod
  132. def _reap_worker(self, worker_id: Any) -> bool:
  133. """
  134. Reaps the given worker. Returns True if the worker has been
  135. successfully reaped, False otherwise. If any uncaught exception
  136. is thrown from this method, the worker is considered reaped
  137. and all associated timers will be removed.
  138. """
  139. def _reap_worker_no_throw(self, worker_id: Any) -> bool:
  140. """
  141. Wraps ``_reap_worker(worker_id)``, if an uncaught exception is
  142. thrown, then it considers the worker as reaped.
  143. """
  144. try:
  145. return self._reap_worker(worker_id)
  146. except Exception as e:
  147. log.error(
  148. "Uncaught exception thrown from _reap_worker(), "
  149. "check that the implementation correctly catches exceptions",
  150. exc_info=e,
  151. )
  152. return True
  153. def _watchdog_loop(self):
  154. while not self._stop_signaled:
  155. try:
  156. self._run_watchdog()
  157. except Exception as e:
  158. log.error("Error running watchdog", exc_info=e)
  159. def _run_watchdog(self):
  160. batch_size = max(1, self._request_queue.size())
  161. timer_requests = self._request_queue.get(batch_size, self._max_interval)
  162. self.register_timers(timer_requests)
  163. now = time.time()
  164. reaped_worker_ids = set()
  165. for worker_id, expired_timers in self.get_expired_timers(now).items():
  166. log.info(
  167. f"Reaping worker_id=[{worker_id}]."
  168. f" Expired timers: {self._get_scopes(expired_timers)}"
  169. )
  170. if self._reap_worker_no_throw(worker_id):
  171. log.info(f"Successfully reaped worker=[{worker_id}]")
  172. reaped_worker_ids.add(worker_id)
  173. else:
  174. log.error(
  175. f"Error reaping worker=[{worker_id}]. Will retry on next watchdog."
  176. )
  177. self.clear_timers(reaped_worker_ids)
  178. def _get_scopes(self, timer_requests):
  179. return [r.scope_id for r in timer_requests]
  180. def start(self) -> None:
  181. log.info(
  182. f"Starting {type(self).__name__}..."
  183. f" max_interval={self._max_interval},"
  184. f" daemon={self._daemon}"
  185. )
  186. self._watchdog_thread = threading.Thread(
  187. target=self._watchdog_loop, daemon=self._daemon
  188. )
  189. log.info("Starting watchdog thread...")
  190. self._watchdog_thread.start()
  191. def stop(self) -> None:
  192. log.info(f"Stopping {type(self).__name__}")
  193. self._stop_signaled = True
  194. if self._watchdog_thread:
  195. log.info("Stopping watchdog thread...")
  196. self._watchdog_thread.join(self._max_interval)
  197. self._watchdog_thread = None
  198. else:
  199. log.info("No watchdog thread running, doing nothing")
  200. _timer_client = None
  201. def configure(timer_client: TimerClient):
  202. """
  203. Configures a timer client. Must be called before using ``expires``.
  204. """
  205. global _timer_client
  206. _timer_client = timer_client
  207. log.info(f"Timer client configured to: {type(_timer_client).__name__}")
  208. @contextmanager
  209. def expires(
  210. after: float, scope: Optional[str] = None, client: Optional[TimerClient] = None
  211. ):
  212. """
  213. Acquires a countdown timer that expires in ``after`` seconds from now,
  214. unless the code-block that it wraps is finished within the timeframe.
  215. When the timer expires, this worker is eligible to be reaped. The
  216. exact meaning of "reaped" depends on the client implementation. In
  217. most cases, reaping means to terminate the worker process.
  218. Note that the worker is NOT guaranteed to be reaped at exactly
  219. ``time.now() + after``, but rather the worker is "eligible" for being
  220. reaped and the ``TimerServer`` that the client talks to will ultimately
  221. make the decision when and how to reap the workers with expired timers.
  222. Usage::
  223. torch.distributed.elastic.timer.configure(LocalTimerClient())
  224. with expires(after=10):
  225. torch.distributed.all_reduce(...)
  226. """
  227. if client is None:
  228. if _timer_client is None:
  229. raise RuntimeError("Configure timer client before using coundown timers.")
  230. client = _timer_client
  231. if scope is None:
  232. # grab the caller file + lineno
  233. caller = getframeinfo(stack()[1][0])
  234. scope = f"{caller.filename}#{caller.lineno}"
  235. expiration = time.time() + after
  236. client.acquire(scope, expiration)
  237. try:
  238. yield
  239. finally:
  240. client.release(scope)