123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427 |
- import torch
- import inspect
- import numbers
- import types
- import typing
- import enum
- import warnings
- from typing import Any, Callable, Dict, List, Optional, Tuple, NamedTuple, cast, TYPE_CHECKING
- from torch._jit_internal import boolean_dispatched
- from ._compatibility import compatibility
- from torch._ops import OpOverloadPacket, OpOverload
- if TYPE_CHECKING:
- from .node import Argument
- __all__ = ["ArgsKwargsPair", "check_for_mutable_operation", "get_signature_for_torch_op", "create_type_hint",
- "type_matches", "normalize_function", "normalize_module"]
- @compatibility(is_backward_compatible=False)
- class ArgsKwargsPair(NamedTuple):
- """
- Simple named tuple for wrapping args/kwargs pairs.
- """
- args: Tuple[Any, ...]
- kwargs: Dict[str, Any]
- _manual_overrides : Dict[Callable, List[inspect.Signature]] = {}
- def _nonzero_schemas():
- signatures = []
- def nonzero(self):
- pass
- signatures.append(inspect.signature(nonzero))
- def nonzero(self, *, as_tuple : bool): # type: ignore[no-redef]
- pass
- signatures.append(inspect.signature(nonzero))
- return signatures
- _manual_overrides[torch.nonzero] = _nonzero_schemas()
- class _FakeGlobalNamespace:
- def __getattr__(self, name):
- if name == 'torch':
- return torch
- raise RuntimeError('Expected a torch namespace lookup')
- _type_eval_globals = {'Tensor' : torch.Tensor, 'Device' : torch.device, 'Layout' : torch.layout,
- 'number' : numbers.Number, 'Future' : torch.jit.Future,
- 'AnyEnumType' : enum.Enum, 'QScheme' : torch.qscheme,
- '__torch__': _FakeGlobalNamespace(), 'NoneType': type(None),
- 't': typing.TypeVar('t')}
- for k in dir(typing):
- _type_eval_globals[k] = getattr(typing, k)
- def _torchscript_type_to_python_type(ts_type : 'torch._C.JitType') -> Any:
- """
- Convert a TorchScript type to a Python type (including subtypes) via
- eval'ing the annotation_str. _type_eval_globals sets up expressions
- like "List" and "Future" to map to actual types (typing.List and jit.Future)
- """
- return eval(ts_type.annotation_str, _type_eval_globals)
- def _torchscript_schema_to_signature(ts_schema : torch._C.FunctionSchema) -> inspect.Signature:
- from inspect import Parameter
- parameters : List[Parameter] = []
- for arg in ts_schema.arguments:
- arg_type = _torchscript_type_to_python_type(arg.type)
- default = arg.default_value if arg.has_default_value() else Parameter.empty
- # TODO: Figure out if this is safe. It seems like when generating the type signatures for
- # PythonArgParser, we emit signatures with `input` instead of `self` as the first tensor
- # argument name. Downstream, if someone converts that positional argument to a keyword
- # argument, the name mismatch will break things, so here we're going to normalize the
- # name to "input"
- name = arg.name if arg.name != 'self' else 'input'
- kind = Parameter.KEYWORD_ONLY if arg.kwarg_only else Parameter.POSITIONAL_OR_KEYWORD
- # "from" is a keyword therefore it must be a POSITIONAL_ONLY argument
- if name == "from":
- assert kind == Parameter.POSITIONAL_OR_KEYWORD
- # ParameterKind type is internal implementation detail to inspec package
- # which makes it hard to do type annoation
- kind = Parameter.POSITIONAL_ONLY # type: ignore[assignment]
- # This renders all previous arguments to positional only
- for idx, p in enumerate(parameters):
- assert p.kind == Parameter.POSITIONAL_OR_KEYWORD
- parameters[idx] = Parameter(name=p.name, kind=Parameter.POSITIONAL_ONLY, default=p.default, annotation=p.annotation)
- parameters.append(Parameter(name=name, kind=kind, default=default, annotation=arg_type))
- return_types = [_torchscript_type_to_python_type(ret.type) for ret in ts_schema.returns]
- if len(return_types) == 0:
- return_type = None
- elif len(return_types) == 1:
- return_type = return_types[0]
- else:
- return_type = tuple(return_types)
- return inspect.Signature(parameters, return_annotation=return_type)
- @compatibility(is_backward_compatible=False)
- def check_for_mutable_operation(target : Callable, args : Tuple['Argument', ...], kwargs : Dict[str, 'Argument']):
- signatures, schemas = get_signature_for_torch_op(target, return_schemas=True)
- if signatures and schemas:
- matched_schemas = []
- # Iterate through all of the schema until we find one that matches
- # If one matches, populate `new_args_and_kwargs` with the new args/kwargs
- # values. If none matches, `new_args_and_kwargs` will be None
- for candidate_signature, schema in zip(signatures, schemas):
- try:
- candidate_signature.bind(*args, **kwargs)
- matched_schemas.append((candidate_signature, schema))
- except TypeError as e:
- continue
- def throw_if_mutable(schema):
- if schema.is_mutable:
- raise RuntimeError(f'Tried to trace mutable operation {schema}. FX only supports functional '
- f'code, so operations that mutate operands in-place (e.g. via `out` arguments) '
- f'are not supported')
- if len(matched_schemas) == 0:
- # Did not match any schema. Cannot check for mutation
- pass
- elif len(matched_schemas) == 1:
- # Matched exactly one schema, unambiguous
- _, schema_to_check = matched_schemas[0]
- throw_if_mutable(schema_to_check)
- pass
- else:
- # Ambiguous schema match. Since mutability checking is best effort,
- # do nothing.
- pass
- @compatibility(is_backward_compatible=False)
- def get_signature_for_torch_op(op : Callable, return_schemas : bool = False):
- """
- Given an operator on the `torch` namespace, return a list of `inspect.Signature`
- objects corresponding to the overloads of that op.. May return `None` if a signature
- could not be retrieved.
- Args:
- op (Callable): An operator on the `torch` namespace to look up a signature for
- Returns:
- Optional[List[inspect.Signature]]: A list of signatures for the overloads of this
- operator, or None if the operator signatures could not be retrieved. If
- return_schemas=True, returns a tuple containing the optional Python signatures
- and the optional TorchScript Function signature
- """
- if isinstance(op, OpOverload):
- schemas = [op._schema]
- elif isinstance(op, OpOverloadPacket):
- schemas = [getattr(op, overload)._schema for overload in op.overloads()]
- else:
- override = _manual_overrides.get(op)
- if override:
- return (override, None) if return_schemas else None
- aten_fn = torch.jit._builtins._find_builtin(op)
- if aten_fn is None:
- return (None, None) if return_schemas else None
- schemas = torch._C._jit_get_schemas_for_operator(aten_fn)
- signatures = [_torchscript_schema_to_signature(schema) for schema in schemas]
- return (signatures, schemas) if return_schemas else signatures
- @compatibility(is_backward_compatible=False)
- def create_type_hint(x):
- try:
- if isinstance(x, (list, tuple)):
- # todo(chilli): Figure out the right way for mypy to handle this
- if isinstance(x, list):
- def ret_type(x):
- return List[x] # type: ignore[valid-type]
- else:
- def ret_type(x):
- return Tuple[x, ...]
- if len(x) == 0:
- return ret_type(Any)
- base_type = x[0]
- for t in x:
- if issubclass(t, base_type):
- continue
- elif issubclass(base_type, t):
- base_type = t
- else:
- return ret_type(Any)
- return ret_type(base_type)
- except Exception as e:
- # We tried to create a type hint for list but failed.
- warnings.warn(f"We were not able to successfully create type hint from the type {x}")
- pass
- return x
- @compatibility(is_backward_compatible=False)
- def type_matches(signature_type : Any, argument_type : Any):
- sig_origin_type = getattr(signature_type, '__origin__', signature_type)
- if signature_type is argument_type:
- return True
- # Union types in signature. Given type needs to match one of the
- # contained types in the Union
- if sig_origin_type is typing.Union and signature_type != argument_type:
- sig_contained = signature_type.__args__
- return any(type_matches(c, argument_type) for c in sig_contained)
- if signature_type is List[int] and argument_type is int:
- # int can be promoted to List[int]
- return True
- if getattr(signature_type, '__origin__', None) in {list, List}:
- sig_el_type = signature_type.__args__[0]
- if not inspect.isclass(sig_el_type):
- warnings.warn(
- f"Does not support nested parametric types, got {signature_type}. Please file a bug.")
- return False
- if getattr(argument_type, '__origin__', None) in {list, List}:
- return issubclass(argument_type.__args__[0], sig_el_type)
- def is_homogeneous_tuple(t):
- if not getattr(t, '__origin__', None) in {tuple, Tuple}:
- return False
- contained = t.__args__
- if t.__args__ == ((),): # Tuple[()].__args__ == ((),) for some reason
- return True
- return all((c is Ellipsis) or issubclass(c, sig_el_type) for c in contained)
- # Tuple[T] is accepted for List[T] parameters
- return is_homogeneous_tuple(argument_type)
- # Dtype is an int in schemas
- if signature_type is int and argument_type is torch.dtype:
- return True
- if signature_type is numbers.Number and argument_type in {int, float}:
- return True
- if inspect.isclass(argument_type) and inspect.isclass(signature_type):
- return issubclass(argument_type, signature_type)
- return False
- @compatibility(is_backward_compatible=False)
- def normalize_function(
- target: Callable, args: Tuple[Any], kwargs : Optional[Dict[str, Any]] = None, arg_types : Optional[Tuple[Any]] = None,
- kwarg_types : Optional[Dict[str, Any]] = None,
- normalize_to_only_use_kwargs : bool = False) -> Optional[ArgsKwargsPair]:
- """
- Returns normalized arguments to PyTorch functions. This means that
- `args/kwargs` will be matched up to the functional's
- signature and return exclusively kwargs in positional order if
- `normalize_to_only_use_kwargs` is True.
- Also populates default values. Does not support positional-only
- parameters or varargs parameters (*args, **kwargs). Does not support modules.
- May require `arg_types` and `kwarg_types` in order to disambiguate overloads.
- Args:
- target (Callable): Function that we are normalizing
- args (Tuple[Any]): Tuple of args to the function
- kwargs (Optional[Dict[str, Any]]): Dict of kwargs to the function
- arg_types (Optional[Tuple[Any]]): Tuple of arg types for the args
- kwarg_types (Optional[Dict[str, Any]]): Dict of arg types for the kwargs
- normalize_to_only_use_kwargs (bool): Whether to normalize to only use kwargs.
- Returns:
- Returns normalized_args_and_kwargs, or `None` if not successful.
- """
- if kwargs is None:
- kwargs = {}
- new_args_and_kwargs = None
- if not isinstance(target, types.BuiltinFunctionType) and not (
- isinstance(target, (OpOverloadPacket, OpOverload))
- ):
- target_for_analysis = target
- if target in boolean_dispatched:
- # HACK: `boolean_dispatch` as used in `torch.nn.functional` makes it so that we have
- # a 2-way dispatch based on a boolean value. Here we check that the `true` and `false`
- # branches of the dispatch have exactly the same signature. If they do, use the `true`
- # branch signature for analysis. Otherwise, leave this un-normalized
- assert not isinstance(target, str)
- dispatched = boolean_dispatched[target]
- if_true, if_false = dispatched['if_true'], dispatched['if_false']
- if inspect.signature(if_true).parameters != inspect.signature(if_false).parameters:
- return None
- target_for_analysis = if_true
- assert callable(target_for_analysis)
- sig = inspect.signature(inspect.unwrap(target_for_analysis))
- new_args_and_kwargs = _args_kwargs_to_normalized_args_kwargs(sig, args, kwargs, normalize_to_only_use_kwargs)
- else:
- assert callable(target)
- torch_op_schemas = get_signature_for_torch_op(target)
- matched_schemas = []
- if torch_op_schemas:
- # Iterate through all of the schema until we find one that matches
- # If one matches, populate `new_args_and_kwargs` with the new args/kwargs
- # values. If none matches, `new_args_and_kwargs` will be None
- for candidate_signature in torch_op_schemas:
- try:
- candidate_signature.bind(*args, **kwargs)
- matched_schemas.append(candidate_signature)
- except TypeError as e:
- continue
- if len(matched_schemas) == 0:
- # Did not match any schema. Cannot normalize
- pass
- elif len(matched_schemas) == 1:
- # Matched exactly one schema, unambiguous
- new_args_and_kwargs = _args_kwargs_to_normalized_args_kwargs(matched_schemas[0], args, kwargs,
- normalize_to_only_use_kwargs)
- else:
- if arg_types is not None or kwarg_types is not None:
- arg_types = arg_types if arg_types else cast(Tuple[Any], ())
- kwarg_types = kwarg_types if kwarg_types else {}
- for candidate_signature in torch_op_schemas:
- sig_matches = True
- try:
- bound_types = candidate_signature.bind(*arg_types, **kwarg_types)
- for arg_name, arg_type in bound_types.arguments.items():
- param = candidate_signature.parameters[arg_name]
- sig_matches = sig_matches and type_matches(param.annotation, arg_type)
- except TypeError as e:
- sig_matches = False
- if sig_matches:
- new_args_and_kwargs = _args_kwargs_to_normalized_args_kwargs(candidate_signature, args, kwargs,
- normalize_to_only_use_kwargs)
- break
- else:
- # Matched more than one schema. In this situation, the caller must provide the types of
- # the arguments of the overload they expect.
- schema_printouts = '\n'.join(str(schema) for schema in matched_schemas)
- raise RuntimeError(f'Tried to normalize arguments to {torch.typename(target)} but '
- f'the schema match was ambiguous! Please provide argument types to '
- f'the normalize_arguments() call. Available schemas:\n{schema_printouts}')
- return new_args_and_kwargs
- @compatibility(is_backward_compatible=False)
- def normalize_module(
- root: torch.nn.Module, target: str, args: Tuple[Any], kwargs : Optional[Dict[str, Any]] = None,
- normalize_to_only_use_kwargs : bool = False) -> Optional[ArgsKwargsPair]:
- """
- Returns normalized arguments to PyTorch modules. This means that
- `args/kwargs` will be matched up to the functional's
- signature and return exclusively kwargs in positional order if
- `normalize_to_only_use_kwargs` is True.
- Also populates default values. Does not support positional-only
- parameters or varargs parameters (*args, **kwargs).
- Args:
- root (nn.Module): root module upon which we query modules
- target (Callable): Function that we are normalizing
- args (Tuple[Any]): Tuple of args to the function
- kwargs (Optional[Dict[str, Any]]): Dict of kwargs to the function
- normalize_to_only_use_kwargs (bool): Whether to normalize to only use kwargs.
- Returns:
- Returns normalized_args_and_kwargs, or `None` if not successful.
- """
- try:
- submod = root.get_submodule(target)
- except AttributeError as e:
- raise RuntimeError(f"Tried to normalize node with target {target} but root did not "
- f"have that target!") from e
- if hasattr(submod.__class__, '__name__'):
- classname = submod.__class__.__name__
- if getattr(torch.nn, classname, None) == submod.__class__:
- sig = inspect.signature(inspect.unwrap(submod.forward))
- if kwargs is None:
- kwargs = {}
- new_args_and_kwargs = _args_kwargs_to_normalized_args_kwargs(sig, args, kwargs,
- normalize_to_only_use_kwargs)
- return new_args_and_kwargs
- return None
- def _args_kwargs_to_normalized_args_kwargs(sig : inspect.Signature, args : Tuple[Any, ...],
- kwargs : Dict[str, Any],
- normalize_to_only_use_kwargs : bool) -> Optional[ArgsKwargsPair]:
- """
- Given a call target, args, and kwargs, return the arguments normalized into
- an ArgsKwargsPair, or None if the type signature is not supported by
- this normalization.
- Args:
- sig (inspect.Signature): Signature object for the target
- args (Tuple): Arguments that appear at the callsite for `target`
- kwargs (Dict): Keyword arguments that appear at the callsite for `target`
- normalize_to_only_use_kwargs (bool): Whether to normalize to only use kwargs.
- Returns:
- Optional[ArgsKwargsPair]: Normalized args and kwargs for `target`, or `None` if
- this target is not supported.
- """
- # Don't currently support positional-only
- # or varargs (*args, **kwargs) signatures
- supported_parameter_types = {
- inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.KEYWORD_ONLY}
- if any(p.kind not in supported_parameter_types for p in sig.parameters.values()):
- # Add an exception for one signature, which is common for random/uniform, i.e.:
- # Tensor(a!) self, float from=0, float to=1, *, Generator? generator=None
- # `from` is Python keyword and as such functions with that signature should have
- # positional-only args, but at the same time they could be dispatched as kwargs
- if list(sig.parameters.keys()) != ['input', 'from', 'to', 'generator']:
- return None
- bound_args = sig.bind(*args, **kwargs)
- bound_args.apply_defaults()
- new_kwargs : Dict[str, Any] = {}
- new_args : List[Any] = []
- for i, param in enumerate(sig.parameters):
- if not normalize_to_only_use_kwargs and i < len(args):
- new_args.append(bound_args.arguments[param])
- else:
- new_kwargs[param] = bound_args.arguments[param]
- return ArgsKwargsPair(tuple(new_args), new_kwargs)
|