diagnostics.py 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. import functools
  2. from typing import Any
  3. import onnxscript # type: ignore[import]
  4. from onnxscript.function_libs.torch_aten import graph_building # type: ignore[import]
  5. import torch
  6. from torch.onnx._internal import diagnostics
  7. from torch.onnx._internal.diagnostics import infra
  8. from torch.onnx._internal.diagnostics.infra import decorator, formatter, utils
  9. _LENGTH_LIMIT: int = 80
  10. # NOTE(bowbao): This is a shim over `torch.onnx._internal.diagnostics`, which is
  11. # used in `torch.onnx`, and loaded with `torch`. Hence anything related to `onnxscript`
  12. # cannot be put there.
  13. @functools.singledispatch
  14. def _format_argument(obj: Any) -> str:
  15. return formatter.format_argument(obj)
  16. def format_argument(obj: Any) -> str:
  17. formatter = _format_argument.dispatch(type(obj))
  18. result_str = formatter(obj)
  19. if len(result_str) > _LENGTH_LIMIT:
  20. # TODO(bowbao): group diagnostics.
  21. # Related fields of sarif.Result: occurance_count, fingerprints.
  22. # Do a final process to group results before outputing sarif log.
  23. diag = infra.Diagnostic(
  24. *diagnostics.rules.arg_format_too_verbose.format(
  25. level=infra.levels.WARNING,
  26. length=len(result_str),
  27. length_limit=_LENGTH_LIMIT,
  28. argument_type=type(obj),
  29. formatter_type=type(format_argument),
  30. )
  31. )
  32. diag.with_location(utils.function_location(formatter))
  33. diagnostics.export_context().add_diagnostic(diag)
  34. return result_str
  35. @_format_argument.register
  36. def _torch_nn_module(obj: torch.nn.Module) -> str:
  37. return f"{obj.__class__.__name__}"
  38. @_format_argument.register
  39. def _torch_fx_graph_module(obj: torch.fx.GraphModule) -> str:
  40. return f"{obj.print_readable(print_output=False)}"
  41. @_format_argument.register
  42. def _torch_tensor(obj: torch.Tensor) -> str:
  43. return f"Tensor(shape={obj.shape}, dtype={obj.dtype})"
  44. @_format_argument.register
  45. def _torch_nn_parameter(obj: torch.nn.Parameter) -> str:
  46. return f"Parameter({format_argument(obj.data)})"
  47. @_format_argument.register
  48. def _onnxscript_torch_script_tensor(obj: graph_building.TorchScriptTensor) -> str:
  49. # TODO(bowbao) obj.dtype throws error.
  50. return f"`TorchScriptTensor({obj.name}, {obj.onnx_dtype}, {obj.shape}, {obj.symbolic_value()})`"
  51. @_format_argument.register
  52. def _onnxscript_onnx_function(obj: onnxscript.values.OnnxFunction) -> str:
  53. return f"`OnnxFunction({obj.name})`"
  54. diagnose_call = functools.partial(
  55. decorator.diagnose_call,
  56. diagnostics.export_context,
  57. diagnostic_type=diagnostics.ExportDiagnostic,
  58. format_argument=format_argument,
  59. )
  60. diagnose_step = functools.partial(
  61. decorator.diagnose_step,
  62. diagnostics.export_context,
  63. format_argument=format_argument,
  64. )
  65. rules = diagnostics.rules
  66. export_context = diagnostics.export_context
  67. levels = diagnostics.levels