linear.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303
  1. from collections.abc import Iterable
  2. import torch
  3. import torch.nn as nn
  4. import torch.ao.nn.intrinsic as nni
  5. import torch.ao.nn.intrinsic.qat as nniqat
  6. from torch.nn.utils.fusion import fuse_linear_bn_weights
  7. from torch.nn.utils.parametrize import type_before_parametrizations
  8. from typing import Optional
  9. from .utils import _quantize_weight, _hide_packed_params_repr, WeightedQuantizedModule
  10. __all__ = ['LinearPackedParams', 'Linear']
  11. class LinearPackedParams(torch.nn.Module):
  12. _version = 3
  13. def __init__(self, dtype=torch.qint8):
  14. super().__init__()
  15. self.dtype = dtype
  16. if self.dtype == torch.qint8:
  17. wq = torch._empty_affine_quantized([1, 1], scale=1.0, zero_point=0, dtype=torch.qint8)
  18. elif self.dtype == torch.float16:
  19. wq = torch.zeros([1, 1], dtype=torch.float)
  20. self.set_weight_bias(wq, None)
  21. @torch.jit.export
  22. def set_weight_bias(self, weight: torch.Tensor, bias: Optional[torch.Tensor]) -> None:
  23. if self.dtype == torch.qint8:
  24. self._packed_params = torch.ops.quantized.linear_prepack(weight, bias)
  25. elif self.dtype == torch.float16:
  26. self._packed_params = torch.ops.quantized.linear_prepack_fp16(weight, bias)
  27. else:
  28. raise RuntimeError('Unsupported dtype on dynamic quantized linear!')
  29. @torch.jit.export
  30. def _weight_bias(self):
  31. if self.dtype == torch.qint8:
  32. return torch.ops.quantized.linear_unpack(self._packed_params)
  33. elif self.dtype == torch.float16:
  34. return torch.ops.quantized.linear_unpack_fp16(self._packed_params)
  35. else:
  36. raise RuntimeError('Unsupported dtype on dynamic quantized linear!')
  37. def forward(self, x):
  38. return x
  39. # Version 1
  40. # self
  41. # |--- weight : Tensor
  42. # |--- bias : Tensor
  43. #
  44. # Version 2
  45. # self
  46. # |--- weight : Tensor
  47. # |--- bias : Tensor
  48. # |--- dtype : torch.dtype
  49. #
  50. # Version 3
  51. # self
  52. # |--- _packed_params : (Tensor, Tensor) representing (weight, bias)
  53. # of LinearPackedParams
  54. # |--- dtype : torch.dtype
  55. def _save_to_state_dict(self, destination, prefix, keep_vars):
  56. super()._save_to_state_dict(destination, prefix, keep_vars)
  57. destination[prefix + 'dtype'] = self.dtype
  58. destination[prefix + '_packed_params'] = self._weight_bias()
  59. def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
  60. missing_keys, unexpected_keys, error_msgs):
  61. version = local_metadata.get('version', None)
  62. if version is None or version < 2:
  63. self.dtype = torch.qint8
  64. else:
  65. self.dtype = state_dict[prefix + 'dtype']
  66. state_dict.pop(prefix + 'dtype')
  67. if version is None or version < 3:
  68. self.set_weight_bias(state_dict[prefix + 'weight'], state_dict[prefix + 'bias'])
  69. state_dict.pop(prefix + 'weight')
  70. state_dict.pop(prefix + 'bias')
  71. if version == 3:
  72. weight, bias = state_dict[prefix + '_packed_params']
  73. state_dict.pop(prefix + '_packed_params')
  74. self.set_weight_bias(weight, bias)
  75. super()._load_from_state_dict(state_dict, prefix, local_metadata, False,
  76. missing_keys, unexpected_keys, error_msgs)
  77. def __repr__(self):
  78. return self._weight_bias().__repr__()
  79. class Linear(WeightedQuantizedModule):
  80. r"""
  81. A quantized linear module with quantized tensor as inputs and outputs.
  82. We adopt the same interface as `torch.nn.Linear`, please see
  83. https://pytorch.org/docs/stable/nn.html#torch.nn.Linear for documentation.
  84. Similar to :class:`~torch.nn.Linear`, attributes will be randomly
  85. initialized at module creation time and will be overwritten later
  86. Attributes:
  87. weight (Tensor): the non-learnable quantized weights of the module of
  88. shape :math:`(\text{out\_features}, \text{in\_features})`.
  89. bias (Tensor): the non-learnable bias of the module of shape :math:`(\text{out\_features})`.
  90. If :attr:`bias` is ``True``, the values are initialized to zero.
  91. scale: `scale` parameter of output Quantized Tensor, type: double
  92. zero_point: `zero_point` parameter for output Quantized Tensor, type: long
  93. Examples::
  94. >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_QENGINE)
  95. >>> m = nn.quantized.Linear(20, 30)
  96. >>> input = torch.randn(128, 20)
  97. >>> # xdoctest: +SKIP
  98. >>> input = torch.quantize_per_tensor(input, 1.0, 0, torch.quint8)
  99. >>> output = m(input)
  100. >>> print(output.size())
  101. torch.Size([128, 30])
  102. """
  103. _version = 3
  104. _FLOAT_MODULE = (nn.Linear, nn.modules.linear.NonDynamicallyQuantizableLinear)
  105. def __init__(self, in_features, out_features, bias_=True,
  106. dtype=torch.qint8):
  107. super().__init__()
  108. # We don't muck around with buffers or attributes or anything here
  109. # to keep the module simple. *everything* is simply a Python attribute.
  110. # Serialization logic is explicitly handled in the below serialization and
  111. # deserialization modules
  112. self.in_features = in_features
  113. self.out_features = out_features
  114. bias = None
  115. if bias_:
  116. bias = torch.zeros(out_features, dtype=torch.float)
  117. if dtype == torch.qint8:
  118. qweight = torch._empty_affine_quantized(
  119. [out_features, in_features], scale=1, zero_point=0, dtype=torch.qint8)
  120. elif dtype == torch.float16:
  121. qweight = torch.zeros([out_features, in_features], dtype=torch.float)
  122. else:
  123. raise RuntimeError('Unsupported dtype specified for quantized Linear!')
  124. self._packed_params = LinearPackedParams(dtype)
  125. self._packed_params.set_weight_bias(qweight, bias)
  126. self.scale = 1.0
  127. self.zero_point = 0
  128. def _get_name(self):
  129. return 'QuantizedLinear'
  130. def extra_repr(self):
  131. return 'in_features={}, out_features={}, scale={}, zero_point={}, qscheme={}'.format(
  132. self.in_features, self.out_features, self.scale, self.zero_point, self.weight().qscheme()
  133. )
  134. def __repr__(self):
  135. return _hide_packed_params_repr(self, LinearPackedParams)
  136. def forward(self, x: torch.Tensor) -> torch.Tensor:
  137. return torch.ops.quantized.linear(
  138. x, self._packed_params._packed_params, self.scale, self.zero_point)
  139. # ===== Serialization methods =====
  140. # The special consideration here is that we have to unpack the weights into their
  141. # regular QTensor form for serialization. Packed weights should not live
  142. # outside the process in which they were created, rather they should be derived
  143. # from the QTensor weight.
  144. #
  145. # Version 1
  146. # self
  147. # |--- scale : float
  148. # |--- zero_point : int
  149. # |--- weight : Tensor
  150. # |--- bias : Tensor
  151. #
  152. # Version 2
  153. # self
  154. # |--- scale : float
  155. # |--- zero_point : int
  156. # |--- _packed_params : Module
  157. # |--- weight : Tensor
  158. # |--- bias : Tensor
  159. #
  160. # Version 3
  161. # self
  162. # |--- scale : float
  163. # |--- zero_point : int
  164. # |--- _packed_params : Module
  165. # |--- _packed_params : (Tensor, Tensor) representing weight, bias
  166. # of LinearPackedParams C++ struct
  167. #
  168. def _save_to_state_dict(self, destination, prefix, keep_vars):
  169. super()._save_to_state_dict(destination, prefix, keep_vars)
  170. destination[prefix + 'scale'] = torch.tensor(self.scale)
  171. destination[prefix + 'zero_point'] = torch.tensor(self.zero_point)
  172. # ===== Deserialization methods =====
  173. # Counterpart to the serialization methods, we must pack the serialized QTensor
  174. # weight into its packed format for use by the FBGEMM ops.
  175. def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
  176. missing_keys, unexpected_keys, error_msgs):
  177. self.scale = float(state_dict[prefix + 'scale'])
  178. state_dict.pop(prefix + 'scale')
  179. self.zero_point = int(state_dict[prefix + 'zero_point'])
  180. state_dict.pop(prefix + 'zero_point')
  181. version = local_metadata.get('version', None)
  182. if version is None or version == 1:
  183. # We moved the parameters into a LinearPackedParameters submodule
  184. weight = state_dict.pop(prefix + 'weight')
  185. bias = state_dict.pop(prefix + 'bias')
  186. state_dict.update({prefix + '_packed_params.weight': weight,
  187. prefix + '_packed_params.bias': bias})
  188. super()._load_from_state_dict(
  189. state_dict, prefix, local_metadata, False,
  190. missing_keys, unexpected_keys, error_msgs)
  191. # Function rather than property to make sure that JIT serialization doesn't
  192. # register this as an attribute
  193. def _weight_bias(self):
  194. return self._packed_params._weight_bias()
  195. def weight(self):
  196. return self._weight_bias()[0]
  197. def bias(self):
  198. return self._weight_bias()[1]
  199. def set_weight_bias(self, w: torch.Tensor, b: Optional[torch.Tensor]) -> None:
  200. self._packed_params.set_weight_bias(w, b)
  201. @classmethod
  202. def from_float(cls, mod):
  203. r"""Create a quantized module from an observed float module
  204. Args:
  205. mod (Module): a float module, either produced by torch.ao.quantization
  206. utilities or provided by the user
  207. """
  208. if hasattr(mod, 'weight_fake_quant'):
  209. if type_before_parametrizations(mod) == nniqat.LinearBn1d:
  210. mod.weight, mod.bias = fuse_linear_bn_weights(
  211. mod.weight, mod.bias, mod.bn.running_mean, mod.bn.running_var,
  212. mod.bn.eps, mod.bn.weight, mod.bn.bias)
  213. weight_post_process = mod.weight_fake_quant
  214. activation_post_process = mod.activation_post_process
  215. else:
  216. # This function does not participate in JIT, so it is OK to ignore
  217. # the type mismatch in assignment. Also, mypy has an issue with
  218. # iterables not being implemented, so we are ignoring those too.
  219. if not isinstance(cls._FLOAT_MODULE, Iterable):
  220. cls._FLOAT_MODULE = [cls._FLOAT_MODULE] # type: ignore[assignment]
  221. supported_modules = ', '.join([float_mod.__name__ for float_mod in cls._FLOAT_MODULE]) # type: ignore[attr-defined]
  222. error_msg = 'nnq.{}.from_float only works for {}, but got: {}'.format(cls.__name__, supported_modules, type(mod))
  223. assert type_before_parametrizations(mod) in cls._FLOAT_MODULE, error_msg.format() # type: ignore[attr-defined]
  224. assert hasattr(mod, 'qconfig'), 'Input float module must have qconfig defined'
  225. activation_post_process = mod.activation_post_process
  226. if type_before_parametrizations(mod) == nni.LinearReLU:
  227. mod = mod[0]
  228. weight_post_process = mod.qconfig.weight()
  229. weight_post_process(mod.weight)
  230. dtype = weight_post_process.dtype
  231. act_scale, act_zp = activation_post_process.calculate_qparams()
  232. assert dtype == torch.qint8, 'Weight observer must have dtype torch.qint8'
  233. qweight = _quantize_weight(mod.weight.float(), weight_post_process)
  234. qlinear = cls(mod.in_features,
  235. mod.out_features,
  236. dtype=dtype)
  237. qlinear.set_weight_bias(qweight, mod.bias)
  238. qlinear.scale = float(act_scale)
  239. qlinear.zero_point = int(act_zp)
  240. return qlinear
  241. @classmethod
  242. def from_reference(cls, ref_qlinear, output_scale, output_zero_point):
  243. r"""Create a (fbgemm/qnnpack) quantized module from a reference quantized module
  244. Args:
  245. ref_qlinear (Module): a reference quantized linear module, either produced by torch.ao.quantization
  246. utilities or provided by the user
  247. output_scale (float): scale for output Tensor
  248. output_zero_point (int): zero point for output Tensor
  249. """
  250. qlinear = cls(
  251. ref_qlinear.in_features,
  252. ref_qlinear.out_features)
  253. qweight = ref_qlinear.get_quantized_weight()
  254. qlinear.set_weight_bias(qweight, ref_qlinear.bias)
  255. qlinear.scale = float(output_scale)
  256. qlinear.zero_point = int(output_zero_point)
  257. return qlinear