123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162 |
- import operator
- from typing import Any, Callable, Dict, Tuple, Optional
- import torch
- import torch.fx
- import torch.fx as fx
- from torch.fx import Transformer, Proxy
- from torch.fx.node import Argument, Target, Node, map_aggregate
- from torch.fx.operator_schemas import (
- normalize_module,
- normalize_function,
- create_type_hint,
- )
- from .schema_type_annotation import AnnotateTypesWithSchema
- class NormalizeArgs(Transformer):
- """
- Normalize arguments to Python targets. This means that
- `args/kwargs` will be matched up to the module/functional's
- signature and rewritten to exclusively kwargs in positional order
- if `normalize_to_only_use_kwargs` is true. Also populates default
- values. Does not support positional-only parameters or varargs
- parameters (*args, **kwargs).
- If the nodes have 'type' metadata, it will use it to disambiguate
- overloads. Otherwise, it will throw an error.
- Example usage:
- m = torchvision.models.resnet18()
- traced = torch.fx.symbolic_trace(m)
- traced = NormalizeArgs(traced).transform()
- """
- def __init__(
- self, module: torch.fx.GraphModule, normalize_to_only_use_kwargs: bool = True
- ):
- super().__init__(module)
- self.node_map: Dict[Proxy, Node] = {}
- self.normalize_to_only_use_kwargs = normalize_to_only_use_kwargs
- def run_node(self, n: Node) -> Any:
- args, kwargs = self.fetch_args_kwargs_from_env(n)
- def get_type(arg):
- if isinstance(arg, fx.Node):
- return n.meta["type"] if "type" in n.meta else None
- return type(arg)
- arg_types = map_aggregate(n.args, get_type)
- assert isinstance(arg_types, tuple)
- arg_types = tuple([create_type_hint(i) for i in arg_types])
- kwarg_types = {k: get_type(v) for k, v in kwargs.items()}
- if n.op == "call_function":
- out = self.call_function(n.target, args, kwargs, arg_types, kwarg_types)
- else:
- out = super().run_node(n)
- if n.op != "output":
- self.node_map[out] = n
- out.node.meta = n.meta
- out.node.type = n.type
- return out
- def call_function(
- self,
- target: Target,
- args: Tuple[Argument, ...],
- kwargs: Dict[str, Any],
- arg_types: Optional[Tuple[Any, ...]] = None,
- kwarg_types: Optional[Dict[str, Any]] = None,
- ):
- assert callable(target)
- new_args_and_kwargs = normalize_function(
- target,
- args, # type: ignore[arg-type]
- kwargs,
- arg_types, # type: ignore[arg-type]
- kwarg_types,
- self.normalize_to_only_use_kwargs,
- )
- if new_args_and_kwargs:
- new_args, new_kwargs = new_args_and_kwargs
- return self.tracer.create_proxy(
- "call_function", target, new_args, new_kwargs
- )
- else:
- return super().call_function(target, args, kwargs)
- def call_module(
- self, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Any]
- ):
- assert isinstance(target, str)
- new_args_and_kwargs = normalize_module(
- self.module,
- target,
- args, # type: ignore[arg-type]
- kwargs,
- self.normalize_to_only_use_kwargs,
- )
- if new_args_and_kwargs:
- new_args, new_kwargs = new_args_and_kwargs
- return super().call_module(target, new_args, new_kwargs)
- else:
- return super().call_module(target, args, kwargs)
- class NormalizeOperators(AnnotateTypesWithSchema):
- """
- Normalize callsites that are different ways of "spelling" the same
- invocation into a single, canonical call. Currently supports:
- 1. Normalize operators (e.g. operator.add) to the `torch` ops they
- ultimately invoke (e.g. torch.add) when it is possible to statically
- reason that
- Example usage:
- m = torchvision.models.resnet18()
- traced = torch.fx.symbolic_trace(m)
- traced = NormalizeOperators(traced).transform()
- """
- binary_magic_method_remap: Dict[
- Callable[[Any, Any], Any], Callable[[Any, Any], Any]
- ] = {
- torch.add: operator.add,
- torch.mul: operator.mul,
- torch.sub: operator.sub,
- torch.div: operator.truediv,
- torch.floor_divide: operator.floordiv,
- torch.remainder: operator.mod,
- torch.eq: operator.eq,
- torch.ne: operator.ne,
- torch.lt: operator.lt,
- torch.le: operator.le,
- torch.gt: operator.gt,
- torch.ge: operator.ge,
- }
- def call_function(
- self, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Any]
- ):
- # Normalize operators according to the magic methods implemented on tensors here:
- # https://github.com/pytorch/pytorch/blob/28c5d90b679c6b38bf4183ec99f16d933c2f1bcd/tools/autograd/templates/python_variable_methods.cpp#L1137 # noqa: B950
- assert callable(target)
- if target in self.binary_magic_method_remap:
- if len(args) != 2:
- return super().call_function(target, args, kwargs)
- lhs, rhs = args
- return super().call_function(
- target=self.binary_magic_method_remap[target],
- args=(lhs, rhs),
- kwargs={},
- )
- return super().call_function(target, args, kwargs)
|