_diagnostic.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169
  1. """Diagnostic components for PyTorch ONNX export."""
  2. from __future__ import annotations
  3. import contextlib
  4. from collections.abc import Generator
  5. from typing import Optional
  6. import torch
  7. from torch.onnx._internal.diagnostics import infra
  8. from torch.utils import cpp_backtrace
  9. def _cpp_call_stack(frames_to_skip: int = 0, frames_to_log: int = 32) -> infra.Stack:
  10. """Returns the current C++ call stack.
  11. This function utilizes `torch.utils.cpp_backtrace` to get the current C++ call stack.
  12. The returned C++ call stack is a concatenated string of the C++ call stack frames.
  13. Each frame is separated by a newline character, in the same format of
  14. r"frame #[0-9]+: (?P<frame_info>.*)". More info at `c10/util/Backtrace.cpp`.
  15. """
  16. # NOTE: Cannot use `@_beartype.beartype`. It somehow erases the cpp stack frame info.
  17. frames = cpp_backtrace.get_cpp_backtrace(frames_to_skip, frames_to_log).split("\n")
  18. frame_messages = []
  19. for frame in frames:
  20. segments = frame.split(":", 1)
  21. if len(segments) == 2:
  22. frame_messages.append(segments[1].strip())
  23. else:
  24. frame_messages.append("<unknown frame>")
  25. return infra.Stack(
  26. frames=[
  27. infra.StackFrame(location=infra.Location(message=message))
  28. for message in frame_messages
  29. ]
  30. )
  31. class ExportDiagnostic(infra.Diagnostic):
  32. """Base class for all export diagnostics.
  33. This class is used to represent all export diagnostics. It is a subclass of
  34. infra.Diagnostic, and adds additional methods to add more information to the
  35. diagnostic.
  36. """
  37. python_call_stack: Optional[infra.Stack] = None
  38. cpp_call_stack: Optional[infra.Stack] = None
  39. def __init__(
  40. self,
  41. *args,
  42. frames_to_skip: int = 1,
  43. cpp_stack: bool = False,
  44. **kwargs,
  45. ) -> None:
  46. super().__init__(*args, **kwargs)
  47. self.python_call_stack = self.record_python_call_stack(
  48. frames_to_skip=frames_to_skip
  49. )
  50. if cpp_stack:
  51. self.cpp_call_stack = self.record_cpp_call_stack(
  52. frames_to_skip=frames_to_skip
  53. )
  54. def record_cpp_call_stack(self, frames_to_skip: int) -> infra.Stack:
  55. """Records the current C++ call stack in the diagnostic."""
  56. # NOTE: Cannot use `@_beartype.beartype`. It somehow erases the cpp stack frame info.
  57. # No need to skip this function because python frame is not recorded
  58. # in cpp call stack.
  59. stack = _cpp_call_stack(frames_to_skip=frames_to_skip)
  60. stack.message = "C++ call stack"
  61. self.with_stack(stack)
  62. return stack
  63. def record_fx_graphmodule(self, gm: torch.fx.GraphModule) -> None:
  64. self.with_graph(infra.Graph(gm.print_readable(False), gm.__class__.__name__))
  65. class ExportDiagnosticEngine(infra.DiagnosticEngine):
  66. """PyTorch ONNX Export diagnostic engine.
  67. The only purpose of creating this class instead of using the base class directly
  68. is to provide a background context for `diagnose` calls inside exporter.
  69. By design, one `torch.onnx.export` call should initialize one diagnostic context.
  70. All `diagnose` calls inside exporter should be made in the context of that export.
  71. However, since diagnostic context is currently being accessed via a global variable,
  72. there is no guarantee that the context is properly initialized. Therefore, we need
  73. to provide a default background context to fallback to, otherwise any invocation of
  74. exporter internals, e.g. unit tests, will fail due to missing diagnostic context.
  75. This can be removed once the pipeline for context to flow through the exporter is
  76. established.
  77. """
  78. _background_context: infra.DiagnosticContext
  79. def __init__(self) -> None:
  80. super().__init__()
  81. self._background_context = infra.DiagnosticContext(
  82. name="torch.onnx",
  83. version=torch.__version__,
  84. diagnostic_type=ExportDiagnostic,
  85. )
  86. @property
  87. def background_context(self) -> infra.DiagnosticContext:
  88. return self._background_context
  89. def clear(self):
  90. super().clear()
  91. self._background_context.diagnostics.clear()
  92. def sarif_log(self):
  93. log = super().sarif_log()
  94. log.runs.append(self._background_context.sarif())
  95. return log
  96. engine = ExportDiagnosticEngine()
  97. _context = engine.background_context
  98. @contextlib.contextmanager
  99. def create_export_diagnostic_context() -> Generator[
  100. infra.DiagnosticContext, None, None
  101. ]:
  102. """Create a diagnostic context for export.
  103. This is a workaround for code robustness since diagnostic context is accessed by
  104. export internals via global variable. See `ExportDiagnosticEngine` for more details.
  105. """
  106. global _context
  107. assert (
  108. _context == engine.background_context
  109. ), "Export context is already set. Nested export is not supported."
  110. _context = engine.create_diagnostic_context(
  111. "torch.onnx.export", torch.__version__, diagnostic_type=ExportDiagnostic
  112. )
  113. try:
  114. yield _context
  115. finally:
  116. _context.pretty_print(_context.options.log_verbose, _context.options.log_level)
  117. _context = engine.background_context
  118. def diagnose(
  119. rule: infra.Rule,
  120. level: infra.Level,
  121. message: Optional[str] = None,
  122. frames_to_skip: int = 2,
  123. **kwargs,
  124. ) -> ExportDiagnostic:
  125. """Creates a diagnostic and record it in the global diagnostic context.
  126. This is a wrapper around `context.record` that uses the global diagnostic context.
  127. """
  128. # NOTE: Cannot use `@_beartype.beartype`. It somehow erases the cpp stack frame info.
  129. diagnostic = ExportDiagnostic(
  130. rule, level, message, frames_to_skip=frames_to_skip, **kwargs
  131. )
  132. export_context().add_diagnostic(diagnostic)
  133. return diagnostic
  134. def export_context() -> infra.DiagnosticContext:
  135. global _context
  136. return _context