activation.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465
  1. import torch
  2. import torch.jit # this is needed to avoid a circular import
  3. from torch import nn
  4. import torch.nn.functional as nnF
  5. from torch import Tensor
  6. from typing import Optional, Tuple
  7. import warnings
  8. __all__ = [
  9. "MultiheadAttention"
  10. ]
  11. class MultiheadAttention(nn.MultiheadAttention):
  12. _FLOAT_MODULE = nn.MultiheadAttention
  13. r"""Quantizable implementation of the MultiheadAttention.
  14. Note::
  15. Please, refer to :class:`~torch.nn.MultiheadAttention` for more
  16. information
  17. Allows the model to jointly attend to information from different
  18. representation subspaces.
  19. See reference: Attention Is All You Need
  20. The original MHA module is not quantizable.
  21. This reimplements it by explicitly instantiating the linear layers.
  22. .. math::
  23. \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O
  24. \text{where} head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)
  25. Args:
  26. embed_dim: total dimension of the model.
  27. num_heads: parallel attention heads.
  28. dropout: a Dropout layer on attn_output_weights. Default: 0.0.
  29. bias: add bias as module parameter. Default: True.
  30. add_bias_kv: add bias to the key and value sequences at dim=0.
  31. add_zero_attn: add a new batch of zeros to the key and
  32. value sequences at dim=1.
  33. kdim: total number of features in key. Default: None.
  34. vdim: total number of features in value. Default: None.
  35. batch_first: If ``True``, then the input and output tensors are provided
  36. as (batch, seq, feature). Default: ``False`` (seq, batch, feature).
  37. Note that if :attr:`kdim` and :attr:`vdim` are None, they will be set
  38. to :attr:`embed_dim` such that query, key, and value have the same
  39. number of features.
  40. Examples::
  41. >>> import torch.ao.nn.quantizable as nnqa
  42. >>> multihead_attn = nnqa.MultiheadAttention(embed_dim, num_heads)
  43. >>> attn_output, attn_output_weights = multihead_attn(query, key, value)
  44. Note::
  45. Please, follow the quantization flow to convert the quantizable MHA.
  46. """
  47. __constants__ = ['batch_first']
  48. def __init__(self, embed_dim: int, num_heads: int,
  49. dropout: float = 0., bias: bool = True,
  50. add_bias_kv: bool = False, add_zero_attn: bool = False,
  51. kdim: int = None, vdim: int = None, batch_first: bool = False,
  52. device=None, dtype=None) -> None:
  53. factory_kwargs = {'device': device, 'dtype': dtype}
  54. super().__init__(embed_dim, num_heads, dropout,
  55. bias, add_bias_kv,
  56. add_zero_attn, kdim, vdim, batch_first,
  57. **factory_kwargs)
  58. self.linear_Q = nn.Linear(self.embed_dim, self.embed_dim, bias=bias, **factory_kwargs)
  59. self.linear_K = nn.Linear(self.kdim, self.embed_dim, bias=bias, **factory_kwargs)
  60. self.linear_V = nn.Linear(self.vdim, self.embed_dim, bias=bias, **factory_kwargs)
  61. # for the type: ignore, see https://github.com/pytorch/pytorch/issues/58969
  62. self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=bias, **factory_kwargs) # type: ignore[assignment]
  63. # Functionals
  64. self.q_scaling_product = torch.ao.nn.quantized.FloatFunctional()
  65. # note: importing torch.ao.nn.quantized at top creates a circular import
  66. # Quant/Dequant
  67. self.quant_attn_output = torch.ao.quantization.QuantStub()
  68. self.quant_attn_output_weights = torch.ao.quantization.QuantStub()
  69. self.dequant_q = torch.ao.quantization.DeQuantStub()
  70. self.dequant_k = torch.ao.quantization.DeQuantStub()
  71. self.dequant_v = torch.ao.quantization.DeQuantStub()
  72. def _get_name(self):
  73. return 'QuantizableMultiheadAttention'
  74. @classmethod
  75. def from_float(cls, other):
  76. assert type(other) == cls._FLOAT_MODULE
  77. assert hasattr(other, 'qconfig'), "The float module must have 'qconfig'"
  78. # Setting the dropout to 0.0!
  79. observed = cls(other.embed_dim, other.num_heads, other.dropout,
  80. (other.in_proj_bias is not None),
  81. (other.bias_k is not None),
  82. other.add_zero_attn, other.kdim, other.vdim,
  83. other.batch_first)
  84. observed.bias_k = other.bias_k
  85. observed.bias_v = other.bias_v
  86. observed.qconfig = other.qconfig
  87. # Set the linear weights
  88. # for the type: ignores, see https://github.com/pytorch/pytorch/issues/58969
  89. observed.out_proj.weight = other.out_proj.weight # type: ignore[has-type]
  90. observed.out_proj.bias = other.out_proj.bias # type: ignore[has-type]
  91. if other._qkv_same_embed_dim:
  92. # Use separate params
  93. bias = other.in_proj_bias
  94. _start = 0
  95. _end = _start + other.embed_dim
  96. weight = other.in_proj_weight[_start:_end, :]
  97. if bias is not None:
  98. bias = torch.nn.Parameter(bias[_start:_end], bias.requires_grad)
  99. observed.linear_Q.weight = torch.nn.Parameter(weight,
  100. weight.requires_grad)
  101. observed.linear_Q.bias = bias
  102. bias = other.in_proj_bias
  103. _start = _end
  104. _end = _start + other.embed_dim
  105. weight = other.in_proj_weight[_start:_end, :]
  106. if bias is not None:
  107. bias = torch.nn.Parameter(bias[_start:_end], bias.requires_grad)
  108. observed.linear_K.weight = torch.nn.Parameter(weight,
  109. weight.requires_grad)
  110. observed.linear_K.bias = bias
  111. bias = other.in_proj_bias
  112. _start = _end
  113. weight = other.in_proj_weight[_start:, :]
  114. if bias is not None:
  115. bias = torch.nn.Parameter(bias[_start:], bias.requires_grad)
  116. observed.linear_V.weight = torch.nn.Parameter(weight,
  117. weight.requires_grad)
  118. observed.linear_V.bias = bias
  119. else:
  120. observed.linear_Q.weight = nn.Parameter(other.q_proj_weight)
  121. observed.linear_K.weight = nn.Parameter(other.k_proj_weight)
  122. observed.linear_V.weight = nn.Parameter(other.v_proj_weight)
  123. if other.in_proj_bias is None:
  124. observed.linear_Q.bias = None # type: ignore[assignment]
  125. observed.linear_K.bias = None # type: ignore[assignment]
  126. observed.linear_V.bias = None # type: ignore[assignment]
  127. else:
  128. observed.linear_Q.bias = nn.Parameter(other.in_proj_bias[0:other.embed_dim])
  129. observed.linear_K.bias = nn.Parameter(other.in_proj_bias[other.embed_dim:(other.embed_dim * 2)])
  130. observed.linear_V.bias = nn.Parameter(other.in_proj_bias[(other.embed_dim * 2):])
  131. observed.eval()
  132. # Explicit prepare
  133. observed = torch.ao.quantization.prepare(observed, inplace=True)
  134. return observed
  135. @torch.jit.unused
  136. def dequantize(self):
  137. r"""Utility to convert the quantized MHA back to float.
  138. The motivation for this is that it is not trivial to conver the weights
  139. from the format that is used in the quantized version back to the
  140. float.
  141. """
  142. fp = self._FLOAT_MODULE(self.embed_dim, self.num_heads, self.dropout,
  143. (self.in_proj_bias is not None),
  144. (self.bias_k is not None),
  145. self.add_zero_attn, self.kdim, self.vdim, self.batch_first)
  146. assert fp._qkv_same_embed_dim == self._qkv_same_embed_dim
  147. if self.bias_k is not None:
  148. fp.bias_k = nn.Parameter(self.bias_k.dequantize())
  149. if self.bias_v is not None:
  150. fp.bias_v = nn.Parameter(self.bias_v.dequantize())
  151. # Set the linear weights
  152. # Note: Because the linear layers are quantized, mypy does not nkow how
  153. # to deal with them -- might need to ignore the typing checks.
  154. # for the type: ignore[has-type], see https://github.com/pytorch/pytorch/issues/58969
  155. w, b = self.out_proj._weight_bias() # type: ignore[operator, has-type]
  156. fp.out_proj.weight = nn.Parameter(w.dequantize())
  157. if b is not None:
  158. fp.out_proj.bias = nn.Parameter(b)
  159. wQ, bQ = self.linear_Q._weight_bias() # type: ignore[operator]
  160. wQ = wQ.dequantize()
  161. wK, bK = self.linear_K._weight_bias() # type: ignore[operator]
  162. wK = wK.dequantize()
  163. wV, bV = self.linear_V._weight_bias() # type: ignore[operator]
  164. wV = wV.dequantize()
  165. if fp._qkv_same_embed_dim:
  166. # Use separate params
  167. _start = 0
  168. _end = _start + fp.embed_dim
  169. fp.in_proj_weight[_start:_end, :] = wQ
  170. if fp.in_proj_bias is not None:
  171. assert all(bQ == 0)
  172. fp.in_proj_bias[_start:_end] = bQ
  173. _start = _end
  174. _end = _start + fp.embed_dim
  175. fp.in_proj_weight[_start:_end, :] = wK
  176. if fp.in_proj_bias is not None:
  177. assert all(bK == 0)
  178. fp.in_proj_bias[_start:_end] = bK
  179. _start = _end
  180. fp.in_proj_weight[_start:, :] = wV
  181. if fp.in_proj_bias is not None:
  182. assert all(bV == 0)
  183. fp.in_proj_bias[_start:] = bV
  184. else:
  185. fp.q_proj_weight = nn.Parameter(wQ)
  186. fp.k_proj_weight = nn.Parameter(wK)
  187. fp.v_proj_weight = nn.Parameter(wV)
  188. if fp.in_proj_bias is None:
  189. self.linear_Q.bias = None
  190. self.linear_K.bias = None
  191. self.linear_V.bias = None
  192. else:
  193. fp.in_proj_bias[0:fp.embed_dim] = bQ
  194. fp.in_proj_bias[fp.embed_dim:(fp.embed_dim * 2)] = bK
  195. fp.in_proj_bias[(fp.embed_dim * 2):] = bV
  196. return fp
  197. @classmethod
  198. def from_observed(cls, other):
  199. # The whole flow is float -> observed -> quantized
  200. # This class does float -> observed only
  201. # See nn.quantized.MultiheadAttention
  202. raise NotImplementedError("It looks like you are trying to prepare an "
  203. "MHA module. Please, see "
  204. "the examples on quantizable MHAs.")
  205. def forward(self,
  206. query: Tensor,
  207. key: Tensor,
  208. value: Tensor,
  209. key_padding_mask: Optional[Tensor] = None,
  210. need_weights: bool = True,
  211. attn_mask: Optional[Tensor] = None,
  212. average_attn_weights: bool = True,
  213. is_causal: bool = False) -> Tuple[Tensor, Optional[Tensor]]:
  214. r"""
  215. Note::
  216. Please, refer to :func:`~torch.nn.MultiheadAttention.forward` for more
  217. information
  218. Args:
  219. query, key, value: map a query and a set of key-value pairs to an output.
  220. See "Attention Is All You Need" for more details.
  221. key_padding_mask: if provided, specified padding elements in the key will
  222. be ignored by the attention. When given a binary mask and a value is True,
  223. the corresponding value on the attention layer will be ignored.
  224. need_weights: output attn_output_weights.
  225. attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
  226. the batches while a 3D mask allows to specify a different mask for the entries of each batch.
  227. Shape:
  228. - Inputs:
  229. - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
  230. the embedding dimension. :math:`(N, L, E)` if ``batch_first`` is ``True``.
  231. - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
  232. the embedding dimension. :math:`(N, S, E)` if ``batch_first`` is ``True``.
  233. - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
  234. the embedding dimension. :math:`(N, S, E)` if ``batch_first`` is ``True``.
  235. - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
  236. If a BoolTensor is provided, the positions with the
  237. value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
  238. - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
  239. 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
  240. S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked
  241. positions. If a BoolTensor is provided, positions with ``True``
  242. is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
  243. is provided, it will be added to the attention weight.
  244. - is_causal: If specified, applies a causal mask as attention mask. Mutually exclusive with providing attn_mask.
  245. Default: ``False``.
  246. - average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across
  247. heads. Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an
  248. effect when ``need_weights=True.``. Default: True (i.e. average weights across heads)
  249. - Outputs:
  250. - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
  251. E is the embedding dimension. :math:`(N, L, E)` if ``batch_first`` is ``True``.
  252. - attn_output_weights: If ``average_attn_weights=True``, returns attention weights averaged
  253. across heads of shape :math:`(N, L, S)`, where N is the batch size, L is the target sequence length,
  254. S is the source sequence length. If ``average_attn_weights=False``, returns attention weights per
  255. head of shape :math:`(N, num_heads, L, S)`.
  256. """
  257. return self._forward_impl(query, key, value, key_padding_mask,
  258. need_weights, attn_mask, average_attn_weights,
  259. is_causal)
  260. def _forward_impl(self,
  261. query: Tensor,
  262. key: Tensor,
  263. value: Tensor,
  264. key_padding_mask: Optional[Tensor] = None,
  265. need_weights: bool = True,
  266. attn_mask: Optional[Tensor] = None,
  267. average_attn_weights: bool = True,
  268. is_causal: bool = False) -> Tuple[Tensor, Optional[Tensor]]:
  269. # This version will not deal with the static key/value pairs.
  270. # Keeping it here for future changes.
  271. #
  272. # TODO: This method has some duplicate lines with the
  273. # `torch.nn.functional.multi_head_attention`. Will need to refactor.
  274. static_k = None
  275. static_v = None
  276. if attn_mask is not None and is_causal:
  277. raise AssertionError("Only allow causal mask or attn_mask")
  278. if is_causal:
  279. raise AssertionError("causal mask not supported by AO MHA module")
  280. if self.batch_first:
  281. query, key, value = [x.transpose(0, 1) for x in (query, key, value)]
  282. tgt_len, bsz, embed_dim_to_check = query.size()
  283. assert self.embed_dim == embed_dim_to_check
  284. # allow MHA to have different sizes for the feature dimension
  285. assert key.size(0) == value.size(0) and key.size(1) == value.size(1)
  286. head_dim = self.embed_dim // self.num_heads
  287. assert head_dim * self.num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
  288. scaling = float(head_dim) ** -0.5
  289. q = self.linear_Q(query)
  290. k = self.linear_K(key)
  291. v = self.linear_V(value)
  292. q = self.q_scaling_product.mul_scalar(q, scaling)
  293. if attn_mask is not None:
  294. if attn_mask.dtype == torch.uint8:
  295. warnings.warn("Byte tensor for attn_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.")
  296. attn_mask = attn_mask.to(torch.bool)
  297. assert attn_mask.is_floating_point() or attn_mask.dtype == torch.bool, \
  298. 'Only float and bool types are supported for attn_mask, not {}'.format(attn_mask.dtype)
  299. if attn_mask.dim() == 2:
  300. attn_mask = attn_mask.unsqueeze(0)
  301. if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
  302. raise RuntimeError('The size of the 2D attn_mask is not correct.')
  303. elif attn_mask.dim() == 3:
  304. if list(attn_mask.size()) != [bsz * self.num_heads, query.size(0), key.size(0)]:
  305. raise RuntimeError('The size of the 3D attn_mask is not correct.')
  306. else:
  307. raise RuntimeError("attn_mask's dimension {} is not supported".format(attn_mask.dim()))
  308. # attn_mask's dim is 3 now.
  309. # convert ByteTensor key_padding_mask to bool
  310. if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:
  311. warnings.warn("Byte tensor for key_padding_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.")
  312. key_padding_mask = key_padding_mask.to(torch.bool)
  313. if self.bias_k is not None and self.bias_v is not None:
  314. if static_k is None and static_v is None:
  315. # Explicitly assert that bias_k and bias_v are not None
  316. # in a way that TorchScript can understand.
  317. bias_k = self.bias_k
  318. assert bias_k is not None
  319. bias_v = self.bias_v
  320. assert bias_v is not None
  321. k = torch.cat([k, bias_k.repeat(1, bsz, 1)])
  322. v = torch.cat([v, bias_v.repeat(1, bsz, 1)])
  323. if attn_mask is not None:
  324. attn_mask = nnF.pad(attn_mask, (0, 1))
  325. if key_padding_mask is not None:
  326. key_padding_mask = nnF.pad(key_padding_mask, (0, 1))
  327. else:
  328. assert static_k is None, "bias cannot be added to static key."
  329. assert static_v is None, "bias cannot be added to static value."
  330. else:
  331. assert self.bias_k is None
  332. assert self.bias_v is None
  333. q = q.contiguous().view(tgt_len, bsz * self.num_heads, head_dim).transpose(0, 1)
  334. if k is not None:
  335. k = k.contiguous().view(-1, bsz * self.num_heads, head_dim).transpose(0, 1)
  336. if v is not None:
  337. v = v.contiguous().view(-1, bsz * self.num_heads, head_dim).transpose(0, 1)
  338. if static_k is not None:
  339. assert static_k.size(0) == bsz * self.num_heads
  340. assert static_k.size(2) == head_dim
  341. k = static_k
  342. if static_v is not None:
  343. assert static_v.size(0) == bsz * self.num_heads
  344. assert static_v.size(2) == head_dim
  345. v = static_v
  346. src_len = k.size(1)
  347. if key_padding_mask is not None:
  348. assert key_padding_mask.size(0) == bsz
  349. assert key_padding_mask.size(1) == src_len
  350. if self.add_zero_attn:
  351. src_len += 1
  352. k_zeros = torch.zeros((k.size(0), 1) + k.size()[2:])
  353. if k.is_quantized:
  354. k_zeros = torch.quantize_per_tensor(k_zeros, k.q_scale(), k.q_zero_point(), k.dtype)
  355. k = torch.cat([k, k_zeros], dim=1)
  356. v_zeros = torch.zeros((v.size(0), 1) + k.size()[2:])
  357. if v.is_quantized:
  358. v_zeros = torch.quantize_per_tensor(v_zeros, v.q_scale(), v.q_zero_point(), v.dtype)
  359. v = torch.cat([v, v_zeros], dim=1)
  360. if attn_mask is not None:
  361. attn_mask = nnF.pad(attn_mask, (0, 1))
  362. if key_padding_mask is not None:
  363. key_padding_mask = nnF.pad(key_padding_mask, (0, 1))
  364. # Leaving the quantized zone here
  365. q = self.dequant_q(q)
  366. k = self.dequant_k(k)
  367. v = self.dequant_v(v)
  368. attn_output_weights = torch.bmm(q, k.transpose(1, 2))
  369. assert list(attn_output_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
  370. if attn_mask is not None:
  371. if attn_mask.dtype == torch.bool:
  372. attn_output_weights.masked_fill_(attn_mask, float('-inf'))
  373. else:
  374. attn_output_weights += attn_mask
  375. if key_padding_mask is not None:
  376. attn_output_weights = attn_output_weights.view(bsz, self.num_heads, tgt_len, src_len)
  377. attn_output_weights = attn_output_weights.masked_fill(
  378. key_padding_mask.unsqueeze(1).unsqueeze(2),
  379. float('-inf'),
  380. )
  381. attn_output_weights = attn_output_weights.view(bsz * self.num_heads, tgt_len, src_len)
  382. attn_output_weights = nnF.softmax(
  383. attn_output_weights, dim=-1)
  384. attn_output_weights = nnF.dropout(attn_output_weights, p=self.dropout, training=self.training)
  385. attn_output = torch.bmm(attn_output_weights, v)
  386. assert list(attn_output.size()) == [bsz * self.num_heads, tgt_len, head_dim]
  387. if self.batch_first:
  388. attn_output = attn_output.view(bsz, tgt_len, self.embed_dim)
  389. else:
  390. attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, self.embed_dim)
  391. # Reentering the quantized zone
  392. attn_output = self.quant_attn_output(attn_output)
  393. # for the type: ignore[has-type], see https://github.com/pytorch/pytorch/issues/58969
  394. attn_output = self.out_proj(attn_output) # type: ignore[has-type]
  395. attn_output_weights = self.quant_attn_output_weights(attn_output_weights)
  396. if need_weights:
  397. # average attention weights over heads
  398. attn_output_weights = attn_output_weights.view(bsz, self.num_heads, tgt_len, src_len)
  399. if average_attn_weights:
  400. attn_output_weights = attn_output_weights.mean(dim=1)
  401. return attn_output, attn_output_weights
  402. else:
  403. return attn_output, None