backend_config.py 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657
  1. from __future__ import annotations
  2. from dataclasses import dataclass
  3. from typing import Any, Callable, Dict, List, Optional, Type, Union
  4. import torch
  5. from torch.ao.quantization.utils import Pattern
  6. from enum import Enum
  7. __all__ = [
  8. "BackendConfig",
  9. "BackendPatternConfig",
  10. "DTypeConfig",
  11. "DTypeWithConstraints",
  12. "ObservationType",
  13. ]
  14. # DTypeConfig dict keys
  15. INPUT_DTYPE_DICT_KEY = "input_dtype"
  16. OUTPUT_DTYPE_DICT_KEY = "output_dtype"
  17. WEIGHT_DTYPE_DICT_KEY = "weight_dtype"
  18. BIAS_DTYPE_DICT_KEY = "bias_dtype"
  19. IS_DYNAMIC_DICT_KEY = "is_dynamic"
  20. # BackendConfig dict keys
  21. NAME_DICT_KEY = "name"
  22. CONFIGS_DICT_KEY = "configs"
  23. # BackendPatternConfig dict keys
  24. PATTERN_DICT_KEY = "pattern"
  25. PATTERN_COMPLEX_FORMAT_DICT_KEY = "pattern_complex_format"
  26. OBSERVATION_TYPE_DICT_KEY = "observation_type"
  27. DTYPE_CONFIGS_DICT_KEY = "dtype_configs"
  28. ROOT_MODULE_DICT_KEY = "root_module"
  29. QAT_MODULE_DICT_KEY = "qat_module"
  30. REFERENCE_QUANTIZED_MODULE_DICT_KEY = "reference_quantized_module_for_root"
  31. FUSED_MODULE_DICT_KEY = "fused_module"
  32. FUSER_METHOD_DICT_KEY = "fuser_method"
  33. ROOT_NODE_GETTER_DICT_KEY = "root_node_getter"
  34. EXTRA_INPUTS_GETTER_DICT_KEY = "extra_inputs_getter"
  35. NUM_TENSOR_ARGS_TO_OBSERVATION_TYPE_DICT_KEY = "num_tensor_args_to_observation_type"
  36. INPUT_TYPE_TO_INDEX_DICT_KEY = "input_type_to_index"
  37. # TODO: maybe rename this to something that's not related to observer
  38. # e.g. QParamsType
  39. class ObservationType(Enum):
  40. """ An enum that represents different ways of how an operator/operator pattern
  41. should be observed
  42. """
  43. OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT = 0
  44. """this means input and output are observed with different observers, based
  45. on qconfig.activation
  46. example: conv, linear, softmax
  47. """
  48. OUTPUT_SHARE_OBSERVER_WITH_INPUT = 1
  49. """this means the output will use the same observer instance as input, based
  50. on qconfig.activation
  51. example: torch.cat, maxpool
  52. """
  53. INPUT_OUTPUT_NOT_OBSERVED = 2
  54. """this means the input and output are never observed
  55. example: x.shape, x.size
  56. """
  57. @dataclass
  58. class DTypeWithConstraints:
  59. """
  60. Config for specifying additional constraints for a given dtype, such as quantization
  61. value ranges, scale value ranges, and fixed quantization params, to be used in
  62. :class:`~torch.ao.quantization.backend_config.DTypeConfig`.
  63. The constraints currently supported are:
  64. * `quant_min_lower_bound` and `quant_max_upper_bound`: Lower and upper
  65. bounds for the minimum and maximum quantized values respectively. If
  66. the QConfig’s `quant_min` and `quant_max` fall outside this range,
  67. then the QConfig will be ignored.
  68. * `scale_min_lower_bound` and `scale_max_upper_bound`: Lower and upper
  69. bounds for the minimum and maximum scale values respectively. If the
  70. QConfig’s minimum scale value (currently exposed as `eps`) falls below
  71. the lower bound, then the QConfig will be ignored. Note that the upper
  72. bound is currently not enforced.
  73. * `scale_exact_match` and `zero_point_exact_match`: Exact match requirements
  74. for scale and zero point, to be used for operators with fixed quantization
  75. parameters such as sigmoid and tanh. If the observer specified in the QConfig
  76. is neither `FixedQParamsObserver` nor `FixedQParamsFakeQuantize`, or if
  77. the quantization parameters don't match, then the QConfig will be ignored.
  78. """
  79. dtype: Optional[torch.dtype] = None
  80. quant_min_lower_bound: Union[int, float, None] = None
  81. quant_max_upper_bound: Union[int, float, None] = None
  82. scale_min_lower_bound: Union[int, float, None] = None
  83. scale_max_upper_bound: Union[int, float, None] = None
  84. scale_exact_match: Optional[float] = None
  85. zero_point_exact_match: Optional[int] = None
  86. @dataclass
  87. class DTypeConfig:
  88. """
  89. Config object that specifies the supported data types passed as arguments to
  90. quantize ops in the reference model spec, for input and output activations,
  91. weights, and biases.
  92. For example, consider the following reference model:
  93. quant1 - [dequant1 - fp32_linear - quant2] - dequant2
  94. The pattern in the square brackets refers to the reference pattern of
  95. statically quantized linear. Setting the input dtype as `torch.quint8`
  96. in the DTypeConfig means we pass in `torch.quint8` as the dtype argument
  97. to the first quantize op (quant1). Similarly, setting the output dtype as
  98. `torch.quint8` means we pass in `torch.quint8` as the dtype argument to
  99. the second quantize op (quant2).
  100. Note that the dtype here does not refer to the interface dtypes of the
  101. op. For example, the "input dtype" here is not the dtype of the input
  102. tensor passed to the quantized linear op. Though it can still be the
  103. same as the interface dtype, this is not always the case, e.g. the
  104. interface dtype is fp32 in dynamic quantization but the "input dtype"
  105. specified in the DTypeConfig would still be quint8. The semantics of
  106. dtypes here are the same as the semantics of the dtypes specified in
  107. the observers.
  108. These dtypes are matched against the ones specified in the user’s
  109. QConfig. If there is a match, and the QConfig satisfies the constraints
  110. specified in the DTypeConfig (if any), then we will quantize the given
  111. pattern using this DTypeConfig. Otherwise, the QConfig is ignored and
  112. the pattern will not be quantized.
  113. Example usage::
  114. >>> # xdoctest: +SKIP(failing)
  115. >>> dtype_config1 = DTypeConfig(
  116. ... input_dtype=torch.quint8,
  117. ... output_dtype=torch.quint8,
  118. ... weight_dtype=torch.qint8,
  119. ... bias_dtype=torch.float)
  120. >>> dtype_config2 = DTypeConfig(
  121. ... input_dtype=DTypeWithConstraints(
  122. ... dtype=torch.quint8,
  123. ... quant_min_lower_bound=0,
  124. ... quant_max_upper_bound=255,
  125. ... ),
  126. ... output_dtype=DTypeWithConstraints(
  127. ... dtype=torch.quint8,
  128. ... quant_min_lower_bound=0,
  129. ... quant_max_upper_bound=255,
  130. ... ),
  131. ... weight_dtype=DTypeWithConstraints(
  132. ... dtype=torch.qint8,
  133. ... quant_min_lower_bound=-128,
  134. ... quant_max_upper_bound=127,
  135. ... ),
  136. ... bias_dtype=torch.float)
  137. >>> dtype_config1.input_dtype
  138. torch.quint8
  139. >>> dtype_config2.input_dtype
  140. torch.quint8
  141. >>> dtype_config2.input_dtype_with_constraints
  142. DTypeWithConstraints(dtype=torch.quint8, quant_min_lower_bound=0, quant_max_upper_bound=255, \
  143. scale_min_lower_bound=None, scale_max_upper_bound=None)
  144. """
  145. input_dtype_with_constraints: DTypeWithConstraints
  146. output_dtype_with_constraints: DTypeWithConstraints
  147. weight_dtype_with_constraints: DTypeWithConstraints
  148. bias_dtype: Optional[torch.dtype]
  149. is_dynamic: Optional[bool]
  150. def __init__(
  151. self,
  152. input_dtype: Union[torch.dtype, DTypeWithConstraints, None] = None,
  153. output_dtype: Union[torch.dtype, DTypeWithConstraints, None] = None,
  154. weight_dtype: Union[torch.dtype, DTypeWithConstraints, None] = None,
  155. bias_dtype: Optional[torch.dtype] = None,
  156. is_dynamic: Optional[bool] = None,
  157. ):
  158. if isinstance(input_dtype, DTypeWithConstraints):
  159. self.input_dtype_with_constraints = input_dtype
  160. else:
  161. self.input_dtype_with_constraints = DTypeWithConstraints(dtype=input_dtype)
  162. if isinstance(output_dtype, DTypeWithConstraints):
  163. self.output_dtype_with_constraints = output_dtype
  164. else:
  165. self.output_dtype_with_constraints = DTypeWithConstraints(dtype=output_dtype)
  166. if isinstance(weight_dtype, DTypeWithConstraints):
  167. self.weight_dtype_with_constraints = weight_dtype
  168. else:
  169. self.weight_dtype_with_constraints = DTypeWithConstraints(dtype=weight_dtype)
  170. self.bias_dtype = bias_dtype
  171. self.is_dynamic = is_dynamic
  172. @property
  173. def input_dtype(self) -> Optional[torch.dtype]:
  174. return self.input_dtype_with_constraints.dtype
  175. @property
  176. def output_dtype(self) -> Optional[torch.dtype]:
  177. return self.output_dtype_with_constraints.dtype
  178. @property
  179. def weight_dtype(self) -> Optional[torch.dtype]:
  180. return self.weight_dtype_with_constraints.dtype
  181. @classmethod
  182. def from_dict(cls, dtype_config_dict: Dict[str, Any]) -> DTypeConfig:
  183. """
  184. Create a ``DTypeConfig`` from a dictionary with the following items (all optional):
  185. "input_dtype": torch.dtype or ``DTypeWithConstraints``
  186. "output_dtype": torch.dtype or ``DTypeWithConstraints``
  187. "weight_dtype": torch.dtype or ``DTypeWithConstraints``
  188. "bias_type": torch.dtype
  189. "is_dynamic": bool
  190. """
  191. input_dtype = dtype_config_dict.get(INPUT_DTYPE_DICT_KEY, None)
  192. if input_dtype is not None and not isinstance(input_dtype, (torch.dtype, DTypeWithConstraints)):
  193. raise ValueError("Expected input_dtype to be a torch.dtype or DTypeWithConstraints")
  194. output_dtype = dtype_config_dict.get(OUTPUT_DTYPE_DICT_KEY, None)
  195. if output_dtype is not None and not isinstance(output_dtype, (torch.dtype, DTypeWithConstraints)):
  196. raise ValueError("Expected output_dtype to be a torch.dtype or DTypeWithConstraints")
  197. weight_dtype = dtype_config_dict.get(WEIGHT_DTYPE_DICT_KEY, None)
  198. if weight_dtype is not None and not isinstance(weight_dtype, (torch.dtype, DTypeWithConstraints)):
  199. raise ValueError("Expected weight_dtype to be a torch.dtype or DTypeWithConstraints")
  200. bias_dtype = dtype_config_dict.get(BIAS_DTYPE_DICT_KEY, None)
  201. is_dynamic = dtype_config_dict.get(IS_DYNAMIC_DICT_KEY, None)
  202. return cls(input_dtype, output_dtype, weight_dtype, bias_dtype, is_dynamic)
  203. def to_dict(self) -> Dict[str, Any]:
  204. """
  205. Convert this ``DTypeConfig`` to a dictionary with the items described in
  206. :func:`~torch.ao.quantization.backend_config.DTypeConfig.from_dict`.
  207. """
  208. dtype_config_dict: Dict[str, Any] = {}
  209. if self.input_dtype is not None:
  210. dtype_config_dict[INPUT_DTYPE_DICT_KEY] = self.input_dtype_with_constraints
  211. if self.output_dtype is not None:
  212. dtype_config_dict[OUTPUT_DTYPE_DICT_KEY] = self.output_dtype_with_constraints
  213. if self.weight_dtype is not None:
  214. dtype_config_dict[WEIGHT_DTYPE_DICT_KEY] = self.weight_dtype_with_constraints
  215. if self.bias_dtype is not None:
  216. dtype_config_dict[BIAS_DTYPE_DICT_KEY] = self.bias_dtype
  217. if self.is_dynamic is not None:
  218. dtype_config_dict[IS_DYNAMIC_DICT_KEY] = self.is_dynamic
  219. return dtype_config_dict
  220. class BackendConfig:
  221. # TODO: refer to NativeBackendConfig once that is implemented
  222. """Config that defines the set of patterns that can be quantized on a given backend, and how reference
  223. quantized models can be produced from these patterns.
  224. A pattern in this context refers to a module, a functional, an operator, or a directed acyclic graph
  225. of the above. Each pattern supported on the target backend can be individually configured through
  226. :class:`~torch.ao.quantization.backend_config.BackendPatternConfig` in terms of:
  227. (1) The supported input/output activation, weight, and bias data types
  228. (2) How observers and quant/dequant ops are inserted in order to construct the reference pattern, and
  229. (3) (Optionally) Fusion, QAT, and reference module mappings.
  230. The format of the patterns is described in:
  231. https://github.com/pytorch/pytorch/blob/master/torch/ao/quantization/backend_config/README.md
  232. Example usage::
  233. import torch
  234. from torch.ao.quantization.backend_config import (
  235. BackendConfig,
  236. BackendPatternConfig,
  237. DTypeConfig,
  238. ObservationType,
  239. )
  240. weighted_int8_dtype_config = DTypeConfig(
  241. input_dtype=torch.quint8,
  242. output_dtype=torch.quint8,
  243. weight_dtype=torch.qint8,
  244. bias_dtype=torch.float)
  245. def fuse_conv2d_relu(is_qat, conv, relu):
  246. return torch.ao.nn.intrinsic.ConvReLU2d(conv, relu)
  247. # For quantizing Linear
  248. linear_config = BackendPatternConfig(torch.nn.Linear) \
  249. .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) \
  250. .add_dtype_config(weighted_int8_dtype_config) \
  251. .set_root_module(torch.nn.Linear) \
  252. .set_qat_module(torch.ao.nn.qat.Linear) \
  253. .set_reference_quantized_module(torch.ao.nn.quantized.reference.Linear)
  254. # For fusing Conv2d + ReLU into ConvReLU2d
  255. conv_relu_config = BackendPatternConfig((torch.nn.Conv2d, torch.nn.ReLU)) \
  256. .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) \
  257. .add_dtype_config(weighted_int8_dtype_config) \
  258. .set_fused_module(torch.ao.nn.intrinsic.ConvReLU2d) \
  259. .set_fuser_method(fuse_conv2d_relu)
  260. # For quantizing ConvReLU2d
  261. fused_conv_relu_config = BackendPatternConfig(torch.ao.nn.intrinsic.ConvReLU2d) \
  262. .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) \
  263. .add_dtype_config(weighted_int8_dtype_config) \
  264. .set_root_module(torch.nn.Conv2d) \
  265. .set_qat_module(torch.ao.nn.intrinsic.qat.ConvReLU2d) \
  266. .set_reference_quantized_module(torch.ao.nn.quantized.reference.Conv2d)
  267. backend_config = BackendConfig("my_backend") \
  268. .set_backend_pattern_config(linear_config) \
  269. .set_backend_pattern_config(conv_relu_config) \
  270. .set_backend_pattern_config(fused_conv_relu_config)
  271. """
  272. def __init__(self, name: str = ""):
  273. self.name = name
  274. # Store all BackendPatternConfigs in a map to handle duplicates
  275. # Note: the key in this map uses the complex reversed tuple format.
  276. # This is intended only for internal use; users who wish to access
  277. # the original patterns should go through `self.configs` instead.
  278. self._pattern_complex_format_to_config: Dict[Pattern, BackendPatternConfig] = {}
  279. def __repr__(self):
  280. return f"BackendConfig({self.__dict__})"
  281. def set_name(self, name: str) -> BackendConfig:
  282. """
  283. Set the name of the target backend.
  284. """
  285. self.name = name
  286. return self
  287. def set_backend_pattern_config(self, config: BackendPatternConfig) -> BackendConfig:
  288. """
  289. Set the config for an pattern that can be run on the target backend.
  290. This overrides any existing config for the given pattern.
  291. """
  292. # Avoid circular dependencies
  293. pattern_complex_format = torch.ao.quantization.backend_config.utils \
  294. ._get_pattern_in_reversed_nested_tuple_format(config) # type: ignore[attr-defined]
  295. self._pattern_complex_format_to_config[pattern_complex_format] = config
  296. return self
  297. def set_backend_pattern_configs(self, configs: List[BackendPatternConfig]) -> BackendConfig:
  298. """
  299. Set the configs for patterns that can be run on the target backend.
  300. This overrides any existing config for a given pattern if it was previously registered already.
  301. """
  302. for conf in configs:
  303. self.set_backend_pattern_config(conf)
  304. return self
  305. @property
  306. def configs(self) -> List[BackendPatternConfig]:
  307. """
  308. Return a copy of the list of configs set in this `BackendConfig`.
  309. """
  310. return list(self._pattern_complex_format_to_config.values())
  311. @classmethod
  312. def from_dict(cls, backend_config_dict: Dict[str, Any]) -> BackendConfig:
  313. """
  314. Create a ``BackendConfig`` from a dictionary with the following items:
  315. "name": the name of the target backend
  316. "configs": a list of dictionaries that each represents a `BackendPatternConfig`
  317. """
  318. conf = cls(backend_config_dict.get(NAME_DICT_KEY, ""))
  319. for d in backend_config_dict.get(CONFIGS_DICT_KEY, []):
  320. if isinstance(d, BackendPatternConfig):
  321. conf.set_backend_pattern_config(d)
  322. elif isinstance(d, Dict):
  323. conf.set_backend_pattern_config(BackendPatternConfig.from_dict(d))
  324. else:
  325. raise ValueError("Expected backend_config_dict['%s'] to be a dictionary" % CONFIGS_DICT_KEY)
  326. return conf
  327. def to_dict(self) -> Dict[str, Any]:
  328. """
  329. Convert this ``BackendConfig`` to a dictionary with the items described in
  330. :func:`~torch.ao.quantization.backend_config.BackendConfig.from_dict`.
  331. """
  332. return {
  333. NAME_DICT_KEY: self.name,
  334. CONFIGS_DICT_KEY: [c.to_dict() for c in self.configs],
  335. }
  336. class BackendPatternConfig:
  337. """
  338. Config object that specifies quantization behavior for a given operator pattern.
  339. For a detailed example usage, see :class:`~torch.ao.quantization.backend_config.BackendConfig`.
  340. """
  341. def __init__(self, pattern: Optional[Pattern] = None):
  342. self.pattern: Optional[Pattern] = pattern
  343. self.observation_type = ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT
  344. self.dtype_configs: List[DTypeConfig] = []
  345. self.root_module: Optional[Type[torch.nn.Module]] = None
  346. self.qat_module: Optional[Type[torch.nn.Module]] = None
  347. self.reference_quantized_module: Optional[Type[torch.nn.Module]] = None
  348. self.fused_module: Optional[Type[torch.nn.Module]] = None
  349. self.fuser_method: Optional[Callable] = None
  350. # Temporary/internal configs
  351. self._root_node_getter: Optional[Callable] = None
  352. self._extra_inputs_getter: Optional[Callable] = None
  353. self._num_tensor_args_to_observation_type: Dict[int, ObservationType] = {}
  354. self._input_type_to_index: Dict[str, int] = {}
  355. self._pattern_complex_format: Optional[Pattern] = None
  356. def __repr__(self):
  357. dict_nonempty = {
  358. k: v for k, v in self.__dict__.items()
  359. if (
  360. (not isinstance(v, (list, dict)) and v is not None)
  361. or (isinstance(v, (list, dict)) and len(v) > 0)
  362. )
  363. }
  364. return f"BackendPatternConfig({dict_nonempty})"
  365. def set_pattern(self, pattern: Pattern) -> BackendPatternConfig:
  366. """
  367. Set the pattern to configure.
  368. The pattern can be a float module, functional operator, pytorch operator, or a tuple
  369. combination of the above. Tuple patterns are treated as sequential patterns, and
  370. currently only tuples of 2 or 3 elements are supported.
  371. """
  372. if self._pattern_complex_format is not None:
  373. raise ValueError("Only one of 'pattern' or 'pattern_complex_format' can be set")
  374. self.pattern = pattern
  375. return self
  376. def set_observation_type(self, observation_type: ObservationType) -> BackendPatternConfig:
  377. """
  378. Set how observers should be inserted in the graph for this pattern.
  379. Observation type here refers to how observers (or quant-dequant ops) will be placed
  380. in the graph. This is used to produce the desired reference patterns understood by
  381. the backend. Weighted ops such as linear and conv require different observers
  382. (or quantization parameters passed to quantize ops in the reference model) for the
  383. input and the output.
  384. There are two observation types:
  385. `OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT` (default): the output observer instance
  386. will be different from the input. This is the most common observation type.
  387. `OUTPUT_SHARE_OBSERVER_WITH_INPUT`: the output observer instance will be the
  388. same as the input. This is useful for operators like `cat`.
  389. Note: This will be renamed in the near future, since we will soon insert QuantDeQuantStubs
  390. with observers (and fake quantizes) attached instead of observers themselves.
  391. """
  392. self.observation_type = observation_type
  393. return self
  394. def add_dtype_config(self, dtype_config: DTypeConfig) -> BackendPatternConfig:
  395. """
  396. Add a set of supported data types passed as arguments to quantize ops in the
  397. reference model spec.
  398. """
  399. self.dtype_configs.append(dtype_config)
  400. return self
  401. def set_dtype_configs(self, dtype_configs: List[DTypeConfig]) -> BackendPatternConfig:
  402. """
  403. Set the supported data types passed as arguments to quantize ops in the
  404. reference model spec, overriding all previously registered data types.
  405. """
  406. self.dtype_configs = dtype_configs
  407. return self
  408. def set_root_module(self, root_module: Type[torch.nn.Module]) -> BackendPatternConfig:
  409. """
  410. Set the module that represents the root for this pattern.
  411. When we construct the reference quantized model during the convert phase,
  412. the root modules (e.g. torch.nn.Linear for torch.ao.nn.intrinsic.LinearReLU)
  413. will be swapped to the corresponding reference quantized modules (e.g.
  414. torch.ao.nn.reference.quantized.Linear). This allows custom backends to
  415. specify custom reference quantized module implementations to match the
  416. numerics of their lowered operators. Since this is a one-to-one mapping,
  417. both the root module and the reference quantized module must be specified
  418. in the same BackendPatternConfig in order for the conversion to take place.
  419. """
  420. self.root_module = root_module
  421. return self
  422. def set_qat_module(self, qat_module: Type[torch.nn.Module]) -> BackendPatternConfig:
  423. """
  424. Set the module that represents the QAT implementation for this pattern.
  425. """
  426. self.qat_module = qat_module
  427. return self
  428. def set_reference_quantized_module(self, reference_quantized_module: Type[torch.nn.Module]) -> BackendPatternConfig:
  429. """
  430. Set the module that represents the reference quantized implementation for
  431. this pattern's root module.
  432. For more detail, see :func:`~torch.ao.quantization.backend_config.BackendPatternConfig.set_root_module`.
  433. """
  434. self.reference_quantized_module = reference_quantized_module
  435. return self
  436. def set_fused_module(self, fused_module: Type[torch.nn.Module]) -> BackendPatternConfig:
  437. """
  438. Set the module that represents the fused implementation for this pattern.
  439. """
  440. self.fused_module = fused_module
  441. return self
  442. def set_fuser_method(self, fuser_method: Callable) -> BackendPatternConfig:
  443. """
  444. Set the function that specifies how to fuse this BackendPatternConfig's pattern.
  445. The first argument of this function should be `is_qat`, and the rest of the arguments
  446. should be the items in the tuple pattern. The return value of this function should be
  447. the resulting fused module.
  448. For example, the fuser method for the pattern `(torch.nn.Linear, torch.nn.ReLU)` can be:
  449. def fuse_linear_relu(is_qat, linear, relu):
  450. return torch.ao.nn.intrinsic.LinearReLU(linear, relu)
  451. For a more complicated example, see https://gist.github.com/jerryzh168/8bea7180a8ba3c279f2c9b050f2a69a6.
  452. """
  453. self.fuser_method = fuser_method
  454. return self
  455. def _set_root_node_getter(self, root_node_getter: Callable) -> BackendPatternConfig:
  456. self._root_node_getter = root_node_getter
  457. return self
  458. def _set_extra_inputs_getter(self, extra_inputs_getter: Callable) -> BackendPatternConfig:
  459. self._extra_inputs_getter = extra_inputs_getter
  460. return self
  461. def _set_num_tensor_args_to_observation_type(
  462. self, num_tensor_args_to_observation_type: Dict[int, ObservationType]) -> BackendPatternConfig:
  463. self._num_tensor_args_to_observation_type = num_tensor_args_to_observation_type
  464. return self
  465. def _set_input_type_to_index(self, input_type_to_index: Dict[str, int]) -> BackendPatternConfig:
  466. self._input_type_to_index = input_type_to_index
  467. return self
  468. def _set_pattern_complex_format(self, pattern: Pattern) -> BackendPatternConfig:
  469. """
  470. Set the pattern to configure, using the reversed nested tuple format.
  471. See the BackendConfig README for more detail:
  472. https://github.com/pytorch/pytorch/blob/master/torch/ao/quantization/backend_config/README.md#advanced-pattern-specification
  473. """
  474. if self.pattern is not None:
  475. raise ValueError("Only one of 'pattern' or 'pattern_complex_format' can be set")
  476. self._pattern_complex_format = pattern
  477. return self
  478. @classmethod
  479. def from_dict(cls, backend_pattern_config_dict: Dict[str, Any]) -> BackendPatternConfig:
  480. """
  481. Create a ``BackendPatternConfig`` from a dictionary with the following items:
  482. "pattern": the pattern being configured
  483. "observation_type": the :class:`~torch.ao.quantization.backend_config.ObservationType` that specifies how
  484. observers should be inserted for this pattern
  485. "dtype_configs": a list of dictionaries that represents :class:`~torch.ao.quantization.backend_config.DTypeConfig` s
  486. "root_module": a :class:`torch.nn.Module` that represents the root for this pattern
  487. "qat_module": a :class:`torch.nn.Module` that represents the QAT implementation for this pattern
  488. "reference_quantized_module": a :class:`torch.nn.Module` that represents the reference quantized
  489. implementation for this pattern's root module.
  490. "fused_module": a :class:`torch.nn.Module` that represents the fused implementation for this pattern
  491. "fuser_method": a function that specifies how to fuse the pattern for this pattern
  492. "pattern_complex_format": the pattern specified in the reversed nested tuple format (deprecated)
  493. """
  494. def _get_dtype_config(obj: Any) -> DTypeConfig:
  495. """
  496. Convert the given object into a ``DTypeConfig`` if possible, else throw an exception.
  497. """
  498. if isinstance(obj, DTypeConfig):
  499. return obj
  500. if isinstance(obj, Dict):
  501. return DTypeConfig.from_dict(obj)
  502. raise ValueError("Expected a list of DTypeConfigs in backend_pattern_config_dict[\"%s\"], got '%s'" %
  503. (DTYPE_CONFIGS_DICT_KEY, type(obj)))
  504. conf = cls()
  505. if PATTERN_DICT_KEY in backend_pattern_config_dict:
  506. conf.set_pattern(backend_pattern_config_dict[PATTERN_DICT_KEY])
  507. if OBSERVATION_TYPE_DICT_KEY in backend_pattern_config_dict:
  508. conf.set_observation_type(backend_pattern_config_dict[OBSERVATION_TYPE_DICT_KEY])
  509. for d in backend_pattern_config_dict.get(DTYPE_CONFIGS_DICT_KEY, []):
  510. conf.add_dtype_config(_get_dtype_config(d))
  511. conf.set_root_module(backend_pattern_config_dict.get(ROOT_MODULE_DICT_KEY, None))
  512. conf.set_qat_module(backend_pattern_config_dict.get(QAT_MODULE_DICT_KEY, None))
  513. conf.set_reference_quantized_module(backend_pattern_config_dict.get(REFERENCE_QUANTIZED_MODULE_DICT_KEY, None))
  514. conf.set_fused_module(backend_pattern_config_dict.get(FUSED_MODULE_DICT_KEY, None))
  515. conf.set_fuser_method(backend_pattern_config_dict.get(FUSER_METHOD_DICT_KEY, None))
  516. conf._set_root_node_getter(backend_pattern_config_dict.get(ROOT_NODE_GETTER_DICT_KEY, None))
  517. conf._set_extra_inputs_getter(backend_pattern_config_dict.get(EXTRA_INPUTS_GETTER_DICT_KEY, None))
  518. conf._set_num_tensor_args_to_observation_type(
  519. backend_pattern_config_dict.get(NUM_TENSOR_ARGS_TO_OBSERVATION_TYPE_DICT_KEY, {}))
  520. conf._set_input_type_to_index(backend_pattern_config_dict.get(INPUT_TYPE_TO_INDEX_DICT_KEY, {}))
  521. if PATTERN_COMPLEX_FORMAT_DICT_KEY in backend_pattern_config_dict:
  522. conf._set_pattern_complex_format(backend_pattern_config_dict[PATTERN_COMPLEX_FORMAT_DICT_KEY])
  523. return conf
  524. def to_dict(self) -> Dict[str, Any]:
  525. """
  526. Convert this ``BackendPatternConfig`` to a dictionary with the items described in
  527. :func:`~torch.ao.quantization.backend_config.BackendPatternConfig.from_dict`.
  528. """
  529. backend_pattern_config_dict: Dict[str, Any] = {
  530. OBSERVATION_TYPE_DICT_KEY: self.observation_type,
  531. DTYPE_CONFIGS_DICT_KEY: [c.to_dict() for c in self.dtype_configs],
  532. }
  533. if self.pattern is not None:
  534. backend_pattern_config_dict[PATTERN_DICT_KEY] = self.pattern
  535. if self.root_module is not None:
  536. backend_pattern_config_dict[ROOT_MODULE_DICT_KEY] = self.root_module
  537. if self.qat_module is not None:
  538. backend_pattern_config_dict[QAT_MODULE_DICT_KEY] = self.qat_module
  539. if self.reference_quantized_module is not None:
  540. backend_pattern_config_dict[REFERENCE_QUANTIZED_MODULE_DICT_KEY] = self.reference_quantized_module
  541. if self.fused_module is not None:
  542. backend_pattern_config_dict[FUSED_MODULE_DICT_KEY] = self.fused_module
  543. if self.fuser_method is not None:
  544. backend_pattern_config_dict[FUSER_METHOD_DICT_KEY] = self.fuser_method
  545. if self._root_node_getter is not None:
  546. backend_pattern_config_dict[ROOT_NODE_GETTER_DICT_KEY] = self._root_node_getter
  547. if self._extra_inputs_getter is not None:
  548. backend_pattern_config_dict[EXTRA_INPUTS_GETTER_DICT_KEY] = self._extra_inputs_getter
  549. if len(self._num_tensor_args_to_observation_type) > 0:
  550. backend_pattern_config_dict[NUM_TENSOR_ARGS_TO_OBSERVATION_TYPE_DICT_KEY] = self._num_tensor_args_to_observation_type
  551. if len(self._input_type_to_index) > 0:
  552. backend_pattern_config_dict[INPUT_TYPE_TO_INDEX_DICT_KEY] = self._input_type_to_index
  553. if self._pattern_complex_format is not None:
  554. backend_pattern_config_dict[PATTERN_COMPLEX_FORMAT_DICT_KEY] = self._pattern_complex_format
  555. return backend_pattern_config_dict