123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117 |
- import torch
- import warnings
- from typing import Any
- __all__ = ["detect_anomaly", "set_detect_anomaly"]
- class detect_anomaly:
- r"""Context-manager that enable anomaly detection for the autograd engine.
- This does two things:
- - Running the forward pass with detection enabled will allow the backward
- pass to print the traceback of the forward operation that created the failing
- backward function.
- - If ``check_nan`` is ``True``, any backward computation that generate "nan"
- value will raise an error. Default ``True``.
- .. warning::
- This mode should be enabled only for debugging as the different tests
- will slow down your program execution.
- Example:
- >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_ANOMOLY)
- >>> import torch
- >>> from torch import autograd
- >>> class MyFunc(autograd.Function):
- ... @staticmethod
- ... def forward(ctx, inp):
- ... return inp.clone()
- ... @staticmethod
- ... def backward(ctx, gO):
- ... # Error during the backward pass
- ... raise RuntimeError("Some error in backward")
- ... return gO.clone()
- >>> def run_fn(a):
- ... out = MyFunc.apply(a)
- ... return out.sum()
- >>> inp = torch.rand(10, 10, requires_grad=True)
- >>> out = run_fn(inp)
- >>> out.backward()
- Traceback (most recent call last):
- File "<stdin>", line 1, in <module>
- File "/your/pytorch/install/torch/_tensor.py", line 93, in backward
- torch.autograd.backward(self, gradient, retain_graph, create_graph)
- File "/your/pytorch/install/torch/autograd/__init__.py", line 90, in backward
- allow_unreachable=True) # allow_unreachable flag
- File "/your/pytorch/install/torch/autograd/function.py", line 76, in apply
- return self._forward_cls.backward(self, *args)
- File "<stdin>", line 8, in backward
- RuntimeError: Some error in backward
- >>> with autograd.detect_anomaly():
- ... inp = torch.rand(10, 10, requires_grad=True)
- ... out = run_fn(inp)
- ... out.backward()
- Traceback of forward call that caused the error:
- File "tmp.py", line 53, in <module>
- out = run_fn(inp)
- File "tmp.py", line 44, in run_fn
- out = MyFunc.apply(a)
- Traceback (most recent call last):
- File "<stdin>", line 4, in <module>
- File "/your/pytorch/install/torch/_tensor.py", line 93, in backward
- torch.autograd.backward(self, gradient, retain_graph, create_graph)
- File "/your/pytorch/install/torch/autograd/__init__.py", line 90, in backward
- allow_unreachable=True) # allow_unreachable flag
- File "/your/pytorch/install/torch/autograd/function.py", line 76, in apply
- return self._forward_cls.backward(self, *args)
- File "<stdin>", line 8, in backward
- RuntimeError: Some error in backward
- """
- def __init__(self, check_nan=True) -> None:
- self.prev = torch.is_anomaly_enabled()
- self.check_nan = check_nan
- self.prev_check_nan = torch.is_anomaly_check_nan_enabled()
- warnings.warn('Anomaly Detection has been enabled. '
- 'This mode will increase the runtime '
- 'and should only be enabled for debugging.', stacklevel=2)
- def __enter__(self) -> None:
- torch.set_anomaly_enabled(True, self.check_nan)
- def __exit__(self, *args: Any) -> None:
- torch.set_anomaly_enabled(self.prev, self.prev_check_nan)
- class set_detect_anomaly:
- r"""Context-manager that sets the anomaly detection for the autograd engine on or off.
- ``set_detect_anomaly`` will enable or disable the autograd anomaly detection
- based on its argument :attr:`mode`.
- It can be used as a context-manager or as a function.
- See ``detect_anomaly`` above for details of the anomaly detection behaviour.
- Args:
- mode (bool): Flag whether to enable anomaly detection (``True``),
- or disable (``False``).
- check_nan (bool): Flag whether to raise an error when the backward
- generate "nan"
- """
- def __init__(self, mode: bool, check_nan: bool = True) -> None:
- self.prev = torch.is_anomaly_enabled()
- self.prev_check_nan = torch.is_anomaly_check_nan_enabled()
- torch.set_anomaly_enabled(mode, check_nan)
- def __enter__(self) -> None:
- pass
- def __exit__(self, *args: Any) -> None:
- torch.set_anomaly_enabled(self.prev, self.prev_check_nan)
|