quantize_fx.py 31 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725
  1. from typing import Any, Dict, Optional, Tuple, Union
  2. import warnings
  3. import torch
  4. import copy
  5. from torch.fx import GraphModule
  6. from torch.fx.graph_module import _USER_PRESERVED_ATTRIBUTES_KEY
  7. from .fx.tracer import QuantizationTracer
  8. from .fx.tracer import ( # noqa: F401
  9. Scope,
  10. ScopeContextManager
  11. )
  12. from .fx import fuse # noqa: F401
  13. from .fx import prepare # noqa: F401
  14. from .fx.convert import convert
  15. from .backend_config import ( # noqa: F401
  16. BackendConfig,
  17. get_tensorrt_backend_config,
  18. )
  19. from .fx.graph_module import ObservedGraphModule # noqa: F401
  20. from .fx.custom_config import (
  21. ConvertCustomConfig,
  22. FuseCustomConfig,
  23. PrepareCustomConfig,
  24. )
  25. from .fx.utils import get_custom_module_class_keys # noqa: F401
  26. from .fx.utils import get_skipped_module_name_and_classes
  27. from .qconfig_mapping import QConfigMapping
  28. def attach_preserved_attrs_to_model(
  29. model: Union[GraphModule, torch.nn.Module], preserved_attrs: Dict[str, Any]):
  30. """ Store preserved attributes to the model.meta so that it can be preserved during deepcopy
  31. """
  32. model.meta[_USER_PRESERVED_ATTRIBUTES_KEY] = copy.copy(preserved_attrs) # type: ignore[operator, index, assignment]
  33. # set the preserved attributes in the model so that user can call
  34. # model.attr as they do before calling fx graph mode quantization
  35. for attr_name, attr in model.meta[_USER_PRESERVED_ATTRIBUTES_KEY].items(): # type: ignore[index, union-attr]
  36. setattr(model, attr_name, attr)
  37. def _check_is_graph_module(model: torch.nn.Module) -> None:
  38. if not isinstance(model, GraphModule):
  39. raise ValueError(
  40. "input model must be a GraphModule, "
  41. + "Got type:"
  42. + str(type(model))
  43. + " Please make "
  44. + "sure to follow the tutorials."
  45. )
  46. def _attach_meta_to_node_if_not_exist(model: GraphModule):
  47. """ Attach meta field to all nodes of the graph if it does not exist,
  48. meta field is a field stores some meta information about the node, such
  49. as dtype and shape information for output of the node, this only exists
  50. if the program is captured by make_fx (used in quantize_pt2e flow), if
  51. the program is captured by torch.fx symbolic tracing, this field may not exist,
  52. so we add it here to avoid checking this all over the places
  53. """
  54. for node in model.graph.nodes:
  55. if not hasattr(node, "meta"):
  56. node.meta = {}
  57. def _swap_ff_with_fxff(model: torch.nn.Module) -> None:
  58. r""" Swap FloatFunctional with FXFloatFunctional
  59. """
  60. modules_to_swap = []
  61. for name, module in model.named_children():
  62. if isinstance(module, torch.ao.nn.quantized.FloatFunctional):
  63. modules_to_swap.append(name)
  64. else:
  65. _swap_ff_with_fxff(module)
  66. for name in modules_to_swap:
  67. del model._modules[name]
  68. model._modules[name] = torch.ao.nn.quantized.FXFloatFunctional()
  69. def _fuse_fx(
  70. model: GraphModule,
  71. is_qat: bool,
  72. fuse_custom_config: Union[FuseCustomConfig, Dict[str, Any], None] = None,
  73. backend_config: Union[BackendConfig, Dict[str, Any], None] = None,
  74. ) -> GraphModule:
  75. r""" Internal helper function to fuse modules in preparation for quantization
  76. Args:
  77. model: GraphModule object from symbolic tracing (torch.fx.symbolic_trace)
  78. """
  79. _check_is_graph_module(model)
  80. return fuse(
  81. model, is_qat, fuse_custom_config, backend_config) # type: ignore[operator]
  82. def _prepare_fx(
  83. model: torch.nn.Module,
  84. qconfig_mapping: Union[QConfigMapping, Dict[str, Any]],
  85. is_qat: bool,
  86. example_inputs: Tuple[Any, ...],
  87. prepare_custom_config: Union[PrepareCustomConfig, Dict[str, Any], None] = None,
  88. _equalization_config: Optional[Union[QConfigMapping, Dict[str, Any]]] = None,
  89. backend_config: Union[BackendConfig, Dict[str, Any], None] = None,
  90. is_standalone_module: bool = False,
  91. ) -> GraphModule:
  92. r""" Internal helper function for prepare_fx
  93. Args:
  94. `model`, `qconfig_mapping`, `prepare_custom_config`, `_equalization_config`:
  95. see docs for :func:`~torch.ao.quantization.prepare_fx`
  96. `is_standalone_module`: a boolean flag indicates whether we are
  97. quantizing a standalone module or not, a standalone module
  98. is a submodule of the parent module that is not inlined in the
  99. forward graph of the parent module,
  100. the way we quantize standalone module is described in:
  101. :func:`~torch.ao.quantization._prepare_standalone_module_fx`
  102. """
  103. if prepare_custom_config is None:
  104. prepare_custom_config = PrepareCustomConfig()
  105. if _equalization_config is None:
  106. _equalization_config = QConfigMapping()
  107. if isinstance(prepare_custom_config, Dict):
  108. warnings.warn(
  109. "Passing a prepare_custom_config_dict to prepare is deprecated and will not be supported "
  110. "in a future version. Please pass in a PrepareCustomConfig instead.")
  111. prepare_custom_config = PrepareCustomConfig.from_dict(prepare_custom_config)
  112. # swap FloatFunctional with FXFloatFunctional
  113. _swap_ff_with_fxff(model)
  114. skipped_module_names, skipped_module_classes = \
  115. get_skipped_module_name_and_classes(prepare_custom_config, is_standalone_module)
  116. preserved_attr_names = prepare_custom_config.preserved_attributes
  117. preserved_attrs = {attr: getattr(model, attr) for attr in preserved_attr_names if hasattr(model, attr)}
  118. # symbolically trace the model
  119. tracer = QuantizationTracer(skipped_module_names, skipped_module_classes) # type: ignore[arg-type]
  120. graph_module = GraphModule(model, tracer.trace(model))
  121. _attach_meta_to_node_if_not_exist(graph_module)
  122. fuse_custom_config = FuseCustomConfig().set_preserved_attributes(prepare_custom_config.preserved_attributes)
  123. graph_module = _fuse_fx(
  124. graph_module,
  125. is_qat,
  126. fuse_custom_config,
  127. backend_config)
  128. prepared = prepare(
  129. graph_module,
  130. qconfig_mapping,
  131. is_qat,
  132. tracer.node_name_to_scope,
  133. example_inputs=example_inputs,
  134. prepare_custom_config=prepare_custom_config,
  135. _equalization_config=_equalization_config,
  136. backend_config=backend_config,
  137. is_standalone_module=is_standalone_module,
  138. ) # type: ignore[operator]
  139. attach_preserved_attrs_to_model(prepared, preserved_attrs)
  140. return prepared
  141. def _prepare_standalone_module_fx(
  142. model: torch.nn.Module,
  143. qconfig_mapping: Union[QConfigMapping, Dict[str, Any]],
  144. is_qat: bool,
  145. example_inputs: Tuple[Any, ...],
  146. prepare_custom_config: Union[PrepareCustomConfig, Dict[str, Any], None] = None,
  147. backend_config: Union[BackendConfig, Dict[str, Any], None] = None,
  148. ) -> GraphModule:
  149. r""" [Internal use only] Prepare a standalone module, so that it can be used when quantizing the
  150. parent module.
  151. standalone_module means it a submodule that is not inlined in parent module,
  152. and will be quantized separately as one unit.
  153. How the standalone module is observed is specified by `input_quantized_idxs` and
  154. `output_quantized_idxs` in the prepare_custom_config for the standalone module
  155. Returns:
  156. * model(GraphModule): prepared standalone module. It has these attributes in
  157. model.meta:
  158. * `standalone_module_input_quantized_idxs(List[Int])`: a list of
  159. indexes for the graph input that is expected to be quantized,
  160. same as input_quantized_idxs configuration provided
  161. for the standalone module
  162. * `standalone_module_output_quantized_idxs(List[Int])`: a list of
  163. indexs for the graph output that is quantized
  164. same as input_quantized_idxs configuration provided
  165. for the standalone module
  166. """
  167. return _prepare_fx(
  168. model,
  169. qconfig_mapping,
  170. is_qat,
  171. example_inputs,
  172. prepare_custom_config,
  173. backend_config=backend_config,
  174. is_standalone_module=True,
  175. )
  176. def fuse_fx(
  177. model: torch.nn.Module,
  178. fuse_custom_config: Union[FuseCustomConfig, Dict[str, Any], None] = None,
  179. backend_config: Union[BackendConfig, Dict[str, Any], None] = None,
  180. ) -> GraphModule:
  181. r""" Fuse modules like conv+bn, conv+bn+relu etc, model must be in eval mode.
  182. Fusion rules are defined in torch.ao.quantization.fx.fusion_pattern.py
  183. Args:
  184. * `model` (torch.nn.Module): a torch.nn.Module model
  185. * `fuse_custom_config` (FuseCustomConfig): custom configurations for fuse_fx.
  186. See :class:`~torch.ao.quantization.fx.custom_config.FuseCustomConfig` for more details
  187. Example::
  188. from torch.ao.quantization import fuse_fx
  189. m = Model().eval()
  190. m = fuse_fx(m)
  191. """
  192. if fuse_custom_config is None:
  193. fuse_custom_config = FuseCustomConfig()
  194. if isinstance(fuse_custom_config, Dict):
  195. warnings.warn(
  196. "Passing a fuse_custom_config_dict to fuse is deprecated and will not be supported "
  197. "in a future version. Please pass in a FuseCustomConfig instead.")
  198. fuse_custom_config = FuseCustomConfig.from_dict(fuse_custom_config)
  199. torch._C._log_api_usage_once("quantization_api.quantize_fx.fuse_fx")
  200. preserved_attr_names = fuse_custom_config.preserved_attributes
  201. preserved_attrs = {attr: getattr(model, attr) for attr in preserved_attr_names if hasattr(model, attr)}
  202. graph_module = torch.fx.symbolic_trace(model)
  203. _attach_meta_to_node_if_not_exist(graph_module)
  204. graph_module = _fuse_fx(graph_module, False, fuse_custom_config, backend_config)
  205. attach_preserved_attrs_to_model(graph_module, preserved_attrs)
  206. return graph_module
  207. def prepare_fx(
  208. model: torch.nn.Module,
  209. qconfig_mapping: Union[QConfigMapping, Dict[str, Any]],
  210. example_inputs: Tuple[Any, ...],
  211. prepare_custom_config: Union[PrepareCustomConfig, Dict[str, Any], None] = None,
  212. _equalization_config: Optional[Union[QConfigMapping, Dict[str, Any]]] = None,
  213. backend_config: Union[BackendConfig, Dict[str, Any], None] = None,
  214. ) -> GraphModule:
  215. r""" Prepare a model for post training static quantization
  216. Args:
  217. * `model` (torch.nn.Module): torch.nn.Module model
  218. * `qconfig_mapping` (QConfigMapping): QConfigMapping object to configure how a model is
  219. quantized, see :class:`~torch.ao.quantization.qconfig_mapping.QConfigMapping`
  220. for more details
  221. * `example_inputs` (Tuple[Any, ...]): Example inputs for forward function of the model,
  222. Tuple of positional args (keyword args can be passed as positional args as well)
  223. * `prepare_custom_config` (PrepareCustomConfig): customization configuration for quantization tool.
  224. See :class:`~torch.ao.quantization.fx.custom_config.PrepareCustomConfig` for more details
  225. * `_equalization_config`: config for specifying how to perform equalization on the model
  226. * `backend_config` (BackendConfig): config that specifies how operators are quantized
  227. in a backend, this includes how the operators are observed,
  228. supported fusion patterns, how quantize/dequantize ops are
  229. inserted, supported dtypes etc. See :class:`~torch.ao.quantization.backend_config.BackendConfig` for more details
  230. Return:
  231. A GraphModule with observer (configured by qconfig_mapping), ready for calibration
  232. Example::
  233. import torch
  234. from torch.ao.quantization import get_default_qconfig_mapping
  235. from torch.ao.quantization import prepare_fx
  236. class Submodule(torch.nn.Module):
  237. def __init__(self):
  238. super().__init__()
  239. self.linear = torch.nn.Linear(5, 5)
  240. def forward(self, x):
  241. x = self.linear(x)
  242. return x
  243. class M(torch.nn.Module):
  244. def __init__(self):
  245. super().__init__()
  246. self.linear = torch.nn.Linear(5, 5)
  247. self.sub = Submodule()
  248. def forward(self, x):
  249. x = self.linear(x)
  250. x = self.sub(x) + x
  251. return x
  252. # initialize a floating point model
  253. float_model = M().eval()
  254. # define calibration function
  255. def calibrate(model, data_loader):
  256. model.eval()
  257. with torch.no_grad():
  258. for image, target in data_loader:
  259. model(image)
  260. # qconfig is the configuration for how we insert observers for a particular
  261. # operator
  262. # qconfig = get_default_qconfig("fbgemm")
  263. # Example of customizing qconfig:
  264. # qconfig = torch.ao.quantization.QConfig(
  265. # activation=MinMaxObserver.with_args(dtype=torch.qint8),
  266. # weight=MinMaxObserver.with_args(dtype=torch.qint8))
  267. # `activation` and `weight` are constructors of observer module
  268. # qconfig_mapping is a collection of quantization configurations, user can
  269. # set the qconfig for each operator (torch op calls, functional calls, module calls)
  270. # in the model through qconfig_mapping
  271. # the following call will get the qconfig_mapping that works best for models
  272. # that target "fbgemm" backend
  273. qconfig_mapping = get_default_qconfig_mapping("fbgemm")
  274. # We can customize qconfig_mapping in different ways.
  275. # e.g. set the global qconfig, which means we will use the same qconfig for
  276. # all operators in the model, this can be overwritten by other settings
  277. # qconfig_mapping = QConfigMapping().set_global(qconfig)
  278. # e.g. quantize the linear submodule with a specific qconfig
  279. # qconfig_mapping = QConfigMapping().set_module_name("linear", qconfig)
  280. # e.g. quantize all nn.Linear modules with a specific qconfig
  281. # qconfig_mapping = QConfigMapping().set_object_type(torch.nn.Linear, qconfig)
  282. # for a more complete list, please see the docstring for :class:`torch.ao.quantization.QConfigMapping`
  283. # argument
  284. # example_inputs is a tuple of inputs, that is used to infer the type of the
  285. # outputs in the model
  286. # currently it's not used, but please make sure model(*example_inputs) runs
  287. example_inputs = (torch.randn(1, 3, 224, 224),)
  288. # TODO: add backend_config after we split the backend_config for fbgemm and qnnpack
  289. # e.g. backend_config = get_default_backend_config("fbgemm")
  290. # `prepare_fx` inserts observers in the model based on qconfig_mapping and
  291. # backend_config. If the configuration for an operator in qconfig_mapping
  292. # is supported in the backend_config (meaning it's supported by the target
  293. # hardware), we'll insert observer modules according to the qconfig_mapping
  294. # otherwise the configuration in qconfig_mapping will be ignored
  295. #
  296. # Example:
  297. # in qconfig_mapping, user sets linear module to be quantized with quint8 for
  298. # activation and qint8 for weight:
  299. # qconfig = torch.ao.quantization.QConfig(
  300. # observer=MinMaxObserver.with_args(dtype=torch.quint8),
  301. # weight=MinMaxObserver.with-args(dtype=torch.qint8))
  302. # Note: current qconfig api does not support setting output observer, but
  303. # we may extend this to support these more fine grained control in the
  304. # future
  305. #
  306. # qconfig_mapping = QConfigMapping().set_object_type(torch.nn.Linear, qconfig)
  307. # in backend config, linear module also supports in this configuration:
  308. # weighted_int8_dtype_config = DTypeConfig(
  309. # input_dtype=torch.quint8,
  310. # output_dtype=torch.quint8,
  311. # weight_dtype=torch.qint8,
  312. # bias_type=torch.float)
  313. # linear_pattern_config = BackendPatternConfig(torch.nn.Linear) \
  314. # .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) \
  315. # .add_dtype_config(weighted_int8_dtype_config) \
  316. # ...
  317. # backend_config = BackendConfig().set_backend_pattern_config(linear_pattern_config)
  318. # `prepare_fx` will check that the setting requested by suer in qconfig_mapping
  319. # is supported by the backend_config and insert observers and fake quant modules
  320. # in the model
  321. prepared_model = prepare_fx(float_model, qconfig_mapping, example_inputs)
  322. # Run calibration
  323. calibrate(prepared_model, sample_inference_data)
  324. """
  325. torch._C._log_api_usage_once("quantization_api.quantize_fx.prepare_fx")
  326. return _prepare_fx(
  327. model,
  328. qconfig_mapping,
  329. False, # is_qat
  330. example_inputs,
  331. prepare_custom_config,
  332. _equalization_config,
  333. backend_config,
  334. )
  335. def prepare_qat_fx(
  336. model: torch.nn.Module,
  337. qconfig_mapping: Union[QConfigMapping, Dict[str, Any]],
  338. example_inputs: Tuple[Any, ...],
  339. prepare_custom_config: Union[PrepareCustomConfig, Dict[str, Any], None] = None,
  340. backend_config: Union[BackendConfig, Dict[str, Any], None] = None,
  341. ) -> GraphModule:
  342. r""" Prepare a model for quantization aware training
  343. Args:
  344. * `model` (torch.nn.Module): torch.nn.Module model
  345. * `qconfig_mapping` (QConfigMapping): see :func:`~torch.ao.quantization.prepare_fx`
  346. * `example_inputs` (Tuple[Any, ...]): see :func:`~torch.ao.quantization.prepare_fx`
  347. * `prepare_custom_config` (PrepareCustomConfig): see :func:`~torch.ao.quantization.prepare_fx`
  348. * `backend_config` (BackendConfig): see :func:`~torch.ao.quantization.prepare_fx`
  349. Return:
  350. A GraphModule with fake quant modules (configured by qconfig_mapping and backend_config), ready for
  351. quantization aware training
  352. Example::
  353. import torch
  354. from torch.ao.quantization import get_default_qat_qconfig_mapping
  355. from torch.ao.quantization import prepare_fx
  356. class Submodule(torch.nn.Module):
  357. def __init__(self):
  358. super().__init__()
  359. self.linear = torch.nn.Linear(5, 5)
  360. def forward(self, x):
  361. x = self.linear(x)
  362. return x
  363. class M(torch.nn.Module):
  364. def __init__(self):
  365. super().__init__()
  366. self.linear = torch.nn.Linear(5, 5)
  367. self.sub = Submodule()
  368. def forward(self, x):
  369. x = self.linear(x)
  370. x = self.sub(x) + x
  371. return x
  372. # initialize a floating point model
  373. float_model = M().train()
  374. # (optional, but preferred) load the weights from pretrained model
  375. # float_model.load_weights(...)
  376. # define the training loop for quantization aware training
  377. def train_loop(model, train_data):
  378. model.train()
  379. for image, target in data_loader:
  380. ...
  381. # qconfig is the configuration for how we insert observers for a particular
  382. # operator
  383. # qconfig = get_default_qconfig("fbgemm")
  384. # Example of customizing qconfig:
  385. # qconfig = torch.ao.quantization.QConfig(
  386. # activation=FakeQuantize.with_args(observer=MinMaxObserver.with_args(dtype=torch.qint8)),
  387. # weight=FakeQuantize.with_args(observer=MinMaxObserver.with_args(dtype=torch.qint8)))
  388. # `activation` and `weight` are constructors of observer module
  389. # qconfig_mapping is a collection of quantization configurations, user can
  390. # set the qconfig for each operator (torch op calls, functional calls, module calls)
  391. # in the model through qconfig_mapping
  392. # the following call will get the qconfig_mapping that works best for models
  393. # that target "fbgemm" backend
  394. qconfig_mapping = get_default_qat_qconfig("fbgemm")
  395. # We can customize qconfig_mapping in different ways, please take a look at
  396. # the docstring for :func:`~torch.ao.quantization.prepare_fx` for different ways
  397. # to configure this
  398. # example_inputs is a tuple of inputs, that is used to infer the type of the
  399. # outputs in the model
  400. # currently it's not used, but please make sure model(*example_inputs) runs
  401. example_inputs = (torch.randn(1, 3, 224, 224),)
  402. # TODO: add backend_config after we split the backend_config for fbgemm and qnnpack
  403. # e.g. backend_config = get_default_backend_config("fbgemm")
  404. # `prepare_qat_fx` inserts observers in the model based on qconfig_mapping and
  405. # backend_config, if the configuration for an operator in qconfig_mapping
  406. # is supported in the backend_config (meaning it's supported by the target
  407. # hardware), we'll insert fake_quantize modules according to the qconfig_mapping
  408. # otherwise the configuration in qconfig_mapping will be ignored
  409. # see :func:`~torch.ao.quantization.prepare_fx` for a detailed explanation of
  410. # how qconfig_mapping interacts with backend_config
  411. prepared_model = prepare_qat_fx(float_model, qconfig_mapping, example_inputs)
  412. # Run training
  413. train_loop(prepared_model, train_loop)
  414. """
  415. torch._C._log_api_usage_once("quantization_api.quantize_fx.prepare_qat_fx")
  416. return _prepare_fx(
  417. model,
  418. qconfig_mapping,
  419. True, # is_qat
  420. example_inputs,
  421. prepare_custom_config,
  422. backend_config=backend_config,
  423. )
  424. def _convert_fx(
  425. graph_module: GraphModule,
  426. is_reference: bool,
  427. convert_custom_config: Union[ConvertCustomConfig, Dict[str, Any], None] = None,
  428. is_standalone_module: bool = False,
  429. _remove_qconfig: bool = True,
  430. qconfig_mapping: Union[QConfigMapping, Dict[str, Any], None] = None,
  431. backend_config: Union[BackendConfig, Dict[str, Any], None] = None,
  432. is_decomposed: bool = False,
  433. ) -> torch.nn.Module:
  434. """ `is_standalone_module`: see docs in :func:`~torch.ao.quantization.prepare_standalone_module_fx`
  435. """
  436. if convert_custom_config is None:
  437. convert_custom_config = ConvertCustomConfig()
  438. if isinstance(convert_custom_config, Dict):
  439. warnings.warn(
  440. "Passing a convert_custom_config_dict to convert is deprecated and will not be supported "
  441. "in a future version. Please pass in a ConvertCustomConfig instead.")
  442. convert_custom_config = ConvertCustomConfig.from_dict(convert_custom_config)
  443. _check_is_graph_module(graph_module)
  444. preserved_attr_names = convert_custom_config.preserved_attributes
  445. preserved_attrs = {attr: getattr(graph_module, attr) for attr in preserved_attr_names if hasattr(graph_module, attr)}
  446. quantized = convert(
  447. graph_module,
  448. is_reference,
  449. convert_custom_config,
  450. is_standalone_module,
  451. _remove_qconfig_flag=_remove_qconfig,
  452. qconfig_mapping=qconfig_mapping,
  453. backend_config=backend_config,
  454. is_decomposed=is_decomposed,
  455. )
  456. attach_preserved_attrs_to_model(quantized, preserved_attrs)
  457. return quantized
  458. def convert_fx(
  459. graph_module: GraphModule,
  460. convert_custom_config: Union[ConvertCustomConfig, Dict[str, Any], None] = None,
  461. _remove_qconfig: bool = True,
  462. qconfig_mapping: Union[QConfigMapping, Dict[str, Any], None] = None,
  463. backend_config: Union[BackendConfig, Dict[str, Any], None] = None,
  464. ) -> torch.nn.Module:
  465. r""" Convert a calibrated or trained model to a quantized model
  466. Args:
  467. * `graph_module` (torch.fx.GraphModule): A prepared and calibrated/trained model (GraphModule)
  468. * `convert_custom_config` (ConvertCustomConfig): custom configurations for convert function.
  469. See :class:`~torch.ao.quantization.fx.custom_config.ConvertCustomConfig` for more details
  470. * `_remove_qconfig` (bool): Option to remove the qconfig attributes in the model after convert.
  471. * `qconfig_mapping` (QConfigMapping): config for specifying how to convert a model for quantization.
  472. The keys must include the ones in the qconfig_mapping passed to `prepare_fx` or `prepare_qat_fx`,
  473. with the same values or `None`. Additional keys can be specified with values set to `None`.
  474. For each entry whose value is set to None, we skip quantizing that entry in the model::
  475. qconfig_mapping = QConfigMapping
  476. .set_global(qconfig_from_prepare)
  477. .set_object_type(torch.nn.functional.add, None) # skip quantizing torch.nn.functional.add
  478. .set_object_type(torch.nn.functional.linear, qconfig_from_prepare)
  479. .set_module_name("foo.bar", None) # skip quantizing module "foo.bar"
  480. * `backend_config` (BackendConfig): A configuration for the backend which describes how
  481. operators should be quantized in the backend, this includes quantization
  482. mode support (static/dynamic/weight_only), dtype support (quint8/qint8 etc.),
  483. observer placement for each operators and fused operators.
  484. See :class:`~torch.ao.quantization.backend_config.BackendConfig` for more details
  485. Return:
  486. A quantized model (torch.nn.Module)
  487. Example::
  488. # prepared_model: the model after prepare_fx/prepare_qat_fx and calibration/training
  489. # convert_fx converts a calibrated/trained model to a quantized model for the
  490. # target hardware, this includes converting the model first to a reference
  491. # quantized model, and then lower the reference quantized model to a backend
  492. # Currently, the supported backends are fbgemm (onednn), qnnpack (xnnpack) and
  493. # they share the same set of quantized operators, so we are using the same
  494. # lowering procedure
  495. #
  496. # backend_config defines the corresponding reference quantized module for
  497. # the weighted modules in the model, e.g. nn.Linear
  498. # TODO: add backend_config after we split the backend_config for fbgemm and qnnpack
  499. # e.g. backend_config = get_default_backend_config("fbgemm")
  500. quantized_model = convert_fx(prepared_model)
  501. """
  502. torch._C._log_api_usage_once("quantization_api.quantize_fx.convert_fx")
  503. return _convert_fx(
  504. graph_module,
  505. is_reference=False,
  506. convert_custom_config=convert_custom_config,
  507. _remove_qconfig=_remove_qconfig,
  508. qconfig_mapping=qconfig_mapping,
  509. backend_config=backend_config,
  510. )
  511. def convert_to_reference_fx(
  512. graph_module: GraphModule,
  513. convert_custom_config: Union[ConvertCustomConfig, Dict[str, Any], None] = None,
  514. _remove_qconfig: bool = True,
  515. qconfig_mapping: Union[QConfigMapping, Dict[str, Any], None] = None,
  516. backend_config: Union[BackendConfig, Dict[str, Any], None] = None,
  517. ) -> torch.nn.Module:
  518. r""" Convert a calibrated or trained model to a reference quantized model,
  519. see https://github.com/pytorch/rfcs/blob/master/RFC-0019-Extending-PyTorch-Quantization-to-Custom-Backends.md for more details,
  520. reference quantzied model is a standard representation of a quantized model provided
  521. by FX Graph Mode Quantization, it can be further lowered to run on the target
  522. hardware, like accelerators
  523. Args:
  524. * `graph_module` (GraphModule): A prepared and calibrated/trained model (GraphModule)
  525. * `convert_custom_config` (ConvertCustomConfig): custom configurations for convert function.
  526. See :func:`~torch.ao.quantization.quantize_fx.convert_fx` for more details.
  527. * `_remove_qconfig` (bool): Option to remove the qconfig attributes in the model after convert.
  528. * `qconfig_mapping` (QConfigMapping): config for specifying how to convert a model for quantization.
  529. See :func:`~torch.ao.quantization.quantize_fx.convert_fx` for more details.
  530. * `backend_config` (BackendConfig): A configuration for the backend which describes how
  531. operators should be quantized in the backend. See
  532. :func:`~torch.ao.quantization.quantize_fx.convert_fx` for more details.
  533. Return:
  534. A reference quantized model (GraphModule)
  535. Example::
  536. # prepared_model: the model after prepare_fx/prepare_qat_fx and calibration/training
  537. # TODO: add backend_config after we split the backend_config for fbgemm and qnnpack
  538. # e.g. backend_config = get_default_backend_config("fbgemm")
  539. reference_quantized_model = convert_to_reference_fx(prepared_model)
  540. """
  541. torch._C._log_api_usage_once("quantization_api.quantize_fx.convert_to_reference_fx")
  542. return _convert_fx(
  543. graph_module,
  544. is_reference=True,
  545. convert_custom_config=convert_custom_config,
  546. _remove_qconfig=_remove_qconfig,
  547. qconfig_mapping=qconfig_mapping,
  548. backend_config=backend_config,
  549. )
  550. def _convert_to_reference_decomposed_fx(
  551. graph_module: GraphModule,
  552. convert_custom_config: Union[ConvertCustomConfig, Dict[str, Any], None] = None,
  553. _remove_qconfig: bool = True,
  554. qconfig_mapping: Union[QConfigMapping, Dict[str, Any], None] = None,
  555. backend_config: Union[BackendConfig, Dict[str, Any], None] = None,
  556. ) -> torch.nn.Module:
  557. r""" Convert a calibrated or trained model to a reference quantized model, with
  558. decomposed representation for quantized Tensor
  559. see https://github.com/pytorch/rfcs/blob/master/RFC-0019-Extending-PyTorch-Quantization-to-Custom-Backends.md for more details,
  560. reference quantzied model is a standard representation of a quantized model provided
  561. by FX Graph Mode Quantization, it can be further lowered to run on the target
  562. hardware, like accelerators
  563. Note: this is not public API
  564. Args:
  565. * `graph_module` (GraphModule): A prepared and calibrated/trained model (GraphModule)
  566. * `convert_custom_config` (ConvertCustomConfig): custom configurations for convert function.
  567. See :func:`~torch.ao.quantization.quantize_fx.convert_fx` for more details.
  568. * `_remove_qconfig` (bool): Option to remove the qconfig attributes in the model after convert.
  569. * `qconfig_mapping` (QConfigMapping): config for specifying how to convert a model for quantization.
  570. See :func:`~torch.ao.quantization.quantize_fx.convert_fx` for more details.
  571. * `backend_config` (BackendConfig): A configuration for the backend which describes how
  572. operators should be quantized in the backend. See
  573. :func:`~torch.ao.quantization.quantize_fx.convert_fx` for more details.
  574. Return:
  575. A reference quantized model (GraphModule) with operators working with decomposed quantized Tensor
  576. Example::
  577. # prepared_model: the model after prepare_fx/prepare_qat_fx and calibration/training
  578. # TODO: add backend_config after we split the backend_config for fbgemm and qnnpack
  579. # e.g. backend_config = get_default_backend_config("fbgemm")
  580. reference_quantized_model = _convert_to_reference_decomposed_fx(prepared_model)
  581. """
  582. torch._C._log_api_usage_once("quantization_api.quantize_fx._convert_to_reference_decomposed_fx")
  583. return _convert_fx(
  584. graph_module,
  585. is_reference=True,
  586. convert_custom_config=convert_custom_config,
  587. _remove_qconfig=_remove_qconfig,
  588. qconfig_mapping=qconfig_mapping,
  589. backend_config=backend_config,
  590. is_decomposed=True,
  591. )
  592. def _convert_standalone_module_fx(
  593. graph_module: GraphModule,
  594. is_reference: bool = False,
  595. convert_custom_config: Union[ConvertCustomConfig, Dict[str, Any], None] = None,
  596. ) -> torch.nn.Module:
  597. r""" [Internal use only] Convert a model produced by :func:`~torch.ao.quantization.prepare_standalone_module_fx`
  598. and convert it to a quantized model
  599. Returns a quantized standalone module, whether input/output is quantized is
  600. specified by prepare_custom_config, with
  601. input_quantized_idxs, output_quantized_idxs, please
  602. see docs for prepare_fx for details
  603. """
  604. return _convert_fx(
  605. graph_module,
  606. is_reference,
  607. convert_custom_config,
  608. is_standalone_module=True,
  609. )