__init__.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
  1. # flake8: noqa: F403
  2. from .fake_quantize import * # noqa: F403
  3. from .fuse_modules import fuse_modules # noqa: F403
  4. from .fuse_modules import fuse_modules_qat # noqa: F403
  5. from .fuser_method_mappings import * # noqa: F403
  6. from .observer import * # noqa: F403
  7. from .qconfig import * # noqa: F403
  8. from .qconfig_mapping import * # noqa: F403
  9. from .quant_type import * # noqa: F403
  10. from .quantization_mappings import * # type: ignore[no-redef]
  11. from .quantize import * # noqa: F403
  12. from .quantize_jit import * # noqa: F403
  13. from .stubs import * # noqa: F403
  14. __all__ = [
  15. "DeQuantStub",
  16. "FakeQuantize",
  17. "FakeQuantizeBase",
  18. "FixedQParamsFakeQuantize",
  19. "FixedQParamsObserver",
  20. "FusedMovingAvgObsFakeQuantize",
  21. "HistogramObserver",
  22. "MatchAllNode",
  23. "MinMaxObserver",
  24. "MovingAverageMinMaxObserver",
  25. "MovingAveragePerChannelMinMaxObserver",
  26. "NoopObserver",
  27. "ObserverBase",
  28. "Pattern",
  29. "PerChannelMinMaxObserver",
  30. "PlaceholderObserver",
  31. "QConfig",
  32. "QConfigAny",
  33. "QConfigDynamic",
  34. "QConfigMapping",
  35. "QuantStub",
  36. "QuantType",
  37. "QuantWrapper",
  38. "RecordingObserver",
  39. "ReuseInputObserver",
  40. "UniformQuantizationObserverBase",
  41. "add_quant_dequant",
  42. "convert",
  43. "convert_dynamic_jit",
  44. "convert_jit",
  45. "default_affine_fixed_qparams_fake_quant",
  46. "default_affine_fixed_qparams_observer",
  47. "default_debug_observer",
  48. "default_dynamic_fake_quant",
  49. "default_dynamic_quant_observer",
  50. "default_embedding_fake_quant",
  51. "default_embedding_fake_quant_4bit",
  52. "default_eval_fn",
  53. "default_fake_quant",
  54. "default_fixed_qparams_range_0to1_fake_quant",
  55. "default_fixed_qparams_range_0to1_observer",
  56. "default_fixed_qparams_range_neg1to1_fake_quant",
  57. "default_fixed_qparams_range_neg1to1_observer",
  58. "default_float_qparams_observer",
  59. "default_float_qparams_observer_4bit",
  60. "default_fused_act_fake_quant",
  61. "default_fused_per_channel_wt_fake_quant",
  62. "default_fused_wt_fake_quant",
  63. "default_histogram_fake_quant",
  64. "default_histogram_observer",
  65. "default_observer",
  66. "default_per_channel_weight_fake_quant",
  67. "default_per_channel_weight_observer",
  68. "default_placeholder_observer",
  69. "default_reuse_input_observer",
  70. "default_symmetric_fixed_qparams_fake_quant",
  71. "default_symmetric_fixed_qparams_observer",
  72. "default_weight_fake_quant",
  73. "default_weight_observer",
  74. "disable_fake_quant",
  75. "disable_observer",
  76. "enable_fake_quant",
  77. "enable_observer",
  78. "fuse_conv_bn",
  79. "fuse_conv_bn_jit",
  80. "fuse_conv_bn_relu",
  81. "fuse_convtranspose_bn",
  82. "fuse_linear_bn",
  83. "fuse_modules",
  84. "fuse_modules_qat",
  85. "fused_per_channel_wt_fake_quant_range_neg_127_to_127",
  86. "fused_wt_fake_quant_range_neg_127_to_127",
  87. "get_combined_dict",
  88. "get_default_compare_output_module_list",
  89. "get_default_custom_config_dict",
  90. "get_default_dynamic_quant_module_mappings",
  91. "get_default_dynamic_sparse_quant_module_mappings",
  92. "get_default_float_to_quantized_operator_mappings",
  93. "get_default_qat_module_mappings",
  94. "get_default_qat_qconfig",
  95. "get_default_qat_qconfig_dict",
  96. "get_default_qat_qconfig_mapping",
  97. "get_default_qconfig",
  98. "get_default_qconfig_dict",
  99. "get_default_qconfig_mapping",
  100. "get_default_qconfig_propagation_list",
  101. "get_default_static_quant_module_mappings",
  102. "get_default_static_quant_reference_module_mappings",
  103. "get_default_static_sparse_quant_module_mappings",
  104. "get_dynamic_quant_module_class",
  105. "get_embedding_qat_module_mappings",
  106. "get_embedding_static_quant_module_mappings",
  107. "get_fuser_method",
  108. "get_fuser_method_new",
  109. "get_observer_state_dict",
  110. "get_quantized_operator",
  111. "get_static_quant_module_class",
  112. "load_observer_state_dict",
  113. "no_observer_set",
  114. "per_channel_weight_observer_range_neg_127_to_127",
  115. "prepare",
  116. "prepare_dynamic_jit",
  117. "prepare_jit",
  118. "prepare_qat",
  119. "propagate_qconfig_",
  120. "qconfig_equals",
  121. "quantize",
  122. "quantize_dynamic",
  123. "quantize_dynamic_jit",
  124. "quantize_jit",
  125. "quantize_qat",
  126. "script_qconfig",
  127. "script_qconfig_dict",
  128. "swap_module",
  129. "weight_observer_range_neg_127_to_127",
  130. ]
  131. def default_eval_fn(model, calib_data):
  132. r"""
  133. Default evaluation function takes a torch.utils.data.Dataset or a list of
  134. input Tensors and run the model on the dataset
  135. """
  136. for data, target in calib_data:
  137. model(data)