fuse.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161
  1. from torch.fx import (
  2. GraphModule,
  3. Node,
  4. map_arg
  5. )
  6. from torch.fx.graph import Graph
  7. from .match_utils import (
  8. _is_match,
  9. MatchAllNode,
  10. )
  11. from .pattern_utils import (
  12. _sorted_patterns_dict,
  13. )
  14. from ..backend_config import (
  15. BackendConfig,
  16. get_native_backend_config,
  17. )
  18. from ..backend_config.utils import (
  19. get_fuser_method_mapping,
  20. get_fusion_pattern_to_root_node_getter,
  21. get_fusion_pattern_to_extra_inputs_getter,
  22. )
  23. from .custom_config import FuseCustomConfig
  24. from .fuse_handler import (
  25. _get_fusion_pattern_to_fuse_handler_cls,
  26. FuseHandler,
  27. )
  28. from typing import Any, Callable, Dict, List, Tuple, Union
  29. import warnings
  30. from torch.ao.quantization.utils import Pattern, NodePattern
  31. __all__ = [
  32. "fuse",
  33. # TODO: We should make this private in the future
  34. # This is currently needed for test_public_bindings for some reason
  35. "FuseHandler",
  36. ]
  37. def fuse(
  38. model: GraphModule,
  39. is_qat: bool,
  40. fuse_custom_config: Union[FuseCustomConfig, Dict[str, Any], None] = None,
  41. backend_config: Union[BackendConfig, Dict[str, Any], None] = None,
  42. ) -> GraphModule:
  43. if fuse_custom_config is None:
  44. fuse_custom_config = FuseCustomConfig()
  45. if isinstance(fuse_custom_config, Dict):
  46. warnings.warn(
  47. "Passing a fuse_custom_config_dict to fuse is deprecated and will not be supported "
  48. "in a future version. Please pass in a FuseCustomConfig instead.")
  49. fuse_custom_config = FuseCustomConfig.from_dict(fuse_custom_config)
  50. if isinstance(backend_config, Dict):
  51. warnings.warn(
  52. "Passing a backend_config_dict to prepare is deprecated and will not be supported "
  53. "in a future version. Please pass in a BackendConfig instead.")
  54. backend_config = BackendConfig.from_dict(backend_config)
  55. named_modules = dict(model.named_modules())
  56. if backend_config is None:
  57. backend_config = get_native_backend_config()
  58. fusion_pattern_to_fuse_handler_cls = _sorted_patterns_dict(_get_fusion_pattern_to_fuse_handler_cls(backend_config))
  59. fuser_method_mapping = get_fuser_method_mapping(backend_config)
  60. fusion_pattern_to_root_node_getter = get_fusion_pattern_to_root_node_getter(backend_config)
  61. fusion_pattern_to_extra_inputs_getter = get_fusion_pattern_to_extra_inputs_getter(backend_config)
  62. # find fusion
  63. fusion_pairs = _find_matches(
  64. model, model.graph, fusion_pattern_to_fuse_handler_cls)
  65. # TODO: change this to inplace changes to graph, since we no longer construct
  66. # new GraphModule anymore
  67. fused_graph = Graph()
  68. env: Dict[Any, Any] = {}
  69. def load_arg(a):
  70. return map_arg(a, lambda node: env[node.name])
  71. def default_root_node_getter(node_pattern):
  72. while not isinstance(node_pattern[-1], Node):
  73. node_pattern = node_pattern[-1]
  74. return node_pattern[-1]
  75. for node in model.graph.nodes:
  76. maybe_last_node, pattern, matched_node_pattern, obj, node_to_subpattern = \
  77. fusion_pairs.get(node.name, (None, None, None, None, None))
  78. # get the corresponding subpattern for the current node
  79. if node_to_subpattern is not None:
  80. node_subpattern = node_to_subpattern.get(node, None)
  81. else:
  82. node_subpattern = None
  83. if maybe_last_node is node:
  84. assert obj is not None
  85. root_node_getter = fusion_pattern_to_root_node_getter.get(pattern, default_root_node_getter)
  86. root_node = root_node_getter(matched_node_pattern) # type: ignore[index]
  87. extra_inputs_getter = fusion_pattern_to_extra_inputs_getter.get(pattern, None)
  88. extra_inputs = []
  89. if extra_inputs_getter is not None:
  90. extra_inputs = extra_inputs_getter(matched_node_pattern)
  91. # TODO: add validation that root_node is a module and has the same type
  92. # as the root_module in the configuration
  93. env[node.name] = obj.fuse(
  94. load_arg, named_modules, fused_graph, root_node, extra_inputs, matched_node_pattern, # type: ignore[arg-type]
  95. fuse_custom_config, fuser_method_mapping, is_qat)
  96. elif maybe_last_node is None or node_subpattern is MatchAllNode:
  97. env[node.name] = fused_graph.node_copy(node, load_arg)
  98. # node matched in patterns and is not root is removed here
  99. model = GraphModule(model, fused_graph)
  100. return model
  101. def _find_matches(
  102. root: GraphModule,
  103. graph: Graph,
  104. pattern_to_fuse_handler_cls: Dict[Pattern, Callable],
  105. ) -> Dict[str, Tuple[Node, Pattern, NodePattern, FuseHandler, Dict[Node, Any]]]:
  106. modules = dict(root.named_modules())
  107. # node name -> (root_node, match_value)
  108. match_map : Dict[
  109. str, Tuple[Node, Pattern, NodePattern, FuseHandler, Dict[Node, Any]]] = {}
  110. # a map from node to the matched subpattern
  111. node_to_subpattern: Dict[Node, Any] = {}
  112. # TODO: dedup with quantization matching function in match_utils.py
  113. def apply_match(pattern, node, match, matched_node_pattern, node_to_subpattern):
  114. if isinstance(pattern, tuple):
  115. s, *args = pattern
  116. current_node_pattern: List[Node] = []
  117. apply_match(s, node, match, current_node_pattern, node_to_subpattern)
  118. for subpattern, arg in zip(args, node.args):
  119. apply_match(subpattern, arg, match, current_node_pattern, node_to_subpattern)
  120. matched_node_pattern.append(tuple(current_node_pattern))
  121. else:
  122. # the first pattern matches will take precedence
  123. if node.name not in match_map:
  124. matched_node_pattern.append(node)
  125. # MatchAllNode here is actually MatchAllInputNode which should not
  126. # be added to match_map
  127. if pattern is not MatchAllNode:
  128. node_to_subpattern[node] = pattern
  129. root_node, pattern, handler = match
  130. match_map[node.name] = (root_node, pattern, matched_node_pattern, handler, node_to_subpattern)
  131. for node in reversed(graph.nodes):
  132. if node.name not in match_map:
  133. for pattern, fuse_handler_cls in pattern_to_fuse_handler_cls.items():
  134. matched_node_pattern: List[Node] = []
  135. if _is_match(modules, node, pattern):
  136. apply_match(pattern, node, (node, pattern, fuse_handler_cls(node)), matched_node_pattern, node_to_subpattern)
  137. break
  138. return match_map