fuse_handler.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. import torch
  2. from torch.ao.quantization.backend_config import BackendConfig
  3. from torch.fx.graph import Node, Graph
  4. from ..utils import _parent_name, NodePattern, Pattern
  5. from ..fuser_method_mappings import get_fuser_method_new
  6. from abc import ABC, abstractmethod
  7. from typing import Any, Callable, Dict, List, Union
  8. from .custom_config import FuseCustomConfig
  9. from .match_utils import MatchAllNode
  10. from torch.nn.utils.parametrize import type_before_parametrizations
  11. __all__ = [
  12. "DefaultFuseHandler",
  13. "FuseHandler",
  14. ]
  15. # ----------------------------
  16. # Fusion Pattern Registrations
  17. # ----------------------------
  18. # Base Pattern Handler
  19. class FuseHandler(ABC):
  20. """ Base handler class for the fusion patterns
  21. """
  22. def __init__(self, node: Node):
  23. pass
  24. @abstractmethod
  25. def fuse(self,
  26. load_arg: Callable,
  27. named_modules: Dict[str, torch.nn.Module],
  28. fused_graph: Graph,
  29. root_node: Node,
  30. extra_inputs: List[Any],
  31. matched_node_pattern: NodePattern,
  32. fuse_custom_config: FuseCustomConfig,
  33. fuser_method_mapping: Dict[Pattern, Union[torch.nn.Sequential, Callable]],
  34. is_qat: bool) -> Node:
  35. pass
  36. class DefaultFuseHandler(FuseHandler):
  37. def __init__(
  38. self,
  39. node: Node):
  40. super().__init__(node)
  41. def fuse(self,
  42. load_arg: Callable,
  43. named_modules: Dict[str, torch.nn.Module],
  44. fused_graph: Graph,
  45. root_node: Node,
  46. extra_inputs: List[Any],
  47. matched_node_pattern: NodePattern,
  48. fuse_custom_config: FuseCustomConfig,
  49. fuser_method_mapping: Dict[Pattern, Union[torch.nn.Sequential, Callable]],
  50. is_qat: bool) -> Node:
  51. assert root_node.op == "call_module", "Expecting module node to be a call_module Node"
  52. root_module = named_modules[str(root_node.target)]
  53. def get_modules(pattern):
  54. """ Given a node pattern, extract the corresponding modules
  55. e.g. input: (relu_node, (bn_node, conv_node))
  56. output: (relu_module, (bn_module, conv_module))
  57. """
  58. if isinstance(pattern, (tuple, list)):
  59. n, *args = pattern
  60. modules: List[torch.nn.Module] = []
  61. modules.append(get_modules(n))
  62. for a in args:
  63. modules.append(get_modules(a))
  64. return tuple(modules)
  65. else:
  66. n = pattern
  67. if n.op == "call_module":
  68. return named_modules[n.target]
  69. elif n.op == "call_function" and n.target == torch.nn.functional.relu:
  70. relu = torch.nn.ReLU()
  71. relu.training = root_module.training
  72. return relu
  73. elif n.op == "call_function" or n.op == "call_method":
  74. return n.target
  75. else:
  76. return MatchAllNode
  77. # since relu can be used multiple times, we'll need to create a relu module for each match
  78. matched_modules = get_modules(matched_node_pattern)
  79. def get_matched_types(m):
  80. if isinstance(m, tuple):
  81. return tuple(map(get_matched_types, m))
  82. if isinstance(m, torch.nn.Module):
  83. return type_before_parametrizations(m)
  84. return m
  85. matched_module_types = get_matched_types(matched_modules)
  86. module_parent_name, module_name = _parent_name(root_node.target)
  87. fuser_method = get_fuser_method_new(matched_module_types, fuser_method_mapping)
  88. # TODO: change the signature for fuser_method to take matched module patterns
  89. # as input
  90. fused_module = fuser_method(is_qat, *matched_modules)
  91. setattr(named_modules[module_parent_name], module_name, fused_module)
  92. extra_args = []
  93. for input in extra_inputs:
  94. extra_args.append(load_arg(input))
  95. node = fused_graph.node_copy(root_node, load_arg)
  96. args = list(node.args)
  97. args.extend(extra_args)
  98. node.args = tuple(args)
  99. return node
  100. def _get_fusion_pattern_to_fuse_handler_cls(
  101. backend_config: BackendConfig) -> Dict[Pattern, Callable]:
  102. fusion_pattern_to_fuse_handlers: Dict[Pattern, Callable] = {}
  103. for pattern, config in backend_config._pattern_complex_format_to_config.items():
  104. if config.fuser_method is not None:
  105. # TODO: is this logic right?
  106. fusion_pattern_to_fuse_handlers[pattern] = DefaultFuseHandler
  107. return fusion_pattern_to_fuse_handlers