graph_module.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. import torch
  2. import copy
  3. from torch.fx import GraphModule
  4. from torch.fx.graph import Graph
  5. from typing import Union, Dict, Any, Set
  6. __all__ = [
  7. "FusedGraphModule",
  8. "ObservedGraphModule",
  9. "ObservedStandaloneGraphModule",
  10. "QuantizedGraphModule",
  11. ]
  12. class FusedGraphModule(GraphModule):
  13. def __init__(self, root: Union[torch.nn.Module, Dict[str, Any]], graph: Graph, preserved_attr_names: Set[str]):
  14. self.preserved_attr_names = preserved_attr_names
  15. preserved_attrs = {attr: getattr(root, attr) for attr in self.preserved_attr_names if hasattr(root, attr)}
  16. super().__init__(root, graph)
  17. for attr in preserved_attrs:
  18. setattr(self, attr, preserved_attrs[attr])
  19. # GraphModule does not copy attributes which are not in the __dict__
  20. # of vanilla nn.Module. So, we override __deepcopy__ in order
  21. # to copy the quantization specific attributes correctly.
  22. def __deepcopy__(self, memo):
  23. fake_mod = torch.nn.Module()
  24. fake_mod.__dict__ = copy.deepcopy(self.__dict__)
  25. return FusedGraphModule(fake_mod, copy.deepcopy(self.graph), copy.deepcopy(self.preserved_attr_names))
  26. class ObservedGraphModule(GraphModule):
  27. def __init__(self, root: Union[torch.nn.Module, Dict[str, Any]], graph: Graph, preserved_attr_names: Set[str]):
  28. self.preserved_attr_names = {
  29. '_activation_post_process_map',
  30. '_activation_post_process_indexes',
  31. '_patterns',
  32. '_node_name_to_qconfig',
  33. '_prepare_custom_config',
  34. '_equalization_node_name_to_qconfig',
  35. '_node_name_to_scope',
  36. '_qconfig_mapping',
  37. '_is_qat',
  38. '_observed_node_names'}.union(preserved_attr_names)
  39. preserved_attrs = {attr: getattr(root, attr) for attr in self.preserved_attr_names if hasattr(root, attr)}
  40. super().__init__(root, graph)
  41. for attr in preserved_attrs:
  42. setattr(self, attr, preserved_attrs[attr])
  43. # GraphModule does not copy attributes which are not in the __dict__
  44. # of vanilla nn.Module. So, we override __deepcopy__ in order
  45. # to copy the quantization specific attributes correctly.
  46. def __deepcopy__(self, memo):
  47. fake_mod = torch.nn.Module()
  48. fake_mod.__dict__ = copy.deepcopy(self.__dict__)
  49. return ObservedGraphModule(fake_mod, copy.deepcopy(self.graph), copy.deepcopy(self.preserved_attr_names))
  50. def _is_observed_module(module: Any) -> bool:
  51. return hasattr(module, "meta") and "_observed_graph_module_attrs" in module.meta
  52. def _get_observed_graph_module_attr(model: Union[torch.nn.Module, GraphModule], attr_name: str) -> Any:
  53. if hasattr(model, "meta") and "_observed_graph_module_attrs" in model.meta: # type: ignore[operator, index]
  54. return getattr(model.meta["_observed_graph_module_attrs"], attr_name) # type: ignore[index]
  55. return None
  56. class ObservedStandaloneGraphModule(ObservedGraphModule):
  57. def __init__(self, root: Union[torch.nn.Module, Dict[str, Any]], graph: Graph, preserved_attr_names: Set[str]):
  58. preserved_attr_names = preserved_attr_names.union({
  59. "_standalone_module_input_quantized_idxs",
  60. "_standalone_module_output_quantized_idxs"})
  61. super().__init__(root, graph, preserved_attr_names)
  62. def __deepcopy__(self, memo):
  63. fake_mod = torch.nn.Module()
  64. fake_mod.__dict__ = copy.deepcopy(self.__dict__)
  65. return ObservedStandaloneGraphModule(fake_mod, copy.deepcopy(self.graph), copy.deepcopy(self.preserved_attr_names))
  66. def _is_observed_standalone_module(module: Any) -> bool:
  67. return _is_observed_module(module) and module.meta["_observed_graph_module_attrs"].is_observed_standalone_module
  68. def _save_packed_weight(self, destination, prefix, keep_vars):
  69. for attr_name in dir(self):
  70. if "_packed_weight" in attr_name and \
  71. isinstance(getattr(self, attr_name), torch._C.ScriptObject): # type: ignore[attr-defined]
  72. packed_weight = getattr(self, attr_name)
  73. destination[prefix + attr_name] = packed_weight
  74. class QuantizedGraphModule(GraphModule):
  75. """ This class is created to make sure PackedParams
  76. (e.g. LinearPackedParams, Conv2dPackedParams) to appear in state_dict
  77. so that we can serialize and deserialize quantized graph module with
  78. torch.save(m.state_dict()) and m.load_state_dict(state_dict)
  79. """
  80. def __init__(self, root: Union[torch.nn.Module, Dict[str, Any]], graph: Graph, preserved_attr_names: Set[str]):
  81. self.preserved_attr_names = preserved_attr_names
  82. preserved_attrs = {attr: getattr(root, attr) for attr in self.preserved_attr_names if hasattr(root, attr)}
  83. super().__init__(root, graph)
  84. for attr in preserved_attrs:
  85. setattr(self, attr, preserved_attrs[attr])
  86. self._register_state_dict_hook(_save_packed_weight)
  87. def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
  88. missing_keys, unexpected_keys, error_msgs):
  89. attrs_to_pop = []
  90. for attr_name in state_dict:
  91. if attr_name.startswith("_packed_weight") and isinstance(state_dict[attr_name], torch._C.ScriptObject): # type: ignore[attr-defined] # noqa: B950
  92. setattr(self, attr_name, state_dict[attr_name])
  93. attrs_to_pop.append(attr_name)
  94. # pop the packed param attributesn
  95. for attr_name in attrs_to_pop:
  96. state_dict.pop(attr_name)
  97. super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
  98. def __deepcopy__(self, memo):
  99. fake_mod = torch.nn.Module()
  100. fake_mod.__dict__ = copy.deepcopy(self.__dict__)
  101. return QuantizedGraphModule(fake_mod, copy.deepcopy(self.graph), copy.deepcopy(self.preserved_attr_names))