import functools from contextlib import contextmanager from dataclasses import dataclass from typing import Any, Callable, Dict, List, NamedTuple, Optional, Set, Tuple import torch import torch.nn as nn @dataclass class TracingConfig: """ This represents a symbolic tracing configuration. Args: tracer (torch.fx.Tracer): An instance of :class:`torch.fx.Tracer` to use for symbolic tracing. The default value is the native :class:`torch.fx.Tracer` constructed with default arguments. However, the user may want to pass a different value such as the ``HFTracer`` for models in the HuggingFace Transformers_ library. .. _Transformers: https://huggingface.co/docs/transformers/index concrete_args (Optional[Dict[str, Any]]): Concrete arguments that should not be treated as ``torch.fx.Proxy`` when tracing the module ``forward()``. Passing ``concrete_args`` allows partially specializing the forward, e.g. to remove control flow or data structures. This ``concrete_args`` here is the same argument used in :meth:`~torch.fx.Tracer.trace`. """ tracer: torch.fx.Tracer = torch.fx.Tracer() concrete_args: Optional[Dict[str, Any]] = None class _ParamUsageInfo(NamedTuple): """ This is used for ``_ExecutionInfo.module_to_param_usage_infos`` to record execution information. The ``dict`` maps modules to a list of these ``_ParamUsageInfo`` instances, where each instance represents a group of parameters used together. Specifically, for each module key in the ``dict``, each instance of this class represents either: (1) the module and some sublist of its ``named_parameters()`` used together in execution (see ``_patched_create_proxy()``), or (2) a submodule and all of ``submodule.named_parameters()`` (see ``_patched_call_module()``). Type (1) corresponds to directly using parameters in ops without calling ``forward()``, and type (2) corresponds to calling ``forward()``. The mapped-to lists in the ``dict`` follow the execution order. """ module: nn.Module named_params: List[Tuple[str, nn.Parameter]] class _ExecutionInfo: """ This represents the execution order information from the forward pass. Attributes: curr_module (nn.Module): Current module being traced. module_forward_order (List[nn.Module]): The modules in (pre-)forward order, i.e. the order in which their ``forward()`` methods are called. Each call to a module's ``forward()`` corresponds to one element in the list. module_to_param_usage_infos (Dict[nn.Module, List[_ParamUsageInfo]]): Maps a module to a list of module execution infos. See :class:`_ParamUsageInfo` for details. param_forward_order (List[nn.Parameter]): The parameters in forward execution order, where only a parameter's first participation is included. visited_params (Set[nn.Parameter]): The parameters visited so far during the trace. This is only used during tracing for fast membership check. Invariant: The parameters in ``param_forward_order`` are exactly those in ``visited_params``. """ def __init__(self, root_module: nn.Module) -> None: self.curr_module: nn.Module = root_module self.module_forward_order: List[nn.Module] = [root_module] self.module_to_param_usage_infos: Dict[nn.Module, List[_ParamUsageInfo]] = { root_module: [] } self.param_forward_order: List[nn.Parameter] = [] self.visited_params: Set[nn.Parameter] = set() class _ExecOrderTracer: def __init__(self) -> None: self.exec_info: Optional[_ExecutionInfo] = None @contextmanager def patch_tracer(self, tracer: torch.fx.Tracer, root_module: nn.Module): self.exec_info = _ExecutionInfo(root_module) orig_call_module = tracer.call_module orig_create_proxy = tracer.create_proxy tracer.call_module = functools.partial( self._patched_call_module, orig_call_module, self.exec_info ) fqn_to_param = dict(root_module.named_parameters()) tracer.create_proxy = functools.partial( self._patched_create_proxy, orig_create_proxy, self.exec_info, fqn_to_param, ) try: yield finally: tracer.call_module = orig_call_module tracer.create_proxy = orig_create_proxy def _patched_call_module( self, call_module: Callable, exec_info: _ExecutionInfo, # Below are the expected arguments to `call_module()` module: nn.Module, forward: Callable, args: Tuple[Any, ...], kwargs: Dict[str, Any], ) -> Any: """ Overrides ``call_module`` to save execution information to ``exec_info``. Note that ``call_module`` is called during symbolic tracing for each non-root module. Args: call_module (Callable): Original ``call_module`` to override. exec_info (_ExecutionInfo): Used to record execution information. module (nn.Module): Module corresponding to this ``call_module``. forward (Callable): ``forward()`` method of ``module`` to be called for this ``call_module``. args (Tuple[Any, ...]): Positional arguments for ``forward``. kwargs (Dict[str, Any]): Keyword arguments for ``forward``. Returns: Same return value as ``call_module``. """ exec_info.module_forward_order.append(module) named_params = list(module.named_parameters()) curr_module = exec_info.curr_module if named_params: assert ( curr_module in exec_info.module_to_param_usage_infos ), "The current module should have already been processed by a patched `call_module`" exec_info.module_to_param_usage_infos[exec_info.curr_module].append( _ParamUsageInfo(module, named_params) ) prev_curr_module = curr_module exec_info.curr_module = module exec_info.module_to_param_usage_infos[module] = [] output = call_module(module, forward, args, kwargs) exec_info.curr_module = prev_curr_module return output def _patched_create_proxy( self, create_proxy: Callable, exec_info: _ExecutionInfo, fqn_to_param: Dict[str, nn.Parameter], # Below are the expected arguments to `create_proxy()` kind: str, target: torch.fx.node.Target, args: Tuple[Any, ...], kwargs: Dict[str, Any], name: Optional[str] = None, type_expr: Optional[Any] = None, proxy_factory_fn: Callable[[torch.fx.Node], torch.fx.Proxy] = None, ) -> torch.fx.Proxy: """ Overrides ``create_proxy`` to save execution information to ``exec_info``. Note that ``create_proxy`` is called during symbolic tracing for each leaf function/method/module. Args: create_proxy (Callable): Original ``create_proxy`` to override. exec_info (_ExecutionInfo): Used to record execution information. fqn_to_param (Dict[str, nn.Parameter]): ``dict`` version of the root module's ``named_parameters()`` with FQN as key and parameter as value. kind (str): Kind of the target method ('call_function', 'call_method', 'get_attr', 'call_module', 'placeholder', or 'output'). See :class:`torch.fx.Graph` for details. This is passed to ``create_proxy``. target (torch.fx.node.Target): Contains the string name of the function/method/module. This is passed to ``create_proxy``. args (Tuple[Any, ...]): Positional arguments for the function/ method/module. This is passed to ``create_proxy``. kwargs (Dict[str, Any]): Keyword arguments for the function/method/ module. This is passed to ``create_proxy`` name (Optional[str]): An optional string name for the ``Node`` created in ``create_proxy``. This is passed to ``create_proxy``. type_expr (Optional[Any]): An optional type annotation representing the Python type that the output of the node has. This is passed to ``create_proxy``. proxy_factory_fn (Callable[[torch.fx.Node], torch.fx.Proxy]): An alternative proxy constructor used in ``create_proxy``. This is passed to ``create_proxy``. Returns: torch.fx.Proxy: Created ``Node`` wrapped in a ``Proxy`` object. """ proxy = create_proxy( kind, target, args, kwargs, name, type_expr, proxy_factory_fn ) curr_module = exec_info.curr_module if kind in ("call_function", "call_method"): if args is not None: named_params: List[Tuple[str, nn.Parameter]] = [] for arg in args: if ( isinstance(arg, torch.fx.Proxy) and arg.node.target in fqn_to_param ): param = fqn_to_param[arg.node.target] named_params.append((arg.node.target, param)) if param not in exec_info.visited_params: exec_info.visited_params.add(param) exec_info.param_forward_order.append(param) if named_params: exec_info.module_to_param_usage_infos[curr_module].append( _ParamUsageInfo(curr_module, named_params) ) elif kind == "call_module": named_params = list(curr_module.named_parameters()) if named_params: exec_info.module_to_param_usage_infos[curr_module].append( _ParamUsageInfo(curr_module, named_params) ) for _, param in named_params: if param not in exec_info.visited_params: exec_info.visited_params.add(param) exec_info.param_forward_order.append(param) return proxy