123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117 |
- """ONNX exporter exceptions."""
- from __future__ import annotations
- import textwrap
- from typing import Optional
- from torch import _C
- from torch.onnx import _constants
- from torch.onnx._internal import diagnostics
- __all__ = [
- "OnnxExporterError",
- "OnnxExporterWarning",
- "CallHintViolationWarning",
- "CheckerError",
- "UnsupportedOperatorError",
- "SymbolicValueError",
- ]
- class OnnxExporterWarning(UserWarning):
- """Base class for all warnings in the ONNX exporter."""
- pass
- class CallHintViolationWarning(OnnxExporterWarning):
- """Warning raised when a type hint is violated during a function call."""
- pass
- class OnnxExporterError(RuntimeError):
- """Errors raised by the ONNX exporter."""
- pass
- class CheckerError(OnnxExporterError):
- """Raised when ONNX checker detects an invalid model."""
- pass
- class UnsupportedOperatorError(OnnxExporterError):
- """Raised when an operator is unsupported by the exporter."""
- def __init__(self, name: str, version: int, supported_version: Optional[int]):
- if supported_version is not None:
- diagnostic_rule: diagnostics.infra.Rule = (
- diagnostics.rules.operator_supported_in_newer_opset_version
- )
- msg = diagnostic_rule.format_message(name, version, supported_version)
- diagnostics.diagnose(diagnostic_rule, diagnostics.levels.ERROR, msg)
- else:
- if (
- name.startswith("aten::")
- or name.startswith("prim::")
- or name.startswith("quantized::")
- ):
- diagnostic_rule = diagnostics.rules.missing_standard_symbolic_function
- msg = diagnostic_rule.format_message(
- name, version, _constants.PYTORCH_GITHUB_ISSUES_URL
- )
- diagnostics.diagnose(diagnostic_rule, diagnostics.levels.ERROR, msg)
- else:
- diagnostic_rule = diagnostics.rules.missing_custom_symbolic_function
- msg = diagnostic_rule.format_message(name)
- diagnostics.diagnose(diagnostic_rule, diagnostics.levels.ERROR, msg)
- super().__init__(msg)
- class SymbolicValueError(OnnxExporterError):
- """Errors around TorchScript values and nodes."""
- def __init__(self, msg: str, value: _C.Value):
- message = (
- f"{msg} [Caused by the value '{value}' (type '{value.type()}') in the "
- f"TorchScript graph. The containing node has kind '{value.node().kind()}'.] "
- )
- code_location = value.node().sourceRange()
- if code_location:
- message += f"\n (node defined in {code_location})"
- try:
- # Add its input and output to the message.
- message += "\n\n"
- message += textwrap.indent(
- (
- "Inputs:\n"
- + (
- "\n".join(
- f" #{i}: {input_} (type '{input_.type()}')"
- for i, input_ in enumerate(value.node().inputs())
- )
- or " Empty"
- )
- + "\n"
- + "Outputs:\n"
- + (
- "\n".join(
- f" #{i}: {output} (type '{output.type()}')"
- for i, output in enumerate(value.node().outputs())
- )
- or " Empty"
- )
- ),
- " ",
- )
- except AttributeError:
- message += (
- " Failed to obtain its input and output for debugging. "
- "Please refer to the TorchScript graph for debugging information."
- )
- super().__init__(message)
|