12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697 |
- import inspect
- from typing import Dict, List, Union
- from torch import _C
- from torch.onnx import _constants
- from torch.onnx._internal import registration
- class _TorchSchema:
- def __init__(self, schema: Union[_C.FunctionSchema, str]) -> None:
- if isinstance(schema, _C.FunctionSchema):
- self.name: str = schema.name
- self.overload_name: str = schema.overload_name
- self.arguments: List[str] = [arg.name for arg in schema.arguments]
- self.optional_arguments: List[str] = []
- self.returns: List[str] = [ret.name for ret in schema.returns]
- self.opsets: List[int] = []
- else:
- self.name = schema
- self.overload_name = ""
- self.arguments = []
- self.optional_arguments = []
- self.returns = []
- self.opsets = []
- def __str__(self) -> str:
- s = (
- f"{self.name}.{self.overload_name}("
- + ", ".join(self.arguments)
- + ") -> ("
- + ", ".join(self.returns)
- + ")"
- + " in opsets "
- + ", ".join(str(opset) for opset in self.opsets)
- )
- return s
- def __hash__(self):
- # TODO(thiagocrepaldi): handle overload_name?
- return hash(self.name)
- def __eq__(self, other) -> bool:
- if not isinstance(other, _TorchSchema):
- return False
- # TODO(thiagocrepaldi): handle overload_name?
- return self.name == other.name
- def is_aten(self) -> bool:
- return self.name.startswith("aten::")
- def is_backward(self) -> bool:
- return "backward" in self.name
- def _symbolic_argument_count(func):
- params = []
- signature = inspect.signature(func)
- optional_params = []
- for name, parameter in signature.parameters.items():
- if name in {"_outputs", "g"}:
- continue
- if parameter.default is parameter.empty:
- optional_params.append(parameter)
- else:
- params.append(str(parameter))
- return params
- def all_forward_schemas() -> Dict[str, _TorchSchema]:
- """Returns schemas for all TorchScript forward ops."""
- torch_schemas = [_TorchSchema(s) for s in _C._jit_get_all_schemas()]
- return {schema.name: schema for schema in torch_schemas if not schema.is_backward()}
- def all_symbolics_schemas() -> Dict[str, _TorchSchema]:
- """Returns schemas for all onnx supported ops."""
- symbolics_schemas = {}
- for name in registration.registry.all_functions():
- func_group = registration.registry.get_function_group(name)
- assert func_group is not None
- symbolics_schema = _TorchSchema(name)
- func = func_group.get(_constants.ONNX_MAX_OPSET)
- if func is not None:
- symbolics_schema.arguments = _symbolic_argument_count(func)
- symbolics_schema.opsets = list(
- range(func_group.get_min_supported(), _constants.ONNX_MAX_OPSET + 1)
- )
- else:
- # Only support opset < 9
- func = func_group.get(7)
- symbolics_schema.arguments = _symbolic_argument_count(func)
- symbolics_schema.opsets = list(range(7, _constants.ONNX_BASE_OPSET))
- symbolics_schemas[name] = symbolics_schema
- return symbolics_schemas
|