123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459 |
- import ast
- import dis
- import enum
- import inspect
- import re
- import builtins
- import torch
- import warnings
- from .._jit_internal import List, Tuple, is_tuple, is_list, Dict, is_dict, Optional, \
- is_optional, _qualified_name, Any, Future, is_future, _Await, is_await, is_ignored_fn, Union, is_union
- from .._jit_internal import BroadcastingList1, BroadcastingList2, BroadcastingList3 # type: ignore[attr-defined]
- from ._state import _get_script_class
- from torch._C import TensorType, TupleType, FloatType, IntType, ComplexType, \
- ListType, StringType, DictType, BoolType, OptionalType, InterfaceType, AnyType, \
- NoneType, DeviceObjType, StreamObjType, FutureType, AwaitType, EnumType, UnionType, NumberType
- from textwrap import dedent
- from torch._sources import get_source_lines_and_file
- from typing import Type
- if torch.distributed.rpc.is_available():
- from .._jit_internal import RRef, is_rref
- from torch._C import RRefType
- from torch._ops import OpOverloadPacket
- class Module:
- def __init__(self, name, members):
- self.name = name
- self.members = members
- def __getattr__(self, name):
- try:
- return self.members[name]
- except KeyError:
- raise RuntimeError(f"Module {self.name} has no member called {name}") from None
- class EvalEnv:
- env = {
- 'torch': Module('torch', {'Tensor': torch.Tensor}),
- 'Tensor': torch.Tensor,
- 'typing': Module('typing', {'Tuple': Tuple}),
- 'Tuple': Tuple,
- 'List': List,
- 'Dict': Dict,
- 'Optional': Optional,
- 'Union': Union,
- 'Future': Future,
- 'Await': _Await
- }
- def __init__(self, rcb):
- self.rcb = rcb
- if torch.distributed.rpc.is_available():
- self.env['RRef'] = RRef
- def __getitem__(self, name):
- if name in self.env:
- return self.env[name]
- if self.rcb is not None:
- return self.rcb(name)
- return getattr(builtins, name, None)
- def get_signature(fn, rcb, loc, is_method):
- if isinstance(fn, OpOverloadPacket):
- signature = try_real_annotations(fn.op, loc)
- else:
- signature = try_real_annotations(fn, loc)
- if signature is not None and is_method:
- # If this is a method, then the signature will include a type for
- # `self`, but type comments do not contain a `self`. So strip it
- # away here so everything is consistent (`inspect.ismethod` does
- # not work here since `fn` is unbound at this point)
- param_types, return_type = signature
- param_types = param_types[1:]
- signature = (param_types, return_type)
- if signature is None:
- type_line, source = None, None
- try:
- source = dedent(''.join(get_source_lines_and_file(fn)[0]))
- type_line = get_type_line(source)
- except TypeError:
- pass
- # This might happen both because we failed to get the source of fn, or
- # because it didn't have any annotations.
- if type_line is not None:
- signature = parse_type_line(type_line, rcb, loc)
- return signature
- def is_function_or_method(the_callable):
- # A stricter version of `inspect.isroutine` that does not pass for built-in
- # functions
- return inspect.isfunction(the_callable) or inspect.ismethod(the_callable)
- def is_vararg(the_callable):
- if not is_function_or_method(the_callable) and hasattr(the_callable, '__call__'): # noqa: B004
- # If `the_callable` is a class, de-sugar the call so we can still get
- # the signature
- the_callable = the_callable.__call__
- if is_function_or_method(the_callable):
- return inspect.getfullargspec(the_callable).varargs is not None
- else:
- return False
- def get_param_names(fn, n_args):
- if isinstance(fn, OpOverloadPacket):
- fn = fn.op
- if not is_function_or_method(fn) and hasattr(fn, '__call__') and is_function_or_method(fn.__call__): # noqa: B004
- # De-sugar calls to classes
- fn = fn.__call__
- if is_function_or_method(fn):
- if is_ignored_fn(fn):
- fn = inspect.unwrap(fn)
- return inspect.getfullargspec(fn).args
- else:
- # The `fn` was not a method or function (maybe a class with a __call__
- # method, so use a default param name list)
- return [str(i) for i in range(n_args)]
- def check_fn(fn, loc):
- # Make sure the function definition is not a class instantiation
- try:
- source = dedent(''.join(get_source_lines_and_file(fn)[0]))
- except (TypeError, IOError):
- return
- if source is None:
- return
- py_ast = ast.parse(source)
- if len(py_ast.body) == 1 and isinstance(py_ast.body[0], ast.ClassDef):
- raise torch.jit.frontend.FrontendError(
- loc, f"Cannot instantiate class '{py_ast.body[0].name}' in a script function")
- if len(py_ast.body) != 1 or not isinstance(py_ast.body[0], ast.FunctionDef):
- raise torch.jit.frontend.FrontendError(loc, "Expected a single top-level function")
- def _eval_no_call(stmt, glob, loc):
- """Evaluate statement as long as it does not contain any method/function calls"""
- bytecode = compile(stmt, "", mode="eval")
- for insn in dis.get_instructions(bytecode):
- if "CALL" in insn.opname:
- raise RuntimeError(f"Type annotation should not contain calls, but '{stmt}' does")
- return eval(bytecode, glob, loc) # type: ignore[arg-type] # noqa: P204
- def parse_type_line(type_line, rcb, loc):
- """Parses a type annotation specified as a comment.
- Example inputs:
- # type: (Tensor, torch.Tensor) -> Tuple[Tensor]
- # type: (Tensor, Tuple[Tensor, Tensor]) -> Tensor
- """
- arg_ann_str, ret_ann_str = split_type_line(type_line)
- try:
- arg_ann = _eval_no_call(arg_ann_str, {}, EvalEnv(rcb))
- except (NameError, SyntaxError) as e:
- raise RuntimeError("Failed to parse the argument list of a type annotation") from e
- if not isinstance(arg_ann, tuple):
- arg_ann = (arg_ann,)
- try:
- ret_ann = _eval_no_call(ret_ann_str, {}, EvalEnv(rcb))
- except (NameError, SyntaxError) as e:
- raise RuntimeError("Failed to parse the return type of a type annotation") from e
- arg_types = [ann_to_type(ann, loc) for ann in arg_ann]
- return arg_types, ann_to_type(ret_ann, loc)
- def get_type_line(source):
- """Tries to find the line containing a comment with the type annotation."""
- type_comment = '# type:'
- lines = source.split('\n')
- lines = [(line_num, line) for line_num, line in enumerate(lines)]
- type_lines = list(filter(lambda line: type_comment in line[1], lines))
- # `type: ignore` comments may be needed in JIT'ed functions for mypy, due
- # to the hack in torch/_VF.py.
- # An ignore type comment can be of following format:
- # 1) type: ignore
- # 2) type: ignore[rule-code]
- # This ignore statement must be at the end of the line
- # adding an extra backslash before the space, to avoid triggering
- # one of the checks in .github/workflows/lint.yml
- type_pattern = re.compile("# type:\\ ignore(\\[[a-zA-Z-]+\\])?$")
- type_lines = list(filter(lambda line: not type_pattern.search(line[1]),
- type_lines))
- if len(type_lines) == 0:
- # Catch common typo patterns like extra spaces, typo in 'ignore', etc.
- wrong_type_pattern = re.compile("#[\t ]*type[\t ]*(?!: ignore(\\[.*\\])?$):")
- wrong_type_lines = list(filter(lambda line: wrong_type_pattern.search(line[1]), lines))
- if len(wrong_type_lines) > 0:
- raise RuntimeError("The annotation prefix in line " + str(wrong_type_lines[0][0])
- + " is probably invalid.\nIt must be '# type:'"
- + "\nSee PEP 484 (https://www.python.org/dev/peps/pep-0484/#suggested-syntax-for-python-2-7-and-straddling-code)" # noqa: B950
- + "\nfor examples")
- return None
- elif len(type_lines) == 1:
- # Only 1 type line, quit now
- return type_lines[0][1].strip()
- # Parse split up argument types according to PEP 484
- # https://www.python.org/dev/peps/pep-0484/#suggested-syntax-for-python-2-7-and-straddling-code
- return_line = None
- parameter_type_lines = []
- for line_num, line in type_lines:
- if '# type: (...) -> ' in line:
- return_line = (line_num, line)
- break
- elif type_comment in line:
- parameter_type_lines.append(line)
- if return_line is None:
- raise RuntimeError(
- "Return type line '# type: (...) -> ...' not found on multiline "
- "type annotation\nfor type lines:\n" +
- '\n'.join([line[1] for line in type_lines]) +
- "\n(See PEP 484 https://www.python.org/dev/peps/pep-0484/#suggested-syntax-for-python-2-7-and-straddling-code)")
- def get_parameter_type(line):
- item_type = line[line.find(type_comment) + len(type_comment):]
- return item_type.strip()
- types = map(get_parameter_type, parameter_type_lines)
- parameter_types = ", ".join(types)
- return return_line[1].replace("...", parameter_types)
- def split_type_line(type_line):
- """Splits the comment with the type annotation into parts for argument and return types.
- For example, for an input of:
- # type: (Tensor, torch.Tensor) -> Tuple[Tensor, Tensor]
- This function will return:
- ("(Tensor, torch.Tensor)", "Tuple[Tensor, Tensor]")
- """
- start_offset = len('# type:')
- try:
- arrow_pos = type_line.index('->')
- except ValueError:
- raise RuntimeError("Syntax error in type annotation (cound't find `->`)") from None
- return type_line[start_offset:arrow_pos].strip(), type_line[arrow_pos + 2:].strip()
- def try_real_annotations(fn, loc):
- """Tries to use the Py3.5+ annotation syntax to get the type."""
- try:
- # Note: anything annotated as `Optional[T]` will automatically
- # be returned as `Union[T, None]` per
- # https://github.com/python/typing/blob/master/src/typing.py#L850
- sig = inspect.signature(fn)
- except ValueError:
- return None
- all_annots = [sig.return_annotation] + [p.annotation for p in sig.parameters.values()]
- if all(ann is sig.empty for ann in all_annots):
- return None
- arg_types = [ann_to_type(p.annotation, loc)
- for p in sig.parameters.values()]
- return_type = ann_to_type(sig.return_annotation, loc)
- return arg_types, return_type
- # Finds common type for enum values belonging to an Enum class. If not all
- # values have the same type, AnyType is returned.
- def get_enum_value_type(e: Type[enum.Enum], loc):
- enum_values: List[enum.Enum] = list(e)
- if not enum_values:
- raise ValueError(f"No enum values defined for: '{e.__class__}'")
- types = {type(v.value) for v in enum_values}
- ir_types = [try_ann_to_type(t, loc) for t in types]
- # If Enum values are of different types, an exception will be raised here.
- # Even though Python supports this case, we chose to not implement it to
- # avoid overcomplicate logic here for a rare use case. Please report a
- # feature request if you find it necessary.
- res = torch._C.unify_type_list(ir_types)
- if not res:
- return AnyType.get()
- return res
- def is_tensor(ann):
- if issubclass(ann, torch.Tensor):
- return True
- if issubclass(ann, (torch.LongTensor, torch.DoubleTensor, torch.FloatTensor,
- torch.IntTensor, torch.ShortTensor, torch.HalfTensor,
- torch.CharTensor, torch.ByteTensor, torch.BoolTensor)):
- warnings.warn("TorchScript will treat type annotations of Tensor "
- "dtype-specific subtypes as if they are normal Tensors. "
- "dtype constraints are not enforced in compilation either.")
- return True
- return False
- def try_ann_to_type(ann, loc):
- if ann is inspect.Signature.empty:
- return TensorType.getInferred()
- if ann is None:
- return NoneType.get()
- if inspect.isclass(ann) and is_tensor(ann):
- return TensorType.get()
- if is_tuple(ann):
- # Special case for the empty Tuple type annotation `Tuple[()]`
- if len(ann.__args__) == 1 and ann.__args__[0] == ():
- return TupleType([])
- return TupleType([try_ann_to_type(a, loc) for a in ann.__args__])
- if is_list(ann):
- elem_type = try_ann_to_type(ann.__args__[0], loc)
- if elem_type:
- return ListType(elem_type)
- if is_dict(ann):
- key = try_ann_to_type(ann.__args__[0], loc)
- value = try_ann_to_type(ann.__args__[1], loc)
- # Raise error if key or value is None
- if key is None:
- raise ValueError(f"Unknown type annotation: '{ann.__args__[0]}' at {loc.highlight()}")
- if value is None:
- raise ValueError(f"Unknown type annotation: '{ann.__args__[1]}' at {loc.highlight()}")
- return DictType(key, value)
- if is_optional(ann):
- if issubclass(ann.__args__[1], type(None)):
- contained = ann.__args__[0]
- else:
- contained = ann.__args__[1]
- valid_type = try_ann_to_type(contained, loc)
- msg = "Unsupported annotation {} could not be resolved because {} could not be resolved."
- assert valid_type, msg.format(repr(ann), repr(contained))
- return OptionalType(valid_type)
- if is_union(ann):
- # TODO: this is hack to recognize NumberType
- if set(ann.__args__) == {int, float, complex}:
- return NumberType.get()
- inner: List = []
- # We need these extra checks because both `None` and invalid
- # values will return `None`
- # TODO: Determine if the other cases need to be fixed as well
- for a in ann.__args__:
- if a is None:
- inner.append(NoneType.get())
- maybe_type = try_ann_to_type(a, loc)
- msg = "Unsupported annotation {} could not be resolved because {} could not be resolved."
- assert maybe_type, msg.format(repr(ann), repr(maybe_type))
- inner.append(maybe_type)
- return UnionType(inner) # type: ignore[arg-type]
- if torch.distributed.rpc.is_available() and is_rref(ann):
- return RRefType(try_ann_to_type(ann.__args__[0], loc))
- if is_future(ann):
- return FutureType(try_ann_to_type(ann.__args__[0], loc))
- if is_await(ann):
- elementType = try_ann_to_type(ann.__args__[0], loc) if hasattr(ann, "__args__") else AnyType.get()
- return AwaitType(elementType)
- if ann is float:
- return FloatType.get()
- if ann is complex:
- return ComplexType.get()
- if ann is int:
- return IntType.get()
- if ann is str:
- return StringType.get()
- if ann is bool:
- return BoolType.get()
- if ann is Any:
- return AnyType.get()
- if ann is type(None):
- return NoneType.get()
- if inspect.isclass(ann) and hasattr(ann, "__torch_script_interface__"):
- return InterfaceType(ann.__torch_script_interface__)
- if ann is torch.device:
- return DeviceObjType.get()
- if ann is torch.Stream:
- return StreamObjType.get()
- if ann is torch.dtype:
- return IntType.get() # dtype not yet bound in as its own type
- if inspect.isclass(ann) and issubclass(ann, enum.Enum):
- if _get_script_class(ann) is None:
- scripted_class = torch.jit._script._recursive_compile_class(ann, loc)
- name = scripted_class.qualified_name()
- else:
- name = _qualified_name(ann)
- return EnumType(name, get_enum_value_type(ann, loc), list(ann))
- if inspect.isclass(ann):
- maybe_script_class = _get_script_class(ann)
- if maybe_script_class is not None:
- return maybe_script_class
- if torch._jit_internal.can_compile_class(ann):
- return torch.jit._script._recursive_compile_class(ann, loc)
- # Maybe resolve a NamedTuple to a Tuple Type
- def fake_rcb(key):
- return None
- return torch._C._resolve_type_from_object(ann, loc, fake_rcb)
- def ann_to_type(ann, loc):
- the_type = try_ann_to_type(ann, loc)
- if the_type is not None:
- return the_type
- raise ValueError(f"Unknown type annotation: '{ann}' at {loc.highlight()}")
- __all__ = [
- 'Any',
- 'List',
- 'BroadcastingList1',
- 'BroadcastingList2',
- 'BroadcastingList3',
- 'Tuple',
- 'is_tuple',
- 'is_list',
- 'Dict',
- 'is_dict',
- 'is_optional',
- 'is_union',
- 'TensorType',
- 'TupleType',
- 'FloatType',
- 'ComplexType',
- 'IntType',
- 'ListType',
- 'StringType',
- 'DictType',
- 'AnyType',
- 'Module',
- # TODO: Consider not exporting these during wildcard import (reserve
- # that for the types; for idiomatic typing code.)
- 'get_signature',
- 'check_fn',
- 'get_param_names',
- 'parse_type_line',
- 'get_type_line',
- 'split_type_line',
- 'try_real_annotations',
- 'try_ann_to_type',
- 'ann_to_type',
- ]
|