param_fetch.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566
  1. from torch.fx.graph_module import GraphModule
  2. from typing import Any, Callable, Dict, List, Tuple, Type
  3. import torch
  4. import torch.nn as nn
  5. from torch.fx._compatibility import compatibility
  6. __all__ = ['default_matching', 'extract_attrs_for_lowering', 'lift_lowering_attrs_to_nodes']
  7. # Matching method matches the attribute name of current version to the attribute name of `target_version`
  8. @compatibility(is_backward_compatible=False)
  9. def default_matching(name: str, target_version: int) -> str:
  10. """Default matching method
  11. """
  12. return name
  13. # This dict maps the nn.Module class name to the attribute name list that we want to fetch for lowering.
  14. # The first integer in the tuple is the version number of the nn.Module class when we create the parameter list.
  15. # If there's a version mismatch then it means the parameter names in the book might be mismatched with nn.Module.
  16. module_fetch_book: Dict[Type, Tuple[int, List[str], Callable[[str, int], str]]] = {
  17. torch.nn.modules.linear.Linear: (1, ["weight", "bias"], default_matching),
  18. torch.nn.modules.conv.Conv2d: (
  19. 1, ["weight", "bias", "kernel_size", "stride", "padding", "dilation", "groups", "padding_mode"], default_matching
  20. ),
  21. torch.nn.modules.batchnorm.BatchNorm2d: (2, ["weight", "bias", "running_mean", "running_var", "eps"], default_matching),
  22. torch.nn.modules.pooling.AdaptiveAvgPool2d: (1, [], default_matching),
  23. torch.nn.modules.pooling.MaxPool2d: (
  24. 1, ["kernel_size", "stride", "padding", "dilation", "return_indices", "ceil_mode"], default_matching
  25. ),
  26. torch.nn.modules.activation.ReLU: (1, ["inplace"], default_matching),
  27. }
  28. @compatibility(is_backward_compatible=False)
  29. def extract_attrs_for_lowering(mod: nn.Module) -> Dict[str, Any]:
  30. """If `mod` is in `module_fetch_book`, fetch the mod's attributes that in the `module_fetch_book`
  31. after checking module's version is compatible with the `module_fetch_book`.
  32. """
  33. attrs_for_lowering: Dict[str, Any] = {}
  34. attrs_for_lowering["name"] = torch.typename(mod)
  35. if type(mod) in module_fetch_book:
  36. version, param_to_fetch, matching_method = module_fetch_book[type(mod)]
  37. if version < mod._version:
  38. raise RuntimeError(f"Fetcher version {version} try to fetch {torch.typename(mod)} version {mod._version}, "
  39. "please upgrade the module_fetch_book, open an issue and @842974287 "
  40. "or report a bug to AIACC team directly.")
  41. for attr in param_to_fetch:
  42. attrs_for_lowering[attr] = getattr(mod, matching_method(attr, mod._version))
  43. else:
  44. raise RuntimeError(f"{torch.typename(mod)} is not in the module_fetch_book yet, "
  45. "please add it to the module_fetch_book, open an issue and @842974287 "
  46. "or report a bug to AIACC team directly.")
  47. return attrs_for_lowering
  48. @compatibility(is_backward_compatible=False)
  49. def lift_lowering_attrs_to_nodes(fx_module: GraphModule) -> None:
  50. """Recursively traverse all `fx_module` nodes and fetch the module's attributes if the node is a leaf module.
  51. """
  52. submodules = dict(fx_module.named_modules())
  53. for node in fx_module.graph.nodes:
  54. if node.op == "call_module":
  55. if isinstance(submodules[node.target], GraphModule):
  56. lift_lowering_attrs_to_nodes(submodules[node.target])
  57. else:
  58. node.attrs_for_lowering = extract_attrs_for_lowering(submodules[node.target])