operator_schemas.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427
  1. import torch
  2. import inspect
  3. import numbers
  4. import types
  5. import typing
  6. import enum
  7. import warnings
  8. from typing import Any, Callable, Dict, List, Optional, Tuple, NamedTuple, cast, TYPE_CHECKING
  9. from torch._jit_internal import boolean_dispatched
  10. from ._compatibility import compatibility
  11. from torch._ops import OpOverloadPacket, OpOverload
  12. if TYPE_CHECKING:
  13. from .node import Argument
  14. __all__ = ["ArgsKwargsPair", "check_for_mutable_operation", "get_signature_for_torch_op", "create_type_hint",
  15. "type_matches", "normalize_function", "normalize_module"]
  16. @compatibility(is_backward_compatible=False)
  17. class ArgsKwargsPair(NamedTuple):
  18. """
  19. Simple named tuple for wrapping args/kwargs pairs.
  20. """
  21. args: Tuple[Any, ...]
  22. kwargs: Dict[str, Any]
  23. _manual_overrides : Dict[Callable, List[inspect.Signature]] = {}
  24. def _nonzero_schemas():
  25. signatures = []
  26. def nonzero(self):
  27. pass
  28. signatures.append(inspect.signature(nonzero))
  29. def nonzero(self, *, as_tuple : bool): # type: ignore[no-redef]
  30. pass
  31. signatures.append(inspect.signature(nonzero))
  32. return signatures
  33. _manual_overrides[torch.nonzero] = _nonzero_schemas()
  34. class _FakeGlobalNamespace:
  35. def __getattr__(self, name):
  36. if name == 'torch':
  37. return torch
  38. raise RuntimeError('Expected a torch namespace lookup')
  39. _type_eval_globals = {'Tensor' : torch.Tensor, 'Device' : torch.device, 'Layout' : torch.layout,
  40. 'number' : numbers.Number, 'Future' : torch.jit.Future,
  41. 'AnyEnumType' : enum.Enum, 'QScheme' : torch.qscheme,
  42. '__torch__': _FakeGlobalNamespace(), 'NoneType': type(None),
  43. 't': typing.TypeVar('t')}
  44. for k in dir(typing):
  45. _type_eval_globals[k] = getattr(typing, k)
  46. def _torchscript_type_to_python_type(ts_type : 'torch._C.JitType') -> Any:
  47. """
  48. Convert a TorchScript type to a Python type (including subtypes) via
  49. eval'ing the annotation_str. _type_eval_globals sets up expressions
  50. like "List" and "Future" to map to actual types (typing.List and jit.Future)
  51. """
  52. return eval(ts_type.annotation_str, _type_eval_globals)
  53. def _torchscript_schema_to_signature(ts_schema : torch._C.FunctionSchema) -> inspect.Signature:
  54. from inspect import Parameter
  55. parameters : List[Parameter] = []
  56. for arg in ts_schema.arguments:
  57. arg_type = _torchscript_type_to_python_type(arg.type)
  58. default = arg.default_value if arg.has_default_value() else Parameter.empty
  59. # TODO: Figure out if this is safe. It seems like when generating the type signatures for
  60. # PythonArgParser, we emit signatures with `input` instead of `self` as the first tensor
  61. # argument name. Downstream, if someone converts that positional argument to a keyword
  62. # argument, the name mismatch will break things, so here we're going to normalize the
  63. # name to "input"
  64. name = arg.name if arg.name != 'self' else 'input'
  65. kind = Parameter.KEYWORD_ONLY if arg.kwarg_only else Parameter.POSITIONAL_OR_KEYWORD
  66. # "from" is a keyword therefore it must be a POSITIONAL_ONLY argument
  67. if name == "from":
  68. assert kind == Parameter.POSITIONAL_OR_KEYWORD
  69. # ParameterKind type is internal implementation detail to inspec package
  70. # which makes it hard to do type annoation
  71. kind = Parameter.POSITIONAL_ONLY # type: ignore[assignment]
  72. # This renders all previous arguments to positional only
  73. for idx, p in enumerate(parameters):
  74. assert p.kind == Parameter.POSITIONAL_OR_KEYWORD
  75. parameters[idx] = Parameter(name=p.name, kind=Parameter.POSITIONAL_ONLY, default=p.default, annotation=p.annotation)
  76. parameters.append(Parameter(name=name, kind=kind, default=default, annotation=arg_type))
  77. return_types = [_torchscript_type_to_python_type(ret.type) for ret in ts_schema.returns]
  78. if len(return_types) == 0:
  79. return_type = None
  80. elif len(return_types) == 1:
  81. return_type = return_types[0]
  82. else:
  83. return_type = tuple(return_types)
  84. return inspect.Signature(parameters, return_annotation=return_type)
  85. @compatibility(is_backward_compatible=False)
  86. def check_for_mutable_operation(target : Callable, args : Tuple['Argument', ...], kwargs : Dict[str, 'Argument']):
  87. signatures, schemas = get_signature_for_torch_op(target, return_schemas=True)
  88. if signatures and schemas:
  89. matched_schemas = []
  90. # Iterate through all of the schema until we find one that matches
  91. # If one matches, populate `new_args_and_kwargs` with the new args/kwargs
  92. # values. If none matches, `new_args_and_kwargs` will be None
  93. for candidate_signature, schema in zip(signatures, schemas):
  94. try:
  95. candidate_signature.bind(*args, **kwargs)
  96. matched_schemas.append((candidate_signature, schema))
  97. except TypeError as e:
  98. continue
  99. def throw_if_mutable(schema):
  100. if schema.is_mutable:
  101. raise RuntimeError(f'Tried to trace mutable operation {schema}. FX only supports functional '
  102. f'code, so operations that mutate operands in-place (e.g. via `out` arguments) '
  103. f'are not supported')
  104. if len(matched_schemas) == 0:
  105. # Did not match any schema. Cannot check for mutation
  106. pass
  107. elif len(matched_schemas) == 1:
  108. # Matched exactly one schema, unambiguous
  109. _, schema_to_check = matched_schemas[0]
  110. throw_if_mutable(schema_to_check)
  111. pass
  112. else:
  113. # Ambiguous schema match. Since mutability checking is best effort,
  114. # do nothing.
  115. pass
  116. @compatibility(is_backward_compatible=False)
  117. def get_signature_for_torch_op(op : Callable, return_schemas : bool = False):
  118. """
  119. Given an operator on the `torch` namespace, return a list of `inspect.Signature`
  120. objects corresponding to the overloads of that op.. May return `None` if a signature
  121. could not be retrieved.
  122. Args:
  123. op (Callable): An operator on the `torch` namespace to look up a signature for
  124. Returns:
  125. Optional[List[inspect.Signature]]: A list of signatures for the overloads of this
  126. operator, or None if the operator signatures could not be retrieved. If
  127. return_schemas=True, returns a tuple containing the optional Python signatures
  128. and the optional TorchScript Function signature
  129. """
  130. if isinstance(op, OpOverload):
  131. schemas = [op._schema]
  132. elif isinstance(op, OpOverloadPacket):
  133. schemas = [getattr(op, overload)._schema for overload in op.overloads()]
  134. else:
  135. override = _manual_overrides.get(op)
  136. if override:
  137. return (override, None) if return_schemas else None
  138. aten_fn = torch.jit._builtins._find_builtin(op)
  139. if aten_fn is None:
  140. return (None, None) if return_schemas else None
  141. schemas = torch._C._jit_get_schemas_for_operator(aten_fn)
  142. signatures = [_torchscript_schema_to_signature(schema) for schema in schemas]
  143. return (signatures, schemas) if return_schemas else signatures
  144. @compatibility(is_backward_compatible=False)
  145. def create_type_hint(x):
  146. try:
  147. if isinstance(x, (list, tuple)):
  148. # todo(chilli): Figure out the right way for mypy to handle this
  149. if isinstance(x, list):
  150. def ret_type(x):
  151. return List[x] # type: ignore[valid-type]
  152. else:
  153. def ret_type(x):
  154. return Tuple[x, ...]
  155. if len(x) == 0:
  156. return ret_type(Any)
  157. base_type = x[0]
  158. for t in x:
  159. if issubclass(t, base_type):
  160. continue
  161. elif issubclass(base_type, t):
  162. base_type = t
  163. else:
  164. return ret_type(Any)
  165. return ret_type(base_type)
  166. except Exception as e:
  167. # We tried to create a type hint for list but failed.
  168. warnings.warn(f"We were not able to successfully create type hint from the type {x}")
  169. pass
  170. return x
  171. @compatibility(is_backward_compatible=False)
  172. def type_matches(signature_type : Any, argument_type : Any):
  173. sig_origin_type = getattr(signature_type, '__origin__', signature_type)
  174. if signature_type is argument_type:
  175. return True
  176. # Union types in signature. Given type needs to match one of the
  177. # contained types in the Union
  178. if sig_origin_type is typing.Union and signature_type != argument_type:
  179. sig_contained = signature_type.__args__
  180. return any(type_matches(c, argument_type) for c in sig_contained)
  181. if signature_type is List[int] and argument_type is int:
  182. # int can be promoted to List[int]
  183. return True
  184. if getattr(signature_type, '__origin__', None) in {list, List}:
  185. sig_el_type = signature_type.__args__[0]
  186. if not inspect.isclass(sig_el_type):
  187. warnings.warn(
  188. f"Does not support nested parametric types, got {signature_type}. Please file a bug.")
  189. return False
  190. if getattr(argument_type, '__origin__', None) in {list, List}:
  191. return issubclass(argument_type.__args__[0], sig_el_type)
  192. def is_homogeneous_tuple(t):
  193. if not getattr(t, '__origin__', None) in {tuple, Tuple}:
  194. return False
  195. contained = t.__args__
  196. if t.__args__ == ((),): # Tuple[()].__args__ == ((),) for some reason
  197. return True
  198. return all((c is Ellipsis) or issubclass(c, sig_el_type) for c in contained)
  199. # Tuple[T] is accepted for List[T] parameters
  200. return is_homogeneous_tuple(argument_type)
  201. # Dtype is an int in schemas
  202. if signature_type is int and argument_type is torch.dtype:
  203. return True
  204. if signature_type is numbers.Number and argument_type in {int, float}:
  205. return True
  206. if inspect.isclass(argument_type) and inspect.isclass(signature_type):
  207. return issubclass(argument_type, signature_type)
  208. return False
  209. @compatibility(is_backward_compatible=False)
  210. def normalize_function(
  211. target: Callable, args: Tuple[Any], kwargs : Optional[Dict[str, Any]] = None, arg_types : Optional[Tuple[Any]] = None,
  212. kwarg_types : Optional[Dict[str, Any]] = None,
  213. normalize_to_only_use_kwargs : bool = False) -> Optional[ArgsKwargsPair]:
  214. """
  215. Returns normalized arguments to PyTorch functions. This means that
  216. `args/kwargs` will be matched up to the functional's
  217. signature and return exclusively kwargs in positional order if
  218. `normalize_to_only_use_kwargs` is True.
  219. Also populates default values. Does not support positional-only
  220. parameters or varargs parameters (*args, **kwargs). Does not support modules.
  221. May require `arg_types` and `kwarg_types` in order to disambiguate overloads.
  222. Args:
  223. target (Callable): Function that we are normalizing
  224. args (Tuple[Any]): Tuple of args to the function
  225. kwargs (Optional[Dict[str, Any]]): Dict of kwargs to the function
  226. arg_types (Optional[Tuple[Any]]): Tuple of arg types for the args
  227. kwarg_types (Optional[Dict[str, Any]]): Dict of arg types for the kwargs
  228. normalize_to_only_use_kwargs (bool): Whether to normalize to only use kwargs.
  229. Returns:
  230. Returns normalized_args_and_kwargs, or `None` if not successful.
  231. """
  232. if kwargs is None:
  233. kwargs = {}
  234. new_args_and_kwargs = None
  235. if not isinstance(target, types.BuiltinFunctionType) and not (
  236. isinstance(target, (OpOverloadPacket, OpOverload))
  237. ):
  238. target_for_analysis = target
  239. if target in boolean_dispatched:
  240. # HACK: `boolean_dispatch` as used in `torch.nn.functional` makes it so that we have
  241. # a 2-way dispatch based on a boolean value. Here we check that the `true` and `false`
  242. # branches of the dispatch have exactly the same signature. If they do, use the `true`
  243. # branch signature for analysis. Otherwise, leave this un-normalized
  244. assert not isinstance(target, str)
  245. dispatched = boolean_dispatched[target]
  246. if_true, if_false = dispatched['if_true'], dispatched['if_false']
  247. if inspect.signature(if_true).parameters != inspect.signature(if_false).parameters:
  248. return None
  249. target_for_analysis = if_true
  250. assert callable(target_for_analysis)
  251. sig = inspect.signature(inspect.unwrap(target_for_analysis))
  252. new_args_and_kwargs = _args_kwargs_to_normalized_args_kwargs(sig, args, kwargs, normalize_to_only_use_kwargs)
  253. else:
  254. assert callable(target)
  255. torch_op_schemas = get_signature_for_torch_op(target)
  256. matched_schemas = []
  257. if torch_op_schemas:
  258. # Iterate through all of the schema until we find one that matches
  259. # If one matches, populate `new_args_and_kwargs` with the new args/kwargs
  260. # values. If none matches, `new_args_and_kwargs` will be None
  261. for candidate_signature in torch_op_schemas:
  262. try:
  263. candidate_signature.bind(*args, **kwargs)
  264. matched_schemas.append(candidate_signature)
  265. except TypeError as e:
  266. continue
  267. if len(matched_schemas) == 0:
  268. # Did not match any schema. Cannot normalize
  269. pass
  270. elif len(matched_schemas) == 1:
  271. # Matched exactly one schema, unambiguous
  272. new_args_and_kwargs = _args_kwargs_to_normalized_args_kwargs(matched_schemas[0], args, kwargs,
  273. normalize_to_only_use_kwargs)
  274. else:
  275. if arg_types is not None or kwarg_types is not None:
  276. arg_types = arg_types if arg_types else cast(Tuple[Any], ())
  277. kwarg_types = kwarg_types if kwarg_types else {}
  278. for candidate_signature in torch_op_schemas:
  279. sig_matches = True
  280. try:
  281. bound_types = candidate_signature.bind(*arg_types, **kwarg_types)
  282. for arg_name, arg_type in bound_types.arguments.items():
  283. param = candidate_signature.parameters[arg_name]
  284. sig_matches = sig_matches and type_matches(param.annotation, arg_type)
  285. except TypeError as e:
  286. sig_matches = False
  287. if sig_matches:
  288. new_args_and_kwargs = _args_kwargs_to_normalized_args_kwargs(candidate_signature, args, kwargs,
  289. normalize_to_only_use_kwargs)
  290. break
  291. else:
  292. # Matched more than one schema. In this situation, the caller must provide the types of
  293. # the arguments of the overload they expect.
  294. schema_printouts = '\n'.join(str(schema) for schema in matched_schemas)
  295. raise RuntimeError(f'Tried to normalize arguments to {torch.typename(target)} but '
  296. f'the schema match was ambiguous! Please provide argument types to '
  297. f'the normalize_arguments() call. Available schemas:\n{schema_printouts}')
  298. return new_args_and_kwargs
  299. @compatibility(is_backward_compatible=False)
  300. def normalize_module(
  301. root: torch.nn.Module, target: str, args: Tuple[Any], kwargs : Optional[Dict[str, Any]] = None,
  302. normalize_to_only_use_kwargs : bool = False) -> Optional[ArgsKwargsPair]:
  303. """
  304. Returns normalized arguments to PyTorch modules. This means that
  305. `args/kwargs` will be matched up to the functional's
  306. signature and return exclusively kwargs in positional order if
  307. `normalize_to_only_use_kwargs` is True.
  308. Also populates default values. Does not support positional-only
  309. parameters or varargs parameters (*args, **kwargs).
  310. Args:
  311. root (nn.Module): root module upon which we query modules
  312. target (Callable): Function that we are normalizing
  313. args (Tuple[Any]): Tuple of args to the function
  314. kwargs (Optional[Dict[str, Any]]): Dict of kwargs to the function
  315. normalize_to_only_use_kwargs (bool): Whether to normalize to only use kwargs.
  316. Returns:
  317. Returns normalized_args_and_kwargs, or `None` if not successful.
  318. """
  319. try:
  320. submod = root.get_submodule(target)
  321. except AttributeError as e:
  322. raise RuntimeError(f"Tried to normalize node with target {target} but root did not "
  323. f"have that target!") from e
  324. if hasattr(submod.__class__, '__name__'):
  325. classname = submod.__class__.__name__
  326. if getattr(torch.nn, classname, None) == submod.__class__:
  327. sig = inspect.signature(inspect.unwrap(submod.forward))
  328. if kwargs is None:
  329. kwargs = {}
  330. new_args_and_kwargs = _args_kwargs_to_normalized_args_kwargs(sig, args, kwargs,
  331. normalize_to_only_use_kwargs)
  332. return new_args_and_kwargs
  333. return None
  334. def _args_kwargs_to_normalized_args_kwargs(sig : inspect.Signature, args : Tuple[Any, ...],
  335. kwargs : Dict[str, Any],
  336. normalize_to_only_use_kwargs : bool) -> Optional[ArgsKwargsPair]:
  337. """
  338. Given a call target, args, and kwargs, return the arguments normalized into
  339. an ArgsKwargsPair, or None if the type signature is not supported by
  340. this normalization.
  341. Args:
  342. sig (inspect.Signature): Signature object for the target
  343. args (Tuple): Arguments that appear at the callsite for `target`
  344. kwargs (Dict): Keyword arguments that appear at the callsite for `target`
  345. normalize_to_only_use_kwargs (bool): Whether to normalize to only use kwargs.
  346. Returns:
  347. Optional[ArgsKwargsPair]: Normalized args and kwargs for `target`, or `None` if
  348. this target is not supported.
  349. """
  350. # Don't currently support positional-only
  351. # or varargs (*args, **kwargs) signatures
  352. supported_parameter_types = {
  353. inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.KEYWORD_ONLY}
  354. if any(p.kind not in supported_parameter_types for p in sig.parameters.values()):
  355. # Add an exception for one signature, which is common for random/uniform, i.e.:
  356. # Tensor(a!) self, float from=0, float to=1, *, Generator? generator=None
  357. # `from` is Python keyword and as such functions with that signature should have
  358. # positional-only args, but at the same time they could be dispatched as kwargs
  359. if list(sig.parameters.keys()) != ['input', 'from', 'to', 'generator']:
  360. return None
  361. bound_args = sig.bind(*args, **kwargs)
  362. bound_args.apply_defaults()
  363. new_kwargs : Dict[str, Any] = {}
  364. new_args : List[Any] = []
  365. for i, param in enumerate(sig.parameters):
  366. if not normalize_to_only_use_kwargs and i < len(args):
  367. new_args.append(bound_args.arguments[param])
  368. else:
  369. new_kwargs[param] = bound_args.arguments[param]
  370. return ArgsKwargsPair(tuple(new_args), new_kwargs)