__init__.py 2.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162
  1. from .quantize import * # noqa: F403
  2. from .observer import * # noqa: F403
  3. from .qconfig import * # noqa: F403
  4. from .fake_quantize import * # noqa: F403
  5. from .fuse_modules import fuse_modules
  6. from .stubs import * # noqa: F403
  7. from .quant_type import * # noqa: F403
  8. from .quantize_jit import * # noqa: F403
  9. # from .quantize_fx import *
  10. from .quantization_mappings import * # noqa: F403
  11. from .fuser_method_mappings import * # noqa: F403
  12. def default_eval_fn(model, calib_data):
  13. r"""
  14. Default evaluation function takes a torch.utils.data.Dataset or a list of
  15. input Tensors and run the model on the dataset
  16. """
  17. for data, target in calib_data:
  18. model(data)
  19. __all__ = [
  20. 'QuantWrapper', 'QuantStub', 'DeQuantStub',
  21. # Top level API for eager mode quantization
  22. 'quantize', 'quantize_dynamic', 'quantize_qat',
  23. 'prepare', 'convert', 'prepare_qat',
  24. # Top level API for graph mode quantization on TorchScript
  25. 'quantize_jit', 'quantize_dynamic_jit', '_prepare_ondevice_dynamic_jit',
  26. '_convert_ondevice_dynamic_jit', '_quantize_ondevice_dynamic_jit',
  27. # Top level API for graph mode quantization on GraphModule(torch.fx)
  28. # 'fuse_fx', 'quantize_fx', # TODO: add quantize_dynamic_fx
  29. # 'prepare_fx', 'prepare_dynamic_fx', 'convert_fx',
  30. 'QuantType', # quantization type
  31. # custom module APIs
  32. 'get_default_static_quant_module_mappings', 'get_static_quant_module_class',
  33. 'get_default_dynamic_quant_module_mappings',
  34. 'get_default_qat_module_mappings',
  35. 'get_default_qconfig_propagation_list',
  36. 'get_default_compare_output_module_list',
  37. 'get_quantized_operator',
  38. 'get_fuser_method',
  39. # Sub functions for `prepare` and `swap_module`
  40. 'propagate_qconfig_', 'add_quant_dequant', 'swap_module',
  41. 'default_eval_fn',
  42. # Observers
  43. 'ObserverBase', 'WeightObserver', 'HistogramObserver',
  44. 'observer', 'default_observer',
  45. 'default_weight_observer', 'default_placeholder_observer',
  46. 'default_per_channel_weight_observer',
  47. # FakeQuantize (for qat)
  48. 'default_fake_quant', 'default_weight_fake_quant',
  49. 'default_fixed_qparams_range_neg1to1_fake_quant',
  50. 'default_fixed_qparams_range_0to1_fake_quant',
  51. 'default_per_channel_weight_fake_quant',
  52. 'default_histogram_fake_quant',
  53. # QConfig
  54. 'QConfig', 'default_qconfig', 'default_dynamic_qconfig', 'float16_dynamic_qconfig',
  55. 'float_qparams_weight_only_qconfig',
  56. # QAT utilities
  57. 'default_qat_qconfig', 'prepare_qat', 'quantize_qat',
  58. # module transformations
  59. 'fuse_modules',
  60. ]