123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465 |
- import torch
- import torch.jit # this is needed to avoid a circular import
- from torch import nn
- import torch.nn.functional as nnF
- from torch import Tensor
- from typing import Optional, Tuple
- import warnings
- __all__ = [
- "MultiheadAttention"
- ]
- class MultiheadAttention(nn.MultiheadAttention):
- _FLOAT_MODULE = nn.MultiheadAttention
- r"""Quantizable implementation of the MultiheadAttention.
- Note::
- Please, refer to :class:`~torch.nn.MultiheadAttention` for more
- information
- Allows the model to jointly attend to information from different
- representation subspaces.
- See reference: Attention Is All You Need
- The original MHA module is not quantizable.
- This reimplements it by explicitly instantiating the linear layers.
- .. math::
- \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O
- \text{where} head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)
- Args:
- embed_dim: total dimension of the model.
- num_heads: parallel attention heads.
- dropout: a Dropout layer on attn_output_weights. Default: 0.0.
- bias: add bias as module parameter. Default: True.
- add_bias_kv: add bias to the key and value sequences at dim=0.
- add_zero_attn: add a new batch of zeros to the key and
- value sequences at dim=1.
- kdim: total number of features in key. Default: None.
- vdim: total number of features in value. Default: None.
- batch_first: If ``True``, then the input and output tensors are provided
- as (batch, seq, feature). Default: ``False`` (seq, batch, feature).
- Note that if :attr:`kdim` and :attr:`vdim` are None, they will be set
- to :attr:`embed_dim` such that query, key, and value have the same
- number of features.
- Examples::
- >>> import torch.ao.nn.quantizable as nnqa
- >>> multihead_attn = nnqa.MultiheadAttention(embed_dim, num_heads)
- >>> attn_output, attn_output_weights = multihead_attn(query, key, value)
- Note::
- Please, follow the quantization flow to convert the quantizable MHA.
- """
- __constants__ = ['batch_first']
- def __init__(self, embed_dim: int, num_heads: int,
- dropout: float = 0., bias: bool = True,
- add_bias_kv: bool = False, add_zero_attn: bool = False,
- kdim: int = None, vdim: int = None, batch_first: bool = False,
- device=None, dtype=None) -> None:
- factory_kwargs = {'device': device, 'dtype': dtype}
- super().__init__(embed_dim, num_heads, dropout,
- bias, add_bias_kv,
- add_zero_attn, kdim, vdim, batch_first,
- **factory_kwargs)
- self.linear_Q = nn.Linear(self.embed_dim, self.embed_dim, bias=bias, **factory_kwargs)
- self.linear_K = nn.Linear(self.kdim, self.embed_dim, bias=bias, **factory_kwargs)
- self.linear_V = nn.Linear(self.vdim, self.embed_dim, bias=bias, **factory_kwargs)
- # for the type: ignore, see https://github.com/pytorch/pytorch/issues/58969
- self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=bias, **factory_kwargs) # type: ignore[assignment]
- # Functionals
- self.q_scaling_product = torch.ao.nn.quantized.FloatFunctional()
- # note: importing torch.ao.nn.quantized at top creates a circular import
- # Quant/Dequant
- self.quant_attn_output = torch.ao.quantization.QuantStub()
- self.quant_attn_output_weights = torch.ao.quantization.QuantStub()
- self.dequant_q = torch.ao.quantization.DeQuantStub()
- self.dequant_k = torch.ao.quantization.DeQuantStub()
- self.dequant_v = torch.ao.quantization.DeQuantStub()
- def _get_name(self):
- return 'QuantizableMultiheadAttention'
- @classmethod
- def from_float(cls, other):
- assert type(other) == cls._FLOAT_MODULE
- assert hasattr(other, 'qconfig'), "The float module must have 'qconfig'"
- # Setting the dropout to 0.0!
- observed = cls(other.embed_dim, other.num_heads, other.dropout,
- (other.in_proj_bias is not None),
- (other.bias_k is not None),
- other.add_zero_attn, other.kdim, other.vdim,
- other.batch_first)
- observed.bias_k = other.bias_k
- observed.bias_v = other.bias_v
- observed.qconfig = other.qconfig
- # Set the linear weights
- # for the type: ignores, see https://github.com/pytorch/pytorch/issues/58969
- observed.out_proj.weight = other.out_proj.weight # type: ignore[has-type]
- observed.out_proj.bias = other.out_proj.bias # type: ignore[has-type]
- if other._qkv_same_embed_dim:
- # Use separate params
- bias = other.in_proj_bias
- _start = 0
- _end = _start + other.embed_dim
- weight = other.in_proj_weight[_start:_end, :]
- if bias is not None:
- bias = torch.nn.Parameter(bias[_start:_end], bias.requires_grad)
- observed.linear_Q.weight = torch.nn.Parameter(weight,
- weight.requires_grad)
- observed.linear_Q.bias = bias
- bias = other.in_proj_bias
- _start = _end
- _end = _start + other.embed_dim
- weight = other.in_proj_weight[_start:_end, :]
- if bias is not None:
- bias = torch.nn.Parameter(bias[_start:_end], bias.requires_grad)
- observed.linear_K.weight = torch.nn.Parameter(weight,
- weight.requires_grad)
- observed.linear_K.bias = bias
- bias = other.in_proj_bias
- _start = _end
- weight = other.in_proj_weight[_start:, :]
- if bias is not None:
- bias = torch.nn.Parameter(bias[_start:], bias.requires_grad)
- observed.linear_V.weight = torch.nn.Parameter(weight,
- weight.requires_grad)
- observed.linear_V.bias = bias
- else:
- observed.linear_Q.weight = nn.Parameter(other.q_proj_weight)
- observed.linear_K.weight = nn.Parameter(other.k_proj_weight)
- observed.linear_V.weight = nn.Parameter(other.v_proj_weight)
- if other.in_proj_bias is None:
- observed.linear_Q.bias = None # type: ignore[assignment]
- observed.linear_K.bias = None # type: ignore[assignment]
- observed.linear_V.bias = None # type: ignore[assignment]
- else:
- observed.linear_Q.bias = nn.Parameter(other.in_proj_bias[0:other.embed_dim])
- observed.linear_K.bias = nn.Parameter(other.in_proj_bias[other.embed_dim:(other.embed_dim * 2)])
- observed.linear_V.bias = nn.Parameter(other.in_proj_bias[(other.embed_dim * 2):])
- observed.eval()
- # Explicit prepare
- observed = torch.ao.quantization.prepare(observed, inplace=True)
- return observed
- @torch.jit.unused
- def dequantize(self):
- r"""Utility to convert the quantized MHA back to float.
- The motivation for this is that it is not trivial to conver the weights
- from the format that is used in the quantized version back to the
- float.
- """
- fp = self._FLOAT_MODULE(self.embed_dim, self.num_heads, self.dropout,
- (self.in_proj_bias is not None),
- (self.bias_k is not None),
- self.add_zero_attn, self.kdim, self.vdim, self.batch_first)
- assert fp._qkv_same_embed_dim == self._qkv_same_embed_dim
- if self.bias_k is not None:
- fp.bias_k = nn.Parameter(self.bias_k.dequantize())
- if self.bias_v is not None:
- fp.bias_v = nn.Parameter(self.bias_v.dequantize())
- # Set the linear weights
- # Note: Because the linear layers are quantized, mypy does not nkow how
- # to deal with them -- might need to ignore the typing checks.
- # for the type: ignore[has-type], see https://github.com/pytorch/pytorch/issues/58969
- w, b = self.out_proj._weight_bias() # type: ignore[operator, has-type]
- fp.out_proj.weight = nn.Parameter(w.dequantize())
- if b is not None:
- fp.out_proj.bias = nn.Parameter(b)
- wQ, bQ = self.linear_Q._weight_bias() # type: ignore[operator]
- wQ = wQ.dequantize()
- wK, bK = self.linear_K._weight_bias() # type: ignore[operator]
- wK = wK.dequantize()
- wV, bV = self.linear_V._weight_bias() # type: ignore[operator]
- wV = wV.dequantize()
- if fp._qkv_same_embed_dim:
- # Use separate params
- _start = 0
- _end = _start + fp.embed_dim
- fp.in_proj_weight[_start:_end, :] = wQ
- if fp.in_proj_bias is not None:
- assert all(bQ == 0)
- fp.in_proj_bias[_start:_end] = bQ
- _start = _end
- _end = _start + fp.embed_dim
- fp.in_proj_weight[_start:_end, :] = wK
- if fp.in_proj_bias is not None:
- assert all(bK == 0)
- fp.in_proj_bias[_start:_end] = bK
- _start = _end
- fp.in_proj_weight[_start:, :] = wV
- if fp.in_proj_bias is not None:
- assert all(bV == 0)
- fp.in_proj_bias[_start:] = bV
- else:
- fp.q_proj_weight = nn.Parameter(wQ)
- fp.k_proj_weight = nn.Parameter(wK)
- fp.v_proj_weight = nn.Parameter(wV)
- if fp.in_proj_bias is None:
- self.linear_Q.bias = None
- self.linear_K.bias = None
- self.linear_V.bias = None
- else:
- fp.in_proj_bias[0:fp.embed_dim] = bQ
- fp.in_proj_bias[fp.embed_dim:(fp.embed_dim * 2)] = bK
- fp.in_proj_bias[(fp.embed_dim * 2):] = bV
- return fp
- @classmethod
- def from_observed(cls, other):
- # The whole flow is float -> observed -> quantized
- # This class does float -> observed only
- # See nn.quantized.MultiheadAttention
- raise NotImplementedError("It looks like you are trying to prepare an "
- "MHA module. Please, see "
- "the examples on quantizable MHAs.")
- def forward(self,
- query: Tensor,
- key: Tensor,
- value: Tensor,
- key_padding_mask: Optional[Tensor] = None,
- need_weights: bool = True,
- attn_mask: Optional[Tensor] = None,
- average_attn_weights: bool = True,
- is_causal: bool = False) -> Tuple[Tensor, Optional[Tensor]]:
- r"""
- Note::
- Please, refer to :func:`~torch.nn.MultiheadAttention.forward` for more
- information
- Args:
- query, key, value: map a query and a set of key-value pairs to an output.
- See "Attention Is All You Need" for more details.
- key_padding_mask: if provided, specified padding elements in the key will
- be ignored by the attention. When given a binary mask and a value is True,
- the corresponding value on the attention layer will be ignored.
- need_weights: output attn_output_weights.
- attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
- the batches while a 3D mask allows to specify a different mask for the entries of each batch.
- Shape:
- - Inputs:
- - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
- the embedding dimension. :math:`(N, L, E)` if ``batch_first`` is ``True``.
- - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
- the embedding dimension. :math:`(N, S, E)` if ``batch_first`` is ``True``.
- - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
- the embedding dimension. :math:`(N, S, E)` if ``batch_first`` is ``True``.
- - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
- If a BoolTensor is provided, the positions with the
- value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
- - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
- 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
- S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked
- positions. If a BoolTensor is provided, positions with ``True``
- is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
- is provided, it will be added to the attention weight.
- - is_causal: If specified, applies a causal mask as attention mask. Mutually exclusive with providing attn_mask.
- Default: ``False``.
- - average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across
- heads. Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an
- effect when ``need_weights=True.``. Default: True (i.e. average weights across heads)
- - Outputs:
- - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
- E is the embedding dimension. :math:`(N, L, E)` if ``batch_first`` is ``True``.
- - attn_output_weights: If ``average_attn_weights=True``, returns attention weights averaged
- across heads of shape :math:`(N, L, S)`, where N is the batch size, L is the target sequence length,
- S is the source sequence length. If ``average_attn_weights=False``, returns attention weights per
- head of shape :math:`(N, num_heads, L, S)`.
- """
- return self._forward_impl(query, key, value, key_padding_mask,
- need_weights, attn_mask, average_attn_weights,
- is_causal)
- def _forward_impl(self,
- query: Tensor,
- key: Tensor,
- value: Tensor,
- key_padding_mask: Optional[Tensor] = None,
- need_weights: bool = True,
- attn_mask: Optional[Tensor] = None,
- average_attn_weights: bool = True,
- is_causal: bool = False) -> Tuple[Tensor, Optional[Tensor]]:
- # This version will not deal with the static key/value pairs.
- # Keeping it here for future changes.
- #
- # TODO: This method has some duplicate lines with the
- # `torch.nn.functional.multi_head_attention`. Will need to refactor.
- static_k = None
- static_v = None
- if attn_mask is not None and is_causal:
- raise AssertionError("Only allow causal mask or attn_mask")
- if is_causal:
- raise AssertionError("causal mask not supported by AO MHA module")
- if self.batch_first:
- query, key, value = [x.transpose(0, 1) for x in (query, key, value)]
- tgt_len, bsz, embed_dim_to_check = query.size()
- assert self.embed_dim == embed_dim_to_check
- # allow MHA to have different sizes for the feature dimension
- assert key.size(0) == value.size(0) and key.size(1) == value.size(1)
- head_dim = self.embed_dim // self.num_heads
- assert head_dim * self.num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
- scaling = float(head_dim) ** -0.5
- q = self.linear_Q(query)
- k = self.linear_K(key)
- v = self.linear_V(value)
- q = self.q_scaling_product.mul_scalar(q, scaling)
- if attn_mask is not None:
- if attn_mask.dtype == torch.uint8:
- warnings.warn("Byte tensor for attn_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.")
- attn_mask = attn_mask.to(torch.bool)
- assert attn_mask.is_floating_point() or attn_mask.dtype == torch.bool, \
- 'Only float and bool types are supported for attn_mask, not {}'.format(attn_mask.dtype)
- if attn_mask.dim() == 2:
- attn_mask = attn_mask.unsqueeze(0)
- if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
- raise RuntimeError('The size of the 2D attn_mask is not correct.')
- elif attn_mask.dim() == 3:
- if list(attn_mask.size()) != [bsz * self.num_heads, query.size(0), key.size(0)]:
- raise RuntimeError('The size of the 3D attn_mask is not correct.')
- else:
- raise RuntimeError("attn_mask's dimension {} is not supported".format(attn_mask.dim()))
- # attn_mask's dim is 3 now.
- # convert ByteTensor key_padding_mask to bool
- if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:
- warnings.warn("Byte tensor for key_padding_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.")
- key_padding_mask = key_padding_mask.to(torch.bool)
- if self.bias_k is not None and self.bias_v is not None:
- if static_k is None and static_v is None:
- # Explicitly assert that bias_k and bias_v are not None
- # in a way that TorchScript can understand.
- bias_k = self.bias_k
- assert bias_k is not None
- bias_v = self.bias_v
- assert bias_v is not None
- k = torch.cat([k, bias_k.repeat(1, bsz, 1)])
- v = torch.cat([v, bias_v.repeat(1, bsz, 1)])
- if attn_mask is not None:
- attn_mask = nnF.pad(attn_mask, (0, 1))
- if key_padding_mask is not None:
- key_padding_mask = nnF.pad(key_padding_mask, (0, 1))
- else:
- assert static_k is None, "bias cannot be added to static key."
- assert static_v is None, "bias cannot be added to static value."
- else:
- assert self.bias_k is None
- assert self.bias_v is None
- q = q.contiguous().view(tgt_len, bsz * self.num_heads, head_dim).transpose(0, 1)
- if k is not None:
- k = k.contiguous().view(-1, bsz * self.num_heads, head_dim).transpose(0, 1)
- if v is not None:
- v = v.contiguous().view(-1, bsz * self.num_heads, head_dim).transpose(0, 1)
- if static_k is not None:
- assert static_k.size(0) == bsz * self.num_heads
- assert static_k.size(2) == head_dim
- k = static_k
- if static_v is not None:
- assert static_v.size(0) == bsz * self.num_heads
- assert static_v.size(2) == head_dim
- v = static_v
- src_len = k.size(1)
- if key_padding_mask is not None:
- assert key_padding_mask.size(0) == bsz
- assert key_padding_mask.size(1) == src_len
- if self.add_zero_attn:
- src_len += 1
- k_zeros = torch.zeros((k.size(0), 1) + k.size()[2:])
- if k.is_quantized:
- k_zeros = torch.quantize_per_tensor(k_zeros, k.q_scale(), k.q_zero_point(), k.dtype)
- k = torch.cat([k, k_zeros], dim=1)
- v_zeros = torch.zeros((v.size(0), 1) + k.size()[2:])
- if v.is_quantized:
- v_zeros = torch.quantize_per_tensor(v_zeros, v.q_scale(), v.q_zero_point(), v.dtype)
- v = torch.cat([v, v_zeros], dim=1)
- if attn_mask is not None:
- attn_mask = nnF.pad(attn_mask, (0, 1))
- if key_padding_mask is not None:
- key_padding_mask = nnF.pad(key_padding_mask, (0, 1))
- # Leaving the quantized zone here
- q = self.dequant_q(q)
- k = self.dequant_k(k)
- v = self.dequant_v(v)
- attn_output_weights = torch.bmm(q, k.transpose(1, 2))
- assert list(attn_output_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
- if attn_mask is not None:
- if attn_mask.dtype == torch.bool:
- attn_output_weights.masked_fill_(attn_mask, float('-inf'))
- else:
- attn_output_weights += attn_mask
- if key_padding_mask is not None:
- attn_output_weights = attn_output_weights.view(bsz, self.num_heads, tgt_len, src_len)
- attn_output_weights = attn_output_weights.masked_fill(
- key_padding_mask.unsqueeze(1).unsqueeze(2),
- float('-inf'),
- )
- attn_output_weights = attn_output_weights.view(bsz * self.num_heads, tgt_len, src_len)
- attn_output_weights = nnF.softmax(
- attn_output_weights, dim=-1)
- attn_output_weights = nnF.dropout(attn_output_weights, p=self.dropout, training=self.training)
- attn_output = torch.bmm(attn_output_weights, v)
- assert list(attn_output.size()) == [bsz * self.num_heads, tgt_len, head_dim]
- if self.batch_first:
- attn_output = attn_output.view(bsz, tgt_len, self.embed_dim)
- else:
- attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, self.embed_dim)
- # Reentering the quantized zone
- attn_output = self.quant_attn_output(attn_output)
- # for the type: ignore[has-type], see https://github.com/pytorch/pytorch/issues/58969
- attn_output = self.out_proj(attn_output) # type: ignore[has-type]
- attn_output_weights = self.quant_attn_output_weights(attn_output_weights)
- if need_weights:
- # average attention weights over heads
- attn_output_weights = attn_output_weights.view(bsz, self.num_heads, tgt_len, src_len)
- if average_attn_weights:
- attn_output_weights = attn_output_weights.mean(dim=1)
- return attn_output, attn_output_weights
- else:
- return attn_output, None
|