qconfig.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551
  1. from collections import namedtuple
  2. from typing import Optional, Any, Union
  3. import torch
  4. import torch.nn as nn
  5. from torch.ao.quantization.fake_quantize import (
  6. FakeQuantize,
  7. FakeQuantizeBase,
  8. default_fake_quant,
  9. default_dynamic_fake_quant,
  10. default_per_channel_weight_fake_quant,
  11. default_weight_fake_quant,
  12. default_fused_act_fake_quant,
  13. default_fused_wt_fake_quant,
  14. FusedMovingAvgObsFakeQuantize,
  15. default_fused_per_channel_wt_fake_quant,
  16. default_embedding_fake_quant,
  17. default_embedding_fake_quant_4bit,
  18. fused_wt_fake_quant_range_neg_127_to_127,
  19. fused_per_channel_wt_fake_quant_range_neg_127_to_127,
  20. )
  21. from .observer import (
  22. _PartialWrapper,
  23. HistogramObserver,
  24. MovingAverageMinMaxObserver,
  25. NoopObserver,
  26. PlaceholderObserver,
  27. ReuseInputObserver,
  28. default_debug_observer,
  29. default_dynamic_quant_observer,
  30. default_float_qparams_observer,
  31. default_float_qparams_observer_4bit,
  32. default_observer,
  33. default_per_channel_weight_observer,
  34. default_placeholder_observer,
  35. default_weight_observer,
  36. weight_observer_range_neg_127_to_127,
  37. per_channel_weight_observer_range_neg_127_to_127,
  38. default_reuse_input_observer,
  39. ObserverBase,
  40. )
  41. import warnings
  42. import copy
  43. __all__ = [
  44. "QConfig",
  45. # TODO: deprecated, remove
  46. "QConfigDynamic",
  47. "default_qconfig",
  48. "default_debug_qconfig",
  49. "default_per_channel_qconfig",
  50. "default_dynamic_qconfig",
  51. "float16_dynamic_qconfig",
  52. "float16_static_qconfig",
  53. "per_channel_dynamic_qconfig",
  54. "float_qparams_weight_only_qconfig",
  55. "float_qparams_weight_only_qconfig_4bit",
  56. "default_qat_qconfig",
  57. "default_dynamic_qat_qconfig",
  58. "default_weight_only_qconfig",
  59. "default_activation_only_qconfig",
  60. "default_qat_qconfig_v2",
  61. "default_reuse_input_qconfig",
  62. "default_symmetric_qnnpack_qconfig",
  63. "default_per_channel_symmetric_qnnpack_qconfig",
  64. "default_symmetric_qnnpack_qat_qconfig",
  65. "default_per_channel_symmetric_qnnpack_qat_qconfig",
  66. "default_embedding_qat_qconfig",
  67. "default_embedding_qat_qconfig_4bit",
  68. "get_default_qconfig",
  69. "get_default_qat_qconfig",
  70. "get_default_qconfig_dict",
  71. "get_default_qat_qconfig_dict",
  72. "QConfigAny",
  73. "qconfig_equals",
  74. ]
  75. class QConfig(namedtuple('QConfig', ['activation', 'weight'])):
  76. """
  77. Describes how to quantize a layer or a part of the network by providing
  78. settings (observer classes) for activations and weights respectively.
  79. Note that QConfig needs to contain observer **classes** (like MinMaxObserver) or a callable that returns
  80. instances on invocation, not the concrete observer instances themselves.
  81. Quantization preparation function will instantiate observers multiple times for each of the layers.
  82. Observer classes have usually reasonable default arguments, but they can be overwritten with `with_args`
  83. method (that behaves like functools.partial)::
  84. my_qconfig = QConfig(
  85. activation=MinMaxObserver.with_args(dtype=torch.qint8),
  86. weight=default_observer.with_args(dtype=torch.qint8))
  87. """
  88. def __new__(cls, activation, weight):
  89. # catch common mistakes
  90. if isinstance(activation, nn.Module) or isinstance(weight, nn.Module):
  91. raise ValueError("QConfig received observer instance, please pass observer class instead. " +
  92. "Use MyObserver.with_args(x=1) to override arguments to constructor if needed")
  93. return super(QConfig, cls).__new__(cls, activation, weight)
  94. class QConfigDynamic(namedtuple('QConfigDynamic', ['activation', 'weight'])):
  95. """
  96. Describes how to dynamically quantize a layer or a part of the network by providing
  97. settings (observer classes) for weights.
  98. It's like QConfig, but for dynamic quantization.
  99. Note that QConfigDynamic needs to contain observer **classes** (like MinMaxObserver) or a callable that returns
  100. instances on invocation, not the concrete observer instances themselves.
  101. Quantization function will instantiate observers multiple times for each of the layers.
  102. Observer classes have usually reasonable default arguments, but they can be overwritten with `with_args`
  103. method (that behaves like functools.partial)::
  104. my_qconfig = QConfigDynamic(weight=default_observer.with_args(dtype=torch.qint8))
  105. """
  106. def __new__(cls, activation=torch.nn.Identity, weight=torch.nn.Identity):
  107. # catch common mistakes
  108. if isinstance(weight, nn.Module):
  109. raise ValueError("QConfigDynamic received observer instance, please pass observer class instead. " +
  110. "Use MyObserver.with_args(x=1) to override arguments to constructor if needed")
  111. warnings.warn("QConfigDynamic is going to be deprecated in PyTorch 1.12, please use QConfig instead")
  112. return super(QConfigDynamic, cls).__new__(cls, activation, weight)
  113. default_qconfig = QConfig(activation=default_observer,
  114. weight=default_weight_observer)
  115. """
  116. Default qconfig configuration.
  117. """
  118. default_debug_qconfig = QConfig(weight=default_weight_observer,
  119. activation=default_debug_observer)
  120. """
  121. Default qconfig configuration for debugging.
  122. """
  123. default_per_channel_qconfig = QConfig(activation=default_observer,
  124. weight=default_per_channel_weight_observer)
  125. """
  126. Default qconfig configuration for per channel weight quantization.
  127. """
  128. default_dynamic_qconfig = QConfig(activation=default_dynamic_quant_observer,
  129. weight=default_weight_observer)
  130. """
  131. Default dynamic qconfig.
  132. """
  133. float16_dynamic_qconfig = QConfig(activation=PlaceholderObserver.with_args(dtype=torch.float16, is_dynamic=True),
  134. weight=PlaceholderObserver.with_args(dtype=torch.float16))
  135. """
  136. Dynamic qconfig with weights quantized to `torch.float16`.
  137. """
  138. float16_static_qconfig = QConfig(activation=PlaceholderObserver.with_args(dtype=torch.float16),
  139. weight=PlaceholderObserver.with_args(dtype=torch.float16))
  140. """
  141. Dynamic qconfig with both activations and weights quantized to `torch.float16`.
  142. """
  143. per_channel_dynamic_qconfig = QConfig(activation=default_dynamic_quant_observer,
  144. weight=default_per_channel_weight_observer)
  145. """
  146. Dynamic qconfig with weights quantized per channel.
  147. """
  148. float_qparams_weight_only_qconfig = QConfig(
  149. activation=default_placeholder_observer,
  150. weight=default_float_qparams_observer)
  151. """
  152. Dynamic qconfig with weights quantized with a floating point zero_point.
  153. """
  154. float_qparams_weight_only_qconfig_4bit = QConfig(
  155. activation=default_placeholder_observer,
  156. weight=default_float_qparams_observer_4bit)
  157. default_qat_qconfig = QConfig(activation=default_fake_quant,
  158. weight=default_weight_fake_quant)
  159. """
  160. Default qconfig for QAT.
  161. """
  162. default_dynamic_qat_qconfig = QConfig(activation=default_dynamic_fake_quant,
  163. weight=default_weight_fake_quant)
  164. """
  165. Default qconfig for dynamic QAT.
  166. """
  167. default_weight_only_qconfig = QConfig(activation=torch.nn.Identity,
  168. weight=default_weight_fake_quant)
  169. """
  170. Default qconfig for quantizing weights only.
  171. """
  172. default_activation_only_qconfig = QConfig(activation=default_fake_quant,
  173. weight=torch.nn.Identity)
  174. """
  175. Default qconfig for quantizing activations only.
  176. """
  177. # QAT config that uses a fused observer + fake quant modules for optimized training performance.
  178. # to modify the activation/weight observers, the default entries in fake_quantize.py can be modified.
  179. default_qat_qconfig_v2 = QConfig(activation=default_fused_act_fake_quant, weight=default_fused_wt_fake_quant)
  180. """
  181. Fused version of `default_qat_config`, has performance benefits.
  182. """
  183. default_reuse_input_qconfig = QConfig(activation=default_reuse_input_observer,
  184. weight=NoopObserver)
  185. """
  186. Default qconfig for operators that reuse the observers from input Tensor, e.g. reshape
  187. """
  188. def get_default_qconfig(backend='x86', version=0):
  189. """
  190. Returns the default PTQ qconfig for the specified backend.
  191. Args:
  192. * `backend` (str): a string representing the target backend. Currently supports
  193. `x86` (default), `fbgemm`, `qnnpack` and `onednn`.
  194. Return:
  195. qconfig
  196. """
  197. supported_backends = ["fbgemm", "x86", "qnnpack", "onednn"]
  198. if backend not in supported_backends:
  199. raise AssertionError(
  200. "backend: " + str(backend) +
  201. " not supported. backend must be one of {}".format(supported_backends)
  202. )
  203. if version == 0:
  204. if backend == 'fbgemm':
  205. qconfig = QConfig(activation=HistogramObserver.with_args(reduce_range=True),
  206. weight=default_per_channel_weight_observer)
  207. elif backend == 'qnnpack':
  208. # TODO: make this compatible with xnnpack constraints
  209. qconfig = QConfig(activation=HistogramObserver.with_args(reduce_range=False),
  210. weight=default_weight_observer)
  211. elif backend == 'onednn':
  212. qconfig = QConfig(activation=HistogramObserver.with_args(reduce_range=False),
  213. weight=default_per_channel_weight_observer)
  214. elif backend == 'x86':
  215. qconfig = QConfig(activation=HistogramObserver.with_args(reduce_range=True),
  216. weight=default_per_channel_weight_observer)
  217. else:
  218. # won't reach
  219. qconfig = default_qconfig
  220. else:
  221. raise AssertionError("Version number: " + str(version) +
  222. " in get_default_qconfig is not supported. Version number must be 0")
  223. return qconfig
  224. """
  225. Default, symmetric PTQ qconfig for the specified backend. And a per_channel
  226. variant of the same.
  227. Symmetric here applies to signed weights with zero point = 0, and additional
  228. value restrictions. The activations are also signed 8-bit integers with this
  229. qconfig.
  230. * Once this change is merged [as of 3/17/22], with backend or qengine =
  231. 'qnnpack', some quantized operators with this symmetric qconfig may use
  232. operators from xnnpack library.
  233. ** Support to use xnnpack ops with `qnnpack` backed for asymmetric
  234. qconfig (returned by get_default_qconfig()) is not available yet.
  235. * This qconfig uses signed activations and weights. Weights have added
  236. restrictions such as zero point is forced to be 0, making the weights
  237. symmetric, hence the name. And the 8-bit quantized values are
  238. restricting to to [-127, +127], excluding -128.
  239. * xnnpack has a requantization scale value restriction, 0x1p-32 <=
  240. requantization_scale < 256.0 where, `requantization_scale = (input_scale
  241. * kernel_scale) / (output_scale)`. Using this eps (w/ assumed max value
  242. of 256) is to prevent requantization_scale to go below xnnpack lower
  243. threshold.
  244. """
  245. default_symmetric_qnnpack_qconfig = QConfig(activation=HistogramObserver.with_args(dtype=torch.qint8,
  246. reduce_range=False,
  247. eps=2 ** -12),
  248. weight=weight_observer_range_neg_127_to_127)
  249. default_per_channel_symmetric_qnnpack_qconfig = QConfig(activation=HistogramObserver.with_args(dtype=torch.qint8,
  250. reduce_range=False,
  251. eps=2 ** -12),
  252. weight=per_channel_weight_observer_range_neg_127_to_127)
  253. default_embedding_qat_qconfig = QConfig(activation=NoopObserver.with_args(dtype=torch.float32),
  254. weight=default_embedding_fake_quant)
  255. default_embedding_qat_qconfig_4bit = QConfig(activation=NoopObserver.with_args(dtype=torch.float32),
  256. weight=default_embedding_fake_quant_4bit)
  257. def get_default_qat_qconfig(backend='x86', version=1):
  258. """
  259. Returns the default QAT qconfig for the specified backend.
  260. Args:
  261. * `backend` (str): a string representing the target backend. Currently supports
  262. `x86` (default), `fbgemm`, `qnnpack` and `onednn`.
  263. * `version`: version, for backwards compatibility. Can be `None` or `1`.
  264. Return:
  265. qconfig
  266. """
  267. supported_backends = ["fbgemm", "x86", "qnnpack", "onednn"]
  268. if backend not in supported_backends:
  269. raise AssertionError(
  270. "backend: " + str(backend) +
  271. " not supported. backend must be one of {}".format(supported_backends)
  272. )
  273. # Histogram observer is too slow for quantization aware training
  274. if version == 0:
  275. if backend == 'fbgemm':
  276. qconfig = QConfig(activation=FakeQuantize.with_args(observer=MovingAverageMinMaxObserver,
  277. quant_min=0,
  278. quant_max=255,
  279. reduce_range=True),
  280. weight=default_per_channel_weight_fake_quant)
  281. elif backend == 'qnnpack':
  282. qconfig = QConfig(activation=FakeQuantize.with_args(observer=MovingAverageMinMaxObserver,
  283. quant_min=0,
  284. quant_max=255,
  285. reduce_range=False),
  286. weight=default_weight_fake_quant)
  287. elif backend == 'onednn':
  288. qconfig = QConfig(activation=FakeQuantize.with_args(observer=MovingAverageMinMaxObserver,
  289. quant_min=0,
  290. quant_max=255),
  291. weight=default_per_channel_weight_fake_quant)
  292. elif backend == 'x86':
  293. qconfig = QConfig(activation=FakeQuantize.with_args(observer=MovingAverageMinMaxObserver,
  294. quant_min=0,
  295. quant_max=255,
  296. reduce_range=True),
  297. weight=default_per_channel_weight_fake_quant)
  298. else:
  299. qconfig = default_qat_qconfig
  300. # Use the fused observe + fake_quant modules for doing QAT.
  301. elif version == 1:
  302. if backend == 'fbgemm':
  303. qconfig = QConfig(activation=FusedMovingAvgObsFakeQuantize.with_args(observer=MovingAverageMinMaxObserver,
  304. quant_min=0,
  305. quant_max=255,
  306. reduce_range=True),
  307. weight=default_fused_per_channel_wt_fake_quant)
  308. elif backend == 'qnnpack':
  309. # TODO: make this compatible with xnnpack constraints
  310. qconfig = QConfig(activation=FusedMovingAvgObsFakeQuantize.with_args(observer=MovingAverageMinMaxObserver,
  311. quant_min=0,
  312. quant_max=255,
  313. reduce_range=False),
  314. weight=default_fused_wt_fake_quant)
  315. elif backend == 'onednn':
  316. qconfig = QConfig(activation=FusedMovingAvgObsFakeQuantize.with_args(observer=MovingAverageMinMaxObserver,
  317. quant_min=0,
  318. quant_max=255),
  319. weight=default_fused_per_channel_wt_fake_quant)
  320. elif backend == 'x86':
  321. qconfig = QConfig(activation=FusedMovingAvgObsFakeQuantize.with_args(observer=MovingAverageMinMaxObserver,
  322. quant_min=0,
  323. quant_max=255,
  324. reduce_range=True),
  325. weight=default_fused_per_channel_wt_fake_quant)
  326. else:
  327. qconfig = default_qat_qconfig_v2
  328. else:
  329. raise AssertionError("Version number: " + str(version) +
  330. "in get_default_qat_qconfig is not supported. Version number must be 0 or 1")
  331. return qconfig
  332. """
  333. Default symmetric QAT qconfig for qnnpack. And its per channel weight variant.
  334. """
  335. default_symmetric_qnnpack_qat_qconfig = QConfig(
  336. activation=FusedMovingAvgObsFakeQuantize.with_args(observer=MovingAverageMinMaxObserver,
  337. quant_min=-128,
  338. quant_max=127,
  339. dtype=torch.qint8,
  340. reduce_range=False,
  341. eps=2 ** -12),
  342. weight=fused_wt_fake_quant_range_neg_127_to_127)
  343. default_per_channel_symmetric_qnnpack_qat_qconfig = QConfig(
  344. activation=FusedMovingAvgObsFakeQuantize.with_args(observer=MovingAverageMinMaxObserver,
  345. quant_min=-128,
  346. quant_max=127,
  347. dtype=torch.qint8,
  348. reduce_range=False,
  349. eps=2 ** -12),
  350. weight=fused_per_channel_wt_fake_quant_range_neg_127_to_127)
  351. _default_fp32_placeholder_qconfig = QConfig(
  352. activation=PlaceholderObserver.with_args(dtype=torch.float32),
  353. weight=PlaceholderObserver.with_args(dtype=torch.float32)
  354. )
  355. _default_quint8_placeholder_qconfig = QConfig(
  356. activation=PlaceholderObserver.with_args(dtype=torch.quint8),
  357. # operators using this qconfig doesn't have weights
  358. weight=None,
  359. )
  360. def get_default_qconfig_dict(backend='x86', version=0):
  361. warnings.warn(
  362. "torch.ao.quantization.get_default_qconfig_dict is deprecated and will be removed in "
  363. "a future version. Please use torch.ao.quantization.get_default_qconfig_mapping instead.")
  364. return torch.ao.quantization.get_default_qconfig_mapping(backend, version).to_dict()
  365. def get_default_qat_qconfig_dict(backend='x86', version=1):
  366. warnings.warn(
  367. "torch.ao.quantization.get_default_qat_qconfig_dict is deprecated and will be removed in "
  368. "a future version. Please use torch.ao.quantization.get_default_qat_qconfig_mapping instead.")
  369. return torch.ao.quantization.get_default_qat_qconfig_mapping(backend, version).to_dict()
  370. def _assert_valid_qconfig(qconfig: Optional[QConfig],
  371. mod: torch.nn.Module) -> None:
  372. """
  373. Verifies that this `qconfig` is valid.
  374. """
  375. if qconfig is None:
  376. return
  377. is_conv_transpose_mod = (
  378. isinstance(mod, (torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, torch.nn.ConvTranspose3d)))
  379. if is_conv_transpose_mod:
  380. if qconfig.weight is None:
  381. # for now, we assume that any qconfig for ConvTranspose without a weight is valid
  382. return
  383. example_observer = qconfig.weight()
  384. is_per_channel = (
  385. isinstance(example_observer, (torch.ao.quantization.PerChannelMinMaxObserver,
  386. torch.ao.quantization.MovingAveragePerChannelMinMaxObserver))
  387. )
  388. assert not is_per_channel, \
  389. 'Per channel weight observer is not supported yet for ConvTranspose{n}d.'
  390. QConfigAny = Optional[QConfig]
  391. QConfigAny.__module__ = "torch.ao.quantization.qconfig"
  392. def _add_module_to_qconfig_obs_ctr(
  393. qconfig: QConfigAny,
  394. module: Optional[nn.Module]) -> Any:
  395. r"""This is a helper function for use in quantization prepare that updates a qconfig so that
  396. the constructors stored in the qconfig will create observers on the same device that
  397. 'module' is on. This is intended to be used when the qconfigs are propagated to each
  398. module in order to avoid potential device alignment issues.
  399. Args:
  400. qconfig: QConfig with obs constructors stored in activation and weight
  401. module: module which the qconfig is related to
  402. Return:
  403. qconfig: configured so that obs constructors set to construct on the same device as module
  404. """
  405. if module is None or qconfig is None or qconfig._fields != ('activation', 'weight'):
  406. return qconfig
  407. def get_factory_kwargs_based_on_module_device():
  408. assert isinstance(module, torch.nn.Module)
  409. devices = {p.device for p in module.parameters()} | \
  410. {p.device for p in module.buffers()}
  411. device = next(iter(devices)) if len(devices) > 0 else None
  412. return None if device is None else {'device': device}
  413. def configure_constructor_to_put_obs_on_module_device(original_constructor):
  414. try:
  415. # check if constructor can accept factory_kwargs
  416. check = original_constructor.with_args(factory_kwargs=None)
  417. check()
  418. return original_constructor.with_callable_args(factory_kwargs=get_factory_kwargs_based_on_module_device)
  419. except AttributeError: # qconfig doesn't have activation or weight
  420. return original_constructor
  421. except TypeError: # the class doesn't accept factory_kwargs argument
  422. return original_constructor
  423. activation = configure_constructor_to_put_obs_on_module_device(qconfig.activation)
  424. weight = configure_constructor_to_put_obs_on_module_device(qconfig.weight)
  425. return QConfig(activation, weight)
  426. _ObserverOrFakeQuantizeConstructor = Union[_PartialWrapper, ObserverBase, FakeQuantizeBase]
  427. def _obs_or_fq_ctr_equals(obs_or_fq1: _ObserverOrFakeQuantizeConstructor, obs_or_fq2: _ObserverOrFakeQuantizeConstructor):
  428. if isinstance(obs_or_fq1, _PartialWrapper) and isinstance(obs_or_fq2, _PartialWrapper):
  429. return _partial_wrapper_equals(obs_or_fq1, obs_or_fq2)
  430. return obs_or_fq1 == obs_or_fq2
  431. def _partial_wrapper_equals(obs_or_fq1: _PartialWrapper, obs_or_fq2: _PartialWrapper):
  432. """
  433. Return whether the two partial wrappers are equal,
  434. """
  435. # functools.partial has no __eq__ operator defined so '==' defaults to 'is'
  436. obs_or_fq1_keywords = copy.copy(obs_or_fq1.p.keywords)
  437. obs_or_fq2_keywords = copy.copy(obs_or_fq2.p.keywords)
  438. keywords_equal = True
  439. # compare observer constructor with _obs_or_fq_ctr_equals since direct compare would fail
  440. if "observer" in obs_or_fq1_keywords and "observer" in obs_or_fq2_keywords:
  441. keywords_equal = keywords_equal and _obs_or_fq_ctr_equals(obs_or_fq1_keywords["observer"], obs_or_fq2_keywords["observer"])
  442. obs_or_fq1_keywords.pop("observer")
  443. obs_or_fq2_keywords.pop("observer")
  444. keywords_equal = keywords_equal and obs_or_fq1_keywords == obs_or_fq2_keywords
  445. return obs_or_fq1.p.func == obs_or_fq2.p.func and obs_or_fq1.p.args == obs_or_fq2.p.args and keywords_equal
  446. def qconfig_equals(q1: QConfigAny, q2: QConfigAny):
  447. """
  448. Returns `True` if `q1` equals `q2`, and `False` otherwise.
  449. """
  450. if q1 is None or q2 is None:
  451. return q1 == q2
  452. else:
  453. assert q1 is not None and q2 is not None
  454. try:
  455. # Qconfig weight and activation can be either a partial wrapper,
  456. # or an observer class. Special handling is required (above) for
  457. # comparing partial wrappers.
  458. activation_same = _obs_or_fq_ctr_equals(q1.activation, q2.activation)
  459. weight_same = _obs_or_fq_ctr_equals(q1.weight, q2.weight)
  460. return activation_same and weight_same
  461. except AttributeError:
  462. return q1 == q2
  463. def _activation_is_memoryless(qconfig: QConfig):
  464. """
  465. Return whether the observer for activations defined in the given QConfig is memoryless.
  466. This means a MovingAverage observer with averaging constant equal to 1.
  467. """
  468. def _is_memoryless(observer):
  469. return hasattr(observer, "averaging_constant") and observer.averaging_constant == 1
  470. act = qconfig.activation()
  471. if isinstance(act, FakeQuantizeBase) and hasattr(act, "activation_post_process"):
  472. return _is_memoryless(act.activation_post_process)
  473. else:
  474. return _is_memoryless(act)
  475. def _is_reuse_input_qconfig(qconfig: Optional[QConfig]):
  476. return qconfig is not None and \
  477. isinstance(qconfig.activation(), ReuseInputObserver) and \
  478. isinstance(qconfig.weight(), NoopObserver)