1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889 |
- r'''
- FX is a toolkit for developers to use to transform ``nn.Module``
- instances. FX consists of three main components: a **symbolic tracer,**
- an **intermediate representation**, and **Python code generation**. A
- demonstration of these components in action:
- ::
- import torch
- # Simple module for demonstration
- class MyModule(torch.nn.Module):
- def __init__(self):
- super().__init__()
- self.param = torch.nn.Parameter(torch.rand(3, 4))
- self.linear = torch.nn.Linear(4, 5)
- def forward(self, x):
- return self.linear(x + self.param).clamp(min=0.0, max=1.0)
- module = MyModule()
- from torch.fx import symbolic_trace
- # Symbolic tracing frontend - captures the semantics of the module
- symbolic_traced : torch.fx.GraphModule = symbolic_trace(module)
- # High-level intermediate representation (IR) - Graph representation
- print(symbolic_traced.graph)
- """
- graph():
- %x : [#users=1] = placeholder[target=x]
- %param : [#users=1] = get_attr[target=param]
- %add : [#users=1] = call_function[target=operator.add](args = (%x, %param), kwargs = {})
- %linear : [#users=1] = call_module[target=linear](args = (%add,), kwargs = {})
- %clamp : [#users=1] = call_method[target=clamp](args = (%linear,), kwargs = {min: 0.0, max: 1.0})
- return clamp
- """
- # Code generation - valid Python code
- print(symbolic_traced.code)
- """
- def forward(self, x):
- param = self.param
- add = x + param; x = param = None
- linear = self.linear(add); add = None
- clamp = linear.clamp(min = 0.0, max = 1.0); linear = None
- return clamp
- """
- The **symbolic tracer** performs "symbolic execution" of the Python
- code. It feeds fake values, called Proxies, through the code. Operations
- on theses Proxies are recorded. More information about symbolic tracing
- can be found in the :func:`symbolic_trace` and :class:`Tracer`
- documentation.
- The **intermediate representation** is the container for the operations
- that were recorded during symbolic tracing. It consists of a list of
- Nodes that represent function inputs, callsites (to functions, methods,
- or :class:`torch.nn.Module` instances), and return values. More information
- about the IR can be found in the documentation for :class:`Graph`. The
- IR is the format on which transformations are applied.
- **Python code generation** is what makes FX a Python-to-Python (or
- Module-to-Module) transformation toolkit. For each Graph IR, we can
- create valid Python code matching the Graph's semantics. This
- functionality is wrapped up in :class:`GraphModule`, which is a
- :class:`torch.nn.Module` instance that holds a :class:`Graph` as well as a
- ``forward`` method generated from the Graph.
- Taken together, this pipeline of components (symbolic tracing ->
- intermediate representation -> transforms -> Python code generation)
- constitutes the Python-to-Python transformation pipeline of FX. In
- addition, these components can be used separately. For example,
- symbolic tracing can be used in isolation to capture a form of
- the code for analysis (and not transformation) purposes. Code
- generation can be used for programmatically generating models, for
- example from a config file. There are many uses for FX!
- Several example transformations can be found at the
- `examples <https://github.com/pytorch/examples/tree/master/fx>`__
- repository.
- '''
- from .graph_module import GraphModule
- from ._symbolic_trace import symbolic_trace, Tracer, wrap, PH, ProxyableClassMeta
- from .graph import Graph, CodeGen
- from .node import Node, map_arg
- from .proxy import Proxy
- from .interpreter import Interpreter as Interpreter, Transformer as Transformer
- from .subgraph_rewriter import replace_pattern
|