engine.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448
  1. """A diagnostic engine based on SARIF."""
  2. from __future__ import annotations
  3. import contextlib
  4. import dataclasses
  5. import gzip
  6. from typing import Callable, Generator, List, Mapping, Optional, Type, TypeVar
  7. from typing_extensions import Literal
  8. from torch.onnx._internal.diagnostics import infra
  9. from torch.onnx._internal.diagnostics.infra import formatter, sarif, utils
  10. from torch.onnx._internal.diagnostics.infra.sarif import version as sarif_version
  11. class DiagnosticError(RuntimeError):
  12. pass
  13. # This is a workaround for mypy not supporting Self from typing_extensions.
  14. _Diagnostic = TypeVar("_Diagnostic", bound="Diagnostic")
  15. @dataclasses.dataclass
  16. class Diagnostic:
  17. rule: infra.Rule
  18. level: infra.Level
  19. message: Optional[str] = None
  20. locations: List[infra.Location] = dataclasses.field(default_factory=list)
  21. stacks: List[infra.Stack] = dataclasses.field(default_factory=list)
  22. graphs: List[infra.Graph] = dataclasses.field(default_factory=list)
  23. thread_flow_locations: List[infra.ThreadFlowLocation] = dataclasses.field(
  24. default_factory=list
  25. )
  26. additional_message: Optional[str] = None
  27. tags: List[infra.Tag] = dataclasses.field(default_factory=list)
  28. def sarif(self) -> sarif.Result:
  29. """Returns the SARIF Result representation of this diagnostic."""
  30. message = self.message or self.rule.message_default_template
  31. if self.additional_message:
  32. message_markdown = (
  33. f"{message}\n\n## Additional Message:\n\n{self.additional_message}"
  34. )
  35. else:
  36. message_markdown = message
  37. kind: Literal["informational", "fail"] = (
  38. "informational" if self.level == infra.Level.NONE else "fail"
  39. )
  40. sarif_result = sarif.Result(
  41. message=sarif.Message(text=message, markdown=message_markdown),
  42. level=self.level.name.lower(), # type: ignore[arg-type]
  43. rule_id=self.rule.id,
  44. kind=kind,
  45. )
  46. sarif_result.locations = [location.sarif() for location in self.locations]
  47. sarif_result.stacks = [stack.sarif() for stack in self.stacks]
  48. sarif_result.graphs = [graph.sarif() for graph in self.graphs]
  49. sarif_result.code_flows = [
  50. sarif.CodeFlow(
  51. thread_flows=[
  52. sarif.ThreadFlow(
  53. locations=[loc.sarif() for loc in self.thread_flow_locations]
  54. )
  55. ]
  56. )
  57. ]
  58. sarif_result.properties = sarif.PropertyBag(
  59. tags=[tag.value for tag in self.tags]
  60. )
  61. return sarif_result
  62. def with_location(self: _Diagnostic, location: infra.Location) -> _Diagnostic:
  63. """Adds a location to the diagnostic."""
  64. self.locations.append(location)
  65. return self
  66. def with_thread_flow_location(
  67. self: _Diagnostic, location: infra.ThreadFlowLocation
  68. ) -> _Diagnostic:
  69. """Adds a thread flow location to the diagnostic."""
  70. self.thread_flow_locations.append(location)
  71. return self
  72. def with_stack(self: _Diagnostic, stack: infra.Stack) -> _Diagnostic:
  73. """Adds a stack to the diagnostic."""
  74. self.stacks.append(stack)
  75. return self
  76. def with_graph(self: _Diagnostic, graph: infra.Graph) -> _Diagnostic:
  77. """Adds a graph to the diagnostic."""
  78. self.graphs.append(graph)
  79. return self
  80. def with_additional_message(self: _Diagnostic, message: str) -> _Diagnostic:
  81. """Adds an additional message to the diagnostic."""
  82. if self.additional_message is None:
  83. self.additional_message = message
  84. else:
  85. self.additional_message = f"{self.additional_message}\n{message}"
  86. return self
  87. def record_python_call_stack(self, frames_to_skip: int) -> infra.Stack:
  88. """Records the current Python call stack."""
  89. frames_to_skip += 1 # Skip this function.
  90. stack = utils.python_call_stack(frames_to_skip=frames_to_skip)
  91. self.with_stack(stack)
  92. if len(stack.frames) > 0:
  93. self.with_location(stack.frames[0].location)
  94. return stack
  95. def record_python_call(
  96. self,
  97. fn: Callable,
  98. state: Mapping[str, str],
  99. message: Optional[str] = None,
  100. frames_to_skip: int = 0,
  101. ) -> infra.ThreadFlowLocation:
  102. """Records a python call as one thread flow step."""
  103. frames_to_skip += 1 # Skip this function.
  104. stack = utils.python_call_stack(frames_to_skip=frames_to_skip, frames_to_log=5)
  105. location = utils.function_location(fn)
  106. location.message = message
  107. # Add function location to the top of the stack.
  108. stack.frames.insert(0, infra.StackFrame(location=location))
  109. thread_flow_location = infra.ThreadFlowLocation(
  110. location=location,
  111. state=state,
  112. index=len(self.thread_flow_locations),
  113. stack=stack,
  114. )
  115. self.with_thread_flow_location(thread_flow_location)
  116. return thread_flow_location
  117. def pretty_print(
  118. self, verbose: bool = False, log_level: infra.Level = infra.Level.ERROR
  119. ):
  120. """Prints the diagnostics in a human-readable format.
  121. Args:
  122. verbose: If True, prints all information. E.g. stack frames, graphs, etc.
  123. Otherwise, only prints compact information. E.g., rule name and display message.
  124. log_level: The minimum level of diagnostics to print.
  125. """
  126. if self.level.value < log_level.value:
  127. return
  128. formatter.pretty_print_item_title(f"{self.level.name}: {self.rule.name}")
  129. print(self.message)
  130. print(self.additional_message)
  131. if not verbose:
  132. print("<Set verbose=True to see more details>\n")
  133. return
  134. formatter.pretty_print_title("Locations", fill_char="-")
  135. for location in self.locations:
  136. location.pretty_print()
  137. for stack in self.stacks:
  138. stack.pretty_print()
  139. formatter.pretty_print_title("Thread Flow Locations", fill_char="-")
  140. for thread_flow_location in self.thread_flow_locations:
  141. thread_flow_location.pretty_print(verbose=verbose)
  142. for graph in self.graphs:
  143. graph.pretty_print(verbose=verbose)
  144. print()
  145. # TODO: print help url to rule at the end.
  146. @dataclasses.dataclass
  147. class DiagnosticContext:
  148. name: str
  149. version: str
  150. options: infra.DiagnosticOptions = dataclasses.field(
  151. default_factory=infra.DiagnosticOptions
  152. )
  153. diagnostic_type: Type[Diagnostic] = dataclasses.field(default=Diagnostic)
  154. diagnostics: List[Diagnostic] = dataclasses.field(init=False, default_factory=list)
  155. # TODO(bowbao): Implement this.
  156. # _invocation: infra.Invocation = dataclasses.field(init=False)
  157. _inflight_diagnostics: List[Diagnostic] = dataclasses.field(
  158. init=False, default_factory=list
  159. )
  160. def __enter__(self):
  161. return self
  162. def __exit__(self, exc_type, exc_val, exc_tb):
  163. return True
  164. def sarif(self) -> sarif.Run:
  165. """Returns the SARIF Run object."""
  166. unique_rules = {diagnostic.rule for diagnostic in self.diagnostics}
  167. return sarif.Run(
  168. tool=sarif.Tool(
  169. driver=sarif.ToolComponent(
  170. name=self.name,
  171. version=self.version,
  172. rules=[rule.sarif() for rule in unique_rules],
  173. )
  174. ),
  175. results=[diagnostic.sarif() for diagnostic in self.diagnostics],
  176. )
  177. def add_diagnostic(self, diagnostic: Diagnostic) -> None:
  178. """Adds a diagnostic to the context.
  179. Use this method to add diagnostics that are not created by the context.
  180. Args:
  181. diagnostic: The diagnostic to add.
  182. """
  183. if not isinstance(diagnostic, Diagnostic):
  184. raise TypeError(
  185. f"Expected diagnostic of type {Diagnostic}, got {type(diagnostic)}"
  186. )
  187. self.diagnostics.append(diagnostic)
  188. @contextlib.contextmanager
  189. def add_inflight_diagnostic(
  190. self, diagnostic: Diagnostic
  191. ) -> Generator[Diagnostic, None, None]:
  192. """Adds a diagnostic to the context.
  193. Use this method to add diagnostics that are not created by the context.
  194. Args:
  195. diagnostic: The diagnostic to add.
  196. """
  197. self._inflight_diagnostics.append(diagnostic)
  198. try:
  199. yield diagnostic
  200. finally:
  201. self._inflight_diagnostics.pop()
  202. def diagnose(
  203. self,
  204. rule: infra.Rule,
  205. level: infra.Level,
  206. message: Optional[str] = None,
  207. **kwargs,
  208. ) -> Diagnostic:
  209. """Creates a diagnostic for the given arguments.
  210. Args:
  211. rule: The rule that triggered the diagnostic.
  212. level: The level of the diagnostic.
  213. message: The message of the diagnostic.
  214. **kwargs: Additional arguments to pass to the Diagnostic constructor.
  215. Returns:
  216. The created diagnostic.
  217. Raises:
  218. ValueError: If the rule is not supported by the tool.
  219. """
  220. diagnostic = self.diagnostic_type(rule, level, message, **kwargs)
  221. self.add_diagnostic(diagnostic)
  222. return diagnostic
  223. def push_inflight_diagnostic(self, diagnostic: Diagnostic) -> None:
  224. """Pushes a diagnostic to the inflight diagnostics stack.
  225. Args:
  226. diagnostic: The diagnostic to push.
  227. Raises:
  228. ValueError: If the rule is not supported by the tool.
  229. """
  230. self._inflight_diagnostics.append(diagnostic)
  231. def pop_inflight_diagnostic(self) -> Diagnostic:
  232. """Pops the last diagnostic from the inflight diagnostics stack.
  233. Returns:
  234. The popped diagnostic.
  235. """
  236. return self._inflight_diagnostics.pop()
  237. def inflight_diagnostic(self, rule: Optional[infra.Rule] = None) -> Diagnostic:
  238. if rule is None:
  239. # TODO(bowbao): Create builtin-rules and create diagnostic using that.
  240. if len(self._inflight_diagnostics) <= 0:
  241. raise DiagnosticError("No inflight diagnostics")
  242. return self._inflight_diagnostics[-1]
  243. else:
  244. # TODO(bowbao): Improve efficiency with Mapping[Rule, List[Diagnostic]]
  245. for diagnostic in reversed(self._inflight_diagnostics):
  246. if diagnostic.rule == rule:
  247. return diagnostic
  248. raise DiagnosticError(f"No inflight diagnostic for rule {rule.name}")
  249. def pretty_print(
  250. self, verbose: Optional[bool] = None, log_level: Optional[infra.Level] = None
  251. ) -> None:
  252. """Prints the diagnostics in a human-readable format.
  253. Args:
  254. verbose: Whether to print the diagnostics in verbose mode. See Diagnostic.pretty_print.
  255. If not specified, uses the value of 'self.options.log_verbose'.
  256. log_level: The minimum level of diagnostics to print.
  257. If not specified, uses the value of 'self.options.log_level'.
  258. """
  259. if verbose is None:
  260. verbose = self.options.log_verbose
  261. if log_level is None:
  262. log_level = self.options.log_level
  263. formatter.pretty_print_title(
  264. f"Diagnostic Run {self.name} version {self.version}"
  265. )
  266. print(f"verbose: {verbose}, log level: {log_level}")
  267. diagnostic_stats = {level: 0 for level in infra.Level}
  268. for diagnostic in self.diagnostics:
  269. diagnostic_stats[diagnostic.level] += 1
  270. formatter.pretty_print_title(
  271. " ".join(f"{diagnostic_stats[level]} {level.name}" for level in infra.Level)
  272. )
  273. for diagnostic in self.diagnostics:
  274. diagnostic.pretty_print(verbose, log_level)
  275. unprinted_diagnostic_stats = [
  276. (level, count)
  277. for level, count in diagnostic_stats.items()
  278. if count > 0 and level.value < log_level.value
  279. ]
  280. if unprinted_diagnostic_stats:
  281. print(
  282. f"{' '.join(f'{count} {level.name}' for level, count in unprinted_diagnostic_stats)} "
  283. "were not printed due to the log level."
  284. )
  285. print()
  286. class DiagnosticEngine:
  287. """A generic diagnostic engine based on SARIF.
  288. This class is the main interface for diagnostics. It manages the creation of diagnostic contexts.
  289. A DiagnosticContext provides the entry point for recording Diagnostics.
  290. See infra.DiagnosticContext for more details.
  291. Examples:
  292. Step 1: Create a set of rules.
  293. >>> # xdoctest: +REQUIRES(module:torch._C._distributed_c10d)
  294. >>> rules = infra.RuleCollection.custom_collection_from_list(
  295. ... "CustomRuleCollection",
  296. ... [
  297. ... infra.Rule(
  298. ... id="r1",
  299. ... name="rule-1",
  300. ... message_default_template="Mising xxx",
  301. ... ),
  302. ... ],
  303. ... )
  304. Step 2: Create a diagnostic engine.
  305. >>> engine = DiagnosticEngine()
  306. Step 3: Start a new diagnostic context.
  307. >>> with engine.create_diagnostic_context("torch.onnx.export", version="1.0") as context:
  308. ... ...
  309. Step 4: Add diagnostics in your code.
  310. ... context.diagnose(rules.rule1, infra.Level.ERROR)
  311. Step 5: Afterwards, get the SARIF log.
  312. >>> sarif_log = engine.sarif_log()
  313. """
  314. contexts: List[DiagnosticContext]
  315. def __init__(self) -> None:
  316. self.contexts = []
  317. def sarif_log(self) -> sarif.SarifLog:
  318. return sarif.SarifLog(
  319. version=sarif_version.SARIF_VERSION,
  320. schema_uri=sarif_version.SARIF_SCHEMA_LINK,
  321. runs=[context.sarif() for context in self.contexts],
  322. )
  323. def __str__(self) -> str:
  324. # TODO: pretty print.
  325. return self.to_json()
  326. def __repr__(self) -> str:
  327. return self.to_json()
  328. def to_json(self) -> str:
  329. return formatter.sarif_to_json(self.sarif_log())
  330. def dump(self, file_path: str, compress: bool = False) -> None:
  331. """Dumps the SARIF log to a file."""
  332. if compress:
  333. with gzip.open(file_path, "wt") as f:
  334. f.write(self.to_json())
  335. else:
  336. with open(file_path, "w") as f:
  337. f.write(self.to_json())
  338. def clear(self) -> None:
  339. """Clears all diagnostic contexts."""
  340. self.contexts.clear()
  341. def create_diagnostic_context(
  342. self,
  343. name: str,
  344. version: str,
  345. options: Optional[infra.DiagnosticOptions] = None,
  346. diagnostic_type: Type[Diagnostic] = Diagnostic,
  347. ) -> DiagnosticContext:
  348. """Creates a new diagnostic context.
  349. Args:
  350. name: The subject name for the diagnostic context.
  351. version: The subject version for the diagnostic context.
  352. options: The options for the diagnostic context.
  353. Returns:
  354. A new diagnostic context.
  355. """
  356. if options is None:
  357. options = infra.DiagnosticOptions()
  358. context = DiagnosticContext(
  359. name, version, options, diagnostic_type=diagnostic_type
  360. )
  361. self.contexts.append(context)
  362. return context
  363. def pretty_print(
  364. self, verbose: bool = False, level: infra.Level = infra.Level.ERROR
  365. ) -> None:
  366. """Pretty prints all diagnostics in the diagnostic contexts.
  367. Args:
  368. verbose: Whether to print the diagnostics in verbose mode. See Diagnostic.pretty_print.
  369. level: The minimum level of diagnostics to print.
  370. """
  371. formatter.pretty_print_title(f"{len(self.contexts)} Diagnostic Run")
  372. for context in self.contexts:
  373. context.pretty_print(verbose, level)