anomaly_mode.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117
  1. import torch
  2. import warnings
  3. from typing import Any
  4. __all__ = ["detect_anomaly", "set_detect_anomaly"]
  5. class detect_anomaly:
  6. r"""Context-manager that enable anomaly detection for the autograd engine.
  7. This does two things:
  8. - Running the forward pass with detection enabled will allow the backward
  9. pass to print the traceback of the forward operation that created the failing
  10. backward function.
  11. - If ``check_nan`` is ``True``, any backward computation that generate "nan"
  12. value will raise an error. Default ``True``.
  13. .. warning::
  14. This mode should be enabled only for debugging as the different tests
  15. will slow down your program execution.
  16. Example:
  17. >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_ANOMOLY)
  18. >>> import torch
  19. >>> from torch import autograd
  20. >>> class MyFunc(autograd.Function):
  21. ... @staticmethod
  22. ... def forward(ctx, inp):
  23. ... return inp.clone()
  24. ... @staticmethod
  25. ... def backward(ctx, gO):
  26. ... # Error during the backward pass
  27. ... raise RuntimeError("Some error in backward")
  28. ... return gO.clone()
  29. >>> def run_fn(a):
  30. ... out = MyFunc.apply(a)
  31. ... return out.sum()
  32. >>> inp = torch.rand(10, 10, requires_grad=True)
  33. >>> out = run_fn(inp)
  34. >>> out.backward()
  35. Traceback (most recent call last):
  36. File "<stdin>", line 1, in <module>
  37. File "/your/pytorch/install/torch/_tensor.py", line 93, in backward
  38. torch.autograd.backward(self, gradient, retain_graph, create_graph)
  39. File "/your/pytorch/install/torch/autograd/__init__.py", line 90, in backward
  40. allow_unreachable=True) # allow_unreachable flag
  41. File "/your/pytorch/install/torch/autograd/function.py", line 76, in apply
  42. return self._forward_cls.backward(self, *args)
  43. File "<stdin>", line 8, in backward
  44. RuntimeError: Some error in backward
  45. >>> with autograd.detect_anomaly():
  46. ... inp = torch.rand(10, 10, requires_grad=True)
  47. ... out = run_fn(inp)
  48. ... out.backward()
  49. Traceback of forward call that caused the error:
  50. File "tmp.py", line 53, in <module>
  51. out = run_fn(inp)
  52. File "tmp.py", line 44, in run_fn
  53. out = MyFunc.apply(a)
  54. Traceback (most recent call last):
  55. File "<stdin>", line 4, in <module>
  56. File "/your/pytorch/install/torch/_tensor.py", line 93, in backward
  57. torch.autograd.backward(self, gradient, retain_graph, create_graph)
  58. File "/your/pytorch/install/torch/autograd/__init__.py", line 90, in backward
  59. allow_unreachable=True) # allow_unreachable flag
  60. File "/your/pytorch/install/torch/autograd/function.py", line 76, in apply
  61. return self._forward_cls.backward(self, *args)
  62. File "<stdin>", line 8, in backward
  63. RuntimeError: Some error in backward
  64. """
  65. def __init__(self, check_nan=True) -> None:
  66. self.prev = torch.is_anomaly_enabled()
  67. self.check_nan = check_nan
  68. self.prev_check_nan = torch.is_anomaly_check_nan_enabled()
  69. warnings.warn('Anomaly Detection has been enabled. '
  70. 'This mode will increase the runtime '
  71. 'and should only be enabled for debugging.', stacklevel=2)
  72. def __enter__(self) -> None:
  73. torch.set_anomaly_enabled(True, self.check_nan)
  74. def __exit__(self, *args: Any) -> None:
  75. torch.set_anomaly_enabled(self.prev, self.prev_check_nan)
  76. class set_detect_anomaly:
  77. r"""Context-manager that sets the anomaly detection for the autograd engine on or off.
  78. ``set_detect_anomaly`` will enable or disable the autograd anomaly detection
  79. based on its argument :attr:`mode`.
  80. It can be used as a context-manager or as a function.
  81. See ``detect_anomaly`` above for details of the anomaly detection behaviour.
  82. Args:
  83. mode (bool): Flag whether to enable anomaly detection (``True``),
  84. or disable (``False``).
  85. check_nan (bool): Flag whether to raise an error when the backward
  86. generate "nan"
  87. """
  88. def __init__(self, mode: bool, check_nan: bool = True) -> None:
  89. self.prev = torch.is_anomaly_enabled()
  90. self.prev_check_nan = torch.is_anomaly_check_nan_enabled()
  91. torch.set_anomaly_enabled(mode, check_nan)
  92. def __enter__(self) -> None:
  93. pass
  94. def __exit__(self, *args: Any) -> None:
  95. torch.set_anomaly_enabled(self.prev, self.prev_check_nan)