123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359 |
- # Copyright 2019 Kakao Brain
- #
- # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
- #
- # This source code is licensed under the BSD license found in the
- # LICENSE file in the root directory of this source tree.
- """Checkpointing with preceding recomputation.
- PyTorch already provides the official checkpointing utilities in
- :mod:`torch.utils.checkpoint`. The official checkpointing combines
- recomputation and recursive backpropagation into one autograd function named
- ``CheckpointFunction``. Hence, the recomputation can be started only when the
- gradients arrive to the function. In Pipe, the recomputation needs to precede
- the gradient arrival to minimize the GPU idle time.
- We solve this problem by introducing separate autograd functions named
- :class:`Recompute` and :class:`Checkpoint`. Each function represents
- recomputation and recursive backpropagation, respectively. We can manipulate
- the control flow in aspect of both the autograd engine and CUDA with a pair of
- the functions.
- Specifically, we place CUDA stream synchronization between :class:`Recompute`
- and :class:`Checkpoint` to delay only :class:`Checkpoint` until the gradient is
- copied entirely.
- """
- from collections import deque
- from contextlib import contextmanager
- import threading
- from typing import (
- Any,
- Deque,
- Generator,
- List,
- Optional,
- Protocol,
- Union,
- Sequence,
- Tuple
- )
- import torch
- from torch import Tensor
- import torch.autograd
- from .dependency import fork, join
- from .microbatch import Batch
- from .phony import get_phony
- __all__ = ["Function", "checkpoint", "Checkpointing", "ThreadLocal", "enable_checkpointing",
- "enable_recomputing", "is_checkpointing", "is_recomputing", "Context", "save_rng_states",
- "restore_rng_states", "Checkpoint", "Recompute"]
- Tensors = Sequence[Tensor]
- TensorOrTensors = Union[Tensor, Tensors]
- # Types for shared memory between Checkpoint and Recompute.
- Recomputed = Tuple[TensorOrTensors, Tensors] # (output, input_leaf)
- RNGStates = Tuple[Tensor, Optional[Tensor]] # (cpu_rng_state, gpu_rng_state)
- # Protocol with __call__ instead of Callable can be used as an attribute type.
- # See: https://github.com/python/mypy/issues/708#issuecomment-561735949
- class Function(Protocol):
- def __call__(self, input: TensorOrTensors) -> TensorOrTensors:
- ...
- def checkpoint(function: Function, input):
- """Makes a checkpoint with a simple interface like
- :func:`torch.utils.checkpoint.checkpoint`. It's only used to test or debug
- :class:`Checkpoint` and :class:`Recompute` without boilerplate.
- """
- batch = Batch(input)
- chk = Checkpointing(function, batch)
- batch = chk.checkpoint()
- chk.recompute(batch)
- return batch.values
- class Checkpointing:
- """Generates a pair of :class:`Checkpoint` and :class:`Recompute`."""
- def __init__(self, function: Function, batch: Batch) -> None:
- self.function = function
- self.batch = batch
- # Shared memory between Checkpoint and Recompute. 1-length deque is
- # used for mutability and length limitation.
- self.recomputed: Deque[Recomputed] = deque(maxlen=1)
- self.rng_states: Deque[RNGStates] = deque(maxlen=1)
- def checkpoint(self) -> Batch:
- """Returns a batch applied by :class:`Checkpoint`."""
- input_atomic = self.batch.atomic
- inputs = tuple(self.batch)
- # Use a phony which requires grad to ensure that Checkpoint can be
- # tracked by the autograd engine even when none of the input tensors
- # require grad.
- phony = get_phony(self.batch.get_device(), requires_grad=True)
- output = Checkpoint.apply(phony, self.recomputed, self.rng_states, self.function, input_atomic, *inputs)
- # Gradients are only supported for float Tensors.
- if isinstance(output, tuple):
- output = tuple([x.detach() if torch.is_tensor(x) and not x.is_floating_point() else x for x in output])
- return Batch(output)
- def recompute(self, batch: Batch) -> None:
- """Applies :class:`Recompute` to the batch in place."""
- input_atomic = self.batch.atomic
- inputs = tuple(self.batch)
- # Use a tensor in the batch to tie together fork-join
- tensor_idx = batch.find_tensor_idx()
- # batch[tensor_idx] is always requiring grad, because it has been passed
- # checkpoint with a phony requiring grad.
- batch[tensor_idx], phony = fork(batch[tensor_idx])
- phony = Recompute.apply(phony, self.recomputed, self.rng_states, self.function, input_atomic, *inputs)
- batch[tensor_idx] = join(batch[tensor_idx], phony)
- class ThreadLocal(threading.local):
- def __init__(self) -> None:
- self.is_checkpointing = False
- self.is_recomputing = False
- thread_local = ThreadLocal()
- @contextmanager
- def enable_checkpointing() -> Generator[None, None, None]:
- """Makes :func:`is_checkpointing` return :data:`True` within a context."""
- orig = thread_local.is_checkpointing
- thread_local.is_checkpointing = True
- try:
- yield
- finally:
- thread_local.is_checkpointing = orig
- @contextmanager
- def enable_recomputing() -> Generator[None, None, None]:
- """Makes :func:`is_recomputing` return :data:`True` within a context."""
- orig = thread_local.is_recomputing
- thread_local.is_recomputing = True
- try:
- yield
- finally:
- thread_local.is_recomputing = orig
- def is_checkpointing() -> bool:
- """Whether the current forward propagation is under checkpointing.
- Returns:
- bool: :data:`True` if it's under checkpointing.
- """
- return thread_local.is_checkpointing
- def is_recomputing() -> bool:
- """Whether the current forward propagation is under checkpoint
- recomputation. Use this to prevent duplicated side-effects at forward
- propagation::
- class Counter(nn.Module):
- def __init__(self):
- super().__init__()
- self.counter = 0
- def forward(self, input):
- if not is_recomputing():
- self.counter += 1
- return input
- Returns:
- bool: :data:`True` if it's under checkpoint recomputation.
- .. seealso:: :ref:`Detecting Recomputation`
- """
- return thread_local.is_recomputing
- class Context:
- """The common interface between the :class:`Checkpoint` and
- :class:`Recompute` context.
- """
- recomputed: Deque[Recomputed]
- rng_states: Deque[RNGStates]
- function: Function
- input_atomic: bool
- inputs: Sequence[Any]
- saved_tensors: Tuple[Tensor, ...]
- def save_for_backward(self, *tensors: Tensor) -> None: # pragma: no cover
- pass
- def save_rng_states(device: torch.device, rng_states: Deque[RNGStates],) -> None:
- """:meth:`Checkpoint.forward` captures the current PyTorch's random number
- generator states at CPU and GPU to reuse in :meth:`Recompute.backward`.
- .. seealso:: :ref:`Referential Transparency`
- """
- cpu_rng_state = torch.get_rng_state()
- gpu_rng_state: Optional[Tensor]
- if device.type == "cuda":
- gpu_rng_state = torch.cuda.get_rng_state(device)
- else:
- gpu_rng_state = None
- rng_states.append((cpu_rng_state, gpu_rng_state))
- @contextmanager
- def restore_rng_states(device: torch.device, rng_states: Deque[RNGStates],) -> Generator[None, None, None]:
- """:meth:`Recompute.backward` restores the random number generator states
- captured by :func:`save_rng_states` within its context.
- .. seealso:: :ref:`Referential Transparency`
- """
- cpu_rng_state, gpu_rng_state = rng_states.pop()
- gpu_devices: List[torch.device] = []
- if device.type == "cuda":
- gpu_devices.append(device)
- with torch.random.fork_rng(gpu_devices):
- torch.set_rng_state(cpu_rng_state)
- if gpu_rng_state is not None:
- torch.cuda.set_rng_state(gpu_rng_state, device)
- yield
- class Checkpoint(torch.autograd.Function):
- @staticmethod
- # type: ignore[override]
- def forward(
- ctx: Context,
- phony: Tensor,
- recomputed: Deque[Recomputed],
- rng_states: Deque[RNGStates],
- function: Function,
- input_atomic: bool,
- *inputs,
- ):
- ctx.recomputed = recomputed
- ctx.rng_states = rng_states
- save_rng_states(phony.device, ctx.rng_states)
- ctx.function = function
- ctx.input_atomic = input_atomic
- if input_atomic:
- tensors = [inputs[0]]
- else:
- tensors = []
- for input in inputs:
- if torch.is_tensor(input):
- tensors.append(input)
- ctx.save_for_backward(*tensors)
- with torch.no_grad(), enable_checkpointing():
- if input_atomic:
- assert len(inputs) == 1
- output = function(inputs[0])
- else:
- output = function(*inputs)
- return output
- @staticmethod
- def backward(ctx: Context, *grad_output: Tensor,) -> Tuple[Optional[Tensor], ...]: # pragma: no cover
- output, input_leaf = ctx.recomputed.pop()
- if isinstance(output, tuple):
- outputs = output
- else:
- outputs = (output,)
- if any(torch.is_tensor(y) and y.requires_grad for y in outputs):
- tensors = tuple([x for x in outputs if torch.is_tensor(x) and x.requires_grad])
- torch.autograd.backward(tensors, grad_output)
- grad_input: List[Optional[Tensor]] = [None, None, None, None, None]
- grad_input.extend(x.grad if torch.is_tensor(x) else None for x in input_leaf)
- return tuple(grad_input)
- class Recompute(torch.autograd.Function):
- @staticmethod
- # type: ignore[override]
- def forward(
- ctx: Context,
- phony: Tensor,
- recomputed: Deque[Recomputed],
- rng_states: Deque[RNGStates],
- function: Function,
- input_atomic: bool,
- *inputs,
- ) -> Tensor:
- ctx.recomputed = recomputed
- ctx.rng_states = rng_states
- ctx.function = function
- ctx.input_atomic = input_atomic
- ctx.inputs = inputs
- if input_atomic:
- tensors = [inputs[0]]
- else:
- tensors = []
- for input in inputs:
- if torch.is_tensor(input):
- tensors.append(input)
- ctx.save_for_backward(*tensors)
- return phony
- @staticmethod
- def backward(ctx: Context, *grad_output: Tensor) -> Tuple[None, ...]: # pragma: no cover
- inputs = ctx.inputs
- inputs_leaf = tuple(x.detach().requires_grad_(x.requires_grad) if torch.is_tensor(x) else x for x in inputs)
- # Get the device for the inputs from a tensor
- device = None
- for input in inputs:
- if torch.is_tensor(input):
- device = input.device
- break
- if device is None:
- raise RuntimeError(f'No tensors found in {inputs}')
- with restore_rng_states(device, ctx.rng_states):
- with torch.enable_grad(), enable_recomputing():
- if ctx.input_atomic:
- assert len(inputs_leaf) == 1
- output = ctx.function(inputs_leaf[0])
- else:
- output = ctx.function(*inputs_leaf)
- ctx.recomputed.append((output, inputs_leaf))
- grad_input: List[None] = [None, None, None, None, None]
- grad_input.extend(None for _ in ctx.inputs)
- return tuple(grad_input)
|