exc.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183
  1. import os
  2. import textwrap
  3. from traceback import extract_stack, format_exc, format_list, FrameSummary
  4. from typing import cast, List
  5. from . import config
  6. from .utils import counters, format_bytecode
  7. class TorchDynamoException(RuntimeError):
  8. pass
  9. class InternalTorchDynamoError(TorchDynamoException):
  10. pass
  11. class RestartAnalysis(TorchDynamoException):
  12. pass
  13. class SkipFrame(TorchDynamoException):
  14. pass
  15. class TorchRuntimeError(TorchDynamoException):
  16. pass
  17. class ResetRequired(TorchDynamoException):
  18. def __init__(self):
  19. super().__init__(
  20. textwrap.dedent(
  21. """
  22. Must call `torch._dynamo.reset()` before changing backends. Detected two calls to
  23. `torch.compile()` with a different backend compiler arguments.
  24. """
  25. )
  26. )
  27. class BackendCompilerFailed(TorchDynamoException):
  28. def __init__(self, backend_fn, inner_exception):
  29. self.backend_name = getattr(backend_fn, "__name__", "?")
  30. self.inner_exception = inner_exception
  31. msg = f"{self.backend_name} raised {type(inner_exception).__name__}: {inner_exception}"
  32. super().__init__(msg)
  33. class Unsupported(TorchDynamoException):
  34. def __init__(self, msg):
  35. super().__init__(msg)
  36. self.real_stack = []
  37. self.msg = msg
  38. self.category = None
  39. self.add_to_stats()
  40. def remove_from_stats(self):
  41. counters[self.category][self.msg] -= 1
  42. if counters[self.category][self.msg] <= 0:
  43. del counters[self.category][self.msg]
  44. def add_to_stats(self, category="unimplemented"):
  45. self.category = category
  46. counters[category][self.msg] += 1
  47. def unimplemented(msg: str):
  48. assert msg != os.environ.get("BREAK", False)
  49. raise Unsupported(msg)
  50. def warning(msg: str):
  51. counters["warnings"][msg] += 1
  52. assert msg != os.environ.get("BREAK", False)
  53. # KeyError has special handling for its args
  54. # see https://github.com/python/cpython/blob/3.11/Objects/exceptions.c#L2534 for details
  55. class KeyErrorMsg:
  56. def __init__(self, value):
  57. self.value = value
  58. def __str__(self):
  59. return str(self.value)
  60. def __repr__(self) -> str:
  61. return self.__str__()
  62. def augment_exc_message(exc, msg="\n"):
  63. import traceback
  64. if (
  65. hasattr(exc, "real_stack")
  66. and len(exc.real_stack) > 0
  67. and not (config.verbose and config.suppress_errors)
  68. ):
  69. msg += f"\nfrom user code:\n {''.join(traceback.format_list(list(reversed(get_real_stack(exc)[0:2]))))}"
  70. if config.replay_record_enabled and hasattr(exc, "record_filename"):
  71. msg += f"\nLast frame execution written to {exc.record_filename}. To run only this frame while debugging, run\
  72. torch._dynamo.replay('{exc.record_filename}').\n"
  73. if not config.verbose:
  74. msg += "\nSet torch._dynamo.config.verbose=True for more information\n"
  75. if hasattr(exc, "inner_exception") and hasattr(
  76. exc.inner_exception, "minifier_path"
  77. ):
  78. msg += (
  79. f"\nMinifier script written to {exc.inner_exception.minifier_path}. Run "
  80. "this script to find the smallest traced graph which reproduces this error.\n"
  81. )
  82. if not config.suppress_errors:
  83. msg += (
  84. "\n\n"
  85. "You can suppress this exception and fall back to eager by setting:\n"
  86. " torch._dynamo.config.suppress_errors = True\n"
  87. )
  88. old_msg = "" if len(exc.args) == 0 else exc.args[0]
  89. if isinstance(exc, KeyError):
  90. exc.args = (KeyErrorMsg(old_msg + msg),) + exc.args[1:]
  91. else:
  92. new_msg = old_msg + msg
  93. exc.args = (new_msg,) + exc.args[1:]
  94. def get_real_stack(exc) -> List[FrameSummary]:
  95. assert hasattr(exc, "real_stack")
  96. return cast(List[FrameSummary], exc.real_stack)
  97. # filter out all frames after entering dynamo
  98. def filter_stack(stack):
  99. user_stack = []
  100. for frame in stack:
  101. if "convert_frame" in frame.filename:
  102. break
  103. if "eval_frame" in frame.filename or "torch._dynamo.optimize(" in frame.line:
  104. continue
  105. user_stack.append(frame)
  106. return user_stack
  107. def format_error_msg(exc, code, record_filename=None, frame=None):
  108. msg = os.linesep * 2
  109. if config.verbose:
  110. msg = format_bytecode(
  111. "WON'T CONVERT", code.co_name, code.co_filename, code.co_firstlineno, code
  112. )
  113. msg += "=" * 10 + " TorchDynamo Stack Trace " + "=" * 10 + "\n"
  114. msg += format_exc()
  115. if hasattr(exc, "real_stack"):
  116. msg += (
  117. "\n"
  118. + "=" * 10
  119. + " The above exception occurred while processing the following code "
  120. + "=" * 10
  121. + "\n\n"
  122. )
  123. stack_above_dynamo = []
  124. if frame is not None:
  125. stack_above_dynamo = filter_stack(extract_stack(frame))
  126. msg += "".join(
  127. format_list(stack_above_dynamo + list(reversed(get_real_stack(exc))))
  128. )
  129. msg += "\n"
  130. msg += "=" * 10
  131. else:
  132. msg = f"WON'T CONVERT {code.co_name} {code.co_filename}\
  133. line {code.co_firstlineno} \ndue to: \n{format_exc(limit=-1)}"
  134. return msg