123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277 |
- from contextlib import contextmanager
- from functools import partial
- from typing import Any, List, Optional, Tuple
- from weakref import ref, ReferenceType, WeakKeyDictionary
- import torch
- import torch.nn as nn
- from torch.utils.checkpoint import detach_variable, get_device_states, set_device_states
- from .contract import contract
- @contextmanager
- def _no_hook(module: nn.Module):
- r"""
- Disable hooks installed by checkpoint to avoid unintentional recursion
- during backward recomputation.
- """
- orig_enable_hook = checkpoint.state(module).enable_hook
- checkpoint.state(module).enable_hook = False
- try:
- yield
- except Exception:
- raise
- finally:
- checkpoint.state(module).enable_hook = orig_enable_hook
- class _ModuleHookCheckpointFunction(torch.autograd.Function):
- @staticmethod
- def forward(ctx, module: nn.Module, output: Any, *inputs: Any) -> Any: # type: ignore[override]
- ctx.module = module
- # Save non-tensor inputs in ctx, keep a placeholder None for tensors
- # to be filled out during the backward.
- ctx.inputs = []
- ctx.tensor_indices = []
- tensor_inputs = []
- for i, inp in enumerate(inputs):
- if torch.is_tensor(inp):
- tensor_inputs.append(inp)
- ctx.tensor_indices.append(i)
- ctx.inputs.append(None)
- else:
- ctx.inputs.append(inp)
- ctx.save_for_backward(*tensor_inputs)
- return output
- @staticmethod
- def backward(ctx, output_grads: Tuple[Optional[torch.Tensor]]) -> Any: # type: ignore[override]
- if not torch.autograd._is_checkpoint_valid():
- raise RuntimeError(
- "Checkpointing is not compatible with .grad() or when an "
- "`inputs` parameter is passed to .backward(). Please use "
- ".backward() and do not pass its `inputs` argument."
- )
- # Copy the list to avoid modifying original list.
- inputs = list(ctx.inputs)
- tensor_indices = ctx.tensor_indices
- tensors = ctx.saved_tensors
- # Fill in inputs with appropriate saved tensors.
- for i, idx in enumerate(tensor_indices):
- inputs[idx] = tensors[i]
- # Stash the surrounding rng state, and mimic the state that was
- # present at this time during forward. Restore the surrounding state
- # when we're done.
- rng_devices = []
- if checkpoint.state(ctx.module).had_cuda_in_fwd:
- rng_devices = checkpoint.state(ctx.module).fwd_gpu_devices
- with torch.random.fork_rng(devices=rng_devices, enabled=True):
- torch.set_rng_state(checkpoint.state(ctx.module).fwd_cpu_state)
- if checkpoint.state(ctx.module).had_cuda_in_fwd:
- set_device_states(
- checkpoint.state(ctx.module).fwd_gpu_devices,
- checkpoint.state(ctx.module).fwd_gpu_states,
- )
- detached_inputs = detach_variable(tuple(inputs))
- with torch.enable_grad(), _no_hook(ctx.module):
- outputs = ctx.module(*detached_inputs)
- if isinstance(outputs, torch.Tensor):
- outputs = (outputs,)
- if isinstance(output_grads, torch.Tensor):
- output_grads = (output_grads,)
- # run backward() with only tensor that requires grad
- outputs_requires_grad: List[torch.Tensor] = []
- output_grad_tensors: List[torch.Tensor] = []
- for i in range(len(outputs)):
- if torch.is_tensor(outputs[i]) and outputs[i].requires_grad:
- outputs_requires_grad.append(outputs[i])
- assert (
- output_grads[i] is not None
- ), f"expecting grad for output at index {i}, but got None."
- output_grad_tensors.append(output_grads[i]) # type: ignore[arg-type]
- if len(outputs_requires_grad) == 0:
- raise RuntimeError(
- "none of output has requires_grad=True,"
- " this checkpoint() is not necessary"
- )
- torch.autograd.backward(outputs_requires_grad, output_grad_tensors)
- grads = tuple(
- inp.grad if isinstance(inp, torch.Tensor) else None
- for inp in detached_inputs
- )
- # The two None is for forward argument module and output respectively.
- return (None, None) + grads
- class _Holder:
- pass
- def _pack(
- x: torch.Tensor,
- *,
- weak_holder_list: List[ReferenceType],
- ) -> _Holder:
- res = _Holder()
- weak_holder_list.append(ref(res))
- return res
- def _unpack(
- holder: _Holder,
- *,
- storage: WeakKeyDictionary,
- weak_holder_list: List[ReferenceType],
- module: nn.Module,
- inputs: Tuple[Any],
- ) -> torch.Tensor:
- holder_index = 0
- if len(storage) == 0:
- def inner_pack(inner: torch.Tensor):
- nonlocal holder_index
- if weak_holder_list[holder_index]() is None:
- # If the holder went out of scope, the SavedVariable is dead
- # and so the value will never be read from the storage. Skip
- # filling it.
- pass
- else:
- # Use detach here to ensure we don't keep the temporary
- # autograd graph created during the second forward
- storage[weak_holder_list[holder_index]()] = inner.detach()
- holder_index += 1
- return
- def inner_unpack(holder: _Holder):
- raise RuntimeError(
- "You are calling backwards on a tensor that is never exposed. "
- "Please open an issue."
- )
- with _no_hook(
- module
- ), torch.enable_grad(), torch.autograd.graph.saved_tensors_hooks(
- inner_pack, inner_unpack
- ):
- _unused = module(*inputs)
- if holder not in storage:
- raise RuntimeError(
- "Attempt to retrieve a tensor saved by autograd multiple times "
- "without checkpoint recomputation being triggered in between, this "
- "is not currently supported. Please open an issue with details on "
- "your use case so that we can prioritize adding this."
- )
- return storage[holder]
- @contract()
- def checkpoint(module: nn.Module, *, use_reentrant: bool = True) -> nn.Module:
- r"""
- This is a composable activation checkpointing API. Unlike functional
- activation checkpointing APIs, this one does not require changing model
- source code. Unlike ``nn.Module`` wrapper activation checkpointing APIs,
- this one does not modify model structure or fully-qualified names either.
- Under the hood, it registers activation checkpointing logic as pre- and
- post-forward hooks. Hence, this API can be easily applied to any model or
- sub-modules in the model.
- Args:
- module (nn.Module): the target model or sub-module to apply activation
- checkpointing.
- use_reentrant (bool): Apply activation checkpointing using reentrant
- autograd.
- Example::
- >>> # xdoctest: +SKIP
- >>> import torch.nn as nn
- >>>
- >>> class MyModel(nn.Module):
- >>> def __init__(self):
- >>> super().__init__()
- >>> self.l1 = nn.Linear(10, 10)
- >>> self.l2 = nn.Linear(10, 10)
- >>>
- >>> def forward(self, x):
- >>> return self.l2(self.l1(x))
- >>>
- >>> model = MyModel()
- >>> checkpoint(model.l1) # apply activation checkpointing only to l1
- >>> model(torch.zeros(2, 10)).sum().backward()
- """
- def forward_pre_hook(module: nn.Module, inputs: Tuple[Any, ...]) -> None:
- if checkpoint.state(module).enable_hook:
- checkpoint.state(module).orig_grad_enabled = torch.is_grad_enabled()
- if checkpoint.state(module).use_reentrant:
- torch.set_grad_enabled(False)
- checkpoint.state(module).fwd_cpu_state = torch.get_rng_state()
- # Don't eagerly initialize the cuda context by accident.
- # (If the user intends that the context is initialized later, within their
- # run_function, we SHOULD actually stash the cuda state here. Unfortunately,
- # we have no way to anticipate this will happen before we run the function.)
- checkpoint.state(module).had_cuda_in_fwd = False
- if torch.cuda._initialized:
- checkpoint.state(module).had_cuda_in_fwd = True
- (
- checkpoint.state(module).fwd_gpu_devices,
- checkpoint.state(module).fwd_gpu_states,
- ) = get_device_states(*inputs)
- else:
- # The Holder object for each of the saved object is saved
- # directly on the SavedVariable and is cleared when reset_data()
- # is called on it. We MUST make sure that this is the only
- # object having an owning reference to ensure that the Tensor
- # stored in storage is deleted as soon as the corresponding
- # SavedVariable data is cleared.
- storage: WeakKeyDictionary = WeakKeyDictionary()
- weak_holder_list: List[ReferenceType] = []
- saved_tensor_hooks = torch.autograd.graph.saved_tensors_hooks(
- partial(_pack, weak_holder_list=weak_holder_list),
- partial(
- _unpack,
- storage=storage,
- weak_holder_list=weak_holder_list,
- module=module,
- inputs=inputs,
- ),
- )
- saved_tensor_hooks.__enter__()
- checkpoint.state(module).saved_tensor_hooks = saved_tensor_hooks
- def forward_hook(module: nn.Module, inputs: Tuple[Any, ...], output: Any) -> Any:
- if checkpoint.state(module).enable_hook:
- torch.set_grad_enabled(checkpoint.state(module).orig_grad_enabled)
- if checkpoint.state(module).use_reentrant:
- return _ModuleHookCheckpointFunction.apply(module, output, *inputs)
- else:
- checkpoint.state(module).saved_tensor_hooks.__exit__()
- checkpoint.state(module).saved_tensor_hooks = None
- return output
- # This hook does the following things:
- # 1. detach outputs from the autograd graph to discard activations
- # 2. insert an autograd.Function after the forward pass to recompute
- # activations during the backward pass.
- checkpoint.state(module).enable_hook = True
- checkpoint.state(module).use_reentrant = use_reentrant
- module.register_forward_pre_hook(forward_pre_hook)
- # Use prepend to make sure we restore the original grad enabled state right
- # after the module forward invocation.
- module.register_forward_hook(forward_hook, prepend=True)
- return module
|