embedding_ops.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295
  1. import torch
  2. import torch.nn as nn
  3. from torch import Tensor # noqa: F401
  4. from torch._jit_internal import Optional, List # noqa: F401
  5. from .utils import _hide_packed_params_repr
  6. from .utils import _quantize_weight
  7. __all__ = ['EmbeddingPackedParams', 'Embedding', 'EmbeddingBag']
  8. class EmbeddingPackedParams(torch.nn.Module):
  9. _version = 1
  10. def __init__(self, num_embeddings, embedding_dim, dtype=torch.quint8):
  11. super().__init__()
  12. self.dtype = dtype
  13. if self.dtype in [torch.quint8, torch.quint4x2]:
  14. scales = torch.ones(num_embeddings, dtype=torch.float)
  15. zero_points = torch.zeros(num_embeddings, dtype=torch.float)
  16. wq = torch._empty_per_channel_affine_quantized([num_embeddings, embedding_dim], scales=scales,
  17. zero_points=zero_points,
  18. axis=0, dtype=self.dtype)
  19. self.set_weight(wq)
  20. else:
  21. raise NotImplementedError(f'Unsupported dtype on quantized embedding! Supports quint8 and quint4x2. Got dtype: {dtype}')
  22. @torch.jit.export
  23. def set_weight(self, weight: torch.Tensor) -> None:
  24. if self.dtype in [torch.quint8, torch.quint4x2]:
  25. self._packed_weight = torch.ops.quantized.embedding_bag_prepack(weight)
  26. else:
  27. raise NotImplementedError('Unsupported dtype for quantized embedding prepack! Supports quint8 and quint4x2.')
  28. @torch.jit.export
  29. def _weight(self):
  30. if self.dtype in [torch.quint8, torch.quint4x2]:
  31. return torch.ops.quantized.embedding_bag_unpack(self._packed_weight)
  32. else:
  33. raise NotImplementedError('Unsupported dtype for quantized embedding unpack! Supports quint8 and quint4x2.')
  34. def forward(self, x):
  35. return x
  36. # Version 1
  37. # self
  38. # |--- _packed_weight : Tensor representing weight of EmbeddingPackedParamsBase
  39. # |--- dtype : torch.dtype
  40. def _save_to_state_dict(self, destination, prefix, keep_vars):
  41. super()._save_to_state_dict(destination, prefix, keep_vars)
  42. destination[prefix + 'dtype'] = self.dtype
  43. destination[prefix + '_packed_weight'] = self._weight()
  44. def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
  45. missing_keys, unexpected_keys, error_msgs):
  46. self.dtype = state_dict[prefix + 'dtype']
  47. state_dict.pop(prefix + 'dtype')
  48. weight = state_dict[prefix + '_packed_weight']
  49. state_dict.pop(prefix + '_packed_weight')
  50. self.set_weight(weight)
  51. super()._load_from_state_dict(state_dict, prefix, local_metadata, False,
  52. missing_keys, unexpected_keys, error_msgs)
  53. def __repr__(self):
  54. return self._weight().__repr__()
  55. class Embedding(torch.nn.Module):
  56. r"""
  57. A quantized Embedding module with quantized packed weights as inputs.
  58. We adopt the same interface as `torch.nn.Embedding`, please see
  59. https://pytorch.org/docs/stable/nn.html#torch.nn.Embedding for documentation.
  60. Similar to :class:`~torch.nn.Embedding`, attributes will be randomly
  61. initialized at module creation time and will be overwritten later
  62. Attributes:
  63. weight (Tensor): the non-learnable quantized weights of the module of
  64. shape :math:`(\text{num\_embeddings}, \text{embedding\_dim})`.
  65. Examples::
  66. >>> m = nn.quantized.Embedding(num_embeddings=10, embedding_dim=12)
  67. >>> indices = torch.tensor([9, 6, 5, 7, 8, 8, 9, 2, 8])
  68. >>> output = m(indices)
  69. >>> print(output.size())
  70. torch.Size([9, 12])
  71. """
  72. _version = 1
  73. def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None,
  74. max_norm: Optional[float] = None, norm_type: float = 2., scale_grad_by_freq: bool = False,
  75. sparse: bool = False, _weight: Optional[Tensor] = None, dtype=torch.quint8) -> None:
  76. super().__init__()
  77. self.num_embeddings = num_embeddings
  78. self.embedding_dim = embedding_dim
  79. self.dtype = dtype
  80. if _weight is None:
  81. scales = torch.ones(num_embeddings, dtype=torch.float)
  82. zero_points = torch.zeros(num_embeddings, dtype=torch.float)
  83. qweight = torch._empty_per_channel_affine_quantized([num_embeddings, embedding_dim],
  84. scales=scales, zero_points=zero_points,
  85. axis=0, dtype=torch.quint8)
  86. else:
  87. assert list(_weight.shape) == [num_embeddings, embedding_dim], \
  88. 'Shape of weight does not match num_embeddings and embedding_dim'
  89. qweight = _weight
  90. self._packed_params = EmbeddingPackedParams(num_embeddings, embedding_dim, dtype)
  91. self._packed_params.set_weight(qweight)
  92. def forward(self, indices: Tensor) -> Tensor:
  93. if self.dtype == torch.quint4x2:
  94. return torch.ops.quantized.embedding_4bit(self._packed_params._packed_weight, indices)
  95. else:
  96. return torch.ops.quantized.embedding_byte(self._packed_params._packed_weight, indices)
  97. def _get_name(self):
  98. return 'QuantizedEmbedding'
  99. def __repr__(self):
  100. return _hide_packed_params_repr(self, EmbeddingPackedParams)
  101. def extra_repr(self):
  102. extra_repr_str = 'num_embeddings={}, embedding_dim={}, dtype={}, qscheme={}'.format(
  103. self.num_embeddings, self.embedding_dim, self._packed_params.dtype, self.weight().qscheme()
  104. )
  105. return extra_repr_str
  106. def set_weight(self, w: torch.Tensor) -> None:
  107. self._packed_params.set_weight(w)
  108. def weight(self):
  109. return self._packed_params._weight()
  110. @classmethod
  111. def from_float(cls, mod):
  112. r"""Create a quantized embedding module from a float module
  113. Args:
  114. mod (Module): a float module, either produced by torch.ao.quantization
  115. utilities or provided by user
  116. """
  117. if hasattr(mod, 'weight_fake_quant'):
  118. assert type(mod) == torch.ao.nn.qat.Embedding, 'nnq.' + cls.__name__ + '.from_float ' + \
  119. 'with fake quant only works for ' + torch.ao.nn.qat.Embedding.__name__
  120. weight_observer = mod.weight_fake_quant
  121. activation_post_process = mod.activation_post_process
  122. else:
  123. assert type(mod) == nn.Embedding, 'nnq.' + cls.__name__ + '.from_float only works for ' + \
  124. nn.Embedding.__name__
  125. assert hasattr(mod, 'qconfig'), 'Embedding input float module must have qconfig defined'
  126. from torch.ao.quantization import float_qparams_weight_only_qconfig
  127. if mod.qconfig is not None and mod.qconfig.weight is not None: # type: ignore[union-attr]
  128. weight_observer = mod.qconfig.weight() # type: ignore[union-attr, operator]
  129. else:
  130. weight_observer = float_qparams_weight_only_qconfig.weight()
  131. dtype = weight_observer.dtype
  132. is_float_qparams_qconfig = weight_observer.qscheme == torch.per_channel_affine_float_qparams
  133. assert is_float_qparams_qconfig, \
  134. 'Embedding quantization is only supported with float_qparams_weight_only_qconfig.'
  135. assert dtype == torch.quint8 or dtype == torch.quint4x2, \
  136. f'The only supported dtype for nnq.Embedding is torch.quint8 and torch.quint4x2, got {dtype}'
  137. # Run the observer to calculate qparams.
  138. weight_observer(mod.weight)
  139. qweight = _quantize_weight(mod.weight.float(), weight_observer)
  140. # Create quantized Embedding module and pass in the quantized weight
  141. qembedding = Embedding(mod.num_embeddings, mod.embedding_dim)
  142. qembedding.set_weight(qweight)
  143. return qembedding
  144. @classmethod
  145. def from_reference(cls, ref_embedding):
  146. qembedding = cls(
  147. ref_embedding.num_embeddings,
  148. ref_embedding.embedding_dim,
  149. ref_embedding.padding_idx,
  150. ref_embedding.max_norm,
  151. ref_embedding.norm_type,
  152. ref_embedding.scale_grad_by_freq,
  153. ref_embedding.sparse,
  154. ref_embedding.get_quantized_weight(),
  155. ref_embedding.weight_dtype,
  156. )
  157. return qembedding
  158. class EmbeddingBag(Embedding):
  159. r"""
  160. A quantized EmbeddingBag module with quantized packed weights as inputs.
  161. We adopt the same interface as `torch.nn.EmbeddingBag`, please see
  162. https://pytorch.org/docs/stable/nn.html#torch.nn.EmbeddingBag for documentation.
  163. Similar to :class:`~torch.nn.EmbeddingBag`, attributes will be randomly
  164. initialized at module creation time and will be overwritten later
  165. Attributes:
  166. weight (Tensor): the non-learnable quantized weights of the module of
  167. shape :math:`(\text{num\_embeddings}, \text{embedding\_dim})`.
  168. Examples::
  169. >>> m = nn.quantized.EmbeddingBag(num_embeddings=10, embedding_dim=12, include_last_offset=True, mode='sum')
  170. >>> indices = torch.tensor([9, 6, 5, 7, 8, 8, 9, 2, 8, 6, 6, 9, 1, 6, 8, 8, 3, 2, 3, 6, 3, 6, 5, 7, 0, 8, 4, 6, 5, 8, 2, 3])
  171. >>> offsets = torch.tensor([0, 19, 20, 28, 28, 32])
  172. >>> output = m(indices, offsets)
  173. >>> print(output.size())
  174. torch.Size([5, 12])
  175. """
  176. _version = 1
  177. def __init__(self, num_embeddings: int, embedding_dim: int,
  178. max_norm: Optional[float] = None, norm_type: float = 2., scale_grad_by_freq: bool = False,
  179. mode: str = 'sum', sparse: bool = False, _weight: Optional[Tensor] = None,
  180. include_last_offset: bool = False, dtype=torch.quint8) -> None:
  181. super().__init__(num_embeddings, embedding_dim, _weight=_weight, dtype=dtype)
  182. self.mode = mode
  183. self.pruned_weights = False
  184. self.include_last_offset = include_last_offset
  185. self.dtype = dtype
  186. def forward(self, indices: Tensor, offsets: Optional[Tensor] = None, per_sample_weights: Optional[Tensor] = None,
  187. compressed_indices_mapping: Optional[Tensor] = None) -> Tensor:
  188. if self.dtype == torch.quint4x2:
  189. return torch.ops.quantized.embedding_bag_4bit(self._packed_params._packed_weight, indices, offsets, False, 0,
  190. self.pruned_weights, per_sample_weights, compressed_indices_mapping,
  191. self.include_last_offset)
  192. else:
  193. return torch.ops.quantized.embedding_bag_byte(self._packed_params._packed_weight, indices, offsets, False, 0,
  194. self.pruned_weights, per_sample_weights, compressed_indices_mapping,
  195. self.include_last_offset)
  196. def _get_name(self):
  197. return 'QuantizedEmbeddingBag'
  198. @classmethod
  199. def from_float(cls, mod):
  200. r"""Create a quantized embedding_bag module from a float module
  201. Args:
  202. mod (Module): a float module, either produced by torch.ao.quantization
  203. utilities or provided by user
  204. """
  205. if hasattr(mod, 'weight_fake_quant'):
  206. weight_observer = mod.weight_fake_quant
  207. else:
  208. assert type(mod) == nn.EmbeddingBag, 'nnq.' + cls.__name__ + '.from_float only works for ' + \
  209. nn.EmbeddingBag.__name__
  210. assert hasattr(mod, 'qconfig'), 'EmbeddingBag input float module must have qconfig defined'
  211. from torch.ao.quantization.qconfig import float_qparams_weight_only_qconfig
  212. if mod.qconfig is not None and mod.qconfig.weight is not None: # type: ignore[union-attr]
  213. weight_observer = mod.qconfig.weight() # type: ignore[union-attr, operator]
  214. else:
  215. weight_observer = float_qparams_weight_only_qconfig.weight()
  216. dtype = weight_observer.dtype
  217. is_float_qparams_qconfig = weight_observer.qscheme == torch.per_channel_affine_float_qparams
  218. assert is_float_qparams_qconfig, \
  219. 'EmbeddingBag quantization is only supported with float_qparams_weight_only_qconfig.'
  220. assert dtype == torch.quint8 or dtype == torch.quint4x2, \
  221. f'The only supported dtype for nnq.EmbeddingBag is torch.quint8 and torch.quint4x2, got {dtype}'
  222. # Run the observer to calculate qparams.
  223. weight_observer(mod.weight)
  224. qweight = _quantize_weight(mod.weight.float(), weight_observer)
  225. # Create quantized EmbeddingBag module and pass in the quantized weight
  226. qembedding_bag = EmbeddingBag(mod.num_embeddings, mod.embedding_dim, dtype=dtype)
  227. qembedding_bag.set_weight(qweight)
  228. return qembedding_bag
  229. @classmethod
  230. def from_reference(cls, ref_embedding_bag):
  231. qembedding_bag = cls(
  232. ref_embedding_bag.num_embeddings,
  233. ref_embedding_bag.embedding_dim,
  234. ref_embedding_bag.max_norm,
  235. ref_embedding_bag.norm_type,
  236. ref_embedding_bag.scale_grad_by_freq,
  237. ref_embedding_bag.mode,
  238. ref_embedding_bag.sparse,
  239. ref_embedding_bag.get_quantized_weight(),
  240. ref_embedding_bag.include_last_offset,
  241. ref_embedding_bag.weight_dtype,
  242. )
  243. return qembedding_bag