123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448 |
- """A diagnostic engine based on SARIF."""
- from __future__ import annotations
- import contextlib
- import dataclasses
- import gzip
- from typing import Callable, Generator, List, Mapping, Optional, Type, TypeVar
- from typing_extensions import Literal
- from torch.onnx._internal.diagnostics import infra
- from torch.onnx._internal.diagnostics.infra import formatter, sarif, utils
- from torch.onnx._internal.diagnostics.infra.sarif import version as sarif_version
- class DiagnosticError(RuntimeError):
- pass
- # This is a workaround for mypy not supporting Self from typing_extensions.
- _Diagnostic = TypeVar("_Diagnostic", bound="Diagnostic")
- @dataclasses.dataclass
- class Diagnostic:
- rule: infra.Rule
- level: infra.Level
- message: Optional[str] = None
- locations: List[infra.Location] = dataclasses.field(default_factory=list)
- stacks: List[infra.Stack] = dataclasses.field(default_factory=list)
- graphs: List[infra.Graph] = dataclasses.field(default_factory=list)
- thread_flow_locations: List[infra.ThreadFlowLocation] = dataclasses.field(
- default_factory=list
- )
- additional_message: Optional[str] = None
- tags: List[infra.Tag] = dataclasses.field(default_factory=list)
- def sarif(self) -> sarif.Result:
- """Returns the SARIF Result representation of this diagnostic."""
- message = self.message or self.rule.message_default_template
- if self.additional_message:
- message_markdown = (
- f"{message}\n\n## Additional Message:\n\n{self.additional_message}"
- )
- else:
- message_markdown = message
- kind: Literal["informational", "fail"] = (
- "informational" if self.level == infra.Level.NONE else "fail"
- )
- sarif_result = sarif.Result(
- message=sarif.Message(text=message, markdown=message_markdown),
- level=self.level.name.lower(), # type: ignore[arg-type]
- rule_id=self.rule.id,
- kind=kind,
- )
- sarif_result.locations = [location.sarif() for location in self.locations]
- sarif_result.stacks = [stack.sarif() for stack in self.stacks]
- sarif_result.graphs = [graph.sarif() for graph in self.graphs]
- sarif_result.code_flows = [
- sarif.CodeFlow(
- thread_flows=[
- sarif.ThreadFlow(
- locations=[loc.sarif() for loc in self.thread_flow_locations]
- )
- ]
- )
- ]
- sarif_result.properties = sarif.PropertyBag(
- tags=[tag.value for tag in self.tags]
- )
- return sarif_result
- def with_location(self: _Diagnostic, location: infra.Location) -> _Diagnostic:
- """Adds a location to the diagnostic."""
- self.locations.append(location)
- return self
- def with_thread_flow_location(
- self: _Diagnostic, location: infra.ThreadFlowLocation
- ) -> _Diagnostic:
- """Adds a thread flow location to the diagnostic."""
- self.thread_flow_locations.append(location)
- return self
- def with_stack(self: _Diagnostic, stack: infra.Stack) -> _Diagnostic:
- """Adds a stack to the diagnostic."""
- self.stacks.append(stack)
- return self
- def with_graph(self: _Diagnostic, graph: infra.Graph) -> _Diagnostic:
- """Adds a graph to the diagnostic."""
- self.graphs.append(graph)
- return self
- def with_additional_message(self: _Diagnostic, message: str) -> _Diagnostic:
- """Adds an additional message to the diagnostic."""
- if self.additional_message is None:
- self.additional_message = message
- else:
- self.additional_message = f"{self.additional_message}\n{message}"
- return self
- def record_python_call_stack(self, frames_to_skip: int) -> infra.Stack:
- """Records the current Python call stack."""
- frames_to_skip += 1 # Skip this function.
- stack = utils.python_call_stack(frames_to_skip=frames_to_skip)
- self.with_stack(stack)
- if len(stack.frames) > 0:
- self.with_location(stack.frames[0].location)
- return stack
- def record_python_call(
- self,
- fn: Callable,
- state: Mapping[str, str],
- message: Optional[str] = None,
- frames_to_skip: int = 0,
- ) -> infra.ThreadFlowLocation:
- """Records a python call as one thread flow step."""
- frames_to_skip += 1 # Skip this function.
- stack = utils.python_call_stack(frames_to_skip=frames_to_skip, frames_to_log=5)
- location = utils.function_location(fn)
- location.message = message
- # Add function location to the top of the stack.
- stack.frames.insert(0, infra.StackFrame(location=location))
- thread_flow_location = infra.ThreadFlowLocation(
- location=location,
- state=state,
- index=len(self.thread_flow_locations),
- stack=stack,
- )
- self.with_thread_flow_location(thread_flow_location)
- return thread_flow_location
- def pretty_print(
- self, verbose: bool = False, log_level: infra.Level = infra.Level.ERROR
- ):
- """Prints the diagnostics in a human-readable format.
- Args:
- verbose: If True, prints all information. E.g. stack frames, graphs, etc.
- Otherwise, only prints compact information. E.g., rule name and display message.
- log_level: The minimum level of diagnostics to print.
- """
- if self.level.value < log_level.value:
- return
- formatter.pretty_print_item_title(f"{self.level.name}: {self.rule.name}")
- print(self.message)
- print(self.additional_message)
- if not verbose:
- print("<Set verbose=True to see more details>\n")
- return
- formatter.pretty_print_title("Locations", fill_char="-")
- for location in self.locations:
- location.pretty_print()
- for stack in self.stacks:
- stack.pretty_print()
- formatter.pretty_print_title("Thread Flow Locations", fill_char="-")
- for thread_flow_location in self.thread_flow_locations:
- thread_flow_location.pretty_print(verbose=verbose)
- for graph in self.graphs:
- graph.pretty_print(verbose=verbose)
- print()
- # TODO: print help url to rule at the end.
- @dataclasses.dataclass
- class DiagnosticContext:
- name: str
- version: str
- options: infra.DiagnosticOptions = dataclasses.field(
- default_factory=infra.DiagnosticOptions
- )
- diagnostic_type: Type[Diagnostic] = dataclasses.field(default=Diagnostic)
- diagnostics: List[Diagnostic] = dataclasses.field(init=False, default_factory=list)
- # TODO(bowbao): Implement this.
- # _invocation: infra.Invocation = dataclasses.field(init=False)
- _inflight_diagnostics: List[Diagnostic] = dataclasses.field(
- init=False, default_factory=list
- )
- def __enter__(self):
- return self
- def __exit__(self, exc_type, exc_val, exc_tb):
- return True
- def sarif(self) -> sarif.Run:
- """Returns the SARIF Run object."""
- unique_rules = {diagnostic.rule for diagnostic in self.diagnostics}
- return sarif.Run(
- tool=sarif.Tool(
- driver=sarif.ToolComponent(
- name=self.name,
- version=self.version,
- rules=[rule.sarif() for rule in unique_rules],
- )
- ),
- results=[diagnostic.sarif() for diagnostic in self.diagnostics],
- )
- def add_diagnostic(self, diagnostic: Diagnostic) -> None:
- """Adds a diagnostic to the context.
- Use this method to add diagnostics that are not created by the context.
- Args:
- diagnostic: The diagnostic to add.
- """
- if not isinstance(diagnostic, Diagnostic):
- raise TypeError(
- f"Expected diagnostic of type {Diagnostic}, got {type(diagnostic)}"
- )
- self.diagnostics.append(diagnostic)
- @contextlib.contextmanager
- def add_inflight_diagnostic(
- self, diagnostic: Diagnostic
- ) -> Generator[Diagnostic, None, None]:
- """Adds a diagnostic to the context.
- Use this method to add diagnostics that are not created by the context.
- Args:
- diagnostic: The diagnostic to add.
- """
- self._inflight_diagnostics.append(diagnostic)
- try:
- yield diagnostic
- finally:
- self._inflight_diagnostics.pop()
- def diagnose(
- self,
- rule: infra.Rule,
- level: infra.Level,
- message: Optional[str] = None,
- **kwargs,
- ) -> Diagnostic:
- """Creates a diagnostic for the given arguments.
- Args:
- rule: The rule that triggered the diagnostic.
- level: The level of the diagnostic.
- message: The message of the diagnostic.
- **kwargs: Additional arguments to pass to the Diagnostic constructor.
- Returns:
- The created diagnostic.
- Raises:
- ValueError: If the rule is not supported by the tool.
- """
- diagnostic = self.diagnostic_type(rule, level, message, **kwargs)
- self.add_diagnostic(diagnostic)
- return diagnostic
- def push_inflight_diagnostic(self, diagnostic: Diagnostic) -> None:
- """Pushes a diagnostic to the inflight diagnostics stack.
- Args:
- diagnostic: The diagnostic to push.
- Raises:
- ValueError: If the rule is not supported by the tool.
- """
- self._inflight_diagnostics.append(diagnostic)
- def pop_inflight_diagnostic(self) -> Diagnostic:
- """Pops the last diagnostic from the inflight diagnostics stack.
- Returns:
- The popped diagnostic.
- """
- return self._inflight_diagnostics.pop()
- def inflight_diagnostic(self, rule: Optional[infra.Rule] = None) -> Diagnostic:
- if rule is None:
- # TODO(bowbao): Create builtin-rules and create diagnostic using that.
- if len(self._inflight_diagnostics) <= 0:
- raise DiagnosticError("No inflight diagnostics")
- return self._inflight_diagnostics[-1]
- else:
- # TODO(bowbao): Improve efficiency with Mapping[Rule, List[Diagnostic]]
- for diagnostic in reversed(self._inflight_diagnostics):
- if diagnostic.rule == rule:
- return diagnostic
- raise DiagnosticError(f"No inflight diagnostic for rule {rule.name}")
- def pretty_print(
- self, verbose: Optional[bool] = None, log_level: Optional[infra.Level] = None
- ) -> None:
- """Prints the diagnostics in a human-readable format.
- Args:
- verbose: Whether to print the diagnostics in verbose mode. See Diagnostic.pretty_print.
- If not specified, uses the value of 'self.options.log_verbose'.
- log_level: The minimum level of diagnostics to print.
- If not specified, uses the value of 'self.options.log_level'.
- """
- if verbose is None:
- verbose = self.options.log_verbose
- if log_level is None:
- log_level = self.options.log_level
- formatter.pretty_print_title(
- f"Diagnostic Run {self.name} version {self.version}"
- )
- print(f"verbose: {verbose}, log level: {log_level}")
- diagnostic_stats = {level: 0 for level in infra.Level}
- for diagnostic in self.diagnostics:
- diagnostic_stats[diagnostic.level] += 1
- formatter.pretty_print_title(
- " ".join(f"{diagnostic_stats[level]} {level.name}" for level in infra.Level)
- )
- for diagnostic in self.diagnostics:
- diagnostic.pretty_print(verbose, log_level)
- unprinted_diagnostic_stats = [
- (level, count)
- for level, count in diagnostic_stats.items()
- if count > 0 and level.value < log_level.value
- ]
- if unprinted_diagnostic_stats:
- print(
- f"{' '.join(f'{count} {level.name}' for level, count in unprinted_diagnostic_stats)} "
- "were not printed due to the log level."
- )
- print()
- class DiagnosticEngine:
- """A generic diagnostic engine based on SARIF.
- This class is the main interface for diagnostics. It manages the creation of diagnostic contexts.
- A DiagnosticContext provides the entry point for recording Diagnostics.
- See infra.DiagnosticContext for more details.
- Examples:
- Step 1: Create a set of rules.
- >>> # xdoctest: +REQUIRES(module:torch._C._distributed_c10d)
- >>> rules = infra.RuleCollection.custom_collection_from_list(
- ... "CustomRuleCollection",
- ... [
- ... infra.Rule(
- ... id="r1",
- ... name="rule-1",
- ... message_default_template="Mising xxx",
- ... ),
- ... ],
- ... )
- Step 2: Create a diagnostic engine.
- >>> engine = DiagnosticEngine()
- Step 3: Start a new diagnostic context.
- >>> with engine.create_diagnostic_context("torch.onnx.export", version="1.0") as context:
- ... ...
- Step 4: Add diagnostics in your code.
- ... context.diagnose(rules.rule1, infra.Level.ERROR)
- Step 5: Afterwards, get the SARIF log.
- >>> sarif_log = engine.sarif_log()
- """
- contexts: List[DiagnosticContext]
- def __init__(self) -> None:
- self.contexts = []
- def sarif_log(self) -> sarif.SarifLog:
- return sarif.SarifLog(
- version=sarif_version.SARIF_VERSION,
- schema_uri=sarif_version.SARIF_SCHEMA_LINK,
- runs=[context.sarif() for context in self.contexts],
- )
- def __str__(self) -> str:
- # TODO: pretty print.
- return self.to_json()
- def __repr__(self) -> str:
- return self.to_json()
- def to_json(self) -> str:
- return formatter.sarif_to_json(self.sarif_log())
- def dump(self, file_path: str, compress: bool = False) -> None:
- """Dumps the SARIF log to a file."""
- if compress:
- with gzip.open(file_path, "wt") as f:
- f.write(self.to_json())
- else:
- with open(file_path, "w") as f:
- f.write(self.to_json())
- def clear(self) -> None:
- """Clears all diagnostic contexts."""
- self.contexts.clear()
- def create_diagnostic_context(
- self,
- name: str,
- version: str,
- options: Optional[infra.DiagnosticOptions] = None,
- diagnostic_type: Type[Diagnostic] = Diagnostic,
- ) -> DiagnosticContext:
- """Creates a new diagnostic context.
- Args:
- name: The subject name for the diagnostic context.
- version: The subject version for the diagnostic context.
- options: The options for the diagnostic context.
- Returns:
- A new diagnostic context.
- """
- if options is None:
- options = infra.DiagnosticOptions()
- context = DiagnosticContext(
- name, version, options, diagnostic_type=diagnostic_type
- )
- self.contexts.append(context)
- return context
- def pretty_print(
- self, verbose: bool = False, level: infra.Level = infra.Level.ERROR
- ) -> None:
- """Pretty prints all diagnostics in the diagnostic contexts.
- Args:
- verbose: Whether to print the diagnostics in verbose mode. See Diagnostic.pretty_print.
- level: The minimum level of diagnostics to print.
- """
- formatter.pretty_print_title(f"{len(self.contexts)} Diagnostic Run")
- for context in self.contexts:
- context.pretty_print(verbose, level)
|