pipeline.py 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258
  1. # -*- coding: utf-8 -*-
  2. # Copyright 2019 Kakao Brain
  3. #
  4. # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
  5. #
  6. # This source code is licensed under the BSD license found in the
  7. # LICENSE file in the root directory of this source tree.
  8. """The pipeline parallelism of Pipe."""
  9. from queue import Queue
  10. from types import TracebackType
  11. from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple, Type, Union, cast, Sequence
  12. import torch
  13. from torch import Tensor, nn
  14. from torch.autograd.profiler import record_function
  15. from .checkpoint import Checkpointing
  16. from .copy import Copy, Wait
  17. from .dependency import fork, join
  18. from .microbatch import Batch
  19. from .skip.layout import SkipLayout
  20. from .skip.tracker import SkipTrackerThroughPotals, use_skip_tracker
  21. from .stream import AbstractStream, current_stream, use_device
  22. from .worker import Task, create_workers
  23. __all__: List[str] = ["Pipeline"]
  24. Tensors = Sequence[Tensor]
  25. TensorOrTensors = Union[Tensor, Tensors]
  26. ExcInfo = Tuple[Type[BaseException], BaseException, TracebackType]
  27. # Queue is generic only in stubs.
  28. # https://mypy.readthedocs.io/en/latest/common_issues.html#using-classes-that-are-generic-in-stubs-but-not-at-runtime
  29. if TYPE_CHECKING:
  30. InQueue = Queue[Optional["Task"]]
  31. OutQueue = Queue[Tuple[bool, Union[Tuple["Task", Batch], ExcInfo, None]]]
  32. else:
  33. InQueue = Queue
  34. OutQueue = Queue
  35. def _depend(fork_from: Batch, join_to: Batch) -> None:
  36. fork_from_idx = fork_from.find_tensor_idx()
  37. join_to_idx = join_to.find_tensor_idx()
  38. fork_from[fork_from_idx], phony = fork(fork_from[fork_from_idx])
  39. join_to[join_to_idx] = join(join_to[join_to_idx], phony)
  40. def _copy(batch: Batch, prev_stream: AbstractStream, next_stream: AbstractStream) -> None:
  41. batch[:] = Copy.apply(prev_stream, next_stream, *batch)
  42. # Gradients are only supported for float Tensors.
  43. batch[:] = tuple([x.detach() if torch.is_tensor(x) and not x.is_floating_point() else x for x in batch])
  44. def _wait(batch: Batch, prev_stream: AbstractStream, next_stream: AbstractStream) -> None:
  45. batch[:] = Wait.apply(prev_stream, next_stream, *batch)
  46. # Gradients are only supported for float Tensors.
  47. batch[:] = tuple([x.detach() if torch.is_tensor(x) and not x.is_floating_point() else x for x in batch])
  48. def _clock_cycles(m: int, n: int) -> Iterable[List[Tuple[int, int]]]:
  49. """Generates schedules for each clock cycle."""
  50. # m: number of micro-batches
  51. # n: number of partitions
  52. # i: index of micro-batch
  53. # j: index of partition
  54. # k: clock number
  55. #
  56. # k (i,j) (i,j) (i,j)
  57. # - ----- ----- -----
  58. # 0 (0,0)
  59. # 1 (1,0) (0,1)
  60. # 2 (2,0) (1,1) (0,2)
  61. # 3 (2,1) (1,2)
  62. # 4 (2,2)
  63. for k in range(m + n - 1):
  64. yield [(k - j, j) for j in range(max(1 + k - m, 0), min(1 + k, n))]
  65. class Pipeline:
  66. """The pipeline parallelism for Pipe."""
  67. def __init__(
  68. self,
  69. partitions: List[nn.Sequential],
  70. devices: List[torch.device],
  71. copy_streams: List[List[AbstractStream]],
  72. skip_layout: SkipLayout,
  73. checkpoint_stop: int,
  74. ) -> None:
  75. self.partitions = partitions
  76. self.devices = devices
  77. self.copy_streams = copy_streams
  78. self.skip_layout = skip_layout
  79. self.checkpoint_stop = checkpoint_stop
  80. (self.in_queues, self.out_queues) = create_workers(devices)
  81. def run(self, batches: List[Batch]) -> None:
  82. """Runs pipeline parallelism.
  83. It modifies the given batches in place.
  84. """
  85. partitions = self.partitions
  86. devices = self.devices
  87. skip_layout = self.skip_layout
  88. m = len(batches)
  89. n = len(partitions)
  90. skip_trackers = [SkipTrackerThroughPotals(skip_layout) for _ in batches]
  91. for schedule in _clock_cycles(m, n):
  92. self.fence(batches, schedule, skip_trackers)
  93. self.compute(batches, schedule, skip_trackers)
  94. def fence(
  95. self, batches: List[Batch], schedule: List[Tuple[int, int]], skip_trackers: List[SkipTrackerThroughPotals],
  96. ) -> None:
  97. """Copies micro-batches after computation for the previous
  98. micro-batches.
  99. """
  100. copy_streams = self.copy_streams
  101. skip_layout = self.skip_layout
  102. for i, j in schedule:
  103. # Ensure that batches[i-1] is executed after batches[i] in
  104. # backpropagation by an explicit dependency.
  105. if i != 0 and j != 0:
  106. _depend(batches[i - 1], batches[i])
  107. next_stream = copy_streams[j][i]
  108. for prev_j, ns, name in skip_layout.copy_policy(j):
  109. prev_stream = copy_streams[prev_j][i]
  110. skip_trackers[i].copy(batches[i], prev_stream, next_stream, ns, name)
  111. if j != 0:
  112. prev_stream = copy_streams[j - 1][i]
  113. _copy(batches[i], prev_stream, next_stream)
  114. def compute(
  115. self, batches: List[Batch], schedule: List[Tuple[int, int]], skip_trackers: List[SkipTrackerThroughPotals],
  116. ) -> None:
  117. """Runs tasks with synchronization to copy streams."""
  118. partitions = self.partitions
  119. devices = self.devices
  120. copy_streams = self.copy_streams
  121. checkpoint_stop = self.checkpoint_stop
  122. # Disable checkpointing if in eval mode.
  123. if not self.partitions[0].training:
  124. checkpoint_stop = 0
  125. n = len(partitions)
  126. streams = [current_stream(d) for d in devices]
  127. exc_info: Optional[ExcInfo] = None
  128. # With checkpointing, the autograd graph looks like this diagram:
  129. # ┌─────┸──────┐
  130. # │ Copy │
  131. # └─────┰──────┘ (fence)
  132. # ─ ─ ─ ╂ ─ ─ ─ ─ ─ ─ ─ ─ ─
  133. # ┃ (compute)
  134. # ┌─────┸──────┐
  135. # │ Wait │ [1] Synchronize the current stream with the copy stream.
  136. # └─────┰──────┘
  137. # ┌─────┸──────┐
  138. # │ Checkpoint │ [2] Compute a partition within checkpointing.
  139. # └─────┰──────┘
  140. # ┌─────┸──────┐
  141. # │ Wait │ [3] Synchronize the copy stream with the current stream.
  142. # └─────┰──────┘
  143. # ┠ ─ ─ ─ ┐
  144. # ┃ ┌─────┴─────┐
  145. # ┃ │ Recompute │ [4] Schedule the recomputation at backpropagation.
  146. # ┃ └─────┬─────┘
  147. # ┠ ─ ─ ─ ┘
  148. # ┃
  149. # ─ ─ ─ ╂ ─ ─ ─ ─ ─ ─ ─ ─ ─
  150. # ┌─────┸──────┐ (fence)
  151. # │ Copy │
  152. # └─────┰──────┘
  153. for i, j in schedule:
  154. batch = batches[i]
  155. partition = partitions[j]
  156. # Synchronize with the copied input. ([1] in the diagram)
  157. if j != 0:
  158. _wait(batch, copy_streams[j][i], streams[j])
  159. # Determine whether checkpointing or not.
  160. checkpoint = i < checkpoint_stop
  161. if checkpoint:
  162. def function(
  163. *inputs,
  164. partition: nn.Module = partition,
  165. skip_tracker: SkipTrackerThroughPotals = skip_trackers[i],
  166. chunk_id: int = i,
  167. part_id: int = j,
  168. ) -> TensorOrTensors:
  169. with use_skip_tracker(skip_tracker), record_function("chunk%d-part%d" % (chunk_id, part_id)):
  170. return partition(*inputs)
  171. chk = Checkpointing(function, batch) # type: ignore[arg-type]
  172. task = Task(streams[j], compute=chk.checkpoint, finalize=chk.recompute)
  173. del function, chk
  174. else:
  175. def compute(
  176. batch: Batch = batch,
  177. partition: nn.Module = partition,
  178. skip_tracker: SkipTrackerThroughPotals = skip_trackers[i],
  179. chunk_id: int = i,
  180. part_id: int = j,
  181. ) -> Batch:
  182. with use_skip_tracker(skip_tracker), record_function("chunk%d-part%d" % (chunk_id, part_id)):
  183. return batch.call(partition)
  184. task = Task(streams[j], compute=compute, finalize=None)
  185. del compute
  186. # Compute tasks in parallel. ([2] in the diagram)
  187. self.in_queues[j].put(task)
  188. for i, j in schedule:
  189. ok, payload = self.out_queues[j].get()
  190. # Hold the first exception.
  191. if exc_info is not None:
  192. continue
  193. elif not ok:
  194. exc_info = cast(ExcInfo, payload)
  195. continue
  196. task, batch = cast(Tuple[Task, Batch], payload)
  197. # The copy stream synchronizes to copy the output. ([3] in the
  198. # diagram)
  199. if j != n - 1:
  200. _wait(batch, streams[j], copy_streams[j][i])
  201. # Finalize tasks. If checkpointing is enabled, here the
  202. # recomputation is scheduled at backpropagation. ([4] in the
  203. # diagram)
  204. with use_device(devices[j]):
  205. task.finalize(batch)
  206. batches[i] = batch
  207. # Fail at the first exception.
  208. if exc_info is not None:
  209. raise exc_info[0].with_traceback(exc_info[1], exc_info[2])