123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441 |
- import torch
- import warnings
- import weakref
- from typing import Any, Iterable, List, Tuple
- __all__ = [
- "checkpoint", "checkpoint_sequential", "CheckpointFunction",
- "check_backward_validity", "detach_variable", "get_device_states",
- "set_device_states",
- ]
- def detach_variable(inputs: Tuple[Any, ...]) -> Tuple[torch.Tensor, ...]:
- if isinstance(inputs, tuple):
- out = []
- for inp in inputs:
- if not isinstance(inp, torch.Tensor):
- out.append(inp)
- continue
- x = inp.detach()
- x.requires_grad = inp.requires_grad
- out.append(x)
- return tuple(out)
- else:
- raise RuntimeError(
- "Only tuple of tensors is supported. Got Unsupported input type: ", type(inputs).__name__)
- def check_backward_validity(inputs: Iterable[Any]) -> None:
- if not any(inp.requires_grad for inp in inputs if isinstance(inp, torch.Tensor)):
- warnings.warn("None of the inputs have requires_grad=True. Gradients will be None")
- # We can't know if the run_fn will internally move some args to different devices,
- # which would require logic to preserve rng states for those devices as well.
- # We could paranoically stash and restore ALL the rng states for all visible devices,
- # but that seems very wasteful for most cases. Compromise: Stash the RNG state for
- # the device of all Tensor args.
- #
- # To consider: maybe get_device_states and set_device_states should reside in torch/random.py?
- def get_device_states(*args) -> Tuple[List[int], List[torch.Tensor]]:
- # This will not error out if "arg" is a CPU tensor or a non-tensor type because
- # the conditionals short-circuit.
- fwd_gpu_devices = list({arg.get_device() for arg in args
- if isinstance(arg, torch.Tensor) and arg.is_cuda})
- fwd_gpu_states = []
- for device in fwd_gpu_devices:
- with torch.cuda.device(device):
- fwd_gpu_states.append(torch.cuda.get_rng_state())
- return fwd_gpu_devices, fwd_gpu_states
- def set_device_states(devices, states) -> None:
- for device, state in zip(devices, states):
- with torch.cuda.device(device):
- torch.cuda.set_rng_state(state)
- def _get_autocast_kwargs():
- gpu_autocast_kwargs = {"enabled": torch.is_autocast_enabled(),
- "dtype": torch.get_autocast_gpu_dtype(),
- "cache_enabled": torch.is_autocast_cache_enabled()}
- cpu_autocast_kwargs = {"enabled": torch.is_autocast_cpu_enabled(),
- "dtype": torch.get_autocast_cpu_dtype(),
- "cache_enabled": torch.is_autocast_cache_enabled()}
- return gpu_autocast_kwargs, cpu_autocast_kwargs
- class CheckpointFunction(torch.autograd.Function):
- @staticmethod
- def forward(ctx, run_function, preserve_rng_state, *args):
- check_backward_validity(args)
- ctx.run_function = run_function
- ctx.preserve_rng_state = preserve_rng_state
- # Accommodates the (remote) possibility that autocast is enabled for cpu AND gpu.
- ctx.gpu_autocast_kwargs, ctx.cpu_autocast_kwargs = _get_autocast_kwargs()
- if preserve_rng_state:
- ctx.fwd_cpu_state = torch.get_rng_state()
- # Don't eagerly initialize the cuda context by accident.
- # (If the user intends that the context is initialized later, within their
- # run_function, we SHOULD actually stash the cuda state here. Unfortunately,
- # we have no way to anticipate this will happen before we run the function.)
- ctx.had_cuda_in_fwd = False
- if torch.cuda._initialized:
- ctx.had_cuda_in_fwd = True
- ctx.fwd_gpu_devices, ctx.fwd_gpu_states = get_device_states(*args)
- # Save non-tensor inputs in ctx, keep a placeholder None for tensors
- # to be filled out during the backward.
- ctx.inputs = []
- ctx.tensor_indices = []
- tensor_inputs = []
- for i, arg in enumerate(args):
- if torch.is_tensor(arg):
- tensor_inputs.append(arg)
- ctx.tensor_indices.append(i)
- ctx.inputs.append(None)
- else:
- ctx.inputs.append(arg)
- ctx.save_for_backward(*tensor_inputs)
- with torch.no_grad():
- outputs = run_function(*args)
- return outputs
- @staticmethod
- def backward(ctx, *args):
- if not torch.autograd._is_checkpoint_valid():
- raise RuntimeError(
- "Checkpointing is not compatible with .grad() or when an `inputs` parameter"
- " is passed to .backward(). Please use .backward() and do not pass its `inputs`"
- " argument.")
- # Copy the list to avoid modifying original list.
- inputs = list(ctx.inputs)
- tensor_indices = ctx.tensor_indices
- tensors = ctx.saved_tensors
- # Fill in inputs with appropriate saved tensors.
- for i, idx in enumerate(tensor_indices):
- inputs[idx] = tensors[i]
- # Stash the surrounding rng state, and mimic the state that was
- # present at this time during forward. Restore the surrounding state
- # when we're done.
- rng_devices = []
- if ctx.preserve_rng_state and ctx.had_cuda_in_fwd:
- rng_devices = ctx.fwd_gpu_devices
- with torch.random.fork_rng(devices=rng_devices, enabled=ctx.preserve_rng_state):
- if ctx.preserve_rng_state:
- torch.set_rng_state(ctx.fwd_cpu_state)
- if ctx.had_cuda_in_fwd:
- set_device_states(ctx.fwd_gpu_devices, ctx.fwd_gpu_states)
- detached_inputs = detach_variable(tuple(inputs))
- with torch.enable_grad(), \
- torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs), \
- torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):
- outputs = ctx.run_function(*detached_inputs)
- if isinstance(outputs, torch.Tensor):
- outputs = (outputs,)
- # run backward() with only tensor that requires grad
- outputs_with_grad = []
- args_with_grad = []
- for i in range(len(outputs)):
- if torch.is_tensor(outputs[i]) and outputs[i].requires_grad:
- outputs_with_grad.append(outputs[i])
- args_with_grad.append(args[i])
- if len(outputs_with_grad) == 0:
- raise RuntimeError(
- "none of output has requires_grad=True,"
- " this checkpoint() is not necessary")
- torch.autograd.backward(outputs_with_grad, args_with_grad)
- grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else None
- for inp in detached_inputs)
- return (None, None) + grads
- def checkpoint(function, *args, use_reentrant: bool = True, **kwargs):
- r"""Checkpoint a model or part of the model
- Checkpointing works by trading compute for memory. Rather than storing all
- intermediate activations of the entire computation graph for computing
- backward, the checkpointed part does **not** save intermediate activations,
- and instead recomputes them in backward pass. It can be applied on any part
- of a model.
- Specifically, in the forward pass, :attr:`function` will run in
- :func:`torch.no_grad` manner, i.e., not storing the intermediate
- activations. Instead, the forward pass saves the inputs tuple and the
- :attr:`function` parameter. In the backwards pass, the saved inputs and
- :attr:`function` is retrieved, and the forward pass is computed on
- :attr:`function` again, now tracking the intermediate activations, and then
- the gradients are calculated using these activation values.
- The output of :attr:`function` can contain non-Tensor values and gradient
- recording is only performed for the Tensor values. Note that if the output
- consists of nested structures (ex: custom objects, lists, dicts etc.)
- consisting of Tensors, these Tensors nested in custom structures will not
- be considered as part of autograd.
- .. warning::
- If :attr:`function` invocation during backward does anything different
- than the one during forward, e.g., due to some global variable, the
- checkpointed version won't be equivalent, and unfortunately it can't be
- detected.
- .. warning::
- If ``use_reentrant=True`` is specified, then if the checkpointed segment
- contains tensors detached from the computational graph by `detach()` or
- `torch.no_grad()`, the backward pass will raise an error. This is
- because `checkpoint` makes all the outputs require gradients which
- causes issues when a tensor is defined to have no gradient in the model.
- To circumvent this, detach the tensors outside of the `checkpoint`
- function. Note that the checkpointed segment can contain tensors
- detached from the computational graph if ``use_reentrant=False`` is
- specified.
- .. warning::
- If ``use_reentrant=True`` is specified, at least one of the inputs needs
- to have :code:`requires_grad=True` if grads are needed for model inputs,
- otherwise the checkpointed part of the model won't have gradients. At
- least one of the outputs needs to have :code:`requires_grad=True` as
- well. Note that this does not apply if ``use_reentrant=False`` is
- specified.
- .. warning::
- If ``use_reentrant=True`` is specified, checkpointing currently only
- supports :func:`torch.autograd.backward` and only if its `inputs`
- argument is not passed. :func:`torch.autograd.grad`
- is not supported. If ``use_reentrant=False`` is specified, checkpointing
- will work with :func:`torch.autograd.grad`.
- Args:
- function: describes what to run in the forward pass of the model or
- part of the model. It should also know how to handle the inputs
- passed as the tuple. For example, in LSTM, if user passes
- ``(activation, hidden)``, :attr:`function` should correctly use the
- first input as ``activation`` and the second input as ``hidden``
- preserve_rng_state(bool, optional): Omit stashing and restoring
- the RNG state during each checkpoint.
- Default: ``True``
- use_reentrant(bool, optional): Use checkpointing
- implementation that requires re-entrant autograd.
- If ``use_reentrant=False`` is specified, ``checkpoint`` will use an
- implementation that does not require re-entrant autograd. This
- allows ``checkpoint`` to support additional functionality, such as
- working as expected with ``torch.autograd.grad`` and support for
- keyword arguments input into the checkpointed function. Note that future
- versions of PyTorch will default to ``use_reentrant=False``.
- Default: ``True``
- args: tuple containing inputs to the :attr:`function`
- Returns:
- Output of running :attr:`function` on :attr:`*args`
- """
- # Hack to mix *args with **kwargs in a python 2.7-compliant way
- preserve = kwargs.pop('preserve_rng_state', True)
- if kwargs and use_reentrant:
- raise ValueError("Unexpected keyword arguments: " + ",".join(arg for arg in kwargs))
- if use_reentrant:
- return CheckpointFunction.apply(function, preserve, *args)
- else:
- return _checkpoint_without_reentrant(
- function,
- preserve,
- *args,
- **kwargs,
- )
- def checkpoint_sequential(functions, segments, input, use_reentrant=True, **kwargs):
- r"""A helper function for checkpointing sequential models.
- Sequential models execute a list of modules/functions in order
- (sequentially). Therefore, we can divide such a model in various segments
- and checkpoint each segment. All segments except the last will run in
- :func:`torch.no_grad` manner, i.e., not storing the intermediate
- activations. The inputs of each checkpointed segment will be saved for
- re-running the segment in the backward pass.
- See :func:`~torch.utils.checkpoint.checkpoint` on how checkpointing works.
- .. warning::
- Checkpointing currently only supports :func:`torch.autograd.backward`
- and only if its `inputs` argument is not passed. :func:`torch.autograd.grad`
- is not supported.
- .. warning:
- At least one of the inputs needs to have :code:`requires_grad=True` if
- grads are needed for model inputs, otherwise the checkpointed part of the
- model won't have gradients.
- .. warning:
- Since PyTorch 1.4, it allows only one Tensor as the input and
- intermediate outputs, just like :class:`torch.nn.Sequential`.
- Args:
- functions: A :class:`torch.nn.Sequential` or the list of modules or
- functions (comprising the model) to run sequentially.
- segments: Number of chunks to create in the model
- input: A Tensor that is input to :attr:`functions`
- preserve_rng_state(bool, optional): Omit stashing and restoring
- the RNG state during each checkpoint.
- Default: ``True``
- use_reentrant(bool, optional): Use checkpointing
- implementation that requires re-entrant autograd.
- If ``use_reentrant=False`` is specified, ``checkpoint`` will use an
- implementation that does not require re-entrant autograd. This
- allows ``checkpoint`` to support additional functionality, such as
- working as expected with ``torch.autograd.grad`` and support for
- keyword arguments input into the checkpointed function.
- Default: ``True``
- Returns:
- Output of running :attr:`functions` sequentially on :attr:`*inputs`
- Example:
- >>> # xdoctest: +SKIP("stub")
- >>> model = nn.Sequential(...)
- >>> input_var = checkpoint_sequential(model, chunks, input_var)
- """
- # Hack for keyword-only parameter in a python 2.7-compliant way
- preserve = kwargs.pop('preserve_rng_state', True)
- if kwargs:
- raise ValueError("Unexpected keyword arguments: " + ",".join(arg for arg in kwargs))
- def run_function(start, end, functions):
- def forward(input):
- for j in range(start, end + 1):
- input = functions[j](input)
- return input
- return forward
- if isinstance(functions, torch.nn.Sequential):
- functions = list(functions.children())
- segment_size = len(functions) // segments
- # the last chunk has to be non-volatile
- end = -1
- for start in range(0, segment_size * (segments - 1), segment_size):
- end = start + segment_size - 1
- input = checkpoint(
- run_function(start, end, functions),
- input,
- use_reentrant=use_reentrant,
- preserve_rng_state=preserve
- )
- return run_function(end + 1, len(functions) - 1, functions)(input)
- def _checkpoint_without_reentrant(function, preserve_rng_state=True, *args, **kwargs):
- """Checkpointining without re-entrant autograd
- Args:
- function: describes what to run in the forward pass of the model or
- part of the model. It should also know how to handle the inputs
- passed as the tuple. For example, in LSTM, if user passes
- ``(activation, hidden)``, :attr:`function` should correctly use the
- first input as ``activation`` and the second input as ``hidden``
- preserve_rng_state(bool, optional): Omit stashing and restoring
- the RNG state during each checkpoint.
- Default: ``True``
- *args: Arguments to pass in to the given ``function``.
- **kwargs: Keyword arguments to pass into the given ``function``.
- """
- # Accommodates the (remote) possibility that autocast is enabled for cpu AND gpu.
- gpu_autocast_kwargs, cpu_autocast_kwargs = _get_autocast_kwargs()
- if preserve_rng_state:
- fwd_cpu_state = torch.get_rng_state()
- # Don't eagerly initialize the cuda context by accident.
- # (If the user intends that the context is initialized later, within their
- # run_function, we SHOULD actually stash the cuda state here. Unfortunately,
- # we have no way to anticipate this will happen before we run the function.
- # If they do so, we raise an error.)
- had_cuda_in_fwd = False
- if torch.cuda._initialized:
- had_cuda_in_fwd = True
- fwd_gpu_devices, fwd_gpu_states = get_device_states(*args)
- # Custom class to be able to take weak references
- class Holder():
- pass
- # The Holder object for each of the saved object is saved directly on the
- # SavedVariable and is cleared when reset_data() is called on it. We MUST make
- # sure that this is the only object having an owning reference to ensure that
- # the Tensor stored in storage is deleted as soon as the corresponding SavedVariable
- # data is cleared.
- storage: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary()
- weak_holder_list = []
- def pack(x):
- # TODO(varal7): Instead of returning abstract object, we can return things metadata (such as
- # size, device, ...) to catch certain cases of undeterministic behavior of the forward
- res = Holder()
- weak_holder_list.append(weakref.ref(res))
- return res
- def unpack(x):
- unpack_counter = 0
- if len(storage) == 0:
- def inner_pack(inner):
- nonlocal unpack_counter
- unpack_counter += 1
- # If the holder went out of scope, the SavedVariable is dead and so
- # the value will never be read from the storage. Skip filling it.
- if weak_holder_list[unpack_counter - 1]() is None:
- return
- # Use detach here to ensure we don't keep the temporary autograd
- # graph created during the second forward
- storage[weak_holder_list[unpack_counter - 1]()] = inner.detach()
- return
- def inner_unpack(packed):
- raise RuntimeError("You are calling backwards on a tensor that is never exposed. Please open an issue.")
- # Stash the surrounding rng state, and mimic the state that was
- # present at this time during forward. Restore the surrounding state
- # when we're done.
- rng_devices = []
- if preserve_rng_state and had_cuda_in_fwd:
- rng_devices = fwd_gpu_devices
- with torch.random.fork_rng(devices=rng_devices, enabled=preserve_rng_state):
- if preserve_rng_state:
- torch.set_rng_state(fwd_cpu_state)
- if had_cuda_in_fwd:
- set_device_states(fwd_gpu_devices, fwd_gpu_states)
- with torch.enable_grad(), \
- torch.cuda.amp.autocast(**gpu_autocast_kwargs), \
- torch.cpu.amp.autocast(**cpu_autocast_kwargs), \
- torch.autograd.graph.saved_tensors_hooks(inner_pack, inner_unpack):
- _unused = function(*args, **kwargs)
- if x not in storage:
- raise RuntimeError(
- "Attempt to retrieve a tensor saved by autograd multiple times without checkpoint"
- " recomputation being triggered in between, this is not currently supported. Please"
- " open an issue with details on your use case so that we can prioritize adding this."
- )
- return storage[x]
- with torch.autograd.graph.saved_tensors_hooks(pack, unpack):
- output = function(*args, **kwargs)
- if torch.cuda._initialized and preserve_rng_state and not had_cuda_in_fwd:
- # Cuda was not initialized before running the forward, so we didn't
- # stash the CUDA state.
- raise RuntimeError(
- "PyTorch's CUDA state was initialized in the forward pass "
- "of a Checkpoint, which is not allowed. Please open an issue "
- "if you need this feature.")
- return output
|