123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258 |
- # -*- coding: utf-8 -*-
- # Copyright 2019 Kakao Brain
- #
- # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
- #
- # This source code is licensed under the BSD license found in the
- # LICENSE file in the root directory of this source tree.
- """The pipeline parallelism of Pipe."""
- from queue import Queue
- from types import TracebackType
- from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple, Type, Union, cast, Sequence
- import torch
- from torch import Tensor, nn
- from torch.autograd.profiler import record_function
- from .checkpoint import Checkpointing
- from .copy import Copy, Wait
- from .dependency import fork, join
- from .microbatch import Batch
- from .skip.layout import SkipLayout
- from .skip.tracker import SkipTrackerThroughPotals, use_skip_tracker
- from .stream import AbstractStream, current_stream, use_device
- from .worker import Task, create_workers
- __all__: List[str] = ["Pipeline"]
- Tensors = Sequence[Tensor]
- TensorOrTensors = Union[Tensor, Tensors]
- ExcInfo = Tuple[Type[BaseException], BaseException, TracebackType]
- # Queue is generic only in stubs.
- # https://mypy.readthedocs.io/en/latest/common_issues.html#using-classes-that-are-generic-in-stubs-but-not-at-runtime
- if TYPE_CHECKING:
- InQueue = Queue[Optional["Task"]]
- OutQueue = Queue[Tuple[bool, Union[Tuple["Task", Batch], ExcInfo, None]]]
- else:
- InQueue = Queue
- OutQueue = Queue
- def _depend(fork_from: Batch, join_to: Batch) -> None:
- fork_from_idx = fork_from.find_tensor_idx()
- join_to_idx = join_to.find_tensor_idx()
- fork_from[fork_from_idx], phony = fork(fork_from[fork_from_idx])
- join_to[join_to_idx] = join(join_to[join_to_idx], phony)
- def _copy(batch: Batch, prev_stream: AbstractStream, next_stream: AbstractStream) -> None:
- batch[:] = Copy.apply(prev_stream, next_stream, *batch)
- # Gradients are only supported for float Tensors.
- batch[:] = tuple([x.detach() if torch.is_tensor(x) and not x.is_floating_point() else x for x in batch])
- def _wait(batch: Batch, prev_stream: AbstractStream, next_stream: AbstractStream) -> None:
- batch[:] = Wait.apply(prev_stream, next_stream, *batch)
- # Gradients are only supported for float Tensors.
- batch[:] = tuple([x.detach() if torch.is_tensor(x) and not x.is_floating_point() else x for x in batch])
- def _clock_cycles(m: int, n: int) -> Iterable[List[Tuple[int, int]]]:
- """Generates schedules for each clock cycle."""
- # m: number of micro-batches
- # n: number of partitions
- # i: index of micro-batch
- # j: index of partition
- # k: clock number
- #
- # k (i,j) (i,j) (i,j)
- # - ----- ----- -----
- # 0 (0,0)
- # 1 (1,0) (0,1)
- # 2 (2,0) (1,1) (0,2)
- # 3 (2,1) (1,2)
- # 4 (2,2)
- for k in range(m + n - 1):
- yield [(k - j, j) for j in range(max(1 + k - m, 0), min(1 + k, n))]
- class Pipeline:
- """The pipeline parallelism for Pipe."""
- def __init__(
- self,
- partitions: List[nn.Sequential],
- devices: List[torch.device],
- copy_streams: List[List[AbstractStream]],
- skip_layout: SkipLayout,
- checkpoint_stop: int,
- ) -> None:
- self.partitions = partitions
- self.devices = devices
- self.copy_streams = copy_streams
- self.skip_layout = skip_layout
- self.checkpoint_stop = checkpoint_stop
- (self.in_queues, self.out_queues) = create_workers(devices)
- def run(self, batches: List[Batch]) -> None:
- """Runs pipeline parallelism.
- It modifies the given batches in place.
- """
- partitions = self.partitions
- devices = self.devices
- skip_layout = self.skip_layout
- m = len(batches)
- n = len(partitions)
- skip_trackers = [SkipTrackerThroughPotals(skip_layout) for _ in batches]
- for schedule in _clock_cycles(m, n):
- self.fence(batches, schedule, skip_trackers)
- self.compute(batches, schedule, skip_trackers)
- def fence(
- self, batches: List[Batch], schedule: List[Tuple[int, int]], skip_trackers: List[SkipTrackerThroughPotals],
- ) -> None:
- """Copies micro-batches after computation for the previous
- micro-batches.
- """
- copy_streams = self.copy_streams
- skip_layout = self.skip_layout
- for i, j in schedule:
- # Ensure that batches[i-1] is executed after batches[i] in
- # backpropagation by an explicit dependency.
- if i != 0 and j != 0:
- _depend(batches[i - 1], batches[i])
- next_stream = copy_streams[j][i]
- for prev_j, ns, name in skip_layout.copy_policy(j):
- prev_stream = copy_streams[prev_j][i]
- skip_trackers[i].copy(batches[i], prev_stream, next_stream, ns, name)
- if j != 0:
- prev_stream = copy_streams[j - 1][i]
- _copy(batches[i], prev_stream, next_stream)
- def compute(
- self, batches: List[Batch], schedule: List[Tuple[int, int]], skip_trackers: List[SkipTrackerThroughPotals],
- ) -> None:
- """Runs tasks with synchronization to copy streams."""
- partitions = self.partitions
- devices = self.devices
- copy_streams = self.copy_streams
- checkpoint_stop = self.checkpoint_stop
- # Disable checkpointing if in eval mode.
- if not self.partitions[0].training:
- checkpoint_stop = 0
- n = len(partitions)
- streams = [current_stream(d) for d in devices]
- exc_info: Optional[ExcInfo] = None
- # With checkpointing, the autograd graph looks like this diagram:
- # ┌─────┸──────┐
- # │ Copy │
- # └─────┰──────┘ (fence)
- # ─ ─ ─ ╂ ─ ─ ─ ─ ─ ─ ─ ─ ─
- # ┃ (compute)
- # ┌─────┸──────┐
- # │ Wait │ [1] Synchronize the current stream with the copy stream.
- # └─────┰──────┘
- # ┌─────┸──────┐
- # │ Checkpoint │ [2] Compute a partition within checkpointing.
- # └─────┰──────┘
- # ┌─────┸──────┐
- # │ Wait │ [3] Synchronize the copy stream with the current stream.
- # └─────┰──────┘
- # ┠ ─ ─ ─ ┐
- # ┃ ┌─────┴─────┐
- # ┃ │ Recompute │ [4] Schedule the recomputation at backpropagation.
- # ┃ └─────┬─────┘
- # ┠ ─ ─ ─ ┘
- # ┃
- # ─ ─ ─ ╂ ─ ─ ─ ─ ─ ─ ─ ─ ─
- # ┌─────┸──────┐ (fence)
- # │ Copy │
- # └─────┰──────┘
- for i, j in schedule:
- batch = batches[i]
- partition = partitions[j]
- # Synchronize with the copied input. ([1] in the diagram)
- if j != 0:
- _wait(batch, copy_streams[j][i], streams[j])
- # Determine whether checkpointing or not.
- checkpoint = i < checkpoint_stop
- if checkpoint:
- def function(
- *inputs,
- partition: nn.Module = partition,
- skip_tracker: SkipTrackerThroughPotals = skip_trackers[i],
- chunk_id: int = i,
- part_id: int = j,
- ) -> TensorOrTensors:
- with use_skip_tracker(skip_tracker), record_function("chunk%d-part%d" % (chunk_id, part_id)):
- return partition(*inputs)
- chk = Checkpointing(function, batch) # type: ignore[arg-type]
- task = Task(streams[j], compute=chk.checkpoint, finalize=chk.recompute)
- del function, chk
- else:
- def compute(
- batch: Batch = batch,
- partition: nn.Module = partition,
- skip_tracker: SkipTrackerThroughPotals = skip_trackers[i],
- chunk_id: int = i,
- part_id: int = j,
- ) -> Batch:
- with use_skip_tracker(skip_tracker), record_function("chunk%d-part%d" % (chunk_id, part_id)):
- return batch.call(partition)
- task = Task(streams[j], compute=compute, finalize=None)
- del compute
- # Compute tasks in parallel. ([2] in the diagram)
- self.in_queues[j].put(task)
- for i, j in schedule:
- ok, payload = self.out_queues[j].get()
- # Hold the first exception.
- if exc_info is not None:
- continue
- elif not ok:
- exc_info = cast(ExcInfo, payload)
- continue
- task, batch = cast(Tuple[Task, Batch], payload)
- # The copy stream synchronizes to copy the output. ([3] in the
- # diagram)
- if j != n - 1:
- _wait(batch, streams[j], copy_streams[j][i])
- # Finalize tasks. If checkpointing is enabled, here the
- # recomputation is scheduled at backpropagation. ([4] in the
- # diagram)
- with use_device(devices[j]):
- task.finalize(batch)
- batches[i] = batch
- # Fail at the first exception.
- if exc_info is not None:
- raise exc_info[0].with_traceback(exc_info[1], exc_info[2])
|