_limiter_utils.py 1.1 KB

123456789101112131415161718192021222324252627282930313233
  1. import collections
  2. from typing import Deque, Optional
  3. import torch
  4. class _FreeEventQueue:
  5. """
  6. This tracks all pending frees corresponding to inflight all-gathers. The
  7. queueing pattern is iterative enqueues with a single dequeue per iteration
  8. once the limit ``_max_num_inflight_all_gathers`` is reached.
  9. """
  10. def __init__(self) -> None:
  11. self._queue: Deque[torch.cuda.Event] = collections.deque()
  12. self._max_num_inflight_all_gathers = 2 # empirically chosen
  13. def enqueue(self, free_event: torch.cuda.Event) -> None:
  14. """Enqueues a free event."""
  15. self._queue.append(free_event)
  16. def dequeue_if_needed(self) -> Optional[torch.cuda.Event]:
  17. """Dequeues a single event if the limit is reached."""
  18. if len(self._queue) >= self._max_num_inflight_all_gathers:
  19. return self._dequeue()
  20. return None
  21. def _dequeue(self) -> Optional[torch.cuda.Event]:
  22. """Dequeues a free event if possible."""
  23. if self._queue:
  24. event = self._queue.popleft()
  25. return event
  26. return None