local_timer.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125
  1. # Copyright (c) Facebook, Inc. and its affiliates.
  2. # All rights reserved.
  3. #
  4. # This source code is licensed under the BSD-style license found in the
  5. # LICENSE file in the root directory of this source tree.
  6. import logging
  7. import multiprocessing as mp
  8. import os
  9. import signal
  10. import time
  11. from queue import Empty
  12. from typing import Any, Dict, List, Set, Tuple
  13. from .api import RequestQueue, TimerClient, TimerRequest, TimerServer
  14. __all__ = ['LocalTimerClient', 'MultiprocessingRequestQueue', 'LocalTimerServer']
  15. log = logging.getLogger(__name__)
  16. class LocalTimerClient(TimerClient):
  17. """
  18. Client side of ``LocalTimerServer``. This client is meant to be used
  19. on the same host that the ``LocalTimerServer`` is running on and uses
  20. pid to uniquely identify a worker. This is particularly useful in situations
  21. where one spawns a subprocess (trainer) per GPU on a host with multiple
  22. GPU devices.
  23. """
  24. def __init__(self, mp_queue):
  25. super().__init__()
  26. self._mp_queue = mp_queue
  27. def acquire(self, scope_id, expiration_time):
  28. pid = os.getpid()
  29. acquire_request = TimerRequest(pid, scope_id, expiration_time)
  30. self._mp_queue.put(acquire_request)
  31. def release(self, scope_id):
  32. pid = os.getpid()
  33. release_request = TimerRequest(pid, scope_id, -1)
  34. self._mp_queue.put(release_request)
  35. class MultiprocessingRequestQueue(RequestQueue):
  36. """
  37. A ``RequestQueue`` backed by python ``multiprocessing.Queue``
  38. """
  39. def __init__(self, mp_queue: mp.Queue):
  40. super().__init__()
  41. self._mp_queue = mp_queue
  42. def size(self) -> int:
  43. return self._mp_queue.qsize()
  44. def get(self, size, timeout: float) -> List[TimerRequest]:
  45. requests = []
  46. wait = timeout
  47. for _ in range(0, size):
  48. start = time.time()
  49. try:
  50. r = self._mp_queue.get(block=True, timeout=wait)
  51. except Empty:
  52. break
  53. requests.append(r)
  54. wait = wait - (time.time() - start)
  55. if wait <= 0:
  56. break
  57. return requests
  58. class LocalTimerServer(TimerServer):
  59. """
  60. Server that works with ``LocalTimerClient``. Clients are expected to be
  61. subprocesses to the parent process that is running this server. Each host
  62. in the job is expected to start its own timer server locally and each
  63. server instance manages timers for local workers (running on processes
  64. on the same host).
  65. """
  66. def __init__(
  67. self, mp_queue: mp.Queue, max_interval: float = 60, daemon: bool = True
  68. ):
  69. super().__init__(MultiprocessingRequestQueue(mp_queue), max_interval, daemon)
  70. self._timers: Dict[Tuple[Any, str], TimerRequest] = {}
  71. def register_timers(self, timer_requests: List[TimerRequest]) -> None:
  72. for request in timer_requests:
  73. pid = request.worker_id
  74. scope_id = request.scope_id
  75. expiration_time = request.expiration_time
  76. # negative expiration is a proxy for a release call
  77. if expiration_time < 0:
  78. self._timers.pop((pid, scope_id), None)
  79. else:
  80. self._timers[(pid, scope_id)] = request
  81. def clear_timers(self, worker_ids: Set[int]) -> None:
  82. for (pid, scope_id) in list(self._timers.keys()):
  83. if pid in worker_ids:
  84. self._timers.pop((pid, scope_id))
  85. def get_expired_timers(self, deadline: float) -> Dict[Any, List[TimerRequest]]:
  86. # pid -> [timer_requests...]
  87. expired_timers: Dict[Any, List[TimerRequest]] = {}
  88. for request in self._timers.values():
  89. if request.expiration_time <= deadline:
  90. expired_scopes = expired_timers.setdefault(request.worker_id, [])
  91. expired_scopes.append(request)
  92. return expired_timers
  93. def _reap_worker(self, worker_id: int) -> bool:
  94. try:
  95. os.kill(worker_id, signal.SIGKILL)
  96. return True
  97. except ProcessLookupError:
  98. log.info(f"Process with pid={worker_id} does not exist. Skipping")
  99. return True
  100. except Exception as e:
  101. log.error(f"Error terminating pid={worker_id}", exc_info=e)
  102. return False