errors.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117
  1. """ONNX exporter exceptions."""
  2. from __future__ import annotations
  3. import textwrap
  4. from typing import Optional
  5. from torch import _C
  6. from torch.onnx import _constants
  7. from torch.onnx._internal import diagnostics
  8. __all__ = [
  9. "OnnxExporterError",
  10. "OnnxExporterWarning",
  11. "CallHintViolationWarning",
  12. "CheckerError",
  13. "UnsupportedOperatorError",
  14. "SymbolicValueError",
  15. ]
  16. class OnnxExporterWarning(UserWarning):
  17. """Base class for all warnings in the ONNX exporter."""
  18. pass
  19. class CallHintViolationWarning(OnnxExporterWarning):
  20. """Warning raised when a type hint is violated during a function call."""
  21. pass
  22. class OnnxExporterError(RuntimeError):
  23. """Errors raised by the ONNX exporter."""
  24. pass
  25. class CheckerError(OnnxExporterError):
  26. """Raised when ONNX checker detects an invalid model."""
  27. pass
  28. class UnsupportedOperatorError(OnnxExporterError):
  29. """Raised when an operator is unsupported by the exporter."""
  30. def __init__(self, name: str, version: int, supported_version: Optional[int]):
  31. if supported_version is not None:
  32. diagnostic_rule: diagnostics.infra.Rule = (
  33. diagnostics.rules.operator_supported_in_newer_opset_version
  34. )
  35. msg = diagnostic_rule.format_message(name, version, supported_version)
  36. diagnostics.diagnose(diagnostic_rule, diagnostics.levels.ERROR, msg)
  37. else:
  38. if (
  39. name.startswith("aten::")
  40. or name.startswith("prim::")
  41. or name.startswith("quantized::")
  42. ):
  43. diagnostic_rule = diagnostics.rules.missing_standard_symbolic_function
  44. msg = diagnostic_rule.format_message(
  45. name, version, _constants.PYTORCH_GITHUB_ISSUES_URL
  46. )
  47. diagnostics.diagnose(diagnostic_rule, diagnostics.levels.ERROR, msg)
  48. else:
  49. diagnostic_rule = diagnostics.rules.missing_custom_symbolic_function
  50. msg = diagnostic_rule.format_message(name)
  51. diagnostics.diagnose(diagnostic_rule, diagnostics.levels.ERROR, msg)
  52. super().__init__(msg)
  53. class SymbolicValueError(OnnxExporterError):
  54. """Errors around TorchScript values and nodes."""
  55. def __init__(self, msg: str, value: _C.Value):
  56. message = (
  57. f"{msg} [Caused by the value '{value}' (type '{value.type()}') in the "
  58. f"TorchScript graph. The containing node has kind '{value.node().kind()}'.] "
  59. )
  60. code_location = value.node().sourceRange()
  61. if code_location:
  62. message += f"\n (node defined in {code_location})"
  63. try:
  64. # Add its input and output to the message.
  65. message += "\n\n"
  66. message += textwrap.indent(
  67. (
  68. "Inputs:\n"
  69. + (
  70. "\n".join(
  71. f" #{i}: {input_} (type '{input_.type()}')"
  72. for i, input_ in enumerate(value.node().inputs())
  73. )
  74. or " Empty"
  75. )
  76. + "\n"
  77. + "Outputs:\n"
  78. + (
  79. "\n".join(
  80. f" #{i}: {output} (type '{output.type()}')"
  81. for i, output in enumerate(value.node().outputs())
  82. )
  83. or " Empty"
  84. )
  85. ),
  86. " ",
  87. )
  88. except AttributeError:
  89. message += (
  90. " Failed to obtain its input and output for debugging. "
  91. "Please refer to the TorchScript graph for debugging information."
  92. )
  93. super().__init__(message)