import functools from typing import Any import onnxscript # type: ignore[import] from onnxscript.function_libs.torch_aten import graph_building # type: ignore[import] import torch from torch.onnx._internal import diagnostics from torch.onnx._internal.diagnostics import infra from torch.onnx._internal.diagnostics.infra import decorator, formatter, utils _LENGTH_LIMIT: int = 80 # NOTE(bowbao): This is a shim over `torch.onnx._internal.diagnostics`, which is # used in `torch.onnx`, and loaded with `torch`. Hence anything related to `onnxscript` # cannot be put there. @functools.singledispatch def _format_argument(obj: Any) -> str: return formatter.format_argument(obj) def format_argument(obj: Any) -> str: formatter = _format_argument.dispatch(type(obj)) result_str = formatter(obj) if len(result_str) > _LENGTH_LIMIT: # TODO(bowbao): group diagnostics. # Related fields of sarif.Result: occurance_count, fingerprints. # Do a final process to group results before outputing sarif log. diag = infra.Diagnostic( *diagnostics.rules.arg_format_too_verbose.format( level=infra.levels.WARNING, length=len(result_str), length_limit=_LENGTH_LIMIT, argument_type=type(obj), formatter_type=type(format_argument), ) ) diag.with_location(utils.function_location(formatter)) diagnostics.export_context().add_diagnostic(diag) return result_str @_format_argument.register def _torch_nn_module(obj: torch.nn.Module) -> str: return f"{obj.__class__.__name__}" @_format_argument.register def _torch_fx_graph_module(obj: torch.fx.GraphModule) -> str: return f"{obj.print_readable(print_output=False)}" @_format_argument.register def _torch_tensor(obj: torch.Tensor) -> str: return f"Tensor(shape={obj.shape}, dtype={obj.dtype})" @_format_argument.register def _torch_nn_parameter(obj: torch.nn.Parameter) -> str: return f"Parameter({format_argument(obj.data)})" @_format_argument.register def _onnxscript_torch_script_tensor(obj: graph_building.TorchScriptTensor) -> str: # TODO(bowbao) obj.dtype throws error. return f"`TorchScriptTensor({obj.name}, {obj.onnx_dtype}, {obj.shape}, {obj.symbolic_value()})`" @_format_argument.register def _onnxscript_onnx_function(obj: onnxscript.values.OnnxFunction) -> str: return f"`OnnxFunction({obj.name})`" diagnose_call = functools.partial( decorator.diagnose_call, diagnostics.export_context, diagnostic_type=diagnostics.ExportDiagnostic, format_argument=format_argument, ) diagnose_step = functools.partial( decorator.diagnose_step, diagnostics.export_context, format_argument=format_argument, ) rules = diagnostics.rules export_context = diagnostics.export_context levels = diagnostics.levels