12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118 |
- import builtins
- import copy
- import functools
- import inspect
- import math
- import os
- import warnings
- import collections
- from itertools import chain
- from types import CodeType, FunctionType, ModuleType
- from typing import (
- Any,
- Callable,
- Dict,
- List,
- NamedTuple,
- Optional,
- Set,
- Tuple,
- Type,
- Union,
- )
- import torch
- import torch.utils._pytree as pytree
- from torch._C import ScriptObject # type: ignore[attr-defined]
- from ._compatibility import compatibility
- from .graph import _PyTreeCodeGen, _PyTreeInfo, Graph
- from .graph_module import GraphModule
- from .node import Argument, base_types, map_aggregate
- from .proxy import ParameterProxy, Proxy, TracerBase, Scope, ScopeContextManager
- HAS_VARSTUFF = inspect.CO_VARARGS | inspect.CO_VARKEYWORDS
- # These need to run in global scope to handle nested calls correctly
- _orig_module_call: Callable = torch.nn.Module.__call__
- _orig_module_getattr: Callable = torch.nn.Module.__getattr__
- _proxyable_classes: Dict[Type, None] = {}
- _is_fx_tracing_flag = False
- def is_fx_tracing():
- return _is_fx_tracing_flag
- @compatibility(is_backward_compatible=True)
- class ProxyableClassMeta(type):
- """
- ProxyableClassMeta allows you to make construction of a given Python class
- symbolically traceable. For example::
- import torch
- import torch.fx
- class TensorPair(metaclass=torch.fx.ProxyableClassMeta):
- def __init__(self, left, right):
- self.left, self.right = left, right
- def add(self, other):
- l = self.left + other.left
- r = self.right + other.right
- return TensorPair(l, r)
- def mul(self, other):
- l = self.left * other.left
- r = self.right * other.right
- return TensorPair(l, r)
- def use_tensor_pair_ctor(x : TensorPair, y : torch.Tensor):
- s = x.add(TensorPair(y, y))
- return s.mul(x)
- x = TensorPair(torch.randn(5, 3), torch.randn(5, 3))
- y = torch.randn(5, 3)
- ref_out = use_tensor_pair_ctor(x, y)
- traced = torch.fx.symbolic_trace(use_tensor_pair_ctor)
- print(traced.code)
- '''
- def forward(self, x : __main___TensorPair, y : torch.Tensor):
- tensor_pair = __main___TensorPair(y, y); y = None
- add = x.add(tensor_pair); tensor_pair = None
- mul = add.mul(x); add = x = None
- return mul
- '''
- From this example, we can see that construction of a class (``TensorPair``)
- defined with ``ProxyableClassMeta`` as metaclass can be recorded in symbolic
- tracing.
- """
- def __init__(cls, name, bases, attrs):
- _proxyable_classes.setdefault(cls)
- super().__init__(name, bases, attrs)
- def __call__(cls, *args, **kwargs):
- instance = cls.__new__(cls) # type: ignore[call-overload]
- found_proxies = []
- def check_proxy(a):
- if isinstance(a, Proxy):
- found_proxies.append(a)
- map_aggregate(args, check_proxy)
- map_aggregate(kwargs, check_proxy)
- if len(found_proxies) != 0:
- tracer = found_proxies[0].tracer
- return tracer.create_proxy("call_function", cls, args, kwargs)
- else:
- cls.__init__(instance, *args, **kwargs) # type: ignore[misc]
- return instance
- def _patch_function(fn: FunctionType, nargs: int) -> FunctionType:
- co = fn.__code__
- co_flags = co.co_flags & ~HAS_VARSTUFF
- co_args: tuple
- if hasattr(co, "co_qualname"):
- # Python-3.11+ code signature
- co_args = (
- nargs,
- 0,
- 0,
- co.co_nlocals,
- co.co_stacksize,
- co_flags,
- co.co_code,
- co.co_consts,
- co.co_names,
- co.co_varnames,
- co.co_filename,
- co.co_name,
- co.co_qualname, # type: ignore[attr-defined]
- co.co_firstlineno,
- co.co_lnotab,
- co.co_exceptiontable, # type: ignore[attr-defined]
- co.co_freevars,
- co.co_cellvars,
- )
- elif hasattr(co, "co_posonlyargcount"):
- co_args = (
- nargs,
- 0,
- 0,
- co.co_nlocals,
- co.co_stacksize,
- co_flags,
- co.co_code,
- co.co_consts,
- co.co_names,
- co.co_varnames,
- co.co_filename,
- co.co_name,
- co.co_firstlineno,
- co.co_lnotab,
- co.co_freevars,
- co.co_cellvars,
- )
- else:
- co_args = (
- nargs,
- 0,
- co.co_nlocals,
- co.co_stacksize,
- co_flags,
- co.co_code,
- co.co_consts,
- co.co_names,
- co.co_varnames,
- co.co_filename,
- co.co_name,
- co.co_firstlineno,
- co.co_lnotab,
- co.co_freevars,
- co.co_cellvars,
- )
- new_code = CodeType(*co_args) # type: ignore[arg-type]
- return FunctionType(
- new_code, fn.__globals__, fn.__name__, fn.__defaults__, fn.__closure__
- )
- # we need to insert placeholder nodes for *args and **kwargs
- # we can't call this function normally, otherwise it would try to unpack them
- # instead, let's make python think that args and kwargs are normal variables
- @compatibility(is_backward_compatible=False)
- class PHBase:
- """
- Object representing an input placeholder to `concrete_args`
- """
- def __repr__(self):
- return "PH"
- PH = PHBase()
- @compatibility(is_backward_compatible=True)
- class Tracer(TracerBase):
- # Reference: https://github.com/pytorch/pytorch/issues/54354
- # The first line of this docstring overrides the one Sphinx generates for the
- # documentation. We need it so that Sphinx doesn't leak `math`s path from the
- # build environment (e.g. `<module 'math' from '/leaked/path').
- """Tracer(autowrap_modules=(math,), autowrap_functions=())
- ``Tracer`` is the class that implements the symbolic tracing functionality
- of ``torch.fx.symbolic_trace``. A call to ``symbolic_trace(m)`` is equivalent
- to ``Tracer().trace(m)``.
- Tracer can be subclassed to override various behaviors of the tracing
- process. The different behaviors that can be overridden are described
- in the docstrings of the methods on this class.
- """
- # Not checking BC on this API because the default value for `autowrap_modules`
- # includes the local filepath to the `math` module, which would jitter
- # across machines.
- @compatibility(is_backward_compatible=True)
- def __init__(
- self,
- autowrap_modules: Tuple[ModuleType] = (math,),
- autowrap_functions: Tuple[Callable, ...] = (),
- param_shapes_constant: bool = False,
- ) -> None:
- # This method's signature is overridden by the first line of this class'
- # docstring. If this method's signature is modified, the signature that
- # overrides it also should be modified accordingly.
- """
- Construct a Tracer object.
- Args:
- autowrap_modules (Tuple[ModuleType]): defaults to `(math, )`,
- Python modules whose functions should be wrapped automatically
- without needing to use fx.wrap(). Backward-compatibility for
- this parameter is guaranteed.
- autowrap_functions (Tuple[Callable, ...]): defaults to `()`,
- Python functions that should be wrapped automatically without
- needing to use fx.wrap(). Backward compatibility for this
- parameter is guaranteed.
- param_shapes_constant (bool): When this flag is set, calls to shape,
- size and a few other shape like attributes of a module's parameter
- will be evaluated directly, rather than returning a new Proxy value
- for an attribute access. Backward compatibility for this parameter
- is guaranteed.
- """
- super().__init__()
- # Functions we will eagerly wrap when we see them while tracing
- # this captures both `math.sqrt()` and `from math import sqrt` automatically
- self._autowrap_function_ids: Set[int] = {
- id(value)
- for name, value in chain(*[m.__dict__.items() for m in autowrap_modules])
- if not name.startswith("_") and callable(value)
- }
- self._autowrap_function_ids.update({id(f) for f in autowrap_functions})
- # Python modules to apply autowrap to at the start, in addition to
- # modules we see while tracing
- self._autowrap_search: List[ModuleType] = list(autowrap_modules)
- self.param_shapes_constant = param_shapes_constant
- self.submodule_paths: Optional[Dict[torch.nn.Module, str]] = None
- self.root_module_name: str = ""
- # Maps the containing module's name to the operator name
- self.scope = Scope("", None)
- # Records the module call stack
- self.module_stack = collections.OrderedDict()
- # Mapping of node name to module scope
- self.node_name_to_scope: Dict[str, Tuple[str, type]] = {}
- @compatibility(is_backward_compatible=True)
- def create_arg(self, a: Any) -> "Argument":
- """
- A method to specify the behavior of tracing when preparing values to
- be used as arguments to nodes in the ``Graph``.
- By default, the behavior includes:
- #. Iterate through collection types (e.g. tuple, list, dict) and recursively
- call ``create_args`` on the elements.
- #. Given a Proxy object, return a reference to the underlying IR ``Node``
- #. Given a non-Proxy Tensor object, emit IR for various cases:
- * For a Parameter, emit a ``get_attr`` node referring to that Parameter
- * For a non-Parameter Tensor, store the Tensor away in a special
- attribute referring to that attribute.
- This method can be overridden to support more types.
- Args:
- a (Any): The value to be emitted as an ``Argument`` in the ``Graph``.
- Returns:
- The value ``a`` converted into the appropriate ``Argument``
- """
- # The base tracer is used to construct Graphs when there is no associated
- # module hierarchy, so it can never create parameter references.
- # The default tracer adds the ability to refer to parameters when
- # tracing modules.
- if isinstance(a, torch.nn.Parameter):
- for n, p in self.root.named_parameters():
- if a is p:
- return self.create_node("get_attr", n, (), {})
- raise NameError("parameter is not a member of this module")
- elif isinstance(a, torch.Tensor):
- for n_, p_ in self.root.named_buffers():
- if a is p_:
- return self.create_node("get_attr", n_, (), {})
- elif isinstance(a, torch.nn.Module):
- for n_, p_ in self.root.named_modules():
- if a is p_:
- return self.create_node("get_attr", n_, (), {})
- # For NamedTuple instances that appear literally as args, we emit
- # a node to construct the NamedTuple and use that Node as the argument.
- if isinstance(a, tuple) and hasattr(a, "_fields"):
- args = tuple(self.create_arg(elem) for elem in a)
- return self.create_node("call_function", a.__class__, args, {})
- # Tensors do not have a reliable string repr() from which they can be
- # constructed (and we probably don't want to rely on that, either), so
- # for any constant Tensor values we encounter, first search for if they
- # are an attribute of some module in the module hierarchy. If so, emit
- # a get_attr to retrieve that tensor. Otherwise, we'll store away the
- # tensor value into a special attribute on the Module s.t. we can
- # retrieve it with a get_attr.
- if isinstance(a, (torch.Tensor, ScriptObject)):
- qualname: Optional[str] = self.tensor_attrs.get(a)
- # Tensor was not found in the Module hierarchy, stow it away in a
- # special attribute and set the qualname to refer to that
- if not qualname:
- i = 0
- while True:
- qualname = f"_tensor_constant{i}"
- if not hasattr(self.root, qualname):
- break
- i += 1
- self.tensor_attrs[a] = qualname
- setattr(self.root, qualname, a)
- return self.create_node("get_attr", qualname, (), {})
- if type(a) in _proxyable_classes:
- # This is an instance of a proxyable class for which we did not
- # witness its construction. Intern this as a constant attribute
- # TODO: binary search
- i = 0
- while True:
- qualname = f"_{a.__class__.__name__}_constant_{i}"
- if not hasattr(self.root, qualname):
- break
- i += 1
- setattr(self.root, qualname, a)
- return self.create_node("get_attr", qualname, (), {})
- return super().create_arg(a)
- @compatibility(is_backward_compatible=True)
- def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool:
- """
- A method to specify whether a given ``nn.Module`` is a "leaf" module.
- Leaf modules are the atomic units that appear in
- the IR, referenced by ``call_module`` calls. By default,
- Modules in the PyTorch standard library namespace (torch.nn)
- are leaf modules. All other modules are traced through and
- their constituent ops are recorded, unless specified otherwise
- via this parameter.
- Args:
- m (Module): The module being queried about
- module_qualified_name (str): The path to root of this module. For example,
- if you have a module hierarchy where submodule ``foo`` contains
- submodule ``bar``, which contains submodule ``baz``, that module will
- appear with the qualified name ``foo.bar.baz`` here.
- """
- return (
- (m.__module__.startswith("torch.nn") or m.__module__.startswith("torch.ao.nn"))
- and not isinstance(m, torch.nn.Sequential)
- )
- @compatibility(is_backward_compatible=True)
- def path_of_module(self, mod: torch.nn.Module) -> str:
- """
- Helper method to find the qualified name of ``mod`` in the Module hierarchy
- of ``root``. For example, if ``root`` has a submodule named ``foo``, which has
- a submodule named ``bar``, passing ``bar`` into this function will return
- the string "foo.bar".
- Args:
- mod (str): The ``Module`` to retrieve the qualified name for.
- """
- # Prefer the O(1) algorithm
- if self.submodule_paths:
- path = self.submodule_paths.get(mod)
- if path is None:
- raise NameError("module is not installed as a submodule")
- assert isinstance(path, str)
- return path
- # O(N^2) fallback in the case that we didn't store the submodule
- # paths.
- else:
- for n, p in self.root.named_modules():
- if mod is p:
- return n
- raise NameError("module is not installed as a submodule")
- @compatibility(is_backward_compatible=True)
- def call_module(
- self,
- m: torch.nn.Module,
- forward: Callable[..., Any],
- args: Tuple[Any, ...],
- kwargs: Dict[str, Any],
- ) -> Any:
- """
- Method that specifies the behavior of this ``Tracer`` when it encounters
- a call to an ``nn.Module`` instance.
- By default, the behavior is to check if the called module is a leaf module
- via ``is_leaf_module``. If it is, emit a ``call_module`` node referring to
- ``m`` in the ``Graph``. Otherwise, call the ``Module`` normally, tracing through
- the operations in its ``forward`` function.
- This method can be overridden to--for example--create nested traced
- GraphModules, or any other behavior you would want while tracing across
- ``Module`` boundaries.
- Args:
- m (Module): The module for which a call is being emitted
- forward (Callable): The forward() method of the ``Module`` to be invoked
- args (Tuple): args of the module callsite
- kwargs (Dict): kwargs of the module callsite
- Return:
- The return value from the Module call. In the case that a ``call_module``
- node was emitted, this is a ``Proxy`` value. Otherwise, it is whatever
- value was returned from the ``Module`` invocation.
- """
- module_qualified_name = self.path_of_module(m)
- with ScopeContextManager(self.scope, Scope(module_qualified_name, type(m))) as _scope:
- # module_stack is an ordered dict so writing then deleting the
- # entry is equivalent to push/pop on a list
- self.module_stack[_scope.module_path] = _scope.module_type
- if not self.is_leaf_module(m, module_qualified_name):
- ret_val = forward(*args, **kwargs)
- else:
- ret_val = self.create_proxy("call_module", module_qualified_name, args, kwargs)
- key, _ = self.module_stack.popitem(last=True)
- assert key == _scope.module_path, f" Unexpected key {key}"
- return ret_val
- @compatibility(is_backward_compatible=False)
- def getattr(self, attr: str, attr_val: Any, parameter_proxy_cache: Dict[str, Any]):
- """
- Method that specifies the behavior of this ``Tracer`` when we call getattr
- on a call to an ``nn.Module`` instance.
- By default, the behavior is to return a proxy value for the attribute. It
- also stores the proxy value in the ``parameter_proxy_cache``, so that future
- calls will reuse the proxy rather than creating a new one.
- This method can be overridden to --for example-- not return proxies when
- querying parameters.
- Args:
- attr (str): The name of the attribute being queried
- attr_val (Any): The value of the attribute
- parameter_proxy_cache (Dict[str, Any]): A cache of attr names to proxies
- Return:
- The return value from the getattr call.
- """
- def maybe_get_proxy_for_attr(
- attr_val, collection_to_search, parameter_proxy_cache
- ):
- for n, p in collection_to_search:
- if attr_val is p:
- if n not in parameter_proxy_cache:
- kwargs = {}
- if (
- "proxy_factory_fn"
- in inspect.signature(self.create_proxy).parameters
- ):
- kwargs["proxy_factory_fn"] = (
- None
- if not self.param_shapes_constant
- else lambda node: ParameterProxy(
- self, node, n, attr_val
- )
- )
- val_proxy = self.create_proxy("get_attr", n, (), {}, **kwargs) # type: ignore[arg-type]
- parameter_proxy_cache[n] = val_proxy
- return parameter_proxy_cache[n]
- return None
- if isinstance(attr_val, torch.nn.Parameter):
- maybe_parameter_proxy = maybe_get_proxy_for_attr(
- attr_val, self.root.named_parameters(), parameter_proxy_cache
- )
- if maybe_parameter_proxy is not None:
- return maybe_parameter_proxy
- if self.proxy_buffer_attributes and isinstance(attr_val, torch.Tensor):
- maybe_buffer_proxy = maybe_get_proxy_for_attr(
- attr_val, self.root.named_buffers(), parameter_proxy_cache
- )
- if maybe_buffer_proxy is not None:
- return maybe_buffer_proxy
- return attr_val
- # This method will be refactored
- @compatibility(is_backward_compatible=False)
- def create_args_for_root(self, root_fn, is_module, concrete_args=None):
- """
- Create ``placeholder`` nodes corresponding to the signature of the ``root``
- Module. This method introspects root's signature and emits those
- nodes accordingly, also supporting ``*args`` and ``**kwargs``.
- """
- # In some cases, a function or method has been decorated with a wrapper
- # defined via ``functools.wraps``. In this case, the outer code object
- # will likely not contain the actual parameters we care about, so unwrap
- # the function to get to the innermost callable.
- fn_for_analysis = inspect.unwrap(root_fn)
- co = fn_for_analysis.__code__
- total_args = co.co_argcount + co.co_kwonlyargcount
- orig_args = list(co.co_varnames)
- names_iter = iter(co.co_varnames)
- args: List[Any] = []
- skip_arg_idx = 0
- if is_module:
- if total_args == 0:
- raise RuntimeError(
- "``self`` argument cannot be part of *args expansion!"
- )
- skip_arg_idx = 1
- next(names_iter) # skip self
- args.append(self.root)
- sig = inspect.signature(fn_for_analysis)
- def proxy_placeholder(name: str):
- if concrete_args is not None and name in concrete_args:
- cnt = 0
- def replace_ph(x):
- nonlocal cnt
- cnt += 1
- param = sig.parameters[name]
- default = (
- ()
- if param.default is inspect.Parameter.empty
- else (param.default,)
- )
- out = self.create_proxy(
- "placeholder", f"{name}_{str(cnt)}", default, {}
- )
- if x == PH:
- return out
- # Union[int, bool] == bool in Python <= 3.6
- if (
- type(x) == bool
- or type(x) in base_types
- and type(x) != torch.Tensor
- ):
- torch._assert(
- out == x,
- f"{name} has been specialized to have value {x} but got another value",
- )
- elif type(x) == type(None):
- args = (
- out,
- f"{name} has been specialized to have value None but got another value",
- )
- self.create_proxy("call_function", _assert_is_none, args, {})
- else:
- warnings.warn(
- f"Was not able to add assertion to guarantee correct input {name} to "
- f"specialized function. It is up to the user to make sure that your inputs match the "
- f"inputs you specialized the function with."
- )
- return x
- return pytree.tree_map(replace_ph, concrete_args[name])
- if name[0] == "*":
- default = ()
- else:
- param = sig.parameters[name]
- default = () if param.default is inspect.Parameter.empty else (param.default,) # type: ignore[assignment]
- return self.create_proxy(
- "placeholder",
- name,
- default,
- {},
- type_expr=fn_for_analysis.__annotations__.get(name, None)
- )
- arg_names = [next(names_iter) for idx in range(skip_arg_idx, total_args)]
- if isinstance(concrete_args, tuple):
- if len(arg_names) != len(concrete_args):
- raise RuntimeError(
- f"Tracing expected {len(arg_names)} arguments but got {len(concrete_args)} concrete arguments"
- )
- concrete_args = {name: val for name, val in zip(arg_names, concrete_args)}
- args.extend(proxy_placeholder(names) for names in arg_names)
- if co.co_kwonlyargcount > 0 or co.co_flags & HAS_VARSTUFF:
- # TODO: type annotations for *args and **kwargs
- if co.co_flags & inspect.CO_VARARGS:
- args.append(proxy_placeholder("*" + next(names_iter)))
- if co.co_flags & inspect.CO_VARKEYWORDS:
- args.append(proxy_placeholder("**" + next(names_iter)))
- root_fn = _patch_function(root_fn, len(args))
- flat_args, in_spec = pytree.tree_flatten(tuple(args))
- if any(not isinstance(i, pytree.LeafSpec) for i in in_spec.children_specs):
- # In the case that we have pytree-flattened inputs in
- # `concrete_args`, generate a flattening wrapper around the
- # original root function and return that.
- self.graph._codegen = _PyTreeCodeGen(
- _PyTreeInfo(orig_args[:total_args], in_spec, None)
- )
- def flatten_fn(*args):
- tree_args = pytree.tree_unflatten(list(args), in_spec)
- tree_out = root_fn(*tree_args)
- out_args, out_spec = pytree.tree_flatten(tree_out)
- assert isinstance(self.graph._codegen, _PyTreeCodeGen)
- self.graph._codegen.pytree_info = (
- self.graph._codegen.pytree_info._replace(out_spec=out_spec)
- )
- return out_args
- return flatten_fn, flat_args
- return root_fn, args
- @compatibility(is_backward_compatible=True)
- def trace(
- self,
- root: Union[torch.nn.Module, Callable[..., Any]],
- concrete_args: Optional[Dict[str, Any]] = None,
- ) -> Graph:
- """
- Trace ``root`` and return the corresponding FX ``Graph`` representation. ``root``
- can either be an ``nn.Module`` instance or a Python callable.
- Note that after this call, ``self.root`` may be different from the ``root`` passed
- in here. For example, when a free function is passed to ``trace()``, we will
- create an ``nn.Module`` instance to use as the root and add embedded constants
- to.
- Args:
- root (Union[Module, Callable]): Either a ``Module`` or a function to be
- traced through. Backwards-compatibility for this parameter is
- guaranteed.
- concrete_args (Optional[Dict[str, any]]): Concrete arguments that should
- not be treated as Proxies. This parameter is experimental and
- its backwards-compatibility is *NOT* guaranteed.
- Returns:
- A ``Graph`` representing the semantics of the passed-in ``root``.
- """
- global _is_fx_tracing_flag
- old_is_fx_tracing_flag = _is_fx_tracing_flag
- _is_fx_tracing_flag = True
- try:
- if isinstance(root, torch.nn.Module):
- self.root = root
- assert hasattr(
- type(root), self.traced_func_name
- ), f"traced_func_name={self.traced_func_name} doesn't exist in {type(root).__name__}"
- fn = getattr(type(root), self.traced_func_name)
- self.root_module_name = root._get_name()
- self.submodule_paths = {mod: name for name, mod in root.named_modules()}
- else:
- self.root = torch.nn.Module()
- fn = root
- tracer_cls: Optional[Type["Tracer"]] = getattr(self, "__class__", None)
- self.graph = Graph(tracer_cls=tracer_cls)
- # When we encounter a Tensor value that's not a parameter, we look if it
- # is some other attribute on the model. Construct a dict mapping Tensor
- # values to the qualified name here for efficiency. This is used downstream
- # in create_arg
- self.tensor_attrs: Dict[Union[torch.Tensor, ScriptObject], str] = {}
- def collect_tensor_attrs(m: torch.nn.Module, prefix_atoms: List[str]):
- for k, v in m.__dict__.items():
- if isinstance(v, (torch.Tensor, ScriptObject)):
- self.tensor_attrs[v] = ".".join(prefix_atoms + [k])
- for k, v in m.named_children():
- collect_tensor_attrs(v, prefix_atoms + [k])
- collect_tensor_attrs(self.root, [])
- assert isinstance(fn, FunctionType)
- fn_globals = fn.__globals__ # run before it gets patched
- fn, args = self.create_args_for_root(
- fn, isinstance(root, torch.nn.Module), concrete_args
- )
- parameter_proxy_cache: Dict[
- str, Proxy
- ] = {} # Reduce number of get_attr calls
- # Method dispatch on parameters is not recorded unless it's directly used.
- # Thus, we need to insert a proxy when __getattr__ requests a parameter.
- @functools.wraps(_orig_module_getattr)
- def module_getattr_wrapper(mod, attr):
- attr_val = _orig_module_getattr(mod, attr)
- return self.getattr(attr, attr_val, parameter_proxy_cache)
- @functools.wraps(_orig_module_call)
- def module_call_wrapper(mod, *args, **kwargs):
- def forward(*args, **kwargs):
- return _orig_module_call(mod, *args, **kwargs)
- _autowrap_check(
- patcher,
- getattr(getattr(mod, "forward", mod), "__globals__", {}),
- self._autowrap_function_ids,
- )
- return self.call_module(mod, forward, args, kwargs)
- with _Patcher() as patcher:
- # allow duplicate patches to support the case of nested calls
- patcher.patch_method(
- torch.nn.Module,
- "__getattr__",
- module_getattr_wrapper,
- deduplicate=False,
- )
- patcher.patch_method(
- torch.nn.Module, "__call__", module_call_wrapper, deduplicate=False
- )
- _patch_wrapped_functions(patcher)
- _autowrap_check(patcher, fn_globals, self._autowrap_function_ids)
- for module in self._autowrap_search:
- _autowrap_check(
- patcher, module.__dict__, self._autowrap_function_ids
- )
- self.create_node(
- "output",
- "output",
- (self.create_arg(fn(*args)),),
- {},
- type_expr=fn.__annotations__.get("return", None),
- )
- self.submodule_paths = None
- finally:
- _is_fx_tracing_flag = old_is_fx_tracing_flag
- return self.graph
- def __deepcopy__(self, memo):
- # _autowrap_search contains modules, which cannot be deepcopied.
- new_tracer = Tracer.__new__(Tracer)
- for k, v in self.__dict__.items():
- if k in {'_autowrap_search'}:
- new_obj = copy.copy(v)
- else:
- new_obj = copy.deepcopy(v, memo)
- new_tracer.__dict__[k] = new_obj
- return new_tracer
- # List of pairs of (global dict, function name) functions
- # to patch for the purposes of the wrap() API.
- _wrapped_fns_to_patch: List[Tuple[dict, str]] = []
- # List of methods on classes to wrap (class type, function name)
- # this currently only works for Tensor.* methods that aren't traced properly
- _wrapped_methods_to_patch: List[Tuple[type, str]] = []
- if os.environ.get("FX_PATCH_GETITEM") == "1":
- # This change is needed to trace models like PositionalEmbedding from BERT:
- # https://github.com/pytorch/benchmark/blob/master/torchbenchmark/models/BERT_pytorch/bert_pytorch/model/embedding/position.py
- # but causes issues in quantization documented here:
- # https://github.com/pytorch/pytorch/issues/50710
- # once that is fixed we can make this the default behavior.
- _wrapped_methods_to_patch.append((torch.Tensor, "__getitem__"))
- def _find_proxy(*objects_to_search):
- """
- Recursively search a data structure for a Proxy() and return it,
- return None if not found.
- """
- proxy = None
- def find_proxy(x):
- nonlocal proxy
- if isinstance(x, Proxy):
- proxy = x
- map_aggregate(objects_to_search, find_proxy)
- return proxy
- def _create_wrapped_func(orig_fn):
- @functools.wraps(orig_fn)
- def wrapped(*args, **kwargs):
- """
- Given an closed-over ``orig_function`` to invoke, search the args and kwargs for
- a Proxy object. If there is one, emit a ``call_function`` node to preserve the
- call to this leaf function directly. Otherwise, just return the results of
- this function call, as this function is not being traced.
- """
- proxy = _find_proxy(args, kwargs)
- if proxy is not None:
- return_proxy = proxy.tracer.create_proxy(
- "call_function", orig_fn, args, kwargs
- )
- return_proxy.node.meta["is_wrapped"] = True
- return return_proxy
- return orig_fn(*args, **kwargs)
- return wrapped
- def _create_wrapped_method(cls, name):
- orig_fn = getattr(cls, name)
- @functools.wraps(orig_fn)
- def wrapped(*args, **kwargs):
- """
- Search the args and kwargs for a Proxy object. If there is one,
- emit a ``call_method`` node to preserve the call to this method
- directly. Otherwise, just return the results of this function
- call, as this function is not being traced.
- """
- proxy = _find_proxy(args, kwargs)
- if proxy is not None:
- return proxy.tracer.create_proxy("call_method", name, args, kwargs)
- return orig_fn(*args, **kwargs)
- return wrapped
- class _PatchedFn(NamedTuple):
- frame_dict: Any
- fn_name: str
- orig_fn: Any
- def revert(self):
- raise NotImplementedError()
- class _PatchedFnSetItem(_PatchedFn):
- def revert(self):
- self.frame_dict[self.fn_name] = self.orig_fn
- class _PatchedFnDel(_PatchedFn):
- def revert(self):
- del self.frame_dict[self.fn_name]
- class _PatchedFnSetAttr(_PatchedFn):
- def revert(self):
- setattr(self.frame_dict, self.fn_name, self.orig_fn)
- class _Patcher:
- def __init__(self):
- super().__init__()
- self.patches_made: List[_PatchedFn] = []
- self.visited: Set[int] = set()
- def patch(
- self,
- frame_dict: Dict[str, Any],
- name: str,
- new_fn: Callable,
- deduplicate: bool = True,
- ):
- """
- Replace frame_dict[name] with new_fn until we exit the context manager.
- """
- new_fn.__fx_already_patched = deduplicate # type: ignore[attr-defined]
- if name not in frame_dict and hasattr(builtins, name):
- self.patches_made.append(_PatchedFnDel(frame_dict, name, None))
- elif getattr(frame_dict[name], "__fx_already_patched", False):
- return # already patched, no need to do it again
- else:
- self.patches_made.append(
- _PatchedFnSetItem(frame_dict, name, frame_dict[name])
- )
- frame_dict[name] = new_fn
- def patch_method(
- self, cls: type, name: str, new_fn: Callable, deduplicate: bool = True
- ):
- """
- Replace object_or_dict.name with new_fn until we exit the context manager.
- """
- new_fn.__fx_already_patched = deduplicate # type: ignore[attr-defined]
- orig_fn = getattr(cls, name)
- if getattr(orig_fn, "__fx_already_patched", False):
- return # already patched, no need to do it again
- self.patches_made.append(_PatchedFnSetAttr(cls, name, orig_fn))
- setattr(cls, name, new_fn)
- def visit_once(self, thing: Any):
- """Return True on the first call to with thing, otherwise false"""
- idx = id(thing)
- if idx in self.visited:
- return False
- self.visited.add(idx)
- return True
- def __enter__(self):
- return self
- def __exit__(self, exc_type, exc_val, exc_tb):
- """
- Undo all the changes made via self.patch() and self.patch_method()
- """
- while self.patches_made:
- # unpatch in reverse order to handle duplicates correctly
- self.patches_made.pop().revert()
- self.visited.clear()
- def _patch_wrapped_functions(patcher: _Patcher):
- """
- Go through ``_wrapped_fn_patch_table`` and, for each frame object, wrap
- the listed global functions in the `_create_wrapped_func` wrapper.
- """
- for frame_dict, name in _wrapped_fns_to_patch:
- if name not in frame_dict and hasattr(builtins, name):
- orig_fn = getattr(builtins, name)
- else:
- orig_fn = frame_dict[name]
- patcher.patch(frame_dict, name, _create_wrapped_func(orig_fn))
- for cls, name in _wrapped_methods_to_patch:
- patcher.patch_method(cls, name, _create_wrapped_method(cls, name))
- def _autowrap_check(
- patcher: _Patcher, frame_dict: Dict[str, Any], function_ids: Set[int]
- ):
- """
- Some methods, like `math.sqrt` are common enough we want to automatically wrap them as we see them.
- This method searches a scope for them and patches them if found.
- """
- if patcher.visit_once(frame_dict):
- for name, value in frame_dict.items():
- if (
- not name.startswith("_")
- and callable(value)
- and id(value) in function_ids
- ):
- patcher.patch(frame_dict, name, _create_wrapped_func(value))
- @compatibility(is_backward_compatible=True)
- def wrap(fn_or_name: Union[str, Callable]):
- """
- This function can be called at module-level scope to register fn_or_name as a "leaf function".
- A "leaf function" will be preserved as a CallFunction node in the FX trace instead of being
- traced through::
- # foo/bar/baz.py
- def my_custom_function(x, y):
- return x * x + y * y
- torch.fx.wrap('my_custom_function')
- def fn_to_be_traced(x, y):
- # When symbolic tracing, the below call to my_custom_function will be inserted into
- # the graph rather than tracing it.
- return my_custom_function(x, y)
- This function can also equivalently be used as a decorator::
- # foo/bar/baz.py
- @torch.fx.wrap
- def my_custom_function(x, y):
- return x * x + y * y
- A wrapped function can be thought of a "leaf function", analogous to the concept of
- "leaf modules", that is, they are functions that are left as calls in the FX trace
- rather than traced through.
- Args:
- fn_or_name (Union[str, Callable]): The function or name of the global function to insert into the
- graph when it's called
- """
- if not callable(fn_or_name) and not isinstance(fn_or_name, str):
- raise RuntimeError(
- "Unsupported type for global function! Must be either a callable or "
- "string name"
- )
- if callable(fn_or_name):
- assert not isinstance(fn_or_name, str) # to make mypy happy
- fn_name = fn_or_name.__name__
- else:
- assert isinstance(
- fn_or_name, str
- ), "fn_or_name must be a global function or string name"
- fn_name = fn_or_name
- currentframe = inspect.currentframe()
- assert currentframe is not None
- f = currentframe.f_back
- assert f is not None
- if f.f_code.co_name != "<module>":
- raise NotImplementedError("wrap must be called at the top level of a module")
- # consider implementing Callable version of this via _autowrap_function_ids / _autowrap_search
- # semantics would be slightly different, but would add support `from x import wrapped_function`
- _wrapped_fns_to_patch.append((f.f_globals, fn_name))
- return fn_or_name
- @compatibility(is_backward_compatible=True)
- def symbolic_trace(
- root: Union[torch.nn.Module, Callable[..., Any]],
- concrete_args: Optional[Dict[str, Any]] = None,
- ) -> GraphModule:
- """
- Symbolic tracing API
- Given an ``nn.Module`` or function instance ``root``, this function will return a ``GraphModule``
- constructed by recording operations seen while tracing through ``root``.
- ``concrete_args`` allows you to partially specialize your function, whether it's to remove control flow or data structures.
- For example::
- def f(a, b):
- if b == True:
- return a
- else:
- return a*2
- FX can typically not trace through this due to the presence of control
- flow. However, we can use `concrete_args` to specialize on the value of
- `b` to trace through this::
- f = fx.symbolic_trace(f, concrete_args={'b': False})
- assert f(3, False) == 6
- Note that although you can still pass in different values of `b`, they will be ignored.
- We can also use `concrete_args` to eliminate data-structure handling from
- our function. This will use pytrees to flatten your input. To avoid
- overspecializing, pass in `fx.PH` for values that shouldn't be
- specialized. For example::
- def f(x):
- out = 0
- for v in x.values():
- out += v
- return out
- f = fx.symbolic_trace(f, concrete_args={'x': {'a': fx.PH, 'b': fx.PH, 'c': fx.PH}})
- assert f({'a': 1, 'b': 2, 'c': 4}) == 7
- Args:
- root (Union[torch.nn.Module, Callable]): Module or function to be traced and converted
- into a Graph representation.
- concrete_args (Optional[Dict[str, any]]): Inputs to be partially specialized
- Returns:
- GraphModule: a Module created from the recorded operations from ``root``.
- """
- tracer = Tracer()
- graph = tracer.trace(root, concrete_args)
- name = (
- root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__
- )
- return GraphModule(tracer.root, graph, name)
- @wrap
- def _assert_is_none(value, msg):
- assert value is None, msg
|