worker.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132
  1. # Copyright 2019 Kakao Brain
  2. #
  3. # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
  4. #
  5. # This source code is licensed under the BSD license found in the
  6. # LICENSE file in the root directory of this source tree.
  7. """Multithreading in pipeline parallelism."""
  8. from contextlib import contextmanager
  9. from queue import Queue
  10. import sys
  11. from threading import Thread
  12. from types import TracebackType
  13. from typing import TYPE_CHECKING, Callable, Dict, Generator, List, Optional, Tuple, Type, Union, cast
  14. import torch
  15. from .microbatch import Batch
  16. from .stream import AbstractStream, use_device, use_stream
  17. __all__: List[str] = ["Task", "worker", "create_workers", "spawn_workers"]
  18. ExcInfo = Tuple[Type[BaseException], BaseException, TracebackType]
  19. # Queue is generic only in stubs.
  20. # https://mypy.readthedocs.io/en/latest/common_issues.html#using-classes-that-are-generic-in-stubs-but-not-at-runtime
  21. if TYPE_CHECKING:
  22. InQueue = Queue[Optional["Task"]]
  23. OutQueue = Queue[Tuple[bool, Union[Tuple["Task", Batch], ExcInfo, None]]]
  24. else:
  25. InQueue = Queue
  26. OutQueue = Queue
  27. class Task:
  28. """A task represents how to compute a micro-batch on a partition.
  29. It consists of two parts: :meth:`compute` and :meth:`finalize`.
  30. :meth:`compute` should be executed in worker threads concurrently.
  31. :meth:`finalize` should be executed after when worker threads complete to
  32. execute :meth:`compute`.
  33. :meth:`compute` might be boosted by worker threads. Because it produces
  34. several CUDA API calls by user code. In PyTorch, parallel CUDA API calls
  35. are not serialized through GIL. So more than one CUDA API call can be
  36. produced at the same time.
  37. """
  38. def __init__(
  39. self, stream: AbstractStream, *, compute: Callable[[], Batch], finalize: Optional[Callable[[Batch], None]],
  40. ) -> None:
  41. self.stream = stream
  42. self._compute = compute
  43. self._finalize = finalize
  44. self._grad_enabled = torch.is_grad_enabled()
  45. def compute(self) -> Batch:
  46. with use_stream(self.stream), torch.set_grad_enabled(self._grad_enabled):
  47. return self._compute()
  48. def finalize(self, batch: Batch) -> None:
  49. if self._finalize is None:
  50. return
  51. with use_stream(self.stream), torch.set_grad_enabled(self._grad_enabled):
  52. self._finalize(batch)
  53. def worker(in_queue: InQueue, out_queue: OutQueue, device: torch.device) -> None:
  54. """The main loop of a worker thread."""
  55. with use_device(device):
  56. while True:
  57. task = in_queue.get()
  58. if task is None:
  59. break
  60. try:
  61. batch = task.compute()
  62. except Exception:
  63. exc_info = cast(ExcInfo, sys.exc_info())
  64. out_queue.put((False, exc_info))
  65. continue
  66. out_queue.put((True, (task, batch)))
  67. done = (False, None)
  68. out_queue.put(done)
  69. def create_workers(devices: List[torch.device],) -> Tuple[List[InQueue], List[OutQueue]]:
  70. """Spawns worker threads. A worker thread is bound to a device."""
  71. in_queues: List[InQueue] = []
  72. out_queues: List[OutQueue] = []
  73. # Spawn workers.
  74. workers: Dict[torch.device, Tuple[InQueue, OutQueue]] = {}
  75. def normalize_device(device: torch.device) -> torch.device:
  76. if device.type == "cuda" and device.index is None:
  77. return torch.device("cuda", index=torch.cuda.current_device())
  78. if device.type == "cpu" and device.index is not None:
  79. return torch.device("cpu")
  80. return device
  81. for device in devices:
  82. device = normalize_device(device)
  83. try:
  84. in_queue, out_queue = workers[device]
  85. except KeyError:
  86. in_queue = Queue()
  87. out_queue = Queue()
  88. workers[device] = (in_queue, out_queue)
  89. t = Thread(target=worker, args=(in_queue, out_queue, device), daemon=True,)
  90. t.start()
  91. in_queues.append(in_queue)
  92. out_queues.append(out_queue)
  93. return (in_queues, out_queues)
  94. @contextmanager
  95. def spawn_workers(devices: List[torch.device],) -> Generator[Tuple[List[InQueue], List[OutQueue]], None, None]:
  96. try:
  97. (in_queues, out_queues) = create_workers(devices)
  98. yield (in_queues, out_queues)
  99. finally:
  100. pass