fuse_modules.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175
  1. import copy
  2. import torch.nn as nn
  3. from torch.ao.quantization.fuser_method_mappings import get_fuser_method
  4. # for backward compatiblity
  5. from torch.ao.quantization.fuser_method_mappings import fuse_conv_bn # noqa: F401
  6. from torch.ao.quantization.fuser_method_mappings import fuse_conv_bn_relu # noqa: F401
  7. from torch.nn.utils.parametrize import type_before_parametrizations
  8. from typing import List, Optional
  9. __all__ = [
  10. "fuse_known_modules",
  11. "fuse_modules",
  12. "fuse_modules_qat",
  13. ]
  14. # Generalization of getattr
  15. def _get_module(model, submodule_key):
  16. tokens = submodule_key.split('.')
  17. cur_mod = model
  18. for s in tokens:
  19. cur_mod = getattr(cur_mod, s)
  20. return cur_mod
  21. # Generalization of setattr
  22. def _set_module(model, submodule_key, module):
  23. tokens = submodule_key.split('.')
  24. sub_tokens = tokens[:-1]
  25. cur_mod = model
  26. for s in sub_tokens:
  27. cur_mod = getattr(cur_mod, s)
  28. setattr(cur_mod, tokens[-1], module)
  29. def fuse_known_modules(mod_list, is_qat, additional_fuser_method_mapping=None):
  30. r"""Returns a list of modules that fuses the operations specified
  31. in the input module list.
  32. Fuses only the following sequence of modules:
  33. conv, bn
  34. conv, bn, relu
  35. conv, relu
  36. linear, bn
  37. linear, relu
  38. For these sequences, the first element in the output module list performs
  39. the fused operation. The rest of the elements are set to nn.Identity()
  40. """
  41. types = tuple(type_before_parametrizations(m) for m in mod_list)
  42. fuser_method = get_fuser_method(types, additional_fuser_method_mapping)
  43. if fuser_method is None:
  44. raise NotImplementedError("Cannot fuse modules: {}".format(types))
  45. new_mod : List[Optional[nn.Module]] = [None] * len(mod_list)
  46. fused = fuser_method(is_qat, *mod_list)
  47. # NOTE: forward hooks not processed in the two following for loops will be lost after the fusion
  48. # Move pre forward hooks of the base module to resulting fused module
  49. for handle_id, pre_hook_fn in mod_list[0]._forward_pre_hooks.items():
  50. fused.register_forward_pre_hook(pre_hook_fn)
  51. del mod_list[0]._forward_pre_hooks[handle_id]
  52. # Move post forward hooks of the last module to resulting fused module
  53. for handle_id, hook_fn in mod_list[-1]._forward_hooks.items():
  54. fused.register_forward_hook(hook_fn)
  55. del mod_list[-1]._forward_hooks[handle_id]
  56. new_mod[0] = fused
  57. for i in range(1, len(mod_list)):
  58. identity = nn.Identity()
  59. identity.training = mod_list[0].training
  60. new_mod[i] = identity
  61. return new_mod
  62. def _fuse_modules_helper(model, modules_to_fuse, is_qat, fuser_func=fuse_known_modules, fuse_custom_config_dict=None):
  63. if fuse_custom_config_dict is None:
  64. fuse_custom_config_dict = {}
  65. additional_fuser_method_mapping = fuse_custom_config_dict.get("additional_fuser_method_mapping", {})
  66. mod_list = []
  67. for item in modules_to_fuse:
  68. mod_list.append(_get_module(model, item))
  69. # Fuse list of modules
  70. new_mod_list = fuser_func(mod_list, is_qat, additional_fuser_method_mapping)
  71. # Replace original module list with fused module list
  72. for i, item in enumerate(modules_to_fuse):
  73. _set_module(model, item, new_mod_list[i])
  74. def _fuse_modules(model, modules_to_fuse, is_qat, inplace=False, fuser_func=fuse_known_modules, fuse_custom_config_dict=None):
  75. if not inplace:
  76. model = copy.deepcopy(model)
  77. if all(isinstance(module_element, str) for module_element in modules_to_fuse):
  78. # Handle case of modules_to_fuse being a list
  79. _fuse_modules_helper(model, modules_to_fuse, is_qat, fuser_func, fuse_custom_config_dict)
  80. else:
  81. # Handle case of modules_to_fuse being a list of lists
  82. for module_list in modules_to_fuse:
  83. _fuse_modules_helper(model, module_list, is_qat, fuser_func, fuse_custom_config_dict)
  84. return model
  85. def fuse_modules(model, modules_to_fuse, inplace=False, fuser_func=fuse_known_modules, fuse_custom_config_dict=None):
  86. r"""Fuses a list of modules into a single module
  87. Fuses only the following sequence of modules:
  88. conv, bn
  89. conv, bn, relu
  90. conv, relu
  91. linear, relu
  92. bn, relu
  93. All other sequences are left unchanged.
  94. For these sequences, replaces the first item in the list
  95. with the fused module, replacing the rest of the modules
  96. with identity.
  97. Args:
  98. model: Model containing the modules to be fused
  99. modules_to_fuse: list of list of module names to fuse. Can also be a list
  100. of strings if there is only a single list of modules to fuse.
  101. inplace: bool specifying if fusion happens in place on the model, by default
  102. a new model is returned
  103. fuser_func: Function that takes in a list of modules and outputs a list of fused modules
  104. of the same length. For example,
  105. fuser_func([convModule, BNModule]) returns the list [ConvBNModule, nn.Identity()]
  106. Defaults to torch.ao.quantization.fuse_known_modules
  107. `fuse_custom_config_dict`: custom configuration for fusion
  108. .. code-block:: python
  109. # Example of fuse_custom_config_dict
  110. fuse_custom_config_dict = {
  111. # Additional fuser_method mapping
  112. "additional_fuser_method_mapping": {
  113. (torch.nn.Conv2d, torch.nn.BatchNorm2d): fuse_conv_bn
  114. },
  115. }
  116. Returns:
  117. model with fused modules. A new copy is created if inplace=True.
  118. Examples::
  119. >>> # xdoctest: +SKIP
  120. >>> m = M().eval()
  121. >>> # m is a module containing the sub-modules below
  122. >>> modules_to_fuse = [ ['conv1', 'bn1', 'relu1'], ['submodule.conv', 'submodule.relu']]
  123. >>> fused_m = torch.ao.quantization.fuse_modules(m, modules_to_fuse)
  124. >>> output = fused_m(input)
  125. >>> m = M().eval()
  126. >>> # Alternately provide a single list of modules to fuse
  127. >>> modules_to_fuse = ['conv1', 'bn1', 'relu1']
  128. >>> fused_m = torch.ao.quantization.fuse_modules(m, modules_to_fuse)
  129. >>> output = fused_m(input)
  130. """
  131. return _fuse_modules(
  132. model,
  133. modules_to_fuse,
  134. is_qat=False,
  135. inplace=inplace,
  136. fuser_func=fuser_func,
  137. fuse_custom_config_dict=None)
  138. def fuse_modules_qat(model, modules_to_fuse, inplace=False, fuser_func=fuse_known_modules, fuse_custom_config_dict=None):
  139. """ QAT version for `fuse_modules`
  140. """
  141. return _fuse_modules(
  142. model,
  143. modules_to_fuse,
  144. is_qat=True,
  145. inplace=inplace,
  146. fuser_func=fuser_func,
  147. fuse_custom_config_dict=None)