utils.py 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766
  1. """
  2. Utils shared by different modes of quantization (eager/graph)
  3. """
  4. import functools
  5. import warnings
  6. from collections import OrderedDict
  7. from inspect import getfullargspec, signature
  8. from typing import Any, Callable, Dict, Optional, Tuple, Union
  9. import torch
  10. from torch.ao.quantization.quant_type import QuantType
  11. from torch.fx import Node
  12. from torch.nn.utils.parametrize import is_parametrized
  13. NodePattern = Union[Tuple[Node, Node], Tuple[Node, Tuple[Node, Node]], Any]
  14. NodePattern.__module__ = "torch.ao.quantization.utils"
  15. # This is the Quantizer class instance from torch/quantization/fx/quantize.py.
  16. # Define separately to prevent circular imports.
  17. # TODO(future PR): improve this.
  18. # make this public once fixed (can't be public as is because setting the module directly
  19. # doesn't work)
  20. QuantizerCls = Any
  21. # Type for fusion patterns, it can be more complicated than the following actually,
  22. # see pattern.md for docs
  23. # TODO: not sure if typing supports recursive data types
  24. Pattern = Union[
  25. Callable, Tuple[Callable, Callable], Tuple[Callable, Tuple[Callable, Callable]], Any
  26. ]
  27. Pattern.__module__ = "torch.ao.quantization.utils"
  28. # TODO: maybe rename this to MatchInputNode
  29. class MatchAllNode:
  30. """ A node pattern that matches all nodes, used in defining
  31. fusion patterns in FX Graph Mode Quantization
  32. """
  33. pass
  34. module_type_list = {
  35. torch.nn.ReLU,
  36. torch.nn.ReLU6,
  37. torch.nn.AdaptiveAvgPool1d,
  38. torch.nn.AdaptiveAvgPool2d,
  39. torch.nn.AdaptiveAvgPool3d,
  40. torch.nn.AvgPool1d,
  41. torch.nn.AvgPool2d,
  42. torch.nn.AvgPool3d,
  43. torch.nn.MaxPool1d,
  44. torch.nn.MaxPool2d,
  45. torch.nn.MaxPool3d,
  46. torch.nn.Identity,
  47. torch.nn.Hardsigmoid,
  48. torch.nn.Sigmoid,
  49. torch.nn.Tanh,
  50. }
  51. func_list = {
  52. torch.nn.functional.adaptive_avg_pool1d,
  53. torch.nn.functional.adaptive_avg_pool2d,
  54. torch.nn.functional.adaptive_avg_pool3d,
  55. torch.nn.functional.elu,
  56. torch.nn.functional.hardswish,
  57. torch.nn.functional.instance_norm,
  58. torch.nn.functional.layer_norm,
  59. torch.nn.functional.leaky_relu,
  60. torch.nn.functional.silu,
  61. torch.nn.functional.mish,
  62. torch.nn.functional.dropout,
  63. torch.nn.functional.max_pool1d,
  64. torch.nn.functional.max_pool2d,
  65. torch.nn.functional.max_pool3d,
  66. torch.nn.functional.relu,
  67. torch.nn.functional.hardtanh,
  68. torch.nn.functional.hardtanh_,
  69. torch.nn.functional.hardsigmoid,
  70. torch.nn.functional.sigmoid,
  71. torch.transpose,
  72. torch.repeat_interleave,
  73. torch.sigmoid,
  74. torch.squeeze,
  75. torch.stack,
  76. torch.sum,
  77. torch.tanh,
  78. torch.unsqueeze,
  79. torch.cat,
  80. }
  81. method_list = {
  82. torch.mean,
  83. 'relu',
  84. 'relu_',
  85. 'contiguous',
  86. 'detach',
  87. 'detach_',
  88. 'hardsigmoid',
  89. 'hardsigmoid_',
  90. 'permute',
  91. 'repeat',
  92. 'repeat_interleave',
  93. 'reshape',
  94. 'resize_',
  95. 'shape',
  96. 'sigmoid',
  97. 'sigmoid_',
  98. 'size',
  99. 'squeeze',
  100. 'squeeze_',
  101. 'tanh',
  102. 'tanh_',
  103. 'transpose',
  104. 'unsqueeze',
  105. 'unsqueeze_',
  106. 'view',
  107. }
  108. # TODO: not used now, remove
  109. def check_node(node, modules):
  110. # TODO: reuse is_fixed_qparam_node after we move this function to _lower_to_native_backend.py
  111. is_call_function = node.op == "call_function" and node.target in func_list
  112. is_call_method = node.op == "call_method" and node.target in method_list
  113. is_call_module = node.op == "call_module" and type(modules[str(node.target)]) in module_type_list
  114. return is_call_function, is_call_method, is_call_module
  115. def get_combined_dict(default_dict, additional_dict):
  116. d = default_dict.copy()
  117. d.update(additional_dict)
  118. return d
  119. def is_per_tensor(qscheme):
  120. return qscheme == torch.per_tensor_affine or \
  121. qscheme == torch.per_tensor_symmetric
  122. def is_per_channel(qscheme):
  123. return qscheme in [torch.per_channel_affine,
  124. torch.per_channel_affine_float_qparams,
  125. torch.per_channel_symmetric]
  126. def getattr_from_fqn(obj: Any, fqn: str) -> Any:
  127. """
  128. Given an obj and a fqn such as "foo.bar.baz", returns gm.foo.bar.baz.
  129. """
  130. return functools.reduce(getattr, fqn.split("."), obj)
  131. def to_underlying_dtype(qdtype):
  132. DTYPE_MAPPING = {
  133. torch.quint8: torch.uint8,
  134. torch.qint8: torch.int8,
  135. torch.qint32: torch.int32,
  136. torch.quint4x2: torch.uint8,
  137. torch.quint2x4: torch.uint8,
  138. }
  139. assert qdtype in DTYPE_MAPPING, "Unsupported dtype: " + qdtype
  140. return DTYPE_MAPPING[qdtype]
  141. def get_qparam_dict(observer_or_fake_quant):
  142. qscheme = observer_or_fake_quant.qscheme if hasattr(observer_or_fake_quant, "qscheme") else None
  143. dtype = observer_or_fake_quant.dtype
  144. qparams = {"qscheme": qscheme, "dtype": dtype}
  145. if not qscheme:
  146. return qparams
  147. if is_per_tensor(qscheme):
  148. qscheme = torch.per_tensor_affine
  149. elif is_per_channel(qscheme):
  150. # change symmetric to affine since we do not have symmetric
  151. # quantized Tensor
  152. if qscheme == torch.per_channel_symmetric:
  153. qscheme = torch.per_channel_affine
  154. qparams["axis"] = observer_or_fake_quant.ch_axis
  155. else:
  156. raise RuntimeError(f"Unrecognized qscheme: {qscheme}")
  157. # update qscheme, since we don't have symmetric quant qscheme
  158. # in quantized Tensor
  159. qparams["qscheme"] = qscheme
  160. scale, zero_point = observer_or_fake_quant.calculate_qparams()
  161. qparams["scale"] = scale
  162. qparams["zero_point"] = zero_point
  163. return qparams
  164. def get_swapped_custom_module_class(custom_module, custom_module_class_mapping, qconfig):
  165. """ Get the observed/quantized custom module class that we need
  166. to swap `custom_module` to
  167. Input:
  168. custom_module: input, can be an instance of either a float or observed custom module
  169. custom_module_class_mapping: the float to observed or observed to quantized custom module class mapping
  170. qconfig: qconfig configured for the custom module
  171. Output:
  172. corresponding observed/quantized custom module class for input custom module instance
  173. """
  174. quant_type = get_quant_type(qconfig)
  175. class_mapping = custom_module_class_mapping.get(quant_type, {})
  176. assert type(custom_module) in class_mapping, "did not find corresponding observed " \
  177. "module class for {} in mapping: {}".format(type(custom_module), class_mapping)
  178. return class_mapping[type(custom_module)]
  179. def activation_dtype(qconfig):
  180. assert qconfig is not None
  181. activation = qconfig.activation()
  182. return activation.dtype
  183. def weight_dtype(qconfig):
  184. assert qconfig is not None
  185. weight = qconfig.weight()
  186. return weight.dtype
  187. def activation_is_statically_quantized(qconfig):
  188. """ Given a qconfig, decide if the activation needs to be
  189. quantized or not, this includes quantizing to quint8, qint8 and qint32 and float16
  190. """
  191. return (
  192. activation_dtype(qconfig) in [torch.quint8, torch.qint8, torch.qint32, torch.float16]
  193. and (not activation_is_dynamically_quantized(qconfig))
  194. )
  195. def activation_is_dynamically_quantized(qconfig):
  196. """ Given a qconfig, decide if the activation needs to be
  197. dynamically quantized or not, this includes dynamically quantizing to
  198. quint8, qint8 and float16
  199. """
  200. activation_dtype, _, activation_is_dynamic = \
  201. get_qconfig_dtypes(qconfig)
  202. return activation_is_dynamic
  203. def activation_is_int8_quantized(qconfig):
  204. """ Given a qconfig, decide if the activation needs to be
  205. quantized to int8 or not, this includes quantizing to quint8, qint8
  206. """
  207. return activation_dtype(qconfig) in [torch.quint8, torch.qint8]
  208. def activation_is_int32_quantized(qconfig):
  209. """ Given a qconfig, decide if the activation needs to be
  210. quantized to int32 or not
  211. """
  212. return activation_dtype(qconfig) == torch.qint32
  213. def weight_is_quantized(qconfig):
  214. """ Given a qconfig, decide if the weight needs to be
  215. quantized or not
  216. """
  217. return weight_dtype(qconfig) in [torch.quint8, torch.qint8, torch.float16, torch.quint4x2]
  218. def weight_is_statically_quantized(qconfig):
  219. """ Given a qconfig, decide if the weight needs to be statically
  220. quantized or not
  221. """
  222. return weight_dtype(qconfig) in [torch.quint8, torch.qint8]
  223. def op_is_int8_dynamically_quantized(qconfig) -> bool:
  224. """ Given a qconfig, returns True if this op is using int8 dynamic
  225. quantization
  226. """
  227. activation_dtype, weight_dtype, activation_is_dynamic = \
  228. get_qconfig_dtypes(qconfig)
  229. return (
  230. activation_dtype is torch.quint8 and
  231. # for now, the lines below assume fbgemm or qnnpack
  232. weight_dtype is torch.qint8 and
  233. activation_is_dynamic
  234. )
  235. def get_qconfig_dtypes(qconfig):
  236. r""" returns the qconfig tuple for qconfig:
  237. (activation_dtype, weight_dtype, activation_is_dynamic)
  238. """
  239. assert qconfig is not None
  240. activation = qconfig.activation()
  241. weight = qconfig.weight()
  242. act_is_dynamic = activation.is_dynamic if hasattr(activation, 'is_dynamic') else False
  243. return (activation.dtype, weight.dtype, act_is_dynamic)
  244. def get_quant_type(qconfig):
  245. assert qconfig is not None
  246. activation = qconfig.activation()
  247. weight = qconfig.weight()
  248. static_dtypes = [torch.quint8, torch.qint8, torch.quint4x2, torch.qint32]
  249. if weight.dtype in static_dtypes:
  250. if hasattr(activation, 'is_dynamic') and activation.is_dynamic:
  251. return QuantType.DYNAMIC
  252. elif activation.dtype in static_dtypes:
  253. return QuantType.STATIC
  254. else:
  255. return QuantType.WEIGHT_ONLY
  256. if weight.dtype == torch.float16:
  257. if hasattr(activation, 'is_dynamic') and activation.is_dynamic:
  258. return QuantType.DYNAMIC
  259. elif activation.dtype == torch.float16:
  260. return QuantType.STATIC
  261. raise Exception("Unrecognized dtype combination in get_quant_type: activation({}),"
  262. "weight({})".format(activation.dtype, weight.dtype))
  263. def check_min_max_valid(min_val: torch.Tensor, max_val: torch.Tensor) -> bool:
  264. """ Checks if the given minimum and maximum values are valid, meaning that
  265. they exist and the min value is less than the max value.
  266. """
  267. if min_val.numel() == 0 or max_val.numel() == 0:
  268. warnings.warn(
  269. "must run observer before calling calculate_qparams. " +
  270. "Returning default values."
  271. )
  272. return False
  273. if min_val.dim() == 0 or max_val.dim() == 0:
  274. if min_val == float("inf") and max_val == float("-inf"):
  275. warnings.warn(
  276. "must run observer before calling calculate_qparams. " +
  277. "Returning default values."
  278. )
  279. return False
  280. assert min_val <= max_val, "min {} should be less than max {}".format(
  281. min_val, max_val
  282. )
  283. else:
  284. assert torch.all(
  285. min_val <= max_val
  286. ), "min {} should be less than max {}".format(min_val, max_val)
  287. return True
  288. def calculate_qmin_qmax(quant_min: int, quant_max: int, has_customized_qrange: bool, dtype: torch.dtype,
  289. reduce_range: bool) -> Tuple[int, int]:
  290. r"""Calculates actual qmin and qmax based on the quantization range,
  291. observer datatype and if range is reduced.
  292. """
  293. # TODO(jerryzh): Figure out why custom quant_min/quant_max are still adjusted.
  294. if has_customized_qrange:
  295. # This initialization here is to be resolve TorchScript compilation issues and allow
  296. # using of refinement to decouple initial_qmin and initial_qmax from quantization range.
  297. # The actual values of initial_qmin and initial_qmax will be reset below.
  298. if dtype == torch.qint32:
  299. initial_quant_min, initial_quant_max = 0, 2**31 - 1
  300. else:
  301. initial_quant_min, initial_quant_max = 0, 255
  302. # The following assignment of self.qmin and self.qmax to the local variables and the if check refine the
  303. # attribute from Optional valid integers for use, based on TorchScript's requirements.
  304. custom_quant_min, custom_quant_max = quant_min, quant_max
  305. if custom_quant_min is not None and custom_quant_max is not None:
  306. initial_quant_min, initial_quant_max = (
  307. custom_quant_min,
  308. custom_quant_max,
  309. )
  310. qrange_len = initial_quant_max - initial_quant_min + 1
  311. if dtype == torch.qint8:
  312. assert (
  313. 0 < qrange_len <= 256
  314. ), "quantization range should be positive and not exceed the maximum bit range (=256)."
  315. elif dtype == torch.qint32:
  316. assert (
  317. 0 < qrange_len <= 2**31
  318. ), "quantization range should be positive and not exceed the maximum bit range (=4294967296)."
  319. if reduce_range:
  320. quant_min, quant_max = quant_min // 2, quant_max // 2
  321. else:
  322. # Fallback onto default 8-bit qmin and qmax calculation if dynamic range is not used.
  323. if dtype == torch.qint8:
  324. if reduce_range:
  325. quant_min, quant_max = -64, 63
  326. else:
  327. quant_min, quant_max = -128, 127
  328. elif dtype == torch.quint8:
  329. if reduce_range:
  330. quant_min, quant_max = 0, 127
  331. else:
  332. quant_min, quant_max = 0, 255
  333. elif dtype == torch.qint32:
  334. quant_min, quant_max = -1 * (2 ** 31), (2 ** 31) - 1
  335. else:
  336. quant_min, quant_max = 0, 15
  337. return quant_min, quant_max
  338. def _parent_name(target):
  339. """
  340. Turn 'foo.bar' into ['foo', 'bar']
  341. """
  342. r = target.rsplit('.', 1)
  343. if len(r) == 1:
  344. return '', r[0]
  345. else:
  346. return r[0], r[1]
  347. def has_no_children_ignoring_parametrizations(module):
  348. """
  349. Checks if module._modules is empty or
  350. if module is a parametrization, checks that module._modules only has
  351. the 'parametrizations' module
  352. """
  353. if len(module._modules) == 0:
  354. return True
  355. elif is_parametrized(module):
  356. return len(module._modules) == 1 and 'parametrizations' in module._modules
  357. else:
  358. return False
  359. def _get_path_of_module(root: torch.nn.Module, submodule: torch.nn.Module) -> Optional[str]:
  360. """ Get the path (fully qualified name) of a submodule
  361. Example::
  362. >> class M(torch.nn.Module):
  363. def __init__(self):
  364. self.linear = torch.nn.Linear(5, 5)
  365. def forward(self, x):
  366. return self.linear(x)
  367. >> m = M()
  368. >> l = m.linear
  369. >> _get_path_of_module(m, l)
  370. "linear"
  371. """
  372. for n, p in root.named_modules():
  373. if submodule is p:
  374. return n
  375. return None
  376. def _get_signature_locals(f: Callable, loc: Dict[str, Any]) -> Dict[str, Any]:
  377. """ Get local keyword arguments
  378. Example::
  379. >> def f(self, a, b=9):
  380. pass
  381. >> loc = {"a": 6, "c": 7}
  382. >> _get_signature_locals(f, loc)
  383. {"a": 6}
  384. """
  385. return {k: v for k, v in loc.items() if k in signature(f).parameters}
  386. def _get_default_kwargs(f: Callable) -> "OrderedDict[str, Any]":
  387. """ Get all default keyword arguments from function signature
  388. Example::
  389. >> def f(self, a, b=9):
  390. pass
  391. >> _get_default_kwargs(f)
  392. {"b": 9}
  393. """
  394. kwargs = {}
  395. for name, param in signature(f).parameters.items():
  396. if param.default is not param.empty:
  397. kwargs[name] = param.default
  398. elif param.kind is param.VAR_POSITIONAL:
  399. kwargs[name] = ()
  400. elif param.kind is param.VAR_KEYWORD:
  401. kwargs[name] = {}
  402. return OrderedDict(kwargs)
  403. def _normalize_kwargs(func: Callable, loc: Dict[str, Any]) -> "OrderedDict[str, Any]":
  404. """ Given a function and local function arguments, normalize the keyword
  405. arguments by filling in default arguments from function signature
  406. Example::
  407. >> def f(self, key1=3, key2=3):
  408. pass
  409. >> loc = {"key2": 6}
  410. >> _normalize_kwargs(f, loc)
  411. {"key1": 3, "key2": 6}
  412. """
  413. default_kwargs = _get_default_kwargs(func)
  414. local_kwargs = _get_signature_locals(func, loc)
  415. normalized_kwargs = default_kwargs.copy()
  416. for attr, val in local_kwargs.items():
  417. if attr in normalized_kwargs:
  418. # override the default keyword arguments
  419. normalized_kwargs[attr] = val
  420. return normalized_kwargs
  421. def validate_qmin_qmax(quant_min: int, quant_max: int) -> None:
  422. r"""Validates that the user-specified quantization range is properly initialized
  423. and within the given bound supported by the observer dtype.
  424. To accommodate lower-bit quantization with respect to the existing torch.qint8 and
  425. torch.quint8 datatypes, the user can choose to use dynamic quantization range by passing
  426. in a tuple of initial qmin and qmax values. One use case is these customized qmin and qmax
  427. values are used to calculate static estimates of the scale and zero point for aggressive lower-bit
  428. fake quantization. These estimates are compared against parameters learned through backpropagation.
  429. The related literatures for scale and zero point via backpropagation are as follows:
  430. Learned Step Size Quantization: https://openreview.net/pdf?id=rkgO66VKDS
  431. Trained Quantization Thresholds: https://arxiv.org/pdf/1903.08066.pdf
  432. """
  433. # The variable names are prefixed with "initial" because their values (qmin and qmax) might be adjusted
  434. # based on whether quantization range is reduced and the datatype (signed/unsigned) used by the observer.
  435. assert (
  436. quant_min <= 0 <= quant_max
  437. ), "Used-specified quantization range must include 0."
  438. assert (
  439. quant_min < quant_max
  440. ), "qmin must be strictly less than qmax for user-specified quantization range."
  441. # Functionally equivalent to '_calculate_qparams' in observer.py. Observers must be torchscriptable however and qscheme
  442. # as far as I can tell is not allowed to passed as a parameter in torchscript functions. This makes refactoring observer
  443. # to use this utility a massive pain and very gross. For now Im opting just to duplicate as this code seems unlikey to change
  444. # (last update over 1 year ago) and when torchscript is fully deprecated we can refactor. TODO(jakeszwe, jerryzh168)
  445. def determine_qparams(
  446. min_val: torch.Tensor, max_val: torch.Tensor, quant_min: int, quant_max: int,
  447. dtype: torch.dtype, eps: torch.Tensor, has_customized_qrange: bool,
  448. qscheme: torch.qscheme = torch.per_tensor_affine) -> Tuple[torch.Tensor, torch.Tensor]:
  449. r"""Calculates the quantization parameters, given min and max
  450. value tensors. Works for both per tensor and per channel cases
  451. Args:
  452. min_val: Minimum values per channel
  453. max_val: Maximum values per channel
  454. Returns:
  455. scales: Scales tensor of shape (#channels,)
  456. zero_points: Zero points tensor of shape (#channels,)
  457. """
  458. if not check_min_max_valid(min_val, max_val):
  459. return torch.tensor([1.0], device=min_val.device.type), torch.tensor([0], device=min_val.device.type)
  460. min_val_neg = torch.min(min_val, torch.zeros_like(min_val))
  461. max_val_pos = torch.max(max_val, torch.zeros_like(max_val))
  462. device = min_val_neg.device
  463. scale = torch.ones(min_val_neg.size(), dtype=torch.float32, device=device)
  464. zero_point = torch.zeros(min_val_neg.size(), dtype=torch.int64, device=device)
  465. if (
  466. qscheme == torch.per_tensor_symmetric
  467. or qscheme == torch.per_channel_symmetric
  468. ):
  469. max_val_pos = torch.max(-min_val_neg, max_val_pos)
  470. scale = max_val_pos / (float(quant_max - quant_min) / 2)
  471. scale = torch.max(scale, eps)
  472. if dtype == torch.uint8 or dtype == torch.quint8:
  473. if has_customized_qrange:
  474. # When customized quantization range is used, down-rounded midpoint of the range is chosen.
  475. zero_point = zero_point.new_full(
  476. zero_point.size(), (quant_min + quant_max) // 2
  477. )
  478. else:
  479. zero_point = zero_point.new_full(zero_point.size(), 128)
  480. elif qscheme == torch.per_channel_affine_float_qparams:
  481. scale = (max_val - min_val) / float(quant_max - quant_min)
  482. scale = torch.where(scale > eps, scale, torch.ones_like(scale))
  483. # We use the quantize function
  484. # xq = Round(Xf * inv_scale + zero_point),
  485. # setting zero_point to (-1 * min *inv_scale) we get
  486. # Xq = Round((Xf - min) * inv_scale)
  487. zero_point = -1 * min_val / scale
  488. else:
  489. scale = (max_val_pos - min_val_neg) / float(quant_max - quant_min)
  490. scale = torch.max(scale, eps)
  491. zero_point = quant_min - torch.round(min_val_neg / scale).to(torch.int)
  492. zero_point = torch.clamp(zero_point, quant_min, quant_max)
  493. # For scalar values, cast them to Tensors of size 1 to keep the shape
  494. # consistent with default values in FakeQuantize.
  495. if len(scale.shape) == 0:
  496. # TODO: switch to scale.item() after adding JIT support
  497. scale = torch.tensor([float(scale)], dtype=scale.dtype, device=device)
  498. if len(zero_point.shape) == 0:
  499. # TODO: switch to zero_point.item() after adding JIT support
  500. zero_point = torch.tensor(
  501. [int(zero_point)], dtype=zero_point.dtype, device=device
  502. )
  503. if qscheme == torch.per_channel_affine_float_qparams:
  504. zero_point = torch.tensor(
  505. [float(zero_point)], dtype=zero_point.dtype, device=device
  506. )
  507. return scale, zero_point
  508. def _get_num_pos_args(f: Callable) -> int:
  509. """ Get number of positional args for a function
  510. Example::
  511. >> def f(self, key1=3, key2=3):
  512. pass
  513. >> _get_num_pos_args(f)
  514. 3
  515. """
  516. return len(getfullargspec(f).args)
  517. def get_fqn_to_example_inputs(
  518. model: torch.nn.Module,
  519. example_inputs: Tuple[Any, ...]
  520. ) -> Dict[str, Tuple[Any, ...]]:
  521. """ Given a model and its example inputs, return a dictionary from
  522. fully qualified name of submodules to example_inputs for that submodule,
  523. e.g. {"linear1": (tensor1,), "linear2": (tensor2,), "sub": (tensor3,),
  524. "sub.linear1": (tensor4,), ...}
  525. Used to make quantizing submodules easier now that FX Graph Mode Quantization requries
  526. example inputs.
  527. Also works for keyword arguments with default values, we would flatten keyword
  528. arguments as positional arguments and fill in the missing keyword args with default
  529. values, e.g. if we have a forward function:
  530. def forward(self, x, key1=3, key2=3):
  531. ...
  532. and we call it with self.submodule(x, key2=6)
  533. we'll get example_inputs: (x, 3, 6)
  534. user can also override `key1` with positional arguments as well:
  535. for self.submodule(x, 5, key2=6)
  536. we'll get: (x, 5, 6)
  537. variable positional arguments and variable positional keyword arguments in forward
  538. function are not supported currently, so please make sure no submodules is using
  539. them.
  540. """
  541. root = model
  542. fqn_to_example_inputs = {}
  543. def _patched_module_call(self, *args, **kwargs):
  544. submodule_example_inputs = list(args).copy()
  545. normalized_kwargs = _normalize_kwargs(self.forward, kwargs)
  546. # minus 1 to skipping counting `self`
  547. num_args = _get_num_pos_args(self.forward) - 1
  548. num_to_pop = num_args - len(submodule_example_inputs)
  549. while num_to_pop and normalized_kwargs:
  550. normalized_kwargs.popitem(last=False)
  551. num_to_pop -= 1
  552. submodule_example_inputs.extend(normalized_kwargs.values())
  553. submodule_example_inputs_tuple = tuple(submodule_example_inputs)
  554. fqn = _get_path_of_module(root, self)
  555. if fqn is not None:
  556. fqn_to_example_inputs[fqn] = submodule_example_inputs_tuple
  557. return orig_module_call(self, *args, **kwargs)
  558. orig_module_call = torch.nn.Module.__call__
  559. torch.nn.Module.__call__ = _patched_module_call
  560. try:
  561. model(*example_inputs)
  562. finally:
  563. # restore the module call even if there is an exception
  564. torch.nn.Module.__call__ = orig_module_call
  565. return fqn_to_example_inputs
  566. def _get_lstm_with_individually_observed_parts(
  567. float_lstm: torch.nn.LSTM,
  568. # Use Callable instead of _PartialWrapper here to avoid circular dependencies
  569. linear_output_obs_ctr: Optional[Callable] = None,
  570. sigmoid_obs_ctr: Optional[Callable] = None,
  571. tanh_obs_ctr: Optional[Callable] = None,
  572. cell_state_obs_ctr: Optional[Callable] = None,
  573. hidden_state_obs_ctr: Optional[Callable] = None,
  574. ) -> torch.ao.nn.quantizable.LSTM:
  575. """
  576. Return an observed `torch.ao.nn.quantizable.LSTM` created from a `torch.nn.LSTM`
  577. with specific observers or fake quantizes assigned to the inner ops or submodules.
  578. In both eager and FX graph mode quantization, `torch.ao.nn.quantizable.LSTM` is
  579. used as an observed custom module, which is responsible for inserting its own
  580. observers. By default, all inner ops inherit the parent custom module's QConfig.
  581. Users who wish to override this behavior may extend `torch.ao.nn.quantizable.LSTM`
  582. and use this helper function to customize the observer insertion logic.
  583. Args:
  584. `float_lstm`: The float LSTM module
  585. `linear_output_obs_ctr`: observer or fake quantize for linear outputs Wx + b,
  586. where W is the weight matrix, b is the bias, and x is either the inputs
  587. or the hidden state from the previous layer (if any)
  588. `sigmoid_obs_ctr`: observer or fake quantize for sigmoid activations
  589. `tanh_obs_ctr`: observer or fake quantize for tanh activations
  590. `cell_state_obs_ctr`: observer or fake quantize for the cell state
  591. `hidden_state_obs_ctr`: observer or fake quantize for the hidden state and
  592. the output
  593. Return:
  594. A `torch.ao.nn.quantizable.LSTM` with the specified observers or fake quantizes
  595. attached to the inner submodules.
  596. """
  597. def make_qconfig(obs_ctr: Callable) -> torch.ao.quantization.QConfig:
  598. """
  599. Make a QConfig with fixed qparams observers or fake quantizes.
  600. """
  601. if isinstance(obs_ctr(), torch.ao.quantization.FakeQuantizeBase):
  602. weight = torch.ao.quantization.default_weight_fake_quant
  603. else:
  604. weight = torch.ao.quantization.default_weight_observer
  605. return torch.ao.quantization.QConfig(activation=obs_ctr, weight=weight)
  606. observed_lstm = torch.ao.nn.quantizable.LSTM(
  607. float_lstm.input_size, float_lstm.hidden_size, float_lstm.num_layers, float_lstm.bias,
  608. float_lstm.batch_first, float_lstm.dropout, float_lstm.bidirectional)
  609. # Assign QConfigs with fixed qparams to all inner submodules
  610. # Module hierarchy: LSTM > _LSTMLayer > _LSTMSingleLayer (forward or backward) > LSTMCell
  611. for layer in observed_lstm.layers:
  612. inner_layers = [layer.layer_fw]
  613. if float_lstm.bidirectional:
  614. inner_layers.append(layer.layer_bw)
  615. for inner_layer in inner_layers:
  616. cell = inner_layer.cell
  617. if linear_output_obs_ctr is not None:
  618. qconfig = make_qconfig(linear_output_obs_ctr)
  619. cell.igates.qconfig = qconfig
  620. cell.hgates.qconfig = qconfig
  621. if sigmoid_obs_ctr is not None:
  622. qconfig = make_qconfig(sigmoid_obs_ctr)
  623. cell.input_gate.qconfig = qconfig
  624. cell.forget_gate.qconfig = qconfig
  625. cell.output_gate.qconfig = qconfig
  626. if tanh_obs_ctr is not None:
  627. cell.cell_gate.qconfig = make_qconfig(tanh_obs_ctr)
  628. if cell_state_obs_ctr is not None:
  629. cell.fgate_cx_igate_cgate.qconfig = make_qconfig(cell_state_obs_ctr)
  630. obs = cell_state_obs_ctr()
  631. if hasattr(obs, "scale") and hasattr(obs, "zero_point"):
  632. cell.initial_cell_state_qparams = (obs.scale, obs.zero_point)
  633. cell.cell_state_dtype = obs.dtype
  634. if hidden_state_obs_ctr is not None:
  635. cell.ogate_cy.qconfig = make_qconfig(hidden_state_obs_ctr)
  636. obs = hidden_state_obs_ctr()
  637. if hasattr(obs, "scale") and hasattr(obs, "zero_point"):
  638. cell.initial_hidden_state_qparams = (obs.scale, obs.zero_point)
  639. cell.hidden_state_dtype = obs.dtype
  640. # need to do this here to avoid circular dependency
  641. from torch.ao.quantization.quantize import _add_observer_
  642. # Insert the observers based on the previously attached QConfigs
  643. # Pass in non_leaf_module_list to prevent the observers for sigmoid/tanh from being overridden
  644. _add_observer_( # type: ignore[attr-defined]
  645. observed_lstm,
  646. non_leaf_module_list=[torch.nn.Sigmoid, torch.nn.Tanh]
  647. )
  648. return observed_lstm
  649. __all__ = [
  650. "NodePattern",
  651. "Pattern",
  652. "MatchAllNode",
  653. "check_node",
  654. "get_combined_dict",
  655. "is_per_tensor",
  656. "is_per_channel",
  657. "getattr_from_fqn",
  658. "get_qparam_dict",
  659. "get_swapped_custom_module_class",
  660. "activation_dtype",
  661. "weight_dtype",
  662. "activation_is_statically_quantized",
  663. "activation_is_dynamically_quantized",
  664. "activation_is_int8_quantized",
  665. "activation_is_int32_quantized",
  666. "weight_is_quantized",
  667. "weight_is_statically_quantized",
  668. "op_is_int8_dynamically_quantized",
  669. "get_qconfig_dtypes",
  670. "get_quant_type",
  671. "check_min_max_valid",
  672. "calculate_qmin_qmax",
  673. "has_no_children_ignoring_parametrizations",
  674. "get_fqn_to_example_inputs",
  675. "to_underlying_dtype",
  676. "determine_qparams",
  677. "validate_qmin_qmax",
  678. ]