custom_config.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422
  1. from __future__ import annotations
  2. from dataclasses import dataclass
  3. from typing import Any, Dict, List, Optional, Tuple, Type
  4. from torch.ao.quantization import QConfigMapping
  5. from torch.ao.quantization.backend_config import BackendConfig
  6. from torch.ao.quantization.quant_type import QuantType, _quant_type_from_str, _get_quant_type_to_str
  7. __all__ = [
  8. "ConvertCustomConfig",
  9. "FuseCustomConfig",
  10. "PrepareCustomConfig",
  11. "StandaloneModuleConfigEntry",
  12. ]
  13. # TODO: replace all usages with these constants
  14. STANDALONE_MODULE_NAME_DICT_KEY = "standalone_module_name"
  15. STANDALONE_MODULE_CLASS_DICT_KEY = "standalone_module_class"
  16. FLOAT_TO_OBSERVED_DICT_KEY = "float_to_observed_custom_module_class"
  17. OBSERVED_TO_QUANTIZED_DICT_KEY = "observed_to_quantized_custom_module_class"
  18. NON_TRACEABLE_MODULE_NAME_DICT_KEY = "non_traceable_module_name"
  19. NON_TRACEABLE_MODULE_CLASS_DICT_KEY = "non_traceable_module_class"
  20. INPUT_QUANTIZED_INDEXES_DICT_KEY = "input_quantized_idxs"
  21. OUTPUT_QUANTIZED_INDEXES_DICT_KEY = "output_quantized_idxs"
  22. PRESERVED_ATTRIBUTES_DICT_KEY = "preserved_attributes"
  23. @dataclass
  24. class StandaloneModuleConfigEntry:
  25. # qconfig_mapping for the prepare function called in the submodule,
  26. # None means use qconfig from parent qconfig_mapping
  27. qconfig_mapping: Optional[QConfigMapping]
  28. example_inputs: Tuple[Any, ...]
  29. prepare_custom_config: Optional[PrepareCustomConfig]
  30. backend_config: Optional[BackendConfig]
  31. class PrepareCustomConfig:
  32. """
  33. Custom configuration for :func:`~torch.ao.quantization.quantize_fx.prepare_fx` and
  34. :func:`~torch.ao.quantization.quantize_fx.prepare_qat_fx`.
  35. Example usage::
  36. prepare_custom_config = PrepareCustomConfig() \
  37. .set_standalone_module_name("module1", qconfig_mapping, example_inputs, \
  38. child_prepare_custom_config, backend_config) \
  39. .set_standalone_module_class(MyStandaloneModule, qconfig_mapping, example_inputs, \
  40. child_prepare_custom_config, backend_config) \
  41. .set_float_to_observed_mapping(FloatCustomModule, ObservedCustomModule) \
  42. .set_non_traceable_module_names(["module2", "module3"]) \
  43. .set_non_traceable_module_classes([NonTraceableModule1, NonTraceableModule2]) \
  44. .set_input_quantized_indexes([0]) \
  45. .set_output_quantized_indexes([0]) \
  46. .set_preserved_attributes(["attr1", "attr2"])
  47. """
  48. def __init__(self):
  49. self.standalone_module_names: Dict[str, StandaloneModuleConfigEntry] = {}
  50. self.standalone_module_classes: Dict[Type, StandaloneModuleConfigEntry] = {}
  51. self.float_to_observed_mapping: Dict[QuantType, Dict[Type, Type]] = {}
  52. self.non_traceable_module_names: List[str] = []
  53. self.non_traceable_module_classes: List[Type] = []
  54. self.input_quantized_indexes: List[int] = []
  55. self.output_quantized_indexes: List[int] = []
  56. self.preserved_attributes: List[str] = []
  57. def __repr__(self):
  58. dict_nonempty = {
  59. k: v for k, v in self.__dict__.items()
  60. if len(v) > 0
  61. }
  62. return f"PrepareCustomConfig({dict_nonempty})"
  63. def set_standalone_module_name(
  64. self,
  65. module_name: str,
  66. qconfig_mapping: Optional[QConfigMapping],
  67. example_inputs: Tuple[Any, ...],
  68. prepare_custom_config: Optional[PrepareCustomConfig],
  69. backend_config: Optional[BackendConfig]) -> PrepareCustomConfig:
  70. """
  71. Set the configuration for running a standalone module identified by ``module_name``.
  72. If ``qconfig_mapping`` is None, the parent ``qconfig_mapping`` will be used instead.
  73. If ``prepare_custom_config`` is None, an empty ``PrepareCustomConfig`` will be used.
  74. If ``backend_config`` is None, the parent ``backend_config`` will be used instead.
  75. """
  76. self.standalone_module_names[module_name] = \
  77. StandaloneModuleConfigEntry(qconfig_mapping, example_inputs, prepare_custom_config, backend_config)
  78. return self
  79. def set_standalone_module_class(
  80. self,
  81. module_class: Type,
  82. qconfig_mapping: Optional[QConfigMapping],
  83. example_inputs: Tuple[Any, ...],
  84. prepare_custom_config: Optional[PrepareCustomConfig],
  85. backend_config: Optional[BackendConfig]) -> PrepareCustomConfig:
  86. """
  87. Set the configuration for running a standalone module identified by ``module_class``.
  88. If ``qconfig_mapping`` is None, the parent ``qconfig_mapping`` will be used instead.
  89. If ``prepare_custom_config`` is None, an empty ``PrepareCustomConfig`` will be used.
  90. If ``backend_config`` is None, the parent ``backend_config`` will be used instead.
  91. """
  92. self.standalone_module_classes[module_class] = \
  93. StandaloneModuleConfigEntry(qconfig_mapping, example_inputs, prepare_custom_config, backend_config)
  94. return self
  95. def set_float_to_observed_mapping(
  96. self,
  97. float_class: Type,
  98. observed_class: Type,
  99. quant_type: QuantType = QuantType.STATIC) -> PrepareCustomConfig:
  100. """
  101. Set the mapping from a custom float module class to a custom observed module class.
  102. The observed module class must have a ``from_float`` class method that converts the float module class
  103. to the observed module class. This is currently only supported for static quantization.
  104. """
  105. if quant_type != QuantType.STATIC:
  106. raise ValueError("set_float_to_observed_mapping is currently only supported for static quantization")
  107. if quant_type not in self.float_to_observed_mapping:
  108. self.float_to_observed_mapping[quant_type] = {}
  109. self.float_to_observed_mapping[quant_type][float_class] = observed_class
  110. return self
  111. def set_non_traceable_module_names(self, module_names: List[str]) -> PrepareCustomConfig:
  112. """
  113. Set the modules that are not symbolically traceable, identified by name.
  114. """
  115. self.non_traceable_module_names = module_names
  116. return self
  117. def set_non_traceable_module_classes(self, module_classes: List[Type]) -> PrepareCustomConfig:
  118. """
  119. Set the modules that are not symbolically traceable, identified by class.
  120. """
  121. self.non_traceable_module_classes = module_classes
  122. return self
  123. def set_input_quantized_indexes(self, indexes: List[int]) -> PrepareCustomConfig:
  124. """
  125. Set the indexes of the inputs of the graph that should be quantized.
  126. Inputs are otherwise assumed to be in fp32 by default instead.
  127. """
  128. self.input_quantized_indexes = indexes
  129. return self
  130. def set_output_quantized_indexes(self, indexes: List[int]) -> PrepareCustomConfig:
  131. """
  132. Set the indexes of the outputs of the graph that should be quantized.
  133. Outputs are otherwise assumed to be in fp32 by default instead.
  134. """
  135. self.output_quantized_indexes = indexes
  136. return self
  137. def set_preserved_attributes(self, attributes: List[str]) -> PrepareCustomConfig:
  138. """
  139. Set the names of the attributes that will persist in the graph module even if they are not used in
  140. the model's ``forward`` method.
  141. """
  142. self.preserved_attributes = attributes
  143. return self
  144. # TODO: remove this
  145. @classmethod
  146. def from_dict(cls, prepare_custom_config_dict: Dict[str, Any]) -> PrepareCustomConfig:
  147. """
  148. Create a ``PrepareCustomConfig`` from a dictionary with the following items:
  149. "standalone_module_name": a list of (module_name, qconfig_mapping, example_inputs,
  150. child_prepare_custom_config, backend_config) tuples
  151. "standalone_module_class" a list of (module_class, qconfig_mapping, example_inputs,
  152. child_prepare_custom_config, backend_config) tuples
  153. "float_to_observed_custom_module_class": a nested dictionary mapping from quantization
  154. mode to an inner mapping from float module classes to observed module classes, e.g.
  155. {"static": {FloatCustomModule: ObservedCustomModule}}
  156. "non_traceable_module_name": a list of modules names that are not symbolically traceable
  157. "non_traceable_module_class": a list of module classes that are not symbolically traceable
  158. "input_quantized_idxs": a list of indexes of graph inputs that should be quantized
  159. "output_quantized_idxs": a list of indexes of graph outputs that should be quantized
  160. "preserved_attributes": a list of attributes that persist even if they are not used in ``forward``
  161. This function is primarily for backward compatibility and may be removed in the future.
  162. """
  163. def _get_qconfig_mapping(obj: Any, dict_key: str) -> Optional[QConfigMapping]:
  164. """
  165. Convert the given object into a QConfigMapping if possible, else throw an exception.
  166. """
  167. if isinstance(obj, QConfigMapping) or obj is None:
  168. return obj
  169. if isinstance(obj, Dict):
  170. return QConfigMapping.from_dict(obj)
  171. raise ValueError("Expected QConfigMapping in prepare_custom_config_dict[\"%s\"], got '%s'" %
  172. (dict_key, type(obj)))
  173. def _get_prepare_custom_config(obj: Any, dict_key: str) -> Optional[PrepareCustomConfig]:
  174. """
  175. Convert the given object into a PrepareCustomConfig if possible, else throw an exception.
  176. """
  177. if isinstance(obj, PrepareCustomConfig) or obj is None:
  178. return obj
  179. if isinstance(obj, Dict):
  180. return PrepareCustomConfig.from_dict(obj)
  181. raise ValueError("Expected PrepareCustomConfig in prepare_custom_config_dict[\"%s\"], got '%s'" %
  182. (dict_key, type(obj)))
  183. def _get_backend_config(obj: Any, dict_key: str) -> Optional[BackendConfig]:
  184. """
  185. Convert the given object into a BackendConfig if possible, else throw an exception.
  186. """
  187. if isinstance(obj, BackendConfig) or obj is None:
  188. return obj
  189. if isinstance(obj, Dict):
  190. return BackendConfig.from_dict(obj)
  191. raise ValueError("Expected BackendConfig in prepare_custom_config_dict[\"%s\"], got '%s'" %
  192. (dict_key, type(obj)))
  193. conf = cls()
  194. for (module_name, qconfig_dict, example_inputs, _prepare_custom_config_dict, backend_config_dict) in\
  195. prepare_custom_config_dict.get(STANDALONE_MODULE_NAME_DICT_KEY, []):
  196. qconfig_mapping = _get_qconfig_mapping(qconfig_dict, STANDALONE_MODULE_NAME_DICT_KEY)
  197. prepare_custom_config = _get_prepare_custom_config(_prepare_custom_config_dict, STANDALONE_MODULE_NAME_DICT_KEY)
  198. backend_config = _get_backend_config(backend_config_dict, STANDALONE_MODULE_NAME_DICT_KEY)
  199. conf.set_standalone_module_name(
  200. module_name, qconfig_mapping, example_inputs, prepare_custom_config, backend_config)
  201. for (module_class, qconfig_dict, example_inputs, _prepare_custom_config_dict, backend_config_dict) in\
  202. prepare_custom_config_dict.get(STANDALONE_MODULE_CLASS_DICT_KEY, []):
  203. qconfig_mapping = _get_qconfig_mapping(qconfig_dict, STANDALONE_MODULE_CLASS_DICT_KEY)
  204. prepare_custom_config = _get_prepare_custom_config(_prepare_custom_config_dict, STANDALONE_MODULE_CLASS_DICT_KEY)
  205. backend_config = _get_backend_config(backend_config_dict, STANDALONE_MODULE_CLASS_DICT_KEY)
  206. conf.set_standalone_module_class(
  207. module_class, qconfig_mapping, example_inputs, prepare_custom_config, backend_config)
  208. for quant_type_name, custom_module_mapping in prepare_custom_config_dict.get(FLOAT_TO_OBSERVED_DICT_KEY, {}).items():
  209. quant_type = _quant_type_from_str(quant_type_name)
  210. for float_class, observed_class in custom_module_mapping.items():
  211. conf.set_float_to_observed_mapping(float_class, observed_class, quant_type)
  212. conf.set_non_traceable_module_names(prepare_custom_config_dict.get(NON_TRACEABLE_MODULE_NAME_DICT_KEY, []))
  213. conf.set_non_traceable_module_classes(prepare_custom_config_dict.get(NON_TRACEABLE_MODULE_CLASS_DICT_KEY, []))
  214. conf.set_input_quantized_indexes(prepare_custom_config_dict.get(INPUT_QUANTIZED_INDEXES_DICT_KEY, []))
  215. conf.set_output_quantized_indexes(prepare_custom_config_dict.get(OUTPUT_QUANTIZED_INDEXES_DICT_KEY, []))
  216. conf.set_preserved_attributes(prepare_custom_config_dict.get(PRESERVED_ATTRIBUTES_DICT_KEY, []))
  217. return conf
  218. def to_dict(self) -> Dict[str, Any]:
  219. """
  220. Convert this ``PrepareCustomConfig`` to a dictionary with the items described in
  221. :func:`~torch.ao.quantization.fx.custom_config.PrepareCustomConfig.from_dict`.
  222. """
  223. def _make_tuple(key: Any, e: StandaloneModuleConfigEntry):
  224. qconfig_dict = e.qconfig_mapping.to_dict() if e.qconfig_mapping else None
  225. prepare_custom_config_dict = e.prepare_custom_config.to_dict() if e.prepare_custom_config else None
  226. return (key, qconfig_dict, e.example_inputs, prepare_custom_config_dict, e.backend_config)
  227. d: Dict[str, Any] = {}
  228. for module_name, sm_config_entry in self.standalone_module_names.items():
  229. if STANDALONE_MODULE_NAME_DICT_KEY not in d:
  230. d[STANDALONE_MODULE_NAME_DICT_KEY] = []
  231. d[STANDALONE_MODULE_NAME_DICT_KEY].append(_make_tuple(module_name, sm_config_entry))
  232. for module_class, sm_config_entry in self.standalone_module_classes.items():
  233. if STANDALONE_MODULE_CLASS_DICT_KEY not in d:
  234. d[STANDALONE_MODULE_CLASS_DICT_KEY] = []
  235. d[STANDALONE_MODULE_CLASS_DICT_KEY].append(_make_tuple(module_class, sm_config_entry))
  236. for quant_type, float_to_observed_mapping in self.float_to_observed_mapping.items():
  237. if FLOAT_TO_OBSERVED_DICT_KEY not in d:
  238. d[FLOAT_TO_OBSERVED_DICT_KEY] = {}
  239. d[FLOAT_TO_OBSERVED_DICT_KEY][_get_quant_type_to_str(quant_type)] = float_to_observed_mapping
  240. if len(self.non_traceable_module_names) > 0:
  241. d[NON_TRACEABLE_MODULE_NAME_DICT_KEY] = self.non_traceable_module_names
  242. if len(self.non_traceable_module_classes) > 0:
  243. d[NON_TRACEABLE_MODULE_CLASS_DICT_KEY] = self.non_traceable_module_classes
  244. if len(self.input_quantized_indexes) > 0:
  245. d[INPUT_QUANTIZED_INDEXES_DICT_KEY] = self.input_quantized_indexes
  246. if len(self.output_quantized_indexes) > 0:
  247. d[OUTPUT_QUANTIZED_INDEXES_DICT_KEY] = self.output_quantized_indexes
  248. if len(self.preserved_attributes) > 0:
  249. d[PRESERVED_ATTRIBUTES_DICT_KEY] = self.preserved_attributes
  250. return d
  251. class ConvertCustomConfig:
  252. """
  253. Custom configuration for :func:`~torch.ao.quantization.quantize_fx.convert_fx`.
  254. Example usage::
  255. convert_custom_config = ConvertCustomConfig() \
  256. .set_observed_to_quantized_mapping(ObservedCustomModule, QuantizedCustomModule) \
  257. .set_preserved_attributes(["attr1", "attr2"])
  258. """
  259. def __init__(self):
  260. self.observed_to_quantized_mapping: Dict[QuantType, Dict[Type, Type]] = {}
  261. self.preserved_attributes: List[str] = []
  262. def __repr__(self):
  263. dict_nonempty = {
  264. k: v for k, v in self.__dict__.items()
  265. if len(v) > 0
  266. }
  267. return f"ConvertCustomConfig({dict_nonempty})"
  268. def set_observed_to_quantized_mapping(
  269. self,
  270. observed_class: Type,
  271. quantized_class: Type,
  272. quant_type: QuantType = QuantType.STATIC) -> ConvertCustomConfig:
  273. """
  274. Set the mapping from a custom observed module class to a custom quantized module class.
  275. The quantized module class must have a ``from_observed`` class method that converts the observed module class
  276. to the quantized module class.
  277. """
  278. if quant_type not in self.observed_to_quantized_mapping:
  279. self.observed_to_quantized_mapping[quant_type] = {}
  280. self.observed_to_quantized_mapping[quant_type][observed_class] = quantized_class
  281. return self
  282. def set_preserved_attributes(self, attributes: List[str]) -> ConvertCustomConfig:
  283. """
  284. Set the names of the attributes that will persist in the graph module even if they are not used in
  285. the model's ``forward`` method.
  286. """
  287. self.preserved_attributes = attributes
  288. return self
  289. # TODO: remove this
  290. @classmethod
  291. def from_dict(cls, convert_custom_config_dict: Dict[str, Any]) -> ConvertCustomConfig:
  292. """
  293. Create a ``ConvertCustomConfig`` from a dictionary with the following items:
  294. "observed_to_quantized_custom_module_class": a nested dictionary mapping from quantization
  295. mode to an inner mapping from observed module classes to quantized module classes, e.g.::
  296. {
  297. "static": {FloatCustomModule: ObservedCustomModule},
  298. "dynamic": {FloatCustomModule: ObservedCustomModule},
  299. "weight_only": {FloatCustomModule: ObservedCustomModule}
  300. }
  301. "preserved_attributes": a list of attributes that persist even if they are not used in ``forward``
  302. This function is primarily for backward compatibility and may be removed in the future.
  303. """
  304. conf = cls()
  305. for quant_type_name, custom_module_mapping in convert_custom_config_dict.get(OBSERVED_TO_QUANTIZED_DICT_KEY, {}).items():
  306. quant_type = _quant_type_from_str(quant_type_name)
  307. for observed_class, quantized_class in custom_module_mapping.items():
  308. conf.set_observed_to_quantized_mapping(observed_class, quantized_class, quant_type)
  309. conf.set_preserved_attributes(convert_custom_config_dict.get(PRESERVED_ATTRIBUTES_DICT_KEY, []))
  310. return conf
  311. def to_dict(self) -> Dict[str, Any]:
  312. """
  313. Convert this ``ConvertCustomConfig`` to a dictionary with the items described in
  314. :func:`~torch.ao.quantization.fx.custom_config.ConvertCustomConfig.from_dict`.
  315. """
  316. d: Dict[str, Any] = {}
  317. for quant_type, observed_to_quantized_mapping in self.observed_to_quantized_mapping.items():
  318. if OBSERVED_TO_QUANTIZED_DICT_KEY not in d:
  319. d[OBSERVED_TO_QUANTIZED_DICT_KEY] = {}
  320. d[OBSERVED_TO_QUANTIZED_DICT_KEY][_get_quant_type_to_str(quant_type)] = observed_to_quantized_mapping
  321. if len(self.preserved_attributes) > 0:
  322. d[PRESERVED_ATTRIBUTES_DICT_KEY] = self.preserved_attributes
  323. return d
  324. class FuseCustomConfig:
  325. """
  326. Custom configuration for :func:`~torch.ao.quantization.quantize_fx.fuse_fx`.
  327. Example usage::
  328. fuse_custom_config = FuseCustomConfig().set_preserved_attributes(["attr1", "attr2"])
  329. """
  330. def __init__(self):
  331. self.preserved_attributes: List[str] = []
  332. def __repr__(self):
  333. dict_nonempty = {
  334. k: v for k, v in self.__dict__.items()
  335. if len(v) > 0
  336. }
  337. return f"FuseCustomConfig({dict_nonempty})"
  338. def set_preserved_attributes(self, attributes: List[str]) -> FuseCustomConfig:
  339. """
  340. Set the names of the attributes that will persist in the graph module even if they are not used in
  341. the model's ``forward`` method.
  342. """
  343. self.preserved_attributes = attributes
  344. return self
  345. # TODO: remove this
  346. @classmethod
  347. def from_dict(cls, fuse_custom_config_dict: Dict[str, Any]) -> FuseCustomConfig:
  348. """
  349. Create a ``ConvertCustomConfig`` from a dictionary with the following items:
  350. "preserved_attributes": a list of attributes that persist even if they are not used in ``forward``
  351. This function is primarily for backward compatibility and may be removed in the future.
  352. """
  353. conf = cls()
  354. conf.set_preserved_attributes(fuse_custom_config_dict.get(PRESERVED_ATTRIBUTES_DICT_KEY, []))
  355. return conf
  356. def to_dict(self) -> Dict[str, Any]:
  357. """
  358. Convert this ``FuseCustomConfig`` to a dictionary with the items described in
  359. :func:`~torch.ao.quantization.fx.custom_config.ConvertCustomConfig.from_dict`.
  360. """
  361. d: Dict[str, Any] = {}
  362. if len(self.preserved_attributes) > 0:
  363. d[PRESERVED_ATTRIBUTES_DICT_KEY] = self.preserved_attributes
  364. return d