fake_tensor_prop.py 1.4 KB

1234567891011121314151617181920212223242526272829303132333435363738
  1. from typing import Optional
  2. import torch.fx
  3. from torch.fx import Node
  4. from torch.fx._compatibility import compatibility
  5. from torch._subclasses.fake_tensor import FakeTensorMode
  6. __all__ = ['FakeTensorProp']
  7. @compatibility(is_backward_compatible=False)
  8. class FakeTensorProp(torch.fx.Interpreter):
  9. """
  10. Execute an FX graph Node-by-Node and record a fake tensor representing
  11. the metadata for the node. Unlike ShapeProp, (1) this propagation
  12. is cheap--it does the propagation with meta tensors which do not actually
  13. store data, and (2) the fake tensors have much more fine grained information,
  14. e.g., they have accurate alias information that can be consulted by looking
  15. at the storages.
  16. Args:
  17. module (GraphModule): The module to be executed
  18. mode (Optional[FakeTensorMode]): The dispatch mode used to execute computation indicated by each FX Node.
  19. """
  20. def __init__(self, module: torch.fx.GraphModule, mode: Optional[FakeTensorMode] = None):
  21. super().__init__(module)
  22. if mode is None:
  23. mode = FakeTensorMode()
  24. self._mode = mode
  25. def run_node(self, n: Node):
  26. result = super().run_node(n)
  27. n.meta['val'] = result
  28. return result
  29. def propagate(self, *args):
  30. with self._mode:
  31. fake_args = [self._mode.from_tensor(a) if isinstance(a, torch.Tensor) else a for a in args]
  32. return super().run(*fake_args)