123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239 |
- import torch
- from collections import OrderedDict
- import weakref
- import warnings
- from typing import Any
- __all__ = ["RemovableHandle", "unserializable_hook", "warn_if_has_hooks", "BackwardHook"]
- class RemovableHandle:
- r"""
- A handle which provides the capability to remove a hook.
- Args:
- hooks_dict (dict): A dictionary of hooks, indexed by hook ``id``.
- extra_dict (dict): An additional dictionary whose keys will be deleted
- when the same keys are removed from ``hooks_dict``.
- """
- id: int
- next_id: int = 0
- def __init__(self, hooks_dict: Any, *, extra_dict: Any = None) -> None:
- self.hooks_dict_ref = weakref.ref(hooks_dict)
- self.id = RemovableHandle.next_id
- RemovableHandle.next_id += 1
- self.extra_dict_ref = (
- weakref.ref(extra_dict)
- if extra_dict is not None
- else None
- )
- def remove(self) -> None:
- hooks_dict = self.hooks_dict_ref()
- if hooks_dict is not None and self.id in hooks_dict:
- del hooks_dict[self.id]
- if self.extra_dict_ref is not None:
- extra_dict = self.extra_dict_ref()
- if extra_dict is not None and self.id in extra_dict:
- del extra_dict[self.id]
- def __getstate__(self):
- return (
- (self.hooks_dict_ref(), self.id)
- if self.extra_dict_ref is None
- else (self.hooks_dict_ref(), self.id, self.extra_dict_ref())
- )
- def __setstate__(self, state) -> None:
- if state[0] is None:
- # create a dead reference
- self.hooks_dict_ref = weakref.ref(OrderedDict())
- else:
- self.hooks_dict_ref = weakref.ref(state[0])
- self.id = state[1]
- RemovableHandle.next_id = max(RemovableHandle.next_id, self.id + 1)
- self.extra_dict_ref = (
- None
- if len(state) < 3
- else weakref.ref(OrderedDict() if state[2] is None else state[2])
- )
- def __enter__(self) -> "RemovableHandle":
- return self
- def __exit__(self, type: Any, value: Any, tb: Any) -> None:
- self.remove()
- def unserializable_hook(f):
- """
- Decorator which marks a function as an unserializable hook.
- This suppresses warnings that would otherwise arise if you attempt
- to serialize a tensor that has a hook.
- """
- f.__torch_unserializable__ = True
- return f
- def warn_if_has_hooks(tensor):
- if tensor._backward_hooks:
- for k in tensor._backward_hooks:
- hook = tensor._backward_hooks[k]
- if not hasattr(k, "__torch_unserializable__"):
- warnings.warn("backward hook {} on tensor will not be "
- "serialized. If this is expected, you can "
- "decorate the function with @torch.utils.hooks.unserializable_hook "
- "to suppress this warning".format(repr(hook)))
- class BackwardHook:
- """
- A wrapper class to implement nn.Module backward hooks.
- It handles:
- - Ignoring non-Tensor inputs and replacing them by None before calling the user hook
- - Generating the proper Node to capture a set of Tensor's gradients
- - Linking the gradients captures for the outputs with the gradients captured for the input
- - Calling the user hook once both output and input gradients are available
- """
- def __init__(self, module, user_hooks, user_pre_hooks):
- self.user_hooks = user_hooks
- self.user_pre_hooks = user_pre_hooks
- self.module = module
- self.grad_outputs = None
- self.n_outputs = -1
- self.output_tensors_index = None
- self.n_inputs = -1
- self.input_tensors_index = None
- def _pack_with_none(self, indices, values, size):
- res = [None] * size
- for idx, val in zip(indices, values):
- res[idx] = val
- return tuple(res)
- def _unpack_none(self, indices, values):
- res = []
- for idx in indices:
- res.append(values[idx])
- return tuple(res)
- def _set_user_hook(self, grad_fn):
- def hook(grad_input, _):
- if self.grad_outputs is None:
- # This happens because the gradient in your nn.Module flows to
- # the Module's input without " passing through the Module's
- # output, e.g. when you're doing double backward.
- return
- res = self._pack_with_none(self.input_tensors_index, grad_input, self.n_inputs)
- for hook in self.user_hooks:
- out = hook(self.module, res, self.grad_outputs)
- if out is None:
- continue
- if len(out) != len(res):
- raise RuntimeError("Backward hook returned an invalid number of grad_input, "
- "got {}, but expected {}".format(len(out), len(res)))
- res = out
- self.grad_outputs = None
- return self._unpack_none(self.input_tensors_index, res)
- grad_fn.register_hook(hook)
- def _apply_on_tensors(self, fn, args):
- # Can be used to apply the given function to the tensors contained in the
- # args. Will return updated args and the tensors indices
- tensors_idx = []
- tensors = []
- requires_grad = False
- for i, arg in enumerate(args):
- if isinstance(arg, torch.Tensor):
- tensors_idx.append(i)
- tensors.append(arg)
- requires_grad |= arg.requires_grad
- if not (requires_grad and torch.is_grad_enabled()):
- return args, None
- new_tensors = torch.nn.modules._functions.BackwardHookFunction.apply(*tensors)
- if len(new_tensors) == 0:
- raise RuntimeError("Cannot set Module backward hook for a Module with no input Tensors.")
- grad_fns = [t.grad_fn for t in new_tensors if t.grad_fn is not None and t.grad_fn.name() == "BackwardHookFunctionBackward"]
- if len(grad_fns) == 0:
- raise RuntimeError("Error while setting up backward hooks. Please open "
- "an issue with a code sample to reproduce this.")
- fn(grad_fns[0])
- arg_list = list(args)
- for idx, val in zip(tensors_idx, new_tensors):
- arg_list[idx] = val
- return tuple(arg_list), tensors_idx
- def setup_input_hook(self, args):
- def fn(grad_fn):
- self._set_user_hook(grad_fn)
- res, input_idx = self._apply_on_tensors(fn, args)
- self.n_inputs = len(args)
- self.input_tensors_index = input_idx
- return res
- def setup_output_hook(self, args):
- def fn(grad_fn):
- def hook(_, grad_output):
- self.grad_outputs = self._pack_with_none(self.output_tensors_index,
- grad_output,
- self.n_outputs)
- if self.user_pre_hooks:
- expected_len = len(self.grad_outputs)
- for user_pre_hook in self.user_pre_hooks:
- hook_grad_outputs = user_pre_hook(self.module, self.grad_outputs)
- if hook_grad_outputs is None:
- continue
- actual_len = len(hook_grad_outputs)
- if actual_len != expected_len:
- raise RuntimeError("Backward pre hook returned an invalid number of grad_output, "
- "got {}, but expected {}".format(actual_len, expected_len))
- self.grad_outputs = hook_grad_outputs
- # Special case if no input required gradients, this hook should call the user
- # hook directly
- if self.input_tensors_index is None:
- grad_inputs = self._pack_with_none([], [], self.n_inputs)
- for user_hook in self.user_hooks:
- res = user_hook(self.module, grad_inputs, self.grad_outputs)
- if res is not None and not (isinstance(res, tuple) and all(el is None for el in res)):
- raise RuntimeError("Backward hook for Modules where no input requires "
- "gradient should always return None or None for all gradients.")
- self.grad_outputs = None
- grad_fn.register_hook(hook)
- is_tuple = True
- if not isinstance(args, tuple):
- args = (args,)
- is_tuple = False
- res, output_idx = self._apply_on_tensors(fn, args)
- self.n_outputs = len(args)
- self.output_tensors_index = output_idx
- if not is_tuple:
- res = res[0]
- return res
|