_trace_utils.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237
  1. import functools
  2. from contextlib import contextmanager
  3. from dataclasses import dataclass
  4. from typing import Any, Callable, Dict, List, NamedTuple, Optional, Set, Tuple
  5. import torch
  6. import torch.nn as nn
  7. @dataclass
  8. class TracingConfig:
  9. """
  10. This represents a symbolic tracing configuration.
  11. Args:
  12. tracer (torch.fx.Tracer): An instance of :class:`torch.fx.Tracer` to
  13. use for symbolic tracing. The default value is the native
  14. :class:`torch.fx.Tracer` constructed with default arguments.
  15. However, the user may want to pass a different value such as the
  16. ``HFTracer`` for models in the HuggingFace Transformers_ library.
  17. .. _Transformers: https://huggingface.co/docs/transformers/index
  18. concrete_args (Optional[Dict[str, Any]]): Concrete arguments that
  19. should not be treated as ``torch.fx.Proxy`` when tracing the
  20. module ``forward()``. Passing ``concrete_args`` allows partially
  21. specializing the forward, e.g. to remove control flow or data
  22. structures. This ``concrete_args`` here is the same argument used
  23. in :meth:`~torch.fx.Tracer.trace`.
  24. """
  25. tracer: torch.fx.Tracer = torch.fx.Tracer()
  26. concrete_args: Optional[Dict[str, Any]] = None
  27. class _ParamUsageInfo(NamedTuple):
  28. """
  29. This is used for ``_ExecutionInfo.module_to_param_usage_infos`` to record
  30. execution information. The ``dict`` maps modules to a list of these
  31. ``_ParamUsageInfo`` instances, where each instance represents a group of
  32. parameters used together.
  33. Specifically, for each module key in the ``dict``, each instance of this
  34. class represents either:
  35. (1) the module and some sublist of its ``named_parameters()`` used
  36. together in execution (see ``_patched_create_proxy()``), or
  37. (2) a submodule and all of ``submodule.named_parameters()`` (see
  38. ``_patched_call_module()``).
  39. Type (1) corresponds to directly using parameters in ops without calling
  40. ``forward()``, and type (2) corresponds to calling ``forward()``. The
  41. mapped-to lists in the ``dict`` follow the execution order.
  42. """
  43. module: nn.Module
  44. named_params: List[Tuple[str, nn.Parameter]]
  45. class _ExecutionInfo:
  46. """
  47. This represents the execution order information from the forward pass.
  48. Attributes:
  49. curr_module (nn.Module): Current module being traced.
  50. module_forward_order (List[nn.Module]): The modules in (pre-)forward
  51. order, i.e. the order in which their ``forward()`` methods are
  52. called. Each call to a module's ``forward()`` corresponds to one
  53. element in the list.
  54. module_to_param_usage_infos (Dict[nn.Module, List[_ParamUsageInfo]]):
  55. Maps a module to a list of module execution infos. See
  56. :class:`_ParamUsageInfo` for details.
  57. param_forward_order (List[nn.Parameter]): The parameters in forward
  58. execution order, where only a parameter's first participation is
  59. included.
  60. visited_params (Set[nn.Parameter]): The parameters visited so far
  61. during the trace. This is only used during tracing for fast
  62. membership check. Invariant: The parameters in
  63. ``param_forward_order`` are exactly those in ``visited_params``.
  64. """
  65. def __init__(self, root_module: nn.Module) -> None:
  66. self.curr_module: nn.Module = root_module
  67. self.module_forward_order: List[nn.Module] = [root_module]
  68. self.module_to_param_usage_infos: Dict[nn.Module, List[_ParamUsageInfo]] = {
  69. root_module: []
  70. }
  71. self.param_forward_order: List[nn.Parameter] = []
  72. self.visited_params: Set[nn.Parameter] = set()
  73. class _ExecOrderTracer:
  74. def __init__(self) -> None:
  75. self.exec_info: Optional[_ExecutionInfo] = None
  76. @contextmanager
  77. def patch_tracer(self, tracer: torch.fx.Tracer, root_module: nn.Module):
  78. self.exec_info = _ExecutionInfo(root_module)
  79. orig_call_module = tracer.call_module
  80. orig_create_proxy = tracer.create_proxy
  81. tracer.call_module = functools.partial(
  82. self._patched_call_module, orig_call_module, self.exec_info
  83. )
  84. fqn_to_param = dict(root_module.named_parameters())
  85. tracer.create_proxy = functools.partial(
  86. self._patched_create_proxy,
  87. orig_create_proxy,
  88. self.exec_info,
  89. fqn_to_param,
  90. )
  91. try:
  92. yield
  93. finally:
  94. tracer.call_module = orig_call_module
  95. tracer.create_proxy = orig_create_proxy
  96. def _patched_call_module(
  97. self,
  98. call_module: Callable,
  99. exec_info: _ExecutionInfo,
  100. # Below are the expected arguments to `call_module()`
  101. module: nn.Module,
  102. forward: Callable,
  103. args: Tuple[Any, ...],
  104. kwargs: Dict[str, Any],
  105. ) -> Any:
  106. """
  107. Overrides ``call_module`` to save execution information to
  108. ``exec_info``. Note that ``call_module`` is called during symbolic
  109. tracing for each non-root module.
  110. Args:
  111. call_module (Callable): Original ``call_module`` to override.
  112. exec_info (_ExecutionInfo): Used to record execution information.
  113. module (nn.Module): Module corresponding to this ``call_module``.
  114. forward (Callable): ``forward()`` method of ``module`` to be called
  115. for this ``call_module``.
  116. args (Tuple[Any, ...]): Positional arguments for ``forward``.
  117. kwargs (Dict[str, Any]): Keyword arguments for ``forward``.
  118. Returns:
  119. Same return value as ``call_module``.
  120. """
  121. exec_info.module_forward_order.append(module)
  122. named_params = list(module.named_parameters())
  123. curr_module = exec_info.curr_module
  124. if named_params:
  125. assert (
  126. curr_module in exec_info.module_to_param_usage_infos
  127. ), "The current module should have already been processed by a patched `call_module`"
  128. exec_info.module_to_param_usage_infos[exec_info.curr_module].append(
  129. _ParamUsageInfo(module, named_params)
  130. )
  131. prev_curr_module = curr_module
  132. exec_info.curr_module = module
  133. exec_info.module_to_param_usage_infos[module] = []
  134. output = call_module(module, forward, args, kwargs)
  135. exec_info.curr_module = prev_curr_module
  136. return output
  137. def _patched_create_proxy(
  138. self,
  139. create_proxy: Callable,
  140. exec_info: _ExecutionInfo,
  141. fqn_to_param: Dict[str, nn.Parameter],
  142. # Below are the expected arguments to `create_proxy()`
  143. kind: str,
  144. target: torch.fx.node.Target,
  145. args: Tuple[Any, ...],
  146. kwargs: Dict[str, Any],
  147. name: Optional[str] = None,
  148. type_expr: Optional[Any] = None,
  149. proxy_factory_fn: Callable[[torch.fx.Node], torch.fx.Proxy] = None,
  150. ) -> torch.fx.Proxy:
  151. """
  152. Overrides ``create_proxy`` to save execution information to
  153. ``exec_info``. Note that ``create_proxy`` is called during symbolic
  154. tracing for each leaf function/method/module.
  155. Args:
  156. create_proxy (Callable): Original ``create_proxy`` to override.
  157. exec_info (_ExecutionInfo): Used to record execution information.
  158. fqn_to_param (Dict[str, nn.Parameter]): ``dict`` version of the
  159. root module's ``named_parameters()`` with FQN as key and
  160. parameter as value.
  161. kind (str): Kind of the target method ('call_function',
  162. 'call_method', 'get_attr', 'call_module', 'placeholder', or
  163. 'output'). See :class:`torch.fx.Graph` for details. This is
  164. passed to ``create_proxy``.
  165. target (torch.fx.node.Target): Contains the string name of the
  166. function/method/module. This is passed to ``create_proxy``.
  167. args (Tuple[Any, ...]): Positional arguments for the function/
  168. method/module. This is passed to ``create_proxy``.
  169. kwargs (Dict[str, Any]): Keyword arguments for the function/method/
  170. module. This is passed to ``create_proxy``
  171. name (Optional[str]): An optional string name for the ``Node``
  172. created in ``create_proxy``. This is passed to
  173. ``create_proxy``.
  174. type_expr (Optional[Any]): An optional type annotation representing
  175. the Python type that the output of the node has. This is passed
  176. to ``create_proxy``.
  177. proxy_factory_fn (Callable[[torch.fx.Node], torch.fx.Proxy]):
  178. An alternative proxy constructor used in ``create_proxy``. This
  179. is passed to ``create_proxy``.
  180. Returns:
  181. torch.fx.Proxy: Created ``Node`` wrapped in a ``Proxy`` object.
  182. """
  183. proxy = create_proxy(
  184. kind, target, args, kwargs, name, type_expr, proxy_factory_fn
  185. )
  186. curr_module = exec_info.curr_module
  187. if kind in ("call_function", "call_method"):
  188. if args is not None:
  189. named_params: List[Tuple[str, nn.Parameter]] = []
  190. for arg in args:
  191. if (
  192. isinstance(arg, torch.fx.Proxy)
  193. and arg.node.target in fqn_to_param
  194. ):
  195. param = fqn_to_param[arg.node.target]
  196. named_params.append((arg.node.target, param))
  197. if param not in exec_info.visited_params:
  198. exec_info.visited_params.add(param)
  199. exec_info.param_forward_order.append(param)
  200. if named_params:
  201. exec_info.module_to_param_usage_infos[curr_module].append(
  202. _ParamUsageInfo(curr_module, named_params)
  203. )
  204. elif kind == "call_module":
  205. named_params = list(curr_module.named_parameters())
  206. if named_params:
  207. exec_info.module_to_param_usage_infos[curr_module].append(
  208. _ParamUsageInfo(curr_module, named_params)
  209. )
  210. for _, param in named_params:
  211. if param not in exec_info.visited_params:
  212. exec_info.visited_params.add(param)
  213. exec_info.param_forward_order.append(param)
  214. return proxy