123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486 |
- from .graph_module import GraphModule
- from .graph import Graph
- from .node import Argument, Node, Target, map_arg, map_aggregate
- from .proxy import Proxy
- from ._symbolic_trace import Tracer
- from ._compatibility import compatibility
- from . import config
- import torch.fx.traceback as fx_traceback
- from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
- import inspect
- from contextlib import contextmanager
- from torch.hub import tqdm
- __all__ = ['Interpreter', 'Transformer']
- @compatibility(is_backward_compatible=True)
- class Interpreter:
- """
- An Interpreter executes an FX graph Node-by-Node. This pattern
- can be useful for many things, including writing code
- transformations as well as analysis passes.
- Methods in the Interpreter class can be overridden to customize
- the behavior of execution. The map of overrideable methods
- in terms of call hierarchy::
- run()
- +-- run_node
- +-- placeholder()
- +-- get_attr()
- +-- call_function()
- +-- call_method()
- +-- call_module()
- +-- output()
- Example:
- Suppose we want to swap all instances of ``torch.neg`` with
- ``torch.sigmoid`` and vice versa (including their ``Tensor``
- method equivalents). We could subclass Interpreter like so::
- class NegSigmSwapInterpreter(Interpreter):
- def call_function(self, target : Target,
- args : Tuple, kwargs : Dict) -> Any:
- if target == torch.sigmoid:
- return torch.neg(*args, **kwargs)
- return super().call_function(n)
- def call_method(self, target : Target,
- args : Tuple, kwargs : Dict) -> Any:
- if target == 'neg':
- call_self, *args_tail = args
- return call_self.sigmoid(*args_tail, **kwargs)
- return super().call_method(n)
- def fn(x):
- return torch.sigmoid(x).neg()
- gm = torch.fx.symbolic_trace(fn)
- input = torch.randn(3, 4)
- result = NegSigmSwapInterpreter(gm).run(input)
- torch.testing.assert_close(result, torch.neg(input).sigmoid())
- Args:
- module (GraphModule): The module to be executed
- garbage_collect_values (bool): Whether to delete values after their last
- use within the Module's execution. This ensures optimal memory usage during
- execution. This can be disabled to, for example, examine all of the intermediate
- values in the execution by looking at the ``Interpreter.env`` attribute.
- """
- @compatibility(is_backward_compatible=True)
- def __init__(self, module : GraphModule, garbage_collect_values : bool = True):
- assert isinstance(module, GraphModule)
- self.module = module
- self.submodules = dict(self.module.named_modules())
- self.env : Dict[Node, Any] = {}
- self.name = "Interpreter"
- self.garbage_collect_values = garbage_collect_values
- if self.garbage_collect_values:
- # Run through reverse nodes and record the first instance of a use
- # of a given node. This represents the *last* use of the node in the
- # execution order of the program, which we will use to free unused
- # values
- node_to_last_use : Dict[Node, Node] = {}
- self.user_to_last_uses : Dict[Node, List[Node]] = {}
- def register_last_uses(n : Node, user : Node):
- if n not in node_to_last_use:
- node_to_last_use[n] = user
- self.user_to_last_uses.setdefault(user, []).append(n)
- for node in reversed(self.module.graph.nodes):
- map_arg(node.args, lambda n: register_last_uses(n, node))
- map_arg(node.kwargs, lambda n: register_last_uses(n, node))
- @compatibility(is_backward_compatible=True)
- def run(self, *args, initial_env : Optional[Dict[Node, Any]] = None, enable_io_processing : bool = True) -> Any:
- """
- Run `module` via interpretation and return the result.
- Args:
- *args: The arguments to the Module to run, in positional order
- initial_env (Optional[Dict[Node, Any]]): An optional starting environment for execution.
- This is a dict mapping `Node` to any value. This can be used, for example, to
- pre-populate results for certain `Nodes` so as to do only partial evaluation within
- the interpreter.
- enable_io_processing (bool): If true, we process the inputs and outputs with graph's process_inputs and
- process_outputs function first before using them.
- Returns:
- Any: The value returned from executing the Module
- """
- self.env = initial_env if initial_env is not None else {}
- # Positional function args are consumed left-to-right by
- # `placeholder` nodes. Use an iterator to keep track of
- # position and extract those values.
- if enable_io_processing:
- args = self.module.graph.process_inputs(*args)
- self.args_iter : Iterator[Any] = iter(args)
- pbar = tqdm(total=len(self.module.graph.nodes),
- desc=f"{self.name}: {str(list(self.module.graph.nodes)) if config.verbose_progress else ''}",
- initial=0, position=0, leave=True, disable=config.disable_progress, delay=0)
- for node in self.module.graph.nodes:
- pbar.update(1)
- if node in self.env:
- # Short circuit if we have this value. This could
- # be used, for example, for partial evaluation
- # where the caller has pre-populated `env` with
- # values for a subset of the program.
- continue
- try:
- self.env[node] = self.run_node(node)
- except Exception as e:
- msg = f"While executing {node.format_node()}"
- msg = '{}\n\n{}'.format(e.args[0], msg) if e.args else str(msg)
- msg += f"\nOriginal traceback:\n{node.stack_trace}"
- e.args = (msg,) + e.args[1:]
- if isinstance(e, KeyError):
- raise RuntimeError(*e.args) from e
- raise
- if self.garbage_collect_values:
- for to_delete in self.user_to_last_uses.get(node, []):
- del self.env[to_delete]
- if node.op == 'output':
- output_val = self.env[node]
- return self.module.graph.process_outputs(output_val) if enable_io_processing else output_val
- @contextmanager
- def _set_current_node(self, node):
- with fx_traceback.set_current_meta(node.meta):
- yield
- @compatibility(is_backward_compatible=True)
- def run_node(self, n : Node) -> Any:
- """
- Run a specific node ``n`` and return the result.
- Calls into placeholder, get_attr, call_function,
- call_method, call_module, or output depending
- on ``node.op``
- Args:
- n (Node): The Node to execute
- Returns:
- Any: The result of executing ``n``
- """
- with self._set_current_node(n):
- args, kwargs = self.fetch_args_kwargs_from_env(n)
- assert isinstance(args, tuple)
- assert isinstance(kwargs, dict)
- return getattr(self, n.op)(n.target, args, kwargs)
- # Main Node running APIs
- @compatibility(is_backward_compatible=True)
- def placeholder(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
- """
- Execute a ``placeholder`` node. Note that this is stateful:
- ``Interpreter`` maintains an internal iterator over
- arguments passed to ``run`` and this method returns
- next() on that iterator.
- Args:
- target (Target): The call target for this node. See
- `Node <https://pytorch.org/docs/master/fx.html#torch.fx.Node>`__ for
- details on semantics
- args (Tuple): Tuple of positional args for this invocation
- kwargs (Dict): Dict of keyword arguments for this invocation
- Returns:
- Any: The argument value that was retrieved.
- """
- assert isinstance(target, str)
- if target.startswith('*'):
- # For a starred parameter e.g. `*args`, retrieve all
- # remaining values from the args list.
- return list(self.args_iter)
- else:
- try:
- return next(self.args_iter)
- except StopIteration as si:
- if len(args) > 0:
- return args[0]
- else:
- raise RuntimeError(f'Expected positional argument for parameter {target}, but one was not passed in!') from si
- @compatibility(is_backward_compatible=True)
- def get_attr(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
- """
- Execute a ``get_attr`` node. Will retrieve an attribute
- value from the ``Module`` hierarchy of ``self.module``.
- Args:
- target (Target): The call target for this node. See
- `Node <https://pytorch.org/docs/master/fx.html#torch.fx.Node>`__ for
- details on semantics
- args (Tuple): Tuple of positional args for this invocation
- kwargs (Dict): Dict of keyword arguments for this invocation
- Return:
- Any: The value of the attribute that was retrieved
- """
- assert isinstance(target, str)
- return self.fetch_attr(target)
- @compatibility(is_backward_compatible=True)
- def call_function(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
- """
- Execute a ``call_function`` node and return the result.
- Args:
- target (Target): The call target for this node. See
- `Node <https://pytorch.org/docs/master/fx.html#torch.fx.Node>`__ for
- details on semantics
- args (Tuple): Tuple of positional args for this invocation
- kwargs (Dict): Dict of keyword arguments for this invocation
- Return
- Any: The value returned by the function invocation
- """
- assert not isinstance(target, str)
- # Execute the function and return the result
- return target(*args, **kwargs)
- @compatibility(is_backward_compatible=True)
- def call_method(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
- """
- Execute a ``call_method`` node and return the result.
- Args:
- target (Target): The call target for this node. See
- `Node <https://pytorch.org/docs/master/fx.html#torch.fx.Node>`__ for
- details on semantics
- args (Tuple): Tuple of positional args for this invocation
- kwargs (Dict): Dict of keyword arguments for this invocation
- Return
- Any: The value returned by the method invocation
- """
- # args[0] is the `self` object for this method call
- self_obj, *args_tail = args
- # Execute the method and return the result
- assert isinstance(target, str)
- return getattr(self_obj, target)(*args_tail, **kwargs)
- @compatibility(is_backward_compatible=True)
- def call_module(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
- """
- Execute a ``call_module`` node and return the result.
- Args:
- target (Target): The call target for this node. See
- `Node <https://pytorch.org/docs/master/fx.html#torch.fx.Node>`__ for
- details on semantics
- args (Tuple): Tuple of positional args for this invocation
- kwargs (Dict): Dict of keyword arguments for this invocation
- Return
- Any: The value returned by the module invocation
- """
- # Retrieve executed args and kwargs values from the environment
- # Execute the method and return the result
- assert isinstance(target, str)
- submod = self.fetch_attr(target)
- return submod(*args, **kwargs)
- @compatibility(is_backward_compatible=True)
- def output(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
- """
- Execute an ``output`` node. This really just retrieves
- the value referenced by the ``output`` node and returns it.
- Args:
- target (Target): The call target for this node. See
- `Node <https://pytorch.org/docs/master/fx.html#torch.fx.Node>`__ for
- details on semantics
- args (Tuple): Tuple of positional args for this invocation
- kwargs (Dict): Dict of keyword arguments for this invocation
- Return:
- Any: The return value referenced by the output node
- """
- return args[0]
- # Helper methods
- @compatibility(is_backward_compatible=True)
- def fetch_attr(self, target : str):
- """
- Fetch an attribute from the ``Module`` hierarchy of ``self.module``.
- Args:
- target (str): The fully-qualified name of the attribute to fetch
- Return:
- Any: The value of the attribute.
- """
- target_atoms = target.split('.')
- attr_itr = self.module
- for i, atom in enumerate(target_atoms):
- if not hasattr(attr_itr, atom):
- raise RuntimeError(f"Node referenced nonexistent target {'.'.join(target_atoms[:i])}")
- attr_itr = getattr(attr_itr, atom)
- return attr_itr
- @compatibility(is_backward_compatible=True)
- def fetch_args_kwargs_from_env(self, n : Node) -> Tuple[Tuple, Dict]:
- """
- Fetch the concrete values of ``args`` and ``kwargs`` of node ``n``
- from the current execution environment.
- Args:
- n (Node): The node for which ``args`` and ``kwargs`` should be fetched.
- Return:
- Tuple[Tuple, Dict]: ``args`` and ``kwargs`` with concrete values for ``n``.
- """
- args = self.map_nodes_to_values(n.args, n)
- assert isinstance(args, tuple)
- kwargs = self.map_nodes_to_values(n.kwargs, n)
- assert isinstance(kwargs, dict)
- return args, kwargs
- @compatibility(is_backward_compatible=True)
- def map_nodes_to_values(self, args : Argument, n : Node) -> Argument:
- """
- Recursively descend through ``args`` and look up the concrete value
- for each ``Node`` in the current execution environment.
- Args:
- args (Argument): Data structure within which to look up concrete values
- n (Node): Node to which ``args`` belongs. This is only used for error reporting.
- """
- def load_arg(n_arg : Node) -> Any:
- if n_arg not in self.env:
- raise RuntimeError(f'Node {n} referenced nonexistent value {n_arg}! Run Graph.lint() '
- f'to diagnose such issues')
- return self.env[n_arg]
- return map_arg(args, load_arg)
- @compatibility(is_backward_compatible=True)
- class Transformer(Interpreter):
- """
- ``Transformer`` is a special type of interpreter that produces a
- new ``Module``. It exposes a ``transform()`` method that returns
- the transformed ``Module``. ``Transformer`` does not require
- arguments to run, as ``Interpreter`` does. ``Transformer`` works
- entirely symbolically.
- Example:
- Suppose we want to swap all instances of ``torch.neg`` with
- ``torch.sigmoid`` and vice versa (including their ``Tensor``
- method equivalents). We could subclass ``Transformer`` like so::
- class NegSigmSwapXformer(Transformer):
- def call_function(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
- if target == torch.sigmoid:
- return torch.neg(*args, **kwargs)
- return super().call_function(n)
- def call_method(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
- if target == 'neg':
- call_self, *args_tail = args
- return call_self.sigmoid(*args_tail, **kwargs)
- return super().call_method(n)
- def fn(x):
- return torch.sigmoid(x).neg()
- gm = torch.fx.symbolic_trace(fn)
- transformed : torch.nn.Module = NegSigmSwapXformer(gm).transform()
- input = torch.randn(3, 4)
- torch.testing.assert_close(transformed(input), torch.neg(input).sigmoid())
- Args:
- module (GraphModule): The ``Module`` to be transformed.
- """
- @compatibility(is_backward_compatible=True)
- def __init__(self, module):
- super().__init__(module)
- self.new_graph = Graph()
- self.new_graph.set_codegen(module.graph._codegen)
- class TransformerTracer(Tracer):
- def __init__(self, graph: Graph):
- super().__init__()
- self.graph = graph
- def is_leaf_module(self, _, __) -> bool:
- return True
- self.tracer = TransformerTracer(self.new_graph)
- self.tracer.root = module
- @compatibility(is_backward_compatible=True)
- def placeholder(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Proxy:
- """
- Execute a ``placeholder`` node. In ``Transformer``, this is
- overridden to insert a new ``placeholder`` into the output
- graph.
- Args:
- target (Target): The call target for this node. See
- `Node <https://pytorch.org/docs/master/fx.html#torch.fx.Node>`__ for
- details on semantics
- args (Tuple): Tuple of positional args for this invocation
- kwargs (Dict): Dict of keyword arguments for this invocation
- """
- assert isinstance(target, str)
- default_value = next(iter(args)) if args else inspect.Signature.empty
- return Proxy(self.new_graph.placeholder(target, default_value=default_value), self.tracer)
- @compatibility(is_backward_compatible=True)
- def get_attr(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Proxy:
- """
- Execute a ``get_attr`` node. In ``Transformer``, this is
- overridden to insert a new ``get_attr`` node into the output
- graph.
- Args:
- target (Target): The call target for this node. See
- `Node <https://pytorch.org/docs/master/fx.html#torch.fx.Node>`__ for
- details on semantics
- args (Tuple): Tuple of positional args for this invocation
- kwargs (Dict): Dict of keyword arguments for this invocation
- """
- assert isinstance(target, str)
- return Proxy(self.new_graph.get_attr(target), self.tracer)
- @compatibility(is_backward_compatible=True)
- def call_module(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
- # Override so that the leaf module policy from `self.tracer` is respected.
- assert isinstance(target, str)
- submod = self.fetch_attr(target)
- return self.tracer.call_module(submod, submod.forward, args, kwargs)
- @compatibility(is_backward_compatible=True)
- def call_function(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
- # Override so that functions that were wrapped are still wrapped.
- return self.tracer.create_proxy('call_function', target, args, kwargs)
- @compatibility(is_backward_compatible=True)
- def transform(self) -> GraphModule:
- """
- Transform ``self.module`` and return the transformed
- ``GraphModule``.
- """
- with fx_traceback.preserve_node_meta():
- result = super().run(enable_io_processing=False)
- if result is not None:
- def strip_proxy(a : Union[Argument, Proxy]) -> Any:
- return a.node if isinstance(a, Proxy) else a
- self.new_graph.output(map_aggregate(result, strip_proxy))
- return GraphModule(self.module, self.new_graph)
|