import inspect from typing import Dict, List, Union from torch import _C from torch.onnx import _constants from torch.onnx._internal import registration class _TorchSchema: def __init__(self, schema: Union[_C.FunctionSchema, str]) -> None: if isinstance(schema, _C.FunctionSchema): self.name: str = schema.name self.overload_name: str = schema.overload_name self.arguments: List[str] = [arg.name for arg in schema.arguments] self.optional_arguments: List[str] = [] self.returns: List[str] = [ret.name for ret in schema.returns] self.opsets: List[int] = [] else: self.name = schema self.overload_name = "" self.arguments = [] self.optional_arguments = [] self.returns = [] self.opsets = [] def __str__(self) -> str: s = ( f"{self.name}.{self.overload_name}(" + ", ".join(self.arguments) + ") -> (" + ", ".join(self.returns) + ")" + " in opsets " + ", ".join(str(opset) for opset in self.opsets) ) return s def __hash__(self): # TODO(thiagocrepaldi): handle overload_name? return hash(self.name) def __eq__(self, other) -> bool: if not isinstance(other, _TorchSchema): return False # TODO(thiagocrepaldi): handle overload_name? return self.name == other.name def is_aten(self) -> bool: return self.name.startswith("aten::") def is_backward(self) -> bool: return "backward" in self.name def _symbolic_argument_count(func): params = [] signature = inspect.signature(func) optional_params = [] for name, parameter in signature.parameters.items(): if name in {"_outputs", "g"}: continue if parameter.default is parameter.empty: optional_params.append(parameter) else: params.append(str(parameter)) return params def all_forward_schemas() -> Dict[str, _TorchSchema]: """Returns schemas for all TorchScript forward ops.""" torch_schemas = [_TorchSchema(s) for s in _C._jit_get_all_schemas()] return {schema.name: schema for schema in torch_schemas if not schema.is_backward()} def all_symbolics_schemas() -> Dict[str, _TorchSchema]: """Returns schemas for all onnx supported ops.""" symbolics_schemas = {} for name in registration.registry.all_functions(): func_group = registration.registry.get_function_group(name) assert func_group is not None symbolics_schema = _TorchSchema(name) func = func_group.get(_constants.ONNX_MAX_OPSET) if func is not None: symbolics_schema.arguments = _symbolic_argument_count(func) symbolics_schema.opsets = list( range(func_group.get_min_supported(), _constants.ONNX_MAX_OPSET + 1) ) else: # Only support opset < 9 func = func_group.get(7) symbolics_schema.arguments = _symbolic_argument_count(func) symbolics_schema.opsets = list(range(7, _constants.ONNX_BASE_OPSET)) symbolics_schemas[name] = symbolics_schema return symbolics_schemas