pattern_utils.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. from collections import OrderedDict
  2. from typing import Dict, Any
  3. from torch.ao.quantization.utils import Pattern
  4. from ..fake_quantize import FixedQParamsFakeQuantize
  5. from ..observer import ObserverBase
  6. import copy
  7. __all__ = [
  8. "get_default_fusion_patterns",
  9. "get_default_quant_patterns",
  10. "get_default_output_activation_post_process_map",
  11. ]
  12. # TODO(future PR): fix the typing on QuantizeHandler (currently a circular dependency)
  13. QuantizeHandler = Any
  14. # pattern for conv bn fusion
  15. _DEFAULT_FUSION_PATTERNS = OrderedDict()
  16. def _register_fusion_pattern(pattern):
  17. def insert(fn):
  18. _DEFAULT_FUSION_PATTERNS[pattern] = fn
  19. return fn
  20. return insert
  21. def get_default_fusion_patterns() -> Dict[Pattern, QuantizeHandler]:
  22. return copy.copy(_DEFAULT_FUSION_PATTERNS)
  23. _DEFAULT_QUANTIZATION_PATTERNS = OrderedDict()
  24. # Mapping from pattern to activation_post_process(observer/fake_quant) constructor for output activation
  25. # e.g. pattern: torch.sigmoid,
  26. # output_activation_post_process: default_fixed_qparams_range_0to1_fake_quant
  27. _DEFAULT_OUTPUT_FAKE_QUANTIZE_MAP = {}
  28. _DEFAULT_OUTPUT_OBSERVER_MAP = {}
  29. # Register pattern for both static quantization and qat
  30. def _register_quant_pattern(pattern, fixed_qparams_observer=None):
  31. def insert(fn):
  32. _DEFAULT_QUANTIZATION_PATTERNS[pattern] = fn
  33. if fixed_qparams_observer is not None:
  34. _DEFAULT_OUTPUT_FAKE_QUANTIZE_MAP[pattern] = FixedQParamsFakeQuantize.with_args(observer=fixed_qparams_observer)
  35. _DEFAULT_OUTPUT_OBSERVER_MAP[pattern] = fixed_qparams_observer
  36. return fn
  37. return insert
  38. # Get patterns for both static quantization and qat
  39. def get_default_quant_patterns() -> Dict[Pattern, QuantizeHandler]:
  40. return copy.copy(_DEFAULT_QUANTIZATION_PATTERNS)
  41. # a map from pattern to output activation post process constructor
  42. # e.g. torch.sigmoid -> default_affine_fixed_qparam_fake_quant
  43. def get_default_output_activation_post_process_map(is_training) -> Dict[Pattern, ObserverBase]:
  44. if is_training:
  45. return copy.copy(_DEFAULT_OUTPUT_FAKE_QUANTIZE_MAP)
  46. else:
  47. return copy.copy(_DEFAULT_OUTPUT_OBSERVER_MAP)
  48. # Example use of register pattern function:
  49. # @_register_fusion_pattern(torch.nn.ReLU, (torch.nn.BatchNorm2d, torch.nn.Conv2d)))
  50. # class ConvOrLinearBNReLUFusion():
  51. # def __init__(...):
  52. # ...
  53. #
  54. def _sorted_patterns_dict(patterns_dict: Dict[Pattern, QuantizeHandler]) -> Dict[Pattern, QuantizeHandler]:
  55. """
  56. Return a sorted version of the patterns dictionary such that longer patterns are matched first,
  57. e.g. match (F.relu, F.linear) before F.relu.
  58. This works for current use cases, but we may need to have a more clever way to sort
  59. things to address more complex patterns
  60. """
  61. def get_len(pattern):
  62. """ this will calculate the length of the pattern by counting all the entries
  63. in the pattern.
  64. this will make sure (nn.ReLU, (nn.BatchNorm, nn.Conv2d)) comes before
  65. (nn.BatchNorm, nn.Conv2d) so that we can match the former first
  66. """
  67. len = 0
  68. if isinstance(pattern, tuple):
  69. for item in pattern:
  70. len += get_len(item)
  71. else:
  72. len += 1
  73. return len
  74. return OrderedDict(sorted(patterns_dict.items(), key=lambda kv: -get_len(kv[0]) if isinstance(kv[0], tuple) else 1))