123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185 |
- import torch
- import torch.fx
- import traceback
- from torch.fx.node import Node, map_aggregate
- from typing import Any, Tuple, NamedTuple, Optional, Dict
- from torch.fx._compatibility import compatibility
- __all__ = ['TensorMetadata', 'ShapeProp']
- @compatibility(is_backward_compatible=True)
- class TensorMetadata(NamedTuple):
- # TensorMetadata is a structure containing pertinent information
- # about a tensor within a PyTorch program.
- # General Tensor metadata
- shape : torch.Size
- dtype : torch.dtype
- requires_grad : bool
- stride : Tuple[int, ...]
- memory_format : Optional[torch.memory_format]
- # Quantization metadata
- is_quantized : bool
- qparams: Dict[str, Any]
- def _extract_tensor_metadata(result : torch.Tensor) -> TensorMetadata:
- """
- Extract a TensorMetadata NamedTuple describing `result`.
- """
- shape = result.shape
- dtype = result.dtype
- requires_grad = result.requires_grad
- stride = result.stride()
- memory_formats = {
- torch.contiguous_format,
- torch.channels_last,
- torch.channels_last_3d,
- }
- memory_format = None
- for query_format in memory_formats:
- if result.is_contiguous(memory_format=query_format):
- memory_format = query_format
- break
- is_quantized = result.is_quantized
- qparams: Dict[str, Any] = {}
- if is_quantized:
- qscheme = result.qscheme()
- qparams["qscheme"] = qscheme
- if qscheme in {torch.per_tensor_affine, torch.per_tensor_symmetric}:
- qparams["scale"] = result.q_scale() # type: ignore[assignment]
- qparams["zero_point"] = result.q_zero_point() # type: ignore[assignment]
- elif qscheme in {torch.per_channel_affine, torch.per_channel_affine_float_qparams, torch.per_channel_symmetric}:
- # In this branch, scale and zero_point are expected to be tensors,
- # we store the values as immutable_list in TensorMetadata for
- # easier serialization downstream
- qparams["scale"] = result.q_per_channel_scales().tolist() # type: ignore[assignment]
- qparams["zero_point"] = result.q_per_channel_zero_points().tolist() # type: ignore[assignment]
- qparams["axis"] = result.q_per_channel_axis() # type: ignore[assignment]
- return TensorMetadata(
- shape, dtype, requires_grad, stride, memory_format, is_quantized, qparams)
- @compatibility(is_backward_compatible=True)
- class ShapeProp(torch.fx.Interpreter):
- """
- Execute an FX graph Node-by-Node and
- record the shape and type of the result
- into the corresponding node.
- Example:
- In this example, we record the shape
- and data type of a module given
- an example input ``torch.randn(50, D_in)``.
- We print the name, shape and dtype of each node.
- class TwoLayerNet(torch.nn.Module):
- def __init__(self, D_in, H, D_out):
- super().__init__()
- self.linear1 = torch.nn.Linear(D_in, H)
- self.linear2 = torch.nn.Linear(H, D_out)
- def forward(self, x):
- h_relu = self.linear1(x).clamp(min=0)
- y_pred = self.linear2(h_relu)
- return y_pred
- N, D_in, H, D_out = 64, 1000, 100, 10
- x = torch.randn(N, D_in)
- y = torch.randn(N, D_out)
- model = TwoLayerNet(D_in, H, D_out)
- gm = torch.fx.symbolic_trace(model)
- sample_input = torch.randn(50, D_in)
- ShapeProp(gm).propagate(sample_input)
- for node in gm.graph.nodes:
- print(node.name, node.meta['tensor_meta'].dtype,
- node.meta['tensor_meta'].shape)
- The output of this code is:
- x torch.float32 torch.Size([50, 1000])
- linear1 torch.float32 torch.Size([50, 100])
- clamp_1 torch.float32 torch.Size([50, 100])
- linear2 torch.float32 torch.Size([50, 10])
- output torch.float32 torch.Size([50, 10])
- Args:
- module (GraphModule): The module to be executed
- fake_mode (FakeTensorMode): A fake mode for copying the gm
- """
- def __init__(self, gm, fake_mode=None):
- super().__init__(gm)
- if fake_mode is not None:
- from torch._dynamo.utils import deepcopy_to_fake_tensor
- # Note:
- # We need fake execution cause the inputs are fake, however, we cannot fakify the module
- # - because we need to write to the tensor_meta of the real module. So we fakify to
- # produce a result (L131 below), to extract tensor meta, and then keep going.
- #
- # If we were to fakify, we would write to the wrong node, and then downstream fusion
- # would be missing the tensor_meta.
- #
- # See torch/_inductor/overrides.py for where this is called upstream of fusion.
- self.fake_module = deepcopy_to_fake_tensor(self.module, fake_mode)
- self.fake_mode = fake_mode
- else:
- self.fake_module = None
- self.fake_mode = None
- self.real_module = self.module
- def run_node(self, n : Node) -> Any:
- try:
- if self.fake_module is not None:
- # Hacky swap. Alternatively, we could do this with overriding
- # call_module and get_attr.
- self.module = self.fake_module
- try:
- if self.fake_mode is not None:
- with self.fake_mode:
- result = super().run_node(n)
- else:
- result = super().run_node(n)
- finally:
- self.module = self.real_module
- except Exception as e:
- traceback.print_exc()
- raise RuntimeError(
- f"ShapeProp error for: node={n.format_node()} with "
- f"meta={n.meta}"
- ) from e
- found_tensor = False
- def extract_tensor_meta(obj):
- if isinstance(obj, torch.Tensor):
- nonlocal found_tensor
- found_tensor = True
- return _extract_tensor_metadata(obj)
- else:
- return obj
- meta = map_aggregate(result, extract_tensor_meta)
- if found_tensor:
- n.meta['tensor_meta'] = meta
- n.meta['type'] = type(result)
- return result
- def propagate(self, *args):
- """
- Run `module` via interpretation and return the result and
- record the shape and type of each node.
- Args:
- *args (Tensor): the sample input.
- Returns:
- Any: The value returned from executing the Module
- """
- return super().run(*args)
|