1234567891011121314151617181920212223242526272829303132333435363738 |
- from typing import Optional
- import torch.fx
- from torch.fx import Node
- from torch.fx._compatibility import compatibility
- from torch._subclasses.fake_tensor import FakeTensorMode
- __all__ = ['FakeTensorProp']
- @compatibility(is_backward_compatible=False)
- class FakeTensorProp(torch.fx.Interpreter):
- """
- Execute an FX graph Node-by-Node and record a fake tensor representing
- the metadata for the node. Unlike ShapeProp, (1) this propagation
- is cheap--it does the propagation with meta tensors which do not actually
- store data, and (2) the fake tensors have much more fine grained information,
- e.g., they have accurate alias information that can be consulted by looking
- at the storages.
- Args:
- module (GraphModule): The module to be executed
- mode (Optional[FakeTensorMode]): The dispatch mode used to execute computation indicated by each FX Node.
- """
- def __init__(self, module: torch.fx.GraphModule, mode: Optional[FakeTensorMode] = None):
- super().__init__(module)
- if mode is None:
- mode = FakeTensorMode()
- self._mode = mode
- def run_node(self, n: Node):
- result = super().run_node(n)
- n.meta['val'] = result
- return result
- def propagate(self, *args):
- with self._mode:
- fake_args = [self._mode.from_tensor(a) if isinstance(a, torch.Tensor) else a for a in args]
- return super().run(*fake_args)
|