annotations.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459
  1. import ast
  2. import dis
  3. import enum
  4. import inspect
  5. import re
  6. import builtins
  7. import torch
  8. import warnings
  9. from .._jit_internal import List, Tuple, is_tuple, is_list, Dict, is_dict, Optional, \
  10. is_optional, _qualified_name, Any, Future, is_future, _Await, is_await, is_ignored_fn, Union, is_union
  11. from .._jit_internal import BroadcastingList1, BroadcastingList2, BroadcastingList3 # type: ignore[attr-defined]
  12. from ._state import _get_script_class
  13. from torch._C import TensorType, TupleType, FloatType, IntType, ComplexType, \
  14. ListType, StringType, DictType, BoolType, OptionalType, InterfaceType, AnyType, \
  15. NoneType, DeviceObjType, StreamObjType, FutureType, AwaitType, EnumType, UnionType, NumberType
  16. from textwrap import dedent
  17. from torch._sources import get_source_lines_and_file
  18. from typing import Type
  19. if torch.distributed.rpc.is_available():
  20. from .._jit_internal import RRef, is_rref
  21. from torch._C import RRefType
  22. from torch._ops import OpOverloadPacket
  23. class Module:
  24. def __init__(self, name, members):
  25. self.name = name
  26. self.members = members
  27. def __getattr__(self, name):
  28. try:
  29. return self.members[name]
  30. except KeyError:
  31. raise RuntimeError(f"Module {self.name} has no member called {name}") from None
  32. class EvalEnv:
  33. env = {
  34. 'torch': Module('torch', {'Tensor': torch.Tensor}),
  35. 'Tensor': torch.Tensor,
  36. 'typing': Module('typing', {'Tuple': Tuple}),
  37. 'Tuple': Tuple,
  38. 'List': List,
  39. 'Dict': Dict,
  40. 'Optional': Optional,
  41. 'Union': Union,
  42. 'Future': Future,
  43. 'Await': _Await
  44. }
  45. def __init__(self, rcb):
  46. self.rcb = rcb
  47. if torch.distributed.rpc.is_available():
  48. self.env['RRef'] = RRef
  49. def __getitem__(self, name):
  50. if name in self.env:
  51. return self.env[name]
  52. if self.rcb is not None:
  53. return self.rcb(name)
  54. return getattr(builtins, name, None)
  55. def get_signature(fn, rcb, loc, is_method):
  56. if isinstance(fn, OpOverloadPacket):
  57. signature = try_real_annotations(fn.op, loc)
  58. else:
  59. signature = try_real_annotations(fn, loc)
  60. if signature is not None and is_method:
  61. # If this is a method, then the signature will include a type for
  62. # `self`, but type comments do not contain a `self`. So strip it
  63. # away here so everything is consistent (`inspect.ismethod` does
  64. # not work here since `fn` is unbound at this point)
  65. param_types, return_type = signature
  66. param_types = param_types[1:]
  67. signature = (param_types, return_type)
  68. if signature is None:
  69. type_line, source = None, None
  70. try:
  71. source = dedent(''.join(get_source_lines_and_file(fn)[0]))
  72. type_line = get_type_line(source)
  73. except TypeError:
  74. pass
  75. # This might happen both because we failed to get the source of fn, or
  76. # because it didn't have any annotations.
  77. if type_line is not None:
  78. signature = parse_type_line(type_line, rcb, loc)
  79. return signature
  80. def is_function_or_method(the_callable):
  81. # A stricter version of `inspect.isroutine` that does not pass for built-in
  82. # functions
  83. return inspect.isfunction(the_callable) or inspect.ismethod(the_callable)
  84. def is_vararg(the_callable):
  85. if not is_function_or_method(the_callable) and hasattr(the_callable, '__call__'): # noqa: B004
  86. # If `the_callable` is a class, de-sugar the call so we can still get
  87. # the signature
  88. the_callable = the_callable.__call__
  89. if is_function_or_method(the_callable):
  90. return inspect.getfullargspec(the_callable).varargs is not None
  91. else:
  92. return False
  93. def get_param_names(fn, n_args):
  94. if isinstance(fn, OpOverloadPacket):
  95. fn = fn.op
  96. if not is_function_or_method(fn) and hasattr(fn, '__call__') and is_function_or_method(fn.__call__): # noqa: B004
  97. # De-sugar calls to classes
  98. fn = fn.__call__
  99. if is_function_or_method(fn):
  100. if is_ignored_fn(fn):
  101. fn = inspect.unwrap(fn)
  102. return inspect.getfullargspec(fn).args
  103. else:
  104. # The `fn` was not a method or function (maybe a class with a __call__
  105. # method, so use a default param name list)
  106. return [str(i) for i in range(n_args)]
  107. def check_fn(fn, loc):
  108. # Make sure the function definition is not a class instantiation
  109. try:
  110. source = dedent(''.join(get_source_lines_and_file(fn)[0]))
  111. except (TypeError, IOError):
  112. return
  113. if source is None:
  114. return
  115. py_ast = ast.parse(source)
  116. if len(py_ast.body) == 1 and isinstance(py_ast.body[0], ast.ClassDef):
  117. raise torch.jit.frontend.FrontendError(
  118. loc, f"Cannot instantiate class '{py_ast.body[0].name}' in a script function")
  119. if len(py_ast.body) != 1 or not isinstance(py_ast.body[0], ast.FunctionDef):
  120. raise torch.jit.frontend.FrontendError(loc, "Expected a single top-level function")
  121. def _eval_no_call(stmt, glob, loc):
  122. """Evaluate statement as long as it does not contain any method/function calls"""
  123. bytecode = compile(stmt, "", mode="eval")
  124. for insn in dis.get_instructions(bytecode):
  125. if "CALL" in insn.opname:
  126. raise RuntimeError(f"Type annotation should not contain calls, but '{stmt}' does")
  127. return eval(bytecode, glob, loc) # type: ignore[arg-type] # noqa: P204
  128. def parse_type_line(type_line, rcb, loc):
  129. """Parses a type annotation specified as a comment.
  130. Example inputs:
  131. # type: (Tensor, torch.Tensor) -> Tuple[Tensor]
  132. # type: (Tensor, Tuple[Tensor, Tensor]) -> Tensor
  133. """
  134. arg_ann_str, ret_ann_str = split_type_line(type_line)
  135. try:
  136. arg_ann = _eval_no_call(arg_ann_str, {}, EvalEnv(rcb))
  137. except (NameError, SyntaxError) as e:
  138. raise RuntimeError("Failed to parse the argument list of a type annotation") from e
  139. if not isinstance(arg_ann, tuple):
  140. arg_ann = (arg_ann,)
  141. try:
  142. ret_ann = _eval_no_call(ret_ann_str, {}, EvalEnv(rcb))
  143. except (NameError, SyntaxError) as e:
  144. raise RuntimeError("Failed to parse the return type of a type annotation") from e
  145. arg_types = [ann_to_type(ann, loc) for ann in arg_ann]
  146. return arg_types, ann_to_type(ret_ann, loc)
  147. def get_type_line(source):
  148. """Tries to find the line containing a comment with the type annotation."""
  149. type_comment = '# type:'
  150. lines = source.split('\n')
  151. lines = [(line_num, line) for line_num, line in enumerate(lines)]
  152. type_lines = list(filter(lambda line: type_comment in line[1], lines))
  153. # `type: ignore` comments may be needed in JIT'ed functions for mypy, due
  154. # to the hack in torch/_VF.py.
  155. # An ignore type comment can be of following format:
  156. # 1) type: ignore
  157. # 2) type: ignore[rule-code]
  158. # This ignore statement must be at the end of the line
  159. # adding an extra backslash before the space, to avoid triggering
  160. # one of the checks in .github/workflows/lint.yml
  161. type_pattern = re.compile("# type:\\ ignore(\\[[a-zA-Z-]+\\])?$")
  162. type_lines = list(filter(lambda line: not type_pattern.search(line[1]),
  163. type_lines))
  164. if len(type_lines) == 0:
  165. # Catch common typo patterns like extra spaces, typo in 'ignore', etc.
  166. wrong_type_pattern = re.compile("#[\t ]*type[\t ]*(?!: ignore(\\[.*\\])?$):")
  167. wrong_type_lines = list(filter(lambda line: wrong_type_pattern.search(line[1]), lines))
  168. if len(wrong_type_lines) > 0:
  169. raise RuntimeError("The annotation prefix in line " + str(wrong_type_lines[0][0])
  170. + " is probably invalid.\nIt must be '# type:'"
  171. + "\nSee PEP 484 (https://www.python.org/dev/peps/pep-0484/#suggested-syntax-for-python-2-7-and-straddling-code)" # noqa: B950
  172. + "\nfor examples")
  173. return None
  174. elif len(type_lines) == 1:
  175. # Only 1 type line, quit now
  176. return type_lines[0][1].strip()
  177. # Parse split up argument types according to PEP 484
  178. # https://www.python.org/dev/peps/pep-0484/#suggested-syntax-for-python-2-7-and-straddling-code
  179. return_line = None
  180. parameter_type_lines = []
  181. for line_num, line in type_lines:
  182. if '# type: (...) -> ' in line:
  183. return_line = (line_num, line)
  184. break
  185. elif type_comment in line:
  186. parameter_type_lines.append(line)
  187. if return_line is None:
  188. raise RuntimeError(
  189. "Return type line '# type: (...) -> ...' not found on multiline "
  190. "type annotation\nfor type lines:\n" +
  191. '\n'.join([line[1] for line in type_lines]) +
  192. "\n(See PEP 484 https://www.python.org/dev/peps/pep-0484/#suggested-syntax-for-python-2-7-and-straddling-code)")
  193. def get_parameter_type(line):
  194. item_type = line[line.find(type_comment) + len(type_comment):]
  195. return item_type.strip()
  196. types = map(get_parameter_type, parameter_type_lines)
  197. parameter_types = ", ".join(types)
  198. return return_line[1].replace("...", parameter_types)
  199. def split_type_line(type_line):
  200. """Splits the comment with the type annotation into parts for argument and return types.
  201. For example, for an input of:
  202. # type: (Tensor, torch.Tensor) -> Tuple[Tensor, Tensor]
  203. This function will return:
  204. ("(Tensor, torch.Tensor)", "Tuple[Tensor, Tensor]")
  205. """
  206. start_offset = len('# type:')
  207. try:
  208. arrow_pos = type_line.index('->')
  209. except ValueError:
  210. raise RuntimeError("Syntax error in type annotation (cound't find `->`)") from None
  211. return type_line[start_offset:arrow_pos].strip(), type_line[arrow_pos + 2:].strip()
  212. def try_real_annotations(fn, loc):
  213. """Tries to use the Py3.5+ annotation syntax to get the type."""
  214. try:
  215. # Note: anything annotated as `Optional[T]` will automatically
  216. # be returned as `Union[T, None]` per
  217. # https://github.com/python/typing/blob/master/src/typing.py#L850
  218. sig = inspect.signature(fn)
  219. except ValueError:
  220. return None
  221. all_annots = [sig.return_annotation] + [p.annotation for p in sig.parameters.values()]
  222. if all(ann is sig.empty for ann in all_annots):
  223. return None
  224. arg_types = [ann_to_type(p.annotation, loc)
  225. for p in sig.parameters.values()]
  226. return_type = ann_to_type(sig.return_annotation, loc)
  227. return arg_types, return_type
  228. # Finds common type for enum values belonging to an Enum class. If not all
  229. # values have the same type, AnyType is returned.
  230. def get_enum_value_type(e: Type[enum.Enum], loc):
  231. enum_values: List[enum.Enum] = list(e)
  232. if not enum_values:
  233. raise ValueError(f"No enum values defined for: '{e.__class__}'")
  234. types = {type(v.value) for v in enum_values}
  235. ir_types = [try_ann_to_type(t, loc) for t in types]
  236. # If Enum values are of different types, an exception will be raised here.
  237. # Even though Python supports this case, we chose to not implement it to
  238. # avoid overcomplicate logic here for a rare use case. Please report a
  239. # feature request if you find it necessary.
  240. res = torch._C.unify_type_list(ir_types)
  241. if not res:
  242. return AnyType.get()
  243. return res
  244. def is_tensor(ann):
  245. if issubclass(ann, torch.Tensor):
  246. return True
  247. if issubclass(ann, (torch.LongTensor, torch.DoubleTensor, torch.FloatTensor,
  248. torch.IntTensor, torch.ShortTensor, torch.HalfTensor,
  249. torch.CharTensor, torch.ByteTensor, torch.BoolTensor)):
  250. warnings.warn("TorchScript will treat type annotations of Tensor "
  251. "dtype-specific subtypes as if they are normal Tensors. "
  252. "dtype constraints are not enforced in compilation either.")
  253. return True
  254. return False
  255. def try_ann_to_type(ann, loc):
  256. if ann is inspect.Signature.empty:
  257. return TensorType.getInferred()
  258. if ann is None:
  259. return NoneType.get()
  260. if inspect.isclass(ann) and is_tensor(ann):
  261. return TensorType.get()
  262. if is_tuple(ann):
  263. # Special case for the empty Tuple type annotation `Tuple[()]`
  264. if len(ann.__args__) == 1 and ann.__args__[0] == ():
  265. return TupleType([])
  266. return TupleType([try_ann_to_type(a, loc) for a in ann.__args__])
  267. if is_list(ann):
  268. elem_type = try_ann_to_type(ann.__args__[0], loc)
  269. if elem_type:
  270. return ListType(elem_type)
  271. if is_dict(ann):
  272. key = try_ann_to_type(ann.__args__[0], loc)
  273. value = try_ann_to_type(ann.__args__[1], loc)
  274. # Raise error if key or value is None
  275. if key is None:
  276. raise ValueError(f"Unknown type annotation: '{ann.__args__[0]}' at {loc.highlight()}")
  277. if value is None:
  278. raise ValueError(f"Unknown type annotation: '{ann.__args__[1]}' at {loc.highlight()}")
  279. return DictType(key, value)
  280. if is_optional(ann):
  281. if issubclass(ann.__args__[1], type(None)):
  282. contained = ann.__args__[0]
  283. else:
  284. contained = ann.__args__[1]
  285. valid_type = try_ann_to_type(contained, loc)
  286. msg = "Unsupported annotation {} could not be resolved because {} could not be resolved."
  287. assert valid_type, msg.format(repr(ann), repr(contained))
  288. return OptionalType(valid_type)
  289. if is_union(ann):
  290. # TODO: this is hack to recognize NumberType
  291. if set(ann.__args__) == {int, float, complex}:
  292. return NumberType.get()
  293. inner: List = []
  294. # We need these extra checks because both `None` and invalid
  295. # values will return `None`
  296. # TODO: Determine if the other cases need to be fixed as well
  297. for a in ann.__args__:
  298. if a is None:
  299. inner.append(NoneType.get())
  300. maybe_type = try_ann_to_type(a, loc)
  301. msg = "Unsupported annotation {} could not be resolved because {} could not be resolved."
  302. assert maybe_type, msg.format(repr(ann), repr(maybe_type))
  303. inner.append(maybe_type)
  304. return UnionType(inner) # type: ignore[arg-type]
  305. if torch.distributed.rpc.is_available() and is_rref(ann):
  306. return RRefType(try_ann_to_type(ann.__args__[0], loc))
  307. if is_future(ann):
  308. return FutureType(try_ann_to_type(ann.__args__[0], loc))
  309. if is_await(ann):
  310. elementType = try_ann_to_type(ann.__args__[0], loc) if hasattr(ann, "__args__") else AnyType.get()
  311. return AwaitType(elementType)
  312. if ann is float:
  313. return FloatType.get()
  314. if ann is complex:
  315. return ComplexType.get()
  316. if ann is int:
  317. return IntType.get()
  318. if ann is str:
  319. return StringType.get()
  320. if ann is bool:
  321. return BoolType.get()
  322. if ann is Any:
  323. return AnyType.get()
  324. if ann is type(None):
  325. return NoneType.get()
  326. if inspect.isclass(ann) and hasattr(ann, "__torch_script_interface__"):
  327. return InterfaceType(ann.__torch_script_interface__)
  328. if ann is torch.device:
  329. return DeviceObjType.get()
  330. if ann is torch.Stream:
  331. return StreamObjType.get()
  332. if ann is torch.dtype:
  333. return IntType.get() # dtype not yet bound in as its own type
  334. if inspect.isclass(ann) and issubclass(ann, enum.Enum):
  335. if _get_script_class(ann) is None:
  336. scripted_class = torch.jit._script._recursive_compile_class(ann, loc)
  337. name = scripted_class.qualified_name()
  338. else:
  339. name = _qualified_name(ann)
  340. return EnumType(name, get_enum_value_type(ann, loc), list(ann))
  341. if inspect.isclass(ann):
  342. maybe_script_class = _get_script_class(ann)
  343. if maybe_script_class is not None:
  344. return maybe_script_class
  345. if torch._jit_internal.can_compile_class(ann):
  346. return torch.jit._script._recursive_compile_class(ann, loc)
  347. # Maybe resolve a NamedTuple to a Tuple Type
  348. def fake_rcb(key):
  349. return None
  350. return torch._C._resolve_type_from_object(ann, loc, fake_rcb)
  351. def ann_to_type(ann, loc):
  352. the_type = try_ann_to_type(ann, loc)
  353. if the_type is not None:
  354. return the_type
  355. raise ValueError(f"Unknown type annotation: '{ann}' at {loc.highlight()}")
  356. __all__ = [
  357. 'Any',
  358. 'List',
  359. 'BroadcastingList1',
  360. 'BroadcastingList2',
  361. 'BroadcastingList3',
  362. 'Tuple',
  363. 'is_tuple',
  364. 'is_list',
  365. 'Dict',
  366. 'is_dict',
  367. 'is_optional',
  368. 'is_union',
  369. 'TensorType',
  370. 'TupleType',
  371. 'FloatType',
  372. 'ComplexType',
  373. 'IntType',
  374. 'ListType',
  375. 'StringType',
  376. 'DictType',
  377. 'AnyType',
  378. 'Module',
  379. # TODO: Consider not exporting these during wildcard import (reserve
  380. # that for the types; for idiomatic typing code.)
  381. 'get_signature',
  382. 'check_fn',
  383. 'get_param_names',
  384. 'parse_type_line',
  385. 'get_type_line',
  386. 'split_type_line',
  387. 'try_real_annotations',
  388. 'try_ann_to_type',
  389. 'ann_to_type',
  390. ]