_symbolic_trace.py 42 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118
  1. import builtins
  2. import copy
  3. import functools
  4. import inspect
  5. import math
  6. import os
  7. import warnings
  8. import collections
  9. from itertools import chain
  10. from types import CodeType, FunctionType, ModuleType
  11. from typing import (
  12. Any,
  13. Callable,
  14. Dict,
  15. List,
  16. NamedTuple,
  17. Optional,
  18. Set,
  19. Tuple,
  20. Type,
  21. Union,
  22. )
  23. import torch
  24. import torch.utils._pytree as pytree
  25. from torch._C import ScriptObject # type: ignore[attr-defined]
  26. from ._compatibility import compatibility
  27. from .graph import _PyTreeCodeGen, _PyTreeInfo, Graph
  28. from .graph_module import GraphModule
  29. from .node import Argument, base_types, map_aggregate
  30. from .proxy import ParameterProxy, Proxy, TracerBase, Scope, ScopeContextManager
  31. HAS_VARSTUFF = inspect.CO_VARARGS | inspect.CO_VARKEYWORDS
  32. # These need to run in global scope to handle nested calls correctly
  33. _orig_module_call: Callable = torch.nn.Module.__call__
  34. _orig_module_getattr: Callable = torch.nn.Module.__getattr__
  35. _proxyable_classes: Dict[Type, None] = {}
  36. _is_fx_tracing_flag = False
  37. def is_fx_tracing():
  38. return _is_fx_tracing_flag
  39. @compatibility(is_backward_compatible=True)
  40. class ProxyableClassMeta(type):
  41. """
  42. ProxyableClassMeta allows you to make construction of a given Python class
  43. symbolically traceable. For example::
  44. import torch
  45. import torch.fx
  46. class TensorPair(metaclass=torch.fx.ProxyableClassMeta):
  47. def __init__(self, left, right):
  48. self.left, self.right = left, right
  49. def add(self, other):
  50. l = self.left + other.left
  51. r = self.right + other.right
  52. return TensorPair(l, r)
  53. def mul(self, other):
  54. l = self.left * other.left
  55. r = self.right * other.right
  56. return TensorPair(l, r)
  57. def use_tensor_pair_ctor(x : TensorPair, y : torch.Tensor):
  58. s = x.add(TensorPair(y, y))
  59. return s.mul(x)
  60. x = TensorPair(torch.randn(5, 3), torch.randn(5, 3))
  61. y = torch.randn(5, 3)
  62. ref_out = use_tensor_pair_ctor(x, y)
  63. traced = torch.fx.symbolic_trace(use_tensor_pair_ctor)
  64. print(traced.code)
  65. '''
  66. def forward(self, x : __main___TensorPair, y : torch.Tensor):
  67. tensor_pair = __main___TensorPair(y, y); y = None
  68. add = x.add(tensor_pair); tensor_pair = None
  69. mul = add.mul(x); add = x = None
  70. return mul
  71. '''
  72. From this example, we can see that construction of a class (``TensorPair``)
  73. defined with ``ProxyableClassMeta`` as metaclass can be recorded in symbolic
  74. tracing.
  75. """
  76. def __init__(cls, name, bases, attrs):
  77. _proxyable_classes.setdefault(cls)
  78. super().__init__(name, bases, attrs)
  79. def __call__(cls, *args, **kwargs):
  80. instance = cls.__new__(cls) # type: ignore[call-overload]
  81. found_proxies = []
  82. def check_proxy(a):
  83. if isinstance(a, Proxy):
  84. found_proxies.append(a)
  85. map_aggregate(args, check_proxy)
  86. map_aggregate(kwargs, check_proxy)
  87. if len(found_proxies) != 0:
  88. tracer = found_proxies[0].tracer
  89. return tracer.create_proxy("call_function", cls, args, kwargs)
  90. else:
  91. cls.__init__(instance, *args, **kwargs) # type: ignore[misc]
  92. return instance
  93. def _patch_function(fn: FunctionType, nargs: int) -> FunctionType:
  94. co = fn.__code__
  95. co_flags = co.co_flags & ~HAS_VARSTUFF
  96. co_args: tuple
  97. if hasattr(co, "co_qualname"):
  98. # Python-3.11+ code signature
  99. co_args = (
  100. nargs,
  101. 0,
  102. 0,
  103. co.co_nlocals,
  104. co.co_stacksize,
  105. co_flags,
  106. co.co_code,
  107. co.co_consts,
  108. co.co_names,
  109. co.co_varnames,
  110. co.co_filename,
  111. co.co_name,
  112. co.co_qualname, # type: ignore[attr-defined]
  113. co.co_firstlineno,
  114. co.co_lnotab,
  115. co.co_exceptiontable, # type: ignore[attr-defined]
  116. co.co_freevars,
  117. co.co_cellvars,
  118. )
  119. elif hasattr(co, "co_posonlyargcount"):
  120. co_args = (
  121. nargs,
  122. 0,
  123. 0,
  124. co.co_nlocals,
  125. co.co_stacksize,
  126. co_flags,
  127. co.co_code,
  128. co.co_consts,
  129. co.co_names,
  130. co.co_varnames,
  131. co.co_filename,
  132. co.co_name,
  133. co.co_firstlineno,
  134. co.co_lnotab,
  135. co.co_freevars,
  136. co.co_cellvars,
  137. )
  138. else:
  139. co_args = (
  140. nargs,
  141. 0,
  142. co.co_nlocals,
  143. co.co_stacksize,
  144. co_flags,
  145. co.co_code,
  146. co.co_consts,
  147. co.co_names,
  148. co.co_varnames,
  149. co.co_filename,
  150. co.co_name,
  151. co.co_firstlineno,
  152. co.co_lnotab,
  153. co.co_freevars,
  154. co.co_cellvars,
  155. )
  156. new_code = CodeType(*co_args) # type: ignore[arg-type]
  157. return FunctionType(
  158. new_code, fn.__globals__, fn.__name__, fn.__defaults__, fn.__closure__
  159. )
  160. # we need to insert placeholder nodes for *args and **kwargs
  161. # we can't call this function normally, otherwise it would try to unpack them
  162. # instead, let's make python think that args and kwargs are normal variables
  163. @compatibility(is_backward_compatible=False)
  164. class PHBase:
  165. """
  166. Object representing an input placeholder to `concrete_args`
  167. """
  168. def __repr__(self):
  169. return "PH"
  170. PH = PHBase()
  171. @compatibility(is_backward_compatible=True)
  172. class Tracer(TracerBase):
  173. # Reference: https://github.com/pytorch/pytorch/issues/54354
  174. # The first line of this docstring overrides the one Sphinx generates for the
  175. # documentation. We need it so that Sphinx doesn't leak `math`s path from the
  176. # build environment (e.g. `<module 'math' from '/leaked/path').
  177. """Tracer(autowrap_modules=(math,), autowrap_functions=())
  178. ``Tracer`` is the class that implements the symbolic tracing functionality
  179. of ``torch.fx.symbolic_trace``. A call to ``symbolic_trace(m)`` is equivalent
  180. to ``Tracer().trace(m)``.
  181. Tracer can be subclassed to override various behaviors of the tracing
  182. process. The different behaviors that can be overridden are described
  183. in the docstrings of the methods on this class.
  184. """
  185. # Not checking BC on this API because the default value for `autowrap_modules`
  186. # includes the local filepath to the `math` module, which would jitter
  187. # across machines.
  188. @compatibility(is_backward_compatible=True)
  189. def __init__(
  190. self,
  191. autowrap_modules: Tuple[ModuleType] = (math,),
  192. autowrap_functions: Tuple[Callable, ...] = (),
  193. param_shapes_constant: bool = False,
  194. ) -> None:
  195. # This method's signature is overridden by the first line of this class'
  196. # docstring. If this method's signature is modified, the signature that
  197. # overrides it also should be modified accordingly.
  198. """
  199. Construct a Tracer object.
  200. Args:
  201. autowrap_modules (Tuple[ModuleType]): defaults to `(math, )`,
  202. Python modules whose functions should be wrapped automatically
  203. without needing to use fx.wrap(). Backward-compatibility for
  204. this parameter is guaranteed.
  205. autowrap_functions (Tuple[Callable, ...]): defaults to `()`,
  206. Python functions that should be wrapped automatically without
  207. needing to use fx.wrap(). Backward compatibility for this
  208. parameter is guaranteed.
  209. param_shapes_constant (bool): When this flag is set, calls to shape,
  210. size and a few other shape like attributes of a module's parameter
  211. will be evaluated directly, rather than returning a new Proxy value
  212. for an attribute access. Backward compatibility for this parameter
  213. is guaranteed.
  214. """
  215. super().__init__()
  216. # Functions we will eagerly wrap when we see them while tracing
  217. # this captures both `math.sqrt()` and `from math import sqrt` automatically
  218. self._autowrap_function_ids: Set[int] = {
  219. id(value)
  220. for name, value in chain(*[m.__dict__.items() for m in autowrap_modules])
  221. if not name.startswith("_") and callable(value)
  222. }
  223. self._autowrap_function_ids.update({id(f) for f in autowrap_functions})
  224. # Python modules to apply autowrap to at the start, in addition to
  225. # modules we see while tracing
  226. self._autowrap_search: List[ModuleType] = list(autowrap_modules)
  227. self.param_shapes_constant = param_shapes_constant
  228. self.submodule_paths: Optional[Dict[torch.nn.Module, str]] = None
  229. self.root_module_name: str = ""
  230. # Maps the containing module's name to the operator name
  231. self.scope = Scope("", None)
  232. # Records the module call stack
  233. self.module_stack = collections.OrderedDict()
  234. # Mapping of node name to module scope
  235. self.node_name_to_scope: Dict[str, Tuple[str, type]] = {}
  236. @compatibility(is_backward_compatible=True)
  237. def create_arg(self, a: Any) -> "Argument":
  238. """
  239. A method to specify the behavior of tracing when preparing values to
  240. be used as arguments to nodes in the ``Graph``.
  241. By default, the behavior includes:
  242. #. Iterate through collection types (e.g. tuple, list, dict) and recursively
  243. call ``create_args`` on the elements.
  244. #. Given a Proxy object, return a reference to the underlying IR ``Node``
  245. #. Given a non-Proxy Tensor object, emit IR for various cases:
  246. * For a Parameter, emit a ``get_attr`` node referring to that Parameter
  247. * For a non-Parameter Tensor, store the Tensor away in a special
  248. attribute referring to that attribute.
  249. This method can be overridden to support more types.
  250. Args:
  251. a (Any): The value to be emitted as an ``Argument`` in the ``Graph``.
  252. Returns:
  253. The value ``a`` converted into the appropriate ``Argument``
  254. """
  255. # The base tracer is used to construct Graphs when there is no associated
  256. # module hierarchy, so it can never create parameter references.
  257. # The default tracer adds the ability to refer to parameters when
  258. # tracing modules.
  259. if isinstance(a, torch.nn.Parameter):
  260. for n, p in self.root.named_parameters():
  261. if a is p:
  262. return self.create_node("get_attr", n, (), {})
  263. raise NameError("parameter is not a member of this module")
  264. elif isinstance(a, torch.Tensor):
  265. for n_, p_ in self.root.named_buffers():
  266. if a is p_:
  267. return self.create_node("get_attr", n_, (), {})
  268. elif isinstance(a, torch.nn.Module):
  269. for n_, p_ in self.root.named_modules():
  270. if a is p_:
  271. return self.create_node("get_attr", n_, (), {})
  272. # For NamedTuple instances that appear literally as args, we emit
  273. # a node to construct the NamedTuple and use that Node as the argument.
  274. if isinstance(a, tuple) and hasattr(a, "_fields"):
  275. args = tuple(self.create_arg(elem) for elem in a)
  276. return self.create_node("call_function", a.__class__, args, {})
  277. # Tensors do not have a reliable string repr() from which they can be
  278. # constructed (and we probably don't want to rely on that, either), so
  279. # for any constant Tensor values we encounter, first search for if they
  280. # are an attribute of some module in the module hierarchy. If so, emit
  281. # a get_attr to retrieve that tensor. Otherwise, we'll store away the
  282. # tensor value into a special attribute on the Module s.t. we can
  283. # retrieve it with a get_attr.
  284. if isinstance(a, (torch.Tensor, ScriptObject)):
  285. qualname: Optional[str] = self.tensor_attrs.get(a)
  286. # Tensor was not found in the Module hierarchy, stow it away in a
  287. # special attribute and set the qualname to refer to that
  288. if not qualname:
  289. i = 0
  290. while True:
  291. qualname = f"_tensor_constant{i}"
  292. if not hasattr(self.root, qualname):
  293. break
  294. i += 1
  295. self.tensor_attrs[a] = qualname
  296. setattr(self.root, qualname, a)
  297. return self.create_node("get_attr", qualname, (), {})
  298. if type(a) in _proxyable_classes:
  299. # This is an instance of a proxyable class for which we did not
  300. # witness its construction. Intern this as a constant attribute
  301. # TODO: binary search
  302. i = 0
  303. while True:
  304. qualname = f"_{a.__class__.__name__}_constant_{i}"
  305. if not hasattr(self.root, qualname):
  306. break
  307. i += 1
  308. setattr(self.root, qualname, a)
  309. return self.create_node("get_attr", qualname, (), {})
  310. return super().create_arg(a)
  311. @compatibility(is_backward_compatible=True)
  312. def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool:
  313. """
  314. A method to specify whether a given ``nn.Module`` is a "leaf" module.
  315. Leaf modules are the atomic units that appear in
  316. the IR, referenced by ``call_module`` calls. By default,
  317. Modules in the PyTorch standard library namespace (torch.nn)
  318. are leaf modules. All other modules are traced through and
  319. their constituent ops are recorded, unless specified otherwise
  320. via this parameter.
  321. Args:
  322. m (Module): The module being queried about
  323. module_qualified_name (str): The path to root of this module. For example,
  324. if you have a module hierarchy where submodule ``foo`` contains
  325. submodule ``bar``, which contains submodule ``baz``, that module will
  326. appear with the qualified name ``foo.bar.baz`` here.
  327. """
  328. return (
  329. (m.__module__.startswith("torch.nn") or m.__module__.startswith("torch.ao.nn"))
  330. and not isinstance(m, torch.nn.Sequential)
  331. )
  332. @compatibility(is_backward_compatible=True)
  333. def path_of_module(self, mod: torch.nn.Module) -> str:
  334. """
  335. Helper method to find the qualified name of ``mod`` in the Module hierarchy
  336. of ``root``. For example, if ``root`` has a submodule named ``foo``, which has
  337. a submodule named ``bar``, passing ``bar`` into this function will return
  338. the string "foo.bar".
  339. Args:
  340. mod (str): The ``Module`` to retrieve the qualified name for.
  341. """
  342. # Prefer the O(1) algorithm
  343. if self.submodule_paths:
  344. path = self.submodule_paths.get(mod)
  345. if path is None:
  346. raise NameError("module is not installed as a submodule")
  347. assert isinstance(path, str)
  348. return path
  349. # O(N^2) fallback in the case that we didn't store the submodule
  350. # paths.
  351. else:
  352. for n, p in self.root.named_modules():
  353. if mod is p:
  354. return n
  355. raise NameError("module is not installed as a submodule")
  356. @compatibility(is_backward_compatible=True)
  357. def call_module(
  358. self,
  359. m: torch.nn.Module,
  360. forward: Callable[..., Any],
  361. args: Tuple[Any, ...],
  362. kwargs: Dict[str, Any],
  363. ) -> Any:
  364. """
  365. Method that specifies the behavior of this ``Tracer`` when it encounters
  366. a call to an ``nn.Module`` instance.
  367. By default, the behavior is to check if the called module is a leaf module
  368. via ``is_leaf_module``. If it is, emit a ``call_module`` node referring to
  369. ``m`` in the ``Graph``. Otherwise, call the ``Module`` normally, tracing through
  370. the operations in its ``forward`` function.
  371. This method can be overridden to--for example--create nested traced
  372. GraphModules, or any other behavior you would want while tracing across
  373. ``Module`` boundaries.
  374. Args:
  375. m (Module): The module for which a call is being emitted
  376. forward (Callable): The forward() method of the ``Module`` to be invoked
  377. args (Tuple): args of the module callsite
  378. kwargs (Dict): kwargs of the module callsite
  379. Return:
  380. The return value from the Module call. In the case that a ``call_module``
  381. node was emitted, this is a ``Proxy`` value. Otherwise, it is whatever
  382. value was returned from the ``Module`` invocation.
  383. """
  384. module_qualified_name = self.path_of_module(m)
  385. with ScopeContextManager(self.scope, Scope(module_qualified_name, type(m))) as _scope:
  386. # module_stack is an ordered dict so writing then deleting the
  387. # entry is equivalent to push/pop on a list
  388. self.module_stack[_scope.module_path] = _scope.module_type
  389. if not self.is_leaf_module(m, module_qualified_name):
  390. ret_val = forward(*args, **kwargs)
  391. else:
  392. ret_val = self.create_proxy("call_module", module_qualified_name, args, kwargs)
  393. key, _ = self.module_stack.popitem(last=True)
  394. assert key == _scope.module_path, f" Unexpected key {key}"
  395. return ret_val
  396. @compatibility(is_backward_compatible=False)
  397. def getattr(self, attr: str, attr_val: Any, parameter_proxy_cache: Dict[str, Any]):
  398. """
  399. Method that specifies the behavior of this ``Tracer`` when we call getattr
  400. on a call to an ``nn.Module`` instance.
  401. By default, the behavior is to return a proxy value for the attribute. It
  402. also stores the proxy value in the ``parameter_proxy_cache``, so that future
  403. calls will reuse the proxy rather than creating a new one.
  404. This method can be overridden to --for example-- not return proxies when
  405. querying parameters.
  406. Args:
  407. attr (str): The name of the attribute being queried
  408. attr_val (Any): The value of the attribute
  409. parameter_proxy_cache (Dict[str, Any]): A cache of attr names to proxies
  410. Return:
  411. The return value from the getattr call.
  412. """
  413. def maybe_get_proxy_for_attr(
  414. attr_val, collection_to_search, parameter_proxy_cache
  415. ):
  416. for n, p in collection_to_search:
  417. if attr_val is p:
  418. if n not in parameter_proxy_cache:
  419. kwargs = {}
  420. if (
  421. "proxy_factory_fn"
  422. in inspect.signature(self.create_proxy).parameters
  423. ):
  424. kwargs["proxy_factory_fn"] = (
  425. None
  426. if not self.param_shapes_constant
  427. else lambda node: ParameterProxy(
  428. self, node, n, attr_val
  429. )
  430. )
  431. val_proxy = self.create_proxy("get_attr", n, (), {}, **kwargs) # type: ignore[arg-type]
  432. parameter_proxy_cache[n] = val_proxy
  433. return parameter_proxy_cache[n]
  434. return None
  435. if isinstance(attr_val, torch.nn.Parameter):
  436. maybe_parameter_proxy = maybe_get_proxy_for_attr(
  437. attr_val, self.root.named_parameters(), parameter_proxy_cache
  438. )
  439. if maybe_parameter_proxy is not None:
  440. return maybe_parameter_proxy
  441. if self.proxy_buffer_attributes and isinstance(attr_val, torch.Tensor):
  442. maybe_buffer_proxy = maybe_get_proxy_for_attr(
  443. attr_val, self.root.named_buffers(), parameter_proxy_cache
  444. )
  445. if maybe_buffer_proxy is not None:
  446. return maybe_buffer_proxy
  447. return attr_val
  448. # This method will be refactored
  449. @compatibility(is_backward_compatible=False)
  450. def create_args_for_root(self, root_fn, is_module, concrete_args=None):
  451. """
  452. Create ``placeholder`` nodes corresponding to the signature of the ``root``
  453. Module. This method introspects root's signature and emits those
  454. nodes accordingly, also supporting ``*args`` and ``**kwargs``.
  455. """
  456. # In some cases, a function or method has been decorated with a wrapper
  457. # defined via ``functools.wraps``. In this case, the outer code object
  458. # will likely not contain the actual parameters we care about, so unwrap
  459. # the function to get to the innermost callable.
  460. fn_for_analysis = inspect.unwrap(root_fn)
  461. co = fn_for_analysis.__code__
  462. total_args = co.co_argcount + co.co_kwonlyargcount
  463. orig_args = list(co.co_varnames)
  464. names_iter = iter(co.co_varnames)
  465. args: List[Any] = []
  466. skip_arg_idx = 0
  467. if is_module:
  468. if total_args == 0:
  469. raise RuntimeError(
  470. "``self`` argument cannot be part of *args expansion!"
  471. )
  472. skip_arg_idx = 1
  473. next(names_iter) # skip self
  474. args.append(self.root)
  475. sig = inspect.signature(fn_for_analysis)
  476. def proxy_placeholder(name: str):
  477. if concrete_args is not None and name in concrete_args:
  478. cnt = 0
  479. def replace_ph(x):
  480. nonlocal cnt
  481. cnt += 1
  482. param = sig.parameters[name]
  483. default = (
  484. ()
  485. if param.default is inspect.Parameter.empty
  486. else (param.default,)
  487. )
  488. out = self.create_proxy(
  489. "placeholder", f"{name}_{str(cnt)}", default, {}
  490. )
  491. if x == PH:
  492. return out
  493. # Union[int, bool] == bool in Python <= 3.6
  494. if (
  495. type(x) == bool
  496. or type(x) in base_types
  497. and type(x) != torch.Tensor
  498. ):
  499. torch._assert(
  500. out == x,
  501. f"{name} has been specialized to have value {x} but got another value",
  502. )
  503. elif type(x) == type(None):
  504. args = (
  505. out,
  506. f"{name} has been specialized to have value None but got another value",
  507. )
  508. self.create_proxy("call_function", _assert_is_none, args, {})
  509. else:
  510. warnings.warn(
  511. f"Was not able to add assertion to guarantee correct input {name} to "
  512. f"specialized function. It is up to the user to make sure that your inputs match the "
  513. f"inputs you specialized the function with."
  514. )
  515. return x
  516. return pytree.tree_map(replace_ph, concrete_args[name])
  517. if name[0] == "*":
  518. default = ()
  519. else:
  520. param = sig.parameters[name]
  521. default = () if param.default is inspect.Parameter.empty else (param.default,) # type: ignore[assignment]
  522. return self.create_proxy(
  523. "placeholder",
  524. name,
  525. default,
  526. {},
  527. type_expr=fn_for_analysis.__annotations__.get(name, None)
  528. )
  529. arg_names = [next(names_iter) for idx in range(skip_arg_idx, total_args)]
  530. if isinstance(concrete_args, tuple):
  531. if len(arg_names) != len(concrete_args):
  532. raise RuntimeError(
  533. f"Tracing expected {len(arg_names)} arguments but got {len(concrete_args)} concrete arguments"
  534. )
  535. concrete_args = {name: val for name, val in zip(arg_names, concrete_args)}
  536. args.extend(proxy_placeholder(names) for names in arg_names)
  537. if co.co_kwonlyargcount > 0 or co.co_flags & HAS_VARSTUFF:
  538. # TODO: type annotations for *args and **kwargs
  539. if co.co_flags & inspect.CO_VARARGS:
  540. args.append(proxy_placeholder("*" + next(names_iter)))
  541. if co.co_flags & inspect.CO_VARKEYWORDS:
  542. args.append(proxy_placeholder("**" + next(names_iter)))
  543. root_fn = _patch_function(root_fn, len(args))
  544. flat_args, in_spec = pytree.tree_flatten(tuple(args))
  545. if any(not isinstance(i, pytree.LeafSpec) for i in in_spec.children_specs):
  546. # In the case that we have pytree-flattened inputs in
  547. # `concrete_args`, generate a flattening wrapper around the
  548. # original root function and return that.
  549. self.graph._codegen = _PyTreeCodeGen(
  550. _PyTreeInfo(orig_args[:total_args], in_spec, None)
  551. )
  552. def flatten_fn(*args):
  553. tree_args = pytree.tree_unflatten(list(args), in_spec)
  554. tree_out = root_fn(*tree_args)
  555. out_args, out_spec = pytree.tree_flatten(tree_out)
  556. assert isinstance(self.graph._codegen, _PyTreeCodeGen)
  557. self.graph._codegen.pytree_info = (
  558. self.graph._codegen.pytree_info._replace(out_spec=out_spec)
  559. )
  560. return out_args
  561. return flatten_fn, flat_args
  562. return root_fn, args
  563. @compatibility(is_backward_compatible=True)
  564. def trace(
  565. self,
  566. root: Union[torch.nn.Module, Callable[..., Any]],
  567. concrete_args: Optional[Dict[str, Any]] = None,
  568. ) -> Graph:
  569. """
  570. Trace ``root`` and return the corresponding FX ``Graph`` representation. ``root``
  571. can either be an ``nn.Module`` instance or a Python callable.
  572. Note that after this call, ``self.root`` may be different from the ``root`` passed
  573. in here. For example, when a free function is passed to ``trace()``, we will
  574. create an ``nn.Module`` instance to use as the root and add embedded constants
  575. to.
  576. Args:
  577. root (Union[Module, Callable]): Either a ``Module`` or a function to be
  578. traced through. Backwards-compatibility for this parameter is
  579. guaranteed.
  580. concrete_args (Optional[Dict[str, any]]): Concrete arguments that should
  581. not be treated as Proxies. This parameter is experimental and
  582. its backwards-compatibility is *NOT* guaranteed.
  583. Returns:
  584. A ``Graph`` representing the semantics of the passed-in ``root``.
  585. """
  586. global _is_fx_tracing_flag
  587. old_is_fx_tracing_flag = _is_fx_tracing_flag
  588. _is_fx_tracing_flag = True
  589. try:
  590. if isinstance(root, torch.nn.Module):
  591. self.root = root
  592. assert hasattr(
  593. type(root), self.traced_func_name
  594. ), f"traced_func_name={self.traced_func_name} doesn't exist in {type(root).__name__}"
  595. fn = getattr(type(root), self.traced_func_name)
  596. self.root_module_name = root._get_name()
  597. self.submodule_paths = {mod: name for name, mod in root.named_modules()}
  598. else:
  599. self.root = torch.nn.Module()
  600. fn = root
  601. tracer_cls: Optional[Type["Tracer"]] = getattr(self, "__class__", None)
  602. self.graph = Graph(tracer_cls=tracer_cls)
  603. # When we encounter a Tensor value that's not a parameter, we look if it
  604. # is some other attribute on the model. Construct a dict mapping Tensor
  605. # values to the qualified name here for efficiency. This is used downstream
  606. # in create_arg
  607. self.tensor_attrs: Dict[Union[torch.Tensor, ScriptObject], str] = {}
  608. def collect_tensor_attrs(m: torch.nn.Module, prefix_atoms: List[str]):
  609. for k, v in m.__dict__.items():
  610. if isinstance(v, (torch.Tensor, ScriptObject)):
  611. self.tensor_attrs[v] = ".".join(prefix_atoms + [k])
  612. for k, v in m.named_children():
  613. collect_tensor_attrs(v, prefix_atoms + [k])
  614. collect_tensor_attrs(self.root, [])
  615. assert isinstance(fn, FunctionType)
  616. fn_globals = fn.__globals__ # run before it gets patched
  617. fn, args = self.create_args_for_root(
  618. fn, isinstance(root, torch.nn.Module), concrete_args
  619. )
  620. parameter_proxy_cache: Dict[
  621. str, Proxy
  622. ] = {} # Reduce number of get_attr calls
  623. # Method dispatch on parameters is not recorded unless it's directly used.
  624. # Thus, we need to insert a proxy when __getattr__ requests a parameter.
  625. @functools.wraps(_orig_module_getattr)
  626. def module_getattr_wrapper(mod, attr):
  627. attr_val = _orig_module_getattr(mod, attr)
  628. return self.getattr(attr, attr_val, parameter_proxy_cache)
  629. @functools.wraps(_orig_module_call)
  630. def module_call_wrapper(mod, *args, **kwargs):
  631. def forward(*args, **kwargs):
  632. return _orig_module_call(mod, *args, **kwargs)
  633. _autowrap_check(
  634. patcher,
  635. getattr(getattr(mod, "forward", mod), "__globals__", {}),
  636. self._autowrap_function_ids,
  637. )
  638. return self.call_module(mod, forward, args, kwargs)
  639. with _Patcher() as patcher:
  640. # allow duplicate patches to support the case of nested calls
  641. patcher.patch_method(
  642. torch.nn.Module,
  643. "__getattr__",
  644. module_getattr_wrapper,
  645. deduplicate=False,
  646. )
  647. patcher.patch_method(
  648. torch.nn.Module, "__call__", module_call_wrapper, deduplicate=False
  649. )
  650. _patch_wrapped_functions(patcher)
  651. _autowrap_check(patcher, fn_globals, self._autowrap_function_ids)
  652. for module in self._autowrap_search:
  653. _autowrap_check(
  654. patcher, module.__dict__, self._autowrap_function_ids
  655. )
  656. self.create_node(
  657. "output",
  658. "output",
  659. (self.create_arg(fn(*args)),),
  660. {},
  661. type_expr=fn.__annotations__.get("return", None),
  662. )
  663. self.submodule_paths = None
  664. finally:
  665. _is_fx_tracing_flag = old_is_fx_tracing_flag
  666. return self.graph
  667. def __deepcopy__(self, memo):
  668. # _autowrap_search contains modules, which cannot be deepcopied.
  669. new_tracer = Tracer.__new__(Tracer)
  670. for k, v in self.__dict__.items():
  671. if k in {'_autowrap_search'}:
  672. new_obj = copy.copy(v)
  673. else:
  674. new_obj = copy.deepcopy(v, memo)
  675. new_tracer.__dict__[k] = new_obj
  676. return new_tracer
  677. # List of pairs of (global dict, function name) functions
  678. # to patch for the purposes of the wrap() API.
  679. _wrapped_fns_to_patch: List[Tuple[dict, str]] = []
  680. # List of methods on classes to wrap (class type, function name)
  681. # this currently only works for Tensor.* methods that aren't traced properly
  682. _wrapped_methods_to_patch: List[Tuple[type, str]] = []
  683. if os.environ.get("FX_PATCH_GETITEM") == "1":
  684. # This change is needed to trace models like PositionalEmbedding from BERT:
  685. # https://github.com/pytorch/benchmark/blob/master/torchbenchmark/models/BERT_pytorch/bert_pytorch/model/embedding/position.py
  686. # but causes issues in quantization documented here:
  687. # https://github.com/pytorch/pytorch/issues/50710
  688. # once that is fixed we can make this the default behavior.
  689. _wrapped_methods_to_patch.append((torch.Tensor, "__getitem__"))
  690. def _find_proxy(*objects_to_search):
  691. """
  692. Recursively search a data structure for a Proxy() and return it,
  693. return None if not found.
  694. """
  695. proxy = None
  696. def find_proxy(x):
  697. nonlocal proxy
  698. if isinstance(x, Proxy):
  699. proxy = x
  700. map_aggregate(objects_to_search, find_proxy)
  701. return proxy
  702. def _create_wrapped_func(orig_fn):
  703. @functools.wraps(orig_fn)
  704. def wrapped(*args, **kwargs):
  705. """
  706. Given an closed-over ``orig_function`` to invoke, search the args and kwargs for
  707. a Proxy object. If there is one, emit a ``call_function`` node to preserve the
  708. call to this leaf function directly. Otherwise, just return the results of
  709. this function call, as this function is not being traced.
  710. """
  711. proxy = _find_proxy(args, kwargs)
  712. if proxy is not None:
  713. return_proxy = proxy.tracer.create_proxy(
  714. "call_function", orig_fn, args, kwargs
  715. )
  716. return_proxy.node.meta["is_wrapped"] = True
  717. return return_proxy
  718. return orig_fn(*args, **kwargs)
  719. return wrapped
  720. def _create_wrapped_method(cls, name):
  721. orig_fn = getattr(cls, name)
  722. @functools.wraps(orig_fn)
  723. def wrapped(*args, **kwargs):
  724. """
  725. Search the args and kwargs for a Proxy object. If there is one,
  726. emit a ``call_method`` node to preserve the call to this method
  727. directly. Otherwise, just return the results of this function
  728. call, as this function is not being traced.
  729. """
  730. proxy = _find_proxy(args, kwargs)
  731. if proxy is not None:
  732. return proxy.tracer.create_proxy("call_method", name, args, kwargs)
  733. return orig_fn(*args, **kwargs)
  734. return wrapped
  735. class _PatchedFn(NamedTuple):
  736. frame_dict: Any
  737. fn_name: str
  738. orig_fn: Any
  739. def revert(self):
  740. raise NotImplementedError()
  741. class _PatchedFnSetItem(_PatchedFn):
  742. def revert(self):
  743. self.frame_dict[self.fn_name] = self.orig_fn
  744. class _PatchedFnDel(_PatchedFn):
  745. def revert(self):
  746. del self.frame_dict[self.fn_name]
  747. class _PatchedFnSetAttr(_PatchedFn):
  748. def revert(self):
  749. setattr(self.frame_dict, self.fn_name, self.orig_fn)
  750. class _Patcher:
  751. def __init__(self):
  752. super().__init__()
  753. self.patches_made: List[_PatchedFn] = []
  754. self.visited: Set[int] = set()
  755. def patch(
  756. self,
  757. frame_dict: Dict[str, Any],
  758. name: str,
  759. new_fn: Callable,
  760. deduplicate: bool = True,
  761. ):
  762. """
  763. Replace frame_dict[name] with new_fn until we exit the context manager.
  764. """
  765. new_fn.__fx_already_patched = deduplicate # type: ignore[attr-defined]
  766. if name not in frame_dict and hasattr(builtins, name):
  767. self.patches_made.append(_PatchedFnDel(frame_dict, name, None))
  768. elif getattr(frame_dict[name], "__fx_already_patched", False):
  769. return # already patched, no need to do it again
  770. else:
  771. self.patches_made.append(
  772. _PatchedFnSetItem(frame_dict, name, frame_dict[name])
  773. )
  774. frame_dict[name] = new_fn
  775. def patch_method(
  776. self, cls: type, name: str, new_fn: Callable, deduplicate: bool = True
  777. ):
  778. """
  779. Replace object_or_dict.name with new_fn until we exit the context manager.
  780. """
  781. new_fn.__fx_already_patched = deduplicate # type: ignore[attr-defined]
  782. orig_fn = getattr(cls, name)
  783. if getattr(orig_fn, "__fx_already_patched", False):
  784. return # already patched, no need to do it again
  785. self.patches_made.append(_PatchedFnSetAttr(cls, name, orig_fn))
  786. setattr(cls, name, new_fn)
  787. def visit_once(self, thing: Any):
  788. """Return True on the first call to with thing, otherwise false"""
  789. idx = id(thing)
  790. if idx in self.visited:
  791. return False
  792. self.visited.add(idx)
  793. return True
  794. def __enter__(self):
  795. return self
  796. def __exit__(self, exc_type, exc_val, exc_tb):
  797. """
  798. Undo all the changes made via self.patch() and self.patch_method()
  799. """
  800. while self.patches_made:
  801. # unpatch in reverse order to handle duplicates correctly
  802. self.patches_made.pop().revert()
  803. self.visited.clear()
  804. def _patch_wrapped_functions(patcher: _Patcher):
  805. """
  806. Go through ``_wrapped_fn_patch_table`` and, for each frame object, wrap
  807. the listed global functions in the `_create_wrapped_func` wrapper.
  808. """
  809. for frame_dict, name in _wrapped_fns_to_patch:
  810. if name not in frame_dict and hasattr(builtins, name):
  811. orig_fn = getattr(builtins, name)
  812. else:
  813. orig_fn = frame_dict[name]
  814. patcher.patch(frame_dict, name, _create_wrapped_func(orig_fn))
  815. for cls, name in _wrapped_methods_to_patch:
  816. patcher.patch_method(cls, name, _create_wrapped_method(cls, name))
  817. def _autowrap_check(
  818. patcher: _Patcher, frame_dict: Dict[str, Any], function_ids: Set[int]
  819. ):
  820. """
  821. Some methods, like `math.sqrt` are common enough we want to automatically wrap them as we see them.
  822. This method searches a scope for them and patches them if found.
  823. """
  824. if patcher.visit_once(frame_dict):
  825. for name, value in frame_dict.items():
  826. if (
  827. not name.startswith("_")
  828. and callable(value)
  829. and id(value) in function_ids
  830. ):
  831. patcher.patch(frame_dict, name, _create_wrapped_func(value))
  832. @compatibility(is_backward_compatible=True)
  833. def wrap(fn_or_name: Union[str, Callable]):
  834. """
  835. This function can be called at module-level scope to register fn_or_name as a "leaf function".
  836. A "leaf function" will be preserved as a CallFunction node in the FX trace instead of being
  837. traced through::
  838. # foo/bar/baz.py
  839. def my_custom_function(x, y):
  840. return x * x + y * y
  841. torch.fx.wrap('my_custom_function')
  842. def fn_to_be_traced(x, y):
  843. # When symbolic tracing, the below call to my_custom_function will be inserted into
  844. # the graph rather than tracing it.
  845. return my_custom_function(x, y)
  846. This function can also equivalently be used as a decorator::
  847. # foo/bar/baz.py
  848. @torch.fx.wrap
  849. def my_custom_function(x, y):
  850. return x * x + y * y
  851. A wrapped function can be thought of a "leaf function", analogous to the concept of
  852. "leaf modules", that is, they are functions that are left as calls in the FX trace
  853. rather than traced through.
  854. Args:
  855. fn_or_name (Union[str, Callable]): The function or name of the global function to insert into the
  856. graph when it's called
  857. """
  858. if not callable(fn_or_name) and not isinstance(fn_or_name, str):
  859. raise RuntimeError(
  860. "Unsupported type for global function! Must be either a callable or "
  861. "string name"
  862. )
  863. if callable(fn_or_name):
  864. assert not isinstance(fn_or_name, str) # to make mypy happy
  865. fn_name = fn_or_name.__name__
  866. else:
  867. assert isinstance(
  868. fn_or_name, str
  869. ), "fn_or_name must be a global function or string name"
  870. fn_name = fn_or_name
  871. currentframe = inspect.currentframe()
  872. assert currentframe is not None
  873. f = currentframe.f_back
  874. assert f is not None
  875. if f.f_code.co_name != "<module>":
  876. raise NotImplementedError("wrap must be called at the top level of a module")
  877. # consider implementing Callable version of this via _autowrap_function_ids / _autowrap_search
  878. # semantics would be slightly different, but would add support `from x import wrapped_function`
  879. _wrapped_fns_to_patch.append((f.f_globals, fn_name))
  880. return fn_or_name
  881. @compatibility(is_backward_compatible=True)
  882. def symbolic_trace(
  883. root: Union[torch.nn.Module, Callable[..., Any]],
  884. concrete_args: Optional[Dict[str, Any]] = None,
  885. ) -> GraphModule:
  886. """
  887. Symbolic tracing API
  888. Given an ``nn.Module`` or function instance ``root``, this function will return a ``GraphModule``
  889. constructed by recording operations seen while tracing through ``root``.
  890. ``concrete_args`` allows you to partially specialize your function, whether it's to remove control flow or data structures.
  891. For example::
  892. def f(a, b):
  893. if b == True:
  894. return a
  895. else:
  896. return a*2
  897. FX can typically not trace through this due to the presence of control
  898. flow. However, we can use `concrete_args` to specialize on the value of
  899. `b` to trace through this::
  900. f = fx.symbolic_trace(f, concrete_args={'b': False})
  901. assert f(3, False) == 6
  902. Note that although you can still pass in different values of `b`, they will be ignored.
  903. We can also use `concrete_args` to eliminate data-structure handling from
  904. our function. This will use pytrees to flatten your input. To avoid
  905. overspecializing, pass in `fx.PH` for values that shouldn't be
  906. specialized. For example::
  907. def f(x):
  908. out = 0
  909. for v in x.values():
  910. out += v
  911. return out
  912. f = fx.symbolic_trace(f, concrete_args={'x': {'a': fx.PH, 'b': fx.PH, 'c': fx.PH}})
  913. assert f({'a': 1, 'b': 2, 'c': 4}) == 7
  914. Args:
  915. root (Union[torch.nn.Module, Callable]): Module or function to be traced and converted
  916. into a Graph representation.
  917. concrete_args (Optional[Dict[str, any]]): Inputs to be partially specialized
  918. Returns:
  919. GraphModule: a Module created from the recorded operations from ``root``.
  920. """
  921. tracer = Tracer()
  922. graph = tracer.trace(root, concrete_args)
  923. name = (
  924. root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__
  925. )
  926. return GraphModule(tracer.root, graph, name)
  927. @wrap
  928. def _assert_is_none(value, msg):
  929. assert value is None, msg