interpreter.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486
  1. from .graph_module import GraphModule
  2. from .graph import Graph
  3. from .node import Argument, Node, Target, map_arg, map_aggregate
  4. from .proxy import Proxy
  5. from ._symbolic_trace import Tracer
  6. from ._compatibility import compatibility
  7. from . import config
  8. import torch.fx.traceback as fx_traceback
  9. from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
  10. import inspect
  11. from contextlib import contextmanager
  12. from torch.hub import tqdm
  13. __all__ = ['Interpreter', 'Transformer']
  14. @compatibility(is_backward_compatible=True)
  15. class Interpreter:
  16. """
  17. An Interpreter executes an FX graph Node-by-Node. This pattern
  18. can be useful for many things, including writing code
  19. transformations as well as analysis passes.
  20. Methods in the Interpreter class can be overridden to customize
  21. the behavior of execution. The map of overrideable methods
  22. in terms of call hierarchy::
  23. run()
  24. +-- run_node
  25. +-- placeholder()
  26. +-- get_attr()
  27. +-- call_function()
  28. +-- call_method()
  29. +-- call_module()
  30. +-- output()
  31. Example:
  32. Suppose we want to swap all instances of ``torch.neg`` with
  33. ``torch.sigmoid`` and vice versa (including their ``Tensor``
  34. method equivalents). We could subclass Interpreter like so::
  35. class NegSigmSwapInterpreter(Interpreter):
  36. def call_function(self, target : Target,
  37. args : Tuple, kwargs : Dict) -> Any:
  38. if target == torch.sigmoid:
  39. return torch.neg(*args, **kwargs)
  40. return super().call_function(n)
  41. def call_method(self, target : Target,
  42. args : Tuple, kwargs : Dict) -> Any:
  43. if target == 'neg':
  44. call_self, *args_tail = args
  45. return call_self.sigmoid(*args_tail, **kwargs)
  46. return super().call_method(n)
  47. def fn(x):
  48. return torch.sigmoid(x).neg()
  49. gm = torch.fx.symbolic_trace(fn)
  50. input = torch.randn(3, 4)
  51. result = NegSigmSwapInterpreter(gm).run(input)
  52. torch.testing.assert_close(result, torch.neg(input).sigmoid())
  53. Args:
  54. module (GraphModule): The module to be executed
  55. garbage_collect_values (bool): Whether to delete values after their last
  56. use within the Module's execution. This ensures optimal memory usage during
  57. execution. This can be disabled to, for example, examine all of the intermediate
  58. values in the execution by looking at the ``Interpreter.env`` attribute.
  59. """
  60. @compatibility(is_backward_compatible=True)
  61. def __init__(self, module : GraphModule, garbage_collect_values : bool = True):
  62. assert isinstance(module, GraphModule)
  63. self.module = module
  64. self.submodules = dict(self.module.named_modules())
  65. self.env : Dict[Node, Any] = {}
  66. self.name = "Interpreter"
  67. self.garbage_collect_values = garbage_collect_values
  68. if self.garbage_collect_values:
  69. # Run through reverse nodes and record the first instance of a use
  70. # of a given node. This represents the *last* use of the node in the
  71. # execution order of the program, which we will use to free unused
  72. # values
  73. node_to_last_use : Dict[Node, Node] = {}
  74. self.user_to_last_uses : Dict[Node, List[Node]] = {}
  75. def register_last_uses(n : Node, user : Node):
  76. if n not in node_to_last_use:
  77. node_to_last_use[n] = user
  78. self.user_to_last_uses.setdefault(user, []).append(n)
  79. for node in reversed(self.module.graph.nodes):
  80. map_arg(node.args, lambda n: register_last_uses(n, node))
  81. map_arg(node.kwargs, lambda n: register_last_uses(n, node))
  82. @compatibility(is_backward_compatible=True)
  83. def run(self, *args, initial_env : Optional[Dict[Node, Any]] = None, enable_io_processing : bool = True) -> Any:
  84. """
  85. Run `module` via interpretation and return the result.
  86. Args:
  87. *args: The arguments to the Module to run, in positional order
  88. initial_env (Optional[Dict[Node, Any]]): An optional starting environment for execution.
  89. This is a dict mapping `Node` to any value. This can be used, for example, to
  90. pre-populate results for certain `Nodes` so as to do only partial evaluation within
  91. the interpreter.
  92. enable_io_processing (bool): If true, we process the inputs and outputs with graph's process_inputs and
  93. process_outputs function first before using them.
  94. Returns:
  95. Any: The value returned from executing the Module
  96. """
  97. self.env = initial_env if initial_env is not None else {}
  98. # Positional function args are consumed left-to-right by
  99. # `placeholder` nodes. Use an iterator to keep track of
  100. # position and extract those values.
  101. if enable_io_processing:
  102. args = self.module.graph.process_inputs(*args)
  103. self.args_iter : Iterator[Any] = iter(args)
  104. pbar = tqdm(total=len(self.module.graph.nodes),
  105. desc=f"{self.name}: {str(list(self.module.graph.nodes)) if config.verbose_progress else ''}",
  106. initial=0, position=0, leave=True, disable=config.disable_progress, delay=0)
  107. for node in self.module.graph.nodes:
  108. pbar.update(1)
  109. if node in self.env:
  110. # Short circuit if we have this value. This could
  111. # be used, for example, for partial evaluation
  112. # where the caller has pre-populated `env` with
  113. # values for a subset of the program.
  114. continue
  115. try:
  116. self.env[node] = self.run_node(node)
  117. except Exception as e:
  118. msg = f"While executing {node.format_node()}"
  119. msg = '{}\n\n{}'.format(e.args[0], msg) if e.args else str(msg)
  120. msg += f"\nOriginal traceback:\n{node.stack_trace}"
  121. e.args = (msg,) + e.args[1:]
  122. if isinstance(e, KeyError):
  123. raise RuntimeError(*e.args) from e
  124. raise
  125. if self.garbage_collect_values:
  126. for to_delete in self.user_to_last_uses.get(node, []):
  127. del self.env[to_delete]
  128. if node.op == 'output':
  129. output_val = self.env[node]
  130. return self.module.graph.process_outputs(output_val) if enable_io_processing else output_val
  131. @contextmanager
  132. def _set_current_node(self, node):
  133. with fx_traceback.set_current_meta(node.meta):
  134. yield
  135. @compatibility(is_backward_compatible=True)
  136. def run_node(self, n : Node) -> Any:
  137. """
  138. Run a specific node ``n`` and return the result.
  139. Calls into placeholder, get_attr, call_function,
  140. call_method, call_module, or output depending
  141. on ``node.op``
  142. Args:
  143. n (Node): The Node to execute
  144. Returns:
  145. Any: The result of executing ``n``
  146. """
  147. with self._set_current_node(n):
  148. args, kwargs = self.fetch_args_kwargs_from_env(n)
  149. assert isinstance(args, tuple)
  150. assert isinstance(kwargs, dict)
  151. return getattr(self, n.op)(n.target, args, kwargs)
  152. # Main Node running APIs
  153. @compatibility(is_backward_compatible=True)
  154. def placeholder(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
  155. """
  156. Execute a ``placeholder`` node. Note that this is stateful:
  157. ``Interpreter`` maintains an internal iterator over
  158. arguments passed to ``run`` and this method returns
  159. next() on that iterator.
  160. Args:
  161. target (Target): The call target for this node. See
  162. `Node <https://pytorch.org/docs/master/fx.html#torch.fx.Node>`__ for
  163. details on semantics
  164. args (Tuple): Tuple of positional args for this invocation
  165. kwargs (Dict): Dict of keyword arguments for this invocation
  166. Returns:
  167. Any: The argument value that was retrieved.
  168. """
  169. assert isinstance(target, str)
  170. if target.startswith('*'):
  171. # For a starred parameter e.g. `*args`, retrieve all
  172. # remaining values from the args list.
  173. return list(self.args_iter)
  174. else:
  175. try:
  176. return next(self.args_iter)
  177. except StopIteration as si:
  178. if len(args) > 0:
  179. return args[0]
  180. else:
  181. raise RuntimeError(f'Expected positional argument for parameter {target}, but one was not passed in!') from si
  182. @compatibility(is_backward_compatible=True)
  183. def get_attr(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
  184. """
  185. Execute a ``get_attr`` node. Will retrieve an attribute
  186. value from the ``Module`` hierarchy of ``self.module``.
  187. Args:
  188. target (Target): The call target for this node. See
  189. `Node <https://pytorch.org/docs/master/fx.html#torch.fx.Node>`__ for
  190. details on semantics
  191. args (Tuple): Tuple of positional args for this invocation
  192. kwargs (Dict): Dict of keyword arguments for this invocation
  193. Return:
  194. Any: The value of the attribute that was retrieved
  195. """
  196. assert isinstance(target, str)
  197. return self.fetch_attr(target)
  198. @compatibility(is_backward_compatible=True)
  199. def call_function(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
  200. """
  201. Execute a ``call_function`` node and return the result.
  202. Args:
  203. target (Target): The call target for this node. See
  204. `Node <https://pytorch.org/docs/master/fx.html#torch.fx.Node>`__ for
  205. details on semantics
  206. args (Tuple): Tuple of positional args for this invocation
  207. kwargs (Dict): Dict of keyword arguments for this invocation
  208. Return
  209. Any: The value returned by the function invocation
  210. """
  211. assert not isinstance(target, str)
  212. # Execute the function and return the result
  213. return target(*args, **kwargs)
  214. @compatibility(is_backward_compatible=True)
  215. def call_method(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
  216. """
  217. Execute a ``call_method`` node and return the result.
  218. Args:
  219. target (Target): The call target for this node. See
  220. `Node <https://pytorch.org/docs/master/fx.html#torch.fx.Node>`__ for
  221. details on semantics
  222. args (Tuple): Tuple of positional args for this invocation
  223. kwargs (Dict): Dict of keyword arguments for this invocation
  224. Return
  225. Any: The value returned by the method invocation
  226. """
  227. # args[0] is the `self` object for this method call
  228. self_obj, *args_tail = args
  229. # Execute the method and return the result
  230. assert isinstance(target, str)
  231. return getattr(self_obj, target)(*args_tail, **kwargs)
  232. @compatibility(is_backward_compatible=True)
  233. def call_module(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
  234. """
  235. Execute a ``call_module`` node and return the result.
  236. Args:
  237. target (Target): The call target for this node. See
  238. `Node <https://pytorch.org/docs/master/fx.html#torch.fx.Node>`__ for
  239. details on semantics
  240. args (Tuple): Tuple of positional args for this invocation
  241. kwargs (Dict): Dict of keyword arguments for this invocation
  242. Return
  243. Any: The value returned by the module invocation
  244. """
  245. # Retrieve executed args and kwargs values from the environment
  246. # Execute the method and return the result
  247. assert isinstance(target, str)
  248. submod = self.fetch_attr(target)
  249. return submod(*args, **kwargs)
  250. @compatibility(is_backward_compatible=True)
  251. def output(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
  252. """
  253. Execute an ``output`` node. This really just retrieves
  254. the value referenced by the ``output`` node and returns it.
  255. Args:
  256. target (Target): The call target for this node. See
  257. `Node <https://pytorch.org/docs/master/fx.html#torch.fx.Node>`__ for
  258. details on semantics
  259. args (Tuple): Tuple of positional args for this invocation
  260. kwargs (Dict): Dict of keyword arguments for this invocation
  261. Return:
  262. Any: The return value referenced by the output node
  263. """
  264. return args[0]
  265. # Helper methods
  266. @compatibility(is_backward_compatible=True)
  267. def fetch_attr(self, target : str):
  268. """
  269. Fetch an attribute from the ``Module`` hierarchy of ``self.module``.
  270. Args:
  271. target (str): The fully-qualified name of the attribute to fetch
  272. Return:
  273. Any: The value of the attribute.
  274. """
  275. target_atoms = target.split('.')
  276. attr_itr = self.module
  277. for i, atom in enumerate(target_atoms):
  278. if not hasattr(attr_itr, atom):
  279. raise RuntimeError(f"Node referenced nonexistent target {'.'.join(target_atoms[:i])}")
  280. attr_itr = getattr(attr_itr, atom)
  281. return attr_itr
  282. @compatibility(is_backward_compatible=True)
  283. def fetch_args_kwargs_from_env(self, n : Node) -> Tuple[Tuple, Dict]:
  284. """
  285. Fetch the concrete values of ``args`` and ``kwargs`` of node ``n``
  286. from the current execution environment.
  287. Args:
  288. n (Node): The node for which ``args`` and ``kwargs`` should be fetched.
  289. Return:
  290. Tuple[Tuple, Dict]: ``args`` and ``kwargs`` with concrete values for ``n``.
  291. """
  292. args = self.map_nodes_to_values(n.args, n)
  293. assert isinstance(args, tuple)
  294. kwargs = self.map_nodes_to_values(n.kwargs, n)
  295. assert isinstance(kwargs, dict)
  296. return args, kwargs
  297. @compatibility(is_backward_compatible=True)
  298. def map_nodes_to_values(self, args : Argument, n : Node) -> Argument:
  299. """
  300. Recursively descend through ``args`` and look up the concrete value
  301. for each ``Node`` in the current execution environment.
  302. Args:
  303. args (Argument): Data structure within which to look up concrete values
  304. n (Node): Node to which ``args`` belongs. This is only used for error reporting.
  305. """
  306. def load_arg(n_arg : Node) -> Any:
  307. if n_arg not in self.env:
  308. raise RuntimeError(f'Node {n} referenced nonexistent value {n_arg}! Run Graph.lint() '
  309. f'to diagnose such issues')
  310. return self.env[n_arg]
  311. return map_arg(args, load_arg)
  312. @compatibility(is_backward_compatible=True)
  313. class Transformer(Interpreter):
  314. """
  315. ``Transformer`` is a special type of interpreter that produces a
  316. new ``Module``. It exposes a ``transform()`` method that returns
  317. the transformed ``Module``. ``Transformer`` does not require
  318. arguments to run, as ``Interpreter`` does. ``Transformer`` works
  319. entirely symbolically.
  320. Example:
  321. Suppose we want to swap all instances of ``torch.neg`` with
  322. ``torch.sigmoid`` and vice versa (including their ``Tensor``
  323. method equivalents). We could subclass ``Transformer`` like so::
  324. class NegSigmSwapXformer(Transformer):
  325. def call_function(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
  326. if target == torch.sigmoid:
  327. return torch.neg(*args, **kwargs)
  328. return super().call_function(n)
  329. def call_method(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
  330. if target == 'neg':
  331. call_self, *args_tail = args
  332. return call_self.sigmoid(*args_tail, **kwargs)
  333. return super().call_method(n)
  334. def fn(x):
  335. return torch.sigmoid(x).neg()
  336. gm = torch.fx.symbolic_trace(fn)
  337. transformed : torch.nn.Module = NegSigmSwapXformer(gm).transform()
  338. input = torch.randn(3, 4)
  339. torch.testing.assert_close(transformed(input), torch.neg(input).sigmoid())
  340. Args:
  341. module (GraphModule): The ``Module`` to be transformed.
  342. """
  343. @compatibility(is_backward_compatible=True)
  344. def __init__(self, module):
  345. super().__init__(module)
  346. self.new_graph = Graph()
  347. self.new_graph.set_codegen(module.graph._codegen)
  348. class TransformerTracer(Tracer):
  349. def __init__(self, graph: Graph):
  350. super().__init__()
  351. self.graph = graph
  352. def is_leaf_module(self, _, __) -> bool:
  353. return True
  354. self.tracer = TransformerTracer(self.new_graph)
  355. self.tracer.root = module
  356. @compatibility(is_backward_compatible=True)
  357. def placeholder(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Proxy:
  358. """
  359. Execute a ``placeholder`` node. In ``Transformer``, this is
  360. overridden to insert a new ``placeholder`` into the output
  361. graph.
  362. Args:
  363. target (Target): The call target for this node. See
  364. `Node <https://pytorch.org/docs/master/fx.html#torch.fx.Node>`__ for
  365. details on semantics
  366. args (Tuple): Tuple of positional args for this invocation
  367. kwargs (Dict): Dict of keyword arguments for this invocation
  368. """
  369. assert isinstance(target, str)
  370. default_value = next(iter(args)) if args else inspect.Signature.empty
  371. return Proxy(self.new_graph.placeholder(target, default_value=default_value), self.tracer)
  372. @compatibility(is_backward_compatible=True)
  373. def get_attr(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Proxy:
  374. """
  375. Execute a ``get_attr`` node. In ``Transformer``, this is
  376. overridden to insert a new ``get_attr`` node into the output
  377. graph.
  378. Args:
  379. target (Target): The call target for this node. See
  380. `Node <https://pytorch.org/docs/master/fx.html#torch.fx.Node>`__ for
  381. details on semantics
  382. args (Tuple): Tuple of positional args for this invocation
  383. kwargs (Dict): Dict of keyword arguments for this invocation
  384. """
  385. assert isinstance(target, str)
  386. return Proxy(self.new_graph.get_attr(target), self.tracer)
  387. @compatibility(is_backward_compatible=True)
  388. def call_module(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
  389. # Override so that the leaf module policy from `self.tracer` is respected.
  390. assert isinstance(target, str)
  391. submod = self.fetch_attr(target)
  392. return self.tracer.call_module(submod, submod.forward, args, kwargs)
  393. @compatibility(is_backward_compatible=True)
  394. def call_function(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
  395. # Override so that functions that were wrapped are still wrapped.
  396. return self.tracer.create_proxy('call_function', target, args, kwargs)
  397. @compatibility(is_backward_compatible=True)
  398. def transform(self) -> GraphModule:
  399. """
  400. Transform ``self.module`` and return the transformed
  401. ``GraphModule``.
  402. """
  403. with fx_traceback.preserve_node_meta():
  404. result = super().run(enable_io_processing=False)
  405. if result is not None:
  406. def strip_proxy(a : Union[Argument, Proxy]) -> Any:
  407. return a.node if isinstance(a, Proxy) else a
  408. self.new_graph.output(map_aggregate(result, strip_proxy))
  409. return GraphModule(self.module, self.new_graph)