123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293 |
- import functools
- from typing import Any
- import onnxscript # type: ignore[import]
- from onnxscript.function_libs.torch_aten import graph_building # type: ignore[import]
- import torch
- from torch.onnx._internal import diagnostics
- from torch.onnx._internal.diagnostics import infra
- from torch.onnx._internal.diagnostics.infra import decorator, formatter, utils
- _LENGTH_LIMIT: int = 80
- # NOTE(bowbao): This is a shim over `torch.onnx._internal.diagnostics`, which is
- # used in `torch.onnx`, and loaded with `torch`. Hence anything related to `onnxscript`
- # cannot be put there.
- @functools.singledispatch
- def _format_argument(obj: Any) -> str:
- return formatter.format_argument(obj)
- def format_argument(obj: Any) -> str:
- formatter = _format_argument.dispatch(type(obj))
- result_str = formatter(obj)
- if len(result_str) > _LENGTH_LIMIT:
- # TODO(bowbao): group diagnostics.
- # Related fields of sarif.Result: occurance_count, fingerprints.
- # Do a final process to group results before outputing sarif log.
- diag = infra.Diagnostic(
- *diagnostics.rules.arg_format_too_verbose.format(
- level=infra.levels.WARNING,
- length=len(result_str),
- length_limit=_LENGTH_LIMIT,
- argument_type=type(obj),
- formatter_type=type(format_argument),
- )
- )
- diag.with_location(utils.function_location(formatter))
- diagnostics.export_context().add_diagnostic(diag)
- return result_str
- @_format_argument.register
- def _torch_nn_module(obj: torch.nn.Module) -> str:
- return f"{obj.__class__.__name__}"
- @_format_argument.register
- def _torch_fx_graph_module(obj: torch.fx.GraphModule) -> str:
- return f"{obj.print_readable(print_output=False)}"
- @_format_argument.register
- def _torch_tensor(obj: torch.Tensor) -> str:
- return f"Tensor(shape={obj.shape}, dtype={obj.dtype})"
- @_format_argument.register
- def _torch_nn_parameter(obj: torch.nn.Parameter) -> str:
- return f"Parameter({format_argument(obj.data)})"
- @_format_argument.register
- def _onnxscript_torch_script_tensor(obj: graph_building.TorchScriptTensor) -> str:
- # TODO(bowbao) obj.dtype throws error.
- return f"`TorchScriptTensor({obj.name}, {obj.onnx_dtype}, {obj.shape}, {obj.symbolic_value()})`"
- @_format_argument.register
- def _onnxscript_onnx_function(obj: onnxscript.values.OnnxFunction) -> str:
- return f"`OnnxFunction({obj.name})`"
- diagnose_call = functools.partial(
- decorator.diagnose_call,
- diagnostics.export_context,
- diagnostic_type=diagnostics.ExportDiagnostic,
- format_argument=format_argument,
- )
- diagnose_step = functools.partial(
- decorator.diagnose_step,
- diagnostics.export_context,
- format_argument=format_argument,
- )
- rules = diagnostics.rules
- export_context = diagnostics.export_context
- levels = diagnostics.levels
|