_onnx_supported_ops.py 3.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. import inspect
  2. from typing import Dict, List, Union
  3. from torch import _C
  4. from torch.onnx import _constants
  5. from torch.onnx._internal import registration
  6. class _TorchSchema:
  7. def __init__(self, schema: Union[_C.FunctionSchema, str]) -> None:
  8. if isinstance(schema, _C.FunctionSchema):
  9. self.name: str = schema.name
  10. self.overload_name: str = schema.overload_name
  11. self.arguments: List[str] = [arg.name for arg in schema.arguments]
  12. self.optional_arguments: List[str] = []
  13. self.returns: List[str] = [ret.name for ret in schema.returns]
  14. self.opsets: List[int] = []
  15. else:
  16. self.name = schema
  17. self.overload_name = ""
  18. self.arguments = []
  19. self.optional_arguments = []
  20. self.returns = []
  21. self.opsets = []
  22. def __str__(self) -> str:
  23. s = (
  24. f"{self.name}.{self.overload_name}("
  25. + ", ".join(self.arguments)
  26. + ") -> ("
  27. + ", ".join(self.returns)
  28. + ")"
  29. + " in opsets "
  30. + ", ".join(str(opset) for opset in self.opsets)
  31. )
  32. return s
  33. def __hash__(self):
  34. # TODO(thiagocrepaldi): handle overload_name?
  35. return hash(self.name)
  36. def __eq__(self, other) -> bool:
  37. if not isinstance(other, _TorchSchema):
  38. return False
  39. # TODO(thiagocrepaldi): handle overload_name?
  40. return self.name == other.name
  41. def is_aten(self) -> bool:
  42. return self.name.startswith("aten::")
  43. def is_backward(self) -> bool:
  44. return "backward" in self.name
  45. def _symbolic_argument_count(func):
  46. params = []
  47. signature = inspect.signature(func)
  48. optional_params = []
  49. for name, parameter in signature.parameters.items():
  50. if name in {"_outputs", "g"}:
  51. continue
  52. if parameter.default is parameter.empty:
  53. optional_params.append(parameter)
  54. else:
  55. params.append(str(parameter))
  56. return params
  57. def all_forward_schemas() -> Dict[str, _TorchSchema]:
  58. """Returns schemas for all TorchScript forward ops."""
  59. torch_schemas = [_TorchSchema(s) for s in _C._jit_get_all_schemas()]
  60. return {schema.name: schema for schema in torch_schemas if not schema.is_backward()}
  61. def all_symbolics_schemas() -> Dict[str, _TorchSchema]:
  62. """Returns schemas for all onnx supported ops."""
  63. symbolics_schemas = {}
  64. for name in registration.registry.all_functions():
  65. func_group = registration.registry.get_function_group(name)
  66. assert func_group is not None
  67. symbolics_schema = _TorchSchema(name)
  68. func = func_group.get(_constants.ONNX_MAX_OPSET)
  69. if func is not None:
  70. symbolics_schema.arguments = _symbolic_argument_count(func)
  71. symbolics_schema.opsets = list(
  72. range(func_group.get_min_supported(), _constants.ONNX_MAX_OPSET + 1)
  73. )
  74. else:
  75. # Only support opset < 9
  76. func = func_group.get(7)
  77. symbolics_schema.arguments = _symbolic_argument_count(func)
  78. symbolics_schema.opsets = list(range(7, _constants.ONNX_BASE_OPSET))
  79. symbolics_schemas[name] = symbolics_schema
  80. return symbolics_schemas