_quantize_pt2e.py 1.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849
  1. from torch.fx import GraphModule
  2. from .qconfig_mapping import QConfigMapping
  3. from .backend_config import BackendConfig
  4. from .fx import prepare
  5. from .quantize_fx import _convert_to_reference_decomposed_fx
  6. from ._pt2e.utils import (
  7. _get_renamed_nn_module_stack,
  8. _fuse_conv_bn_,
  9. _rearrange_weight_observer_for_addmm,
  10. )
  11. from typing import Tuple, Any, Dict
  12. def prepare_pt2e(
  13. model: GraphModule,
  14. qconfig_mapping: QConfigMapping,
  15. example_inputs: Tuple[Any, ...],
  16. backend_config: BackendConfig,
  17. ):
  18. # TODO: move this information to fx node itself
  19. node_name_to_scope: Dict[str, Tuple[str, type]] = {}
  20. for n in model.graph.nodes:
  21. renamed_stack = _get_renamed_nn_module_stack(n.meta.get("nn_module_stack", None))
  22. current_scope = list(renamed_stack.items())[-1]
  23. node_name_to_scope[n.name] = current_scope
  24. # TODO: check qconfig_mapping to make sure conv and bn are both configured
  25. # to be quantized before fusion
  26. # TODO: (maybe) rewrite this with subgraph_rewriter
  27. _fuse_conv_bn_(model)
  28. model = prepare(
  29. model,
  30. qconfig_mapping,
  31. False, # is_qat
  32. node_name_to_scope,
  33. example_inputs,
  34. backend_config=backend_config
  35. )
  36. # TODO: remove hack when we have better support for pattern matching
  37. # move around the observer for addmm
  38. _rearrange_weight_observer_for_addmm(model)
  39. return model
  40. def convert_pt2e(
  41. model: GraphModule
  42. ):
  43. return _convert_to_reference_decomposed_fx(model)