shape_prop.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185
  1. import torch
  2. import torch.fx
  3. import traceback
  4. from torch.fx.node import Node, map_aggregate
  5. from typing import Any, Tuple, NamedTuple, Optional, Dict
  6. from torch.fx._compatibility import compatibility
  7. __all__ = ['TensorMetadata', 'ShapeProp']
  8. @compatibility(is_backward_compatible=True)
  9. class TensorMetadata(NamedTuple):
  10. # TensorMetadata is a structure containing pertinent information
  11. # about a tensor within a PyTorch program.
  12. # General Tensor metadata
  13. shape : torch.Size
  14. dtype : torch.dtype
  15. requires_grad : bool
  16. stride : Tuple[int, ...]
  17. memory_format : Optional[torch.memory_format]
  18. # Quantization metadata
  19. is_quantized : bool
  20. qparams: Dict[str, Any]
  21. def _extract_tensor_metadata(result : torch.Tensor) -> TensorMetadata:
  22. """
  23. Extract a TensorMetadata NamedTuple describing `result`.
  24. """
  25. shape = result.shape
  26. dtype = result.dtype
  27. requires_grad = result.requires_grad
  28. stride = result.stride()
  29. memory_formats = {
  30. torch.contiguous_format,
  31. torch.channels_last,
  32. torch.channels_last_3d,
  33. }
  34. memory_format = None
  35. for query_format in memory_formats:
  36. if result.is_contiguous(memory_format=query_format):
  37. memory_format = query_format
  38. break
  39. is_quantized = result.is_quantized
  40. qparams: Dict[str, Any] = {}
  41. if is_quantized:
  42. qscheme = result.qscheme()
  43. qparams["qscheme"] = qscheme
  44. if qscheme in {torch.per_tensor_affine, torch.per_tensor_symmetric}:
  45. qparams["scale"] = result.q_scale() # type: ignore[assignment]
  46. qparams["zero_point"] = result.q_zero_point() # type: ignore[assignment]
  47. elif qscheme in {torch.per_channel_affine, torch.per_channel_affine_float_qparams, torch.per_channel_symmetric}:
  48. # In this branch, scale and zero_point are expected to be tensors,
  49. # we store the values as immutable_list in TensorMetadata for
  50. # easier serialization downstream
  51. qparams["scale"] = result.q_per_channel_scales().tolist() # type: ignore[assignment]
  52. qparams["zero_point"] = result.q_per_channel_zero_points().tolist() # type: ignore[assignment]
  53. qparams["axis"] = result.q_per_channel_axis() # type: ignore[assignment]
  54. return TensorMetadata(
  55. shape, dtype, requires_grad, stride, memory_format, is_quantized, qparams)
  56. @compatibility(is_backward_compatible=True)
  57. class ShapeProp(torch.fx.Interpreter):
  58. """
  59. Execute an FX graph Node-by-Node and
  60. record the shape and type of the result
  61. into the corresponding node.
  62. Example:
  63. In this example, we record the shape
  64. and data type of a module given
  65. an example input ``torch.randn(50, D_in)``.
  66. We print the name, shape and dtype of each node.
  67. class TwoLayerNet(torch.nn.Module):
  68. def __init__(self, D_in, H, D_out):
  69. super().__init__()
  70. self.linear1 = torch.nn.Linear(D_in, H)
  71. self.linear2 = torch.nn.Linear(H, D_out)
  72. def forward(self, x):
  73. h_relu = self.linear1(x).clamp(min=0)
  74. y_pred = self.linear2(h_relu)
  75. return y_pred
  76. N, D_in, H, D_out = 64, 1000, 100, 10
  77. x = torch.randn(N, D_in)
  78. y = torch.randn(N, D_out)
  79. model = TwoLayerNet(D_in, H, D_out)
  80. gm = torch.fx.symbolic_trace(model)
  81. sample_input = torch.randn(50, D_in)
  82. ShapeProp(gm).propagate(sample_input)
  83. for node in gm.graph.nodes:
  84. print(node.name, node.meta['tensor_meta'].dtype,
  85. node.meta['tensor_meta'].shape)
  86. The output of this code is:
  87. x torch.float32 torch.Size([50, 1000])
  88. linear1 torch.float32 torch.Size([50, 100])
  89. clamp_1 torch.float32 torch.Size([50, 100])
  90. linear2 torch.float32 torch.Size([50, 10])
  91. output torch.float32 torch.Size([50, 10])
  92. Args:
  93. module (GraphModule): The module to be executed
  94. fake_mode (FakeTensorMode): A fake mode for copying the gm
  95. """
  96. def __init__(self, gm, fake_mode=None):
  97. super().__init__(gm)
  98. if fake_mode is not None:
  99. from torch._dynamo.utils import deepcopy_to_fake_tensor
  100. # Note:
  101. # We need fake execution cause the inputs are fake, however, we cannot fakify the module
  102. # - because we need to write to the tensor_meta of the real module. So we fakify to
  103. # produce a result (L131 below), to extract tensor meta, and then keep going.
  104. #
  105. # If we were to fakify, we would write to the wrong node, and then downstream fusion
  106. # would be missing the tensor_meta.
  107. #
  108. # See torch/_inductor/overrides.py for where this is called upstream of fusion.
  109. self.fake_module = deepcopy_to_fake_tensor(self.module, fake_mode)
  110. self.fake_mode = fake_mode
  111. else:
  112. self.fake_module = None
  113. self.fake_mode = None
  114. self.real_module = self.module
  115. def run_node(self, n : Node) -> Any:
  116. try:
  117. if self.fake_module is not None:
  118. # Hacky swap. Alternatively, we could do this with overriding
  119. # call_module and get_attr.
  120. self.module = self.fake_module
  121. try:
  122. if self.fake_mode is not None:
  123. with self.fake_mode:
  124. result = super().run_node(n)
  125. else:
  126. result = super().run_node(n)
  127. finally:
  128. self.module = self.real_module
  129. except Exception as e:
  130. traceback.print_exc()
  131. raise RuntimeError(
  132. f"ShapeProp error for: node={n.format_node()} with "
  133. f"meta={n.meta}"
  134. ) from e
  135. found_tensor = False
  136. def extract_tensor_meta(obj):
  137. if isinstance(obj, torch.Tensor):
  138. nonlocal found_tensor
  139. found_tensor = True
  140. return _extract_tensor_metadata(obj)
  141. else:
  142. return obj
  143. meta = map_aggregate(result, extract_tensor_meta)
  144. if found_tensor:
  145. n.meta['tensor_meta'] = meta
  146. n.meta['type'] = type(result)
  147. return result
  148. def propagate(self, *args):
  149. """
  150. Run `module` via interpretation and return the result and
  151. record the shape and type of each node.
  152. Args:
  153. *args (Tensor): the sample input.
  154. Returns:
  155. Any: The value returned from executing the Module
  156. """
  157. return super().run(*args)