123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741 |
- import copy
- from typing import Optional, Any, Union, Callable
- import torch
- from torch import Tensor
- from .. import functional as F
- from .module import Module
- from .activation import MultiheadAttention
- from .container import ModuleList
- from ..init import xavier_uniform_
- from .dropout import Dropout
- from .linear import Linear
- from .normalization import LayerNorm
- __all__ = ['Transformer', 'TransformerEncoder', 'TransformerDecoder', 'TransformerEncoderLayer', 'TransformerDecoderLayer']
- class Transformer(Module):
- r"""A transformer model. User is able to modify the attributes as needed. The architecture
- is based on the paper "Attention Is All You Need". Ashish Vaswani, Noam Shazeer,
- Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Lukasz Kaiser, and
- Illia Polosukhin. 2017. Attention is all you need. In Advances in Neural Information
- Processing Systems, pages 6000-6010.
- Args:
- d_model: the number of expected features in the encoder/decoder inputs (default=512).
- nhead: the number of heads in the multiheadattention models (default=8).
- num_encoder_layers: the number of sub-encoder-layers in the encoder (default=6).
- num_decoder_layers: the number of sub-decoder-layers in the decoder (default=6).
- dim_feedforward: the dimension of the feedforward network model (default=2048).
- dropout: the dropout value (default=0.1).
- activation: the activation function of encoder/decoder intermediate layer, can be a string
- ("relu" or "gelu") or a unary callable. Default: relu
- custom_encoder: custom encoder (default=None).
- custom_decoder: custom decoder (default=None).
- layer_norm_eps: the eps value in layer normalization components (default=1e-5).
- batch_first: If ``True``, then the input and output tensors are provided
- as (batch, seq, feature). Default: ``False`` (seq, batch, feature).
- norm_first: if ``True``, encoder and decoder layers will perform LayerNorms before
- other attention and feedforward operations, otherwise after. Default: ``False`` (after).
- Examples::
- >>> transformer_model = nn.Transformer(nhead=16, num_encoder_layers=12)
- >>> src = torch.rand((10, 32, 512))
- >>> tgt = torch.rand((20, 32, 512))
- >>> out = transformer_model(src, tgt)
- Note: A full example to apply nn.Transformer module for the word language model is available in
- https://github.com/pytorch/examples/tree/master/word_language_model
- """
- def __init__(self, d_model: int = 512, nhead: int = 8, num_encoder_layers: int = 6,
- num_decoder_layers: int = 6, dim_feedforward: int = 2048, dropout: float = 0.1,
- activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
- custom_encoder: Optional[Any] = None, custom_decoder: Optional[Any] = None,
- layer_norm_eps: float = 1e-5, batch_first: bool = False, norm_first: bool = False,
- device=None, dtype=None) -> None:
- factory_kwargs = {'device': device, 'dtype': dtype}
- super().__init__()
- torch._C._log_api_usage_once(f"torch.nn.modules.{self.__class__.__name__}")
- if custom_encoder is not None:
- self.encoder = custom_encoder
- else:
- encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout,
- activation, layer_norm_eps, batch_first, norm_first,
- **factory_kwargs)
- encoder_norm = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
- self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
- if custom_decoder is not None:
- self.decoder = custom_decoder
- else:
- decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward, dropout,
- activation, layer_norm_eps, batch_first, norm_first,
- **factory_kwargs)
- decoder_norm = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
- self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm)
- self._reset_parameters()
- self.d_model = d_model
- self.nhead = nhead
- self.batch_first = batch_first
- def forward(self, src: Tensor, tgt: Tensor, src_mask: Optional[Tensor] = None, tgt_mask: Optional[Tensor] = None,
- memory_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None,
- tgt_key_padding_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None) -> Tensor:
- r"""Take in and process masked source/target sequences.
- Args:
- src: the sequence to the encoder (required).
- tgt: the sequence to the decoder (required).
- src_mask: the additive mask for the src sequence (optional).
- tgt_mask: the additive mask for the tgt sequence (optional).
- memory_mask: the additive mask for the encoder output (optional).
- src_key_padding_mask: the Tensor mask for src keys per batch (optional).
- tgt_key_padding_mask: the Tensor mask for tgt keys per batch (optional).
- memory_key_padding_mask: the Tensor mask for memory keys per batch (optional).
- Shape:
- - src: :math:`(S, E)` for unbatched input, :math:`(S, N, E)` if `batch_first=False` or
- `(N, S, E)` if `batch_first=True`.
- - tgt: :math:`(T, E)` for unbatched input, :math:`(T, N, E)` if `batch_first=False` or
- `(N, T, E)` if `batch_first=True`.
- - src_mask: :math:`(S, S)` or :math:`(N\cdot\text{num\_heads}, S, S)`.
- - tgt_mask: :math:`(T, T)` or :math:`(N\cdot\text{num\_heads}, T, T)`.
- - memory_mask: :math:`(T, S)`.
- - src_key_padding_mask: :math:`(S)` for unbatched input otherwise :math:`(N, S)`.
- - tgt_key_padding_mask: :math:`(T)` for unbatched input otherwise :math:`(N, T)`.
- - memory_key_padding_mask: :math:`(S)` for unbatched input otherwise :math:`(N, S)`.
- Note: [src/tgt/memory]_mask ensures that position i is allowed to attend the unmasked
- positions. If a BoolTensor is provided, positions with ``True``
- are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
- is provided, it will be added to the attention weight.
- [src/tgt/memory]_key_padding_mask provides specified elements in the key to be ignored by
- the attention. If a BoolTensor is provided, the positions with the
- value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
- - output: :math:`(T, E)` for unbatched input, :math:`(T, N, E)` if `batch_first=False` or
- `(N, T, E)` if `batch_first=True`.
- Note: Due to the multi-head attention architecture in the transformer model,
- the output sequence length of a transformer is same as the input sequence
- (i.e. target) length of the decoder.
- where S is the source sequence length, T is the target sequence length, N is the
- batch size, E is the feature number
- Examples:
- >>> # xdoctest: +SKIP
- >>> output = transformer_model(src, tgt, src_mask=src_mask, tgt_mask=tgt_mask)
- """
- is_batched = src.dim() == 3
- if not self.batch_first and src.size(1) != tgt.size(1) and is_batched:
- raise RuntimeError("the batch number of src and tgt must be equal")
- elif self.batch_first and src.size(0) != tgt.size(0) and is_batched:
- raise RuntimeError("the batch number of src and tgt must be equal")
- if src.size(-1) != self.d_model or tgt.size(-1) != self.d_model:
- raise RuntimeError("the feature number of src and tgt must be equal to d_model")
- memory = self.encoder(src, mask=src_mask, src_key_padding_mask=src_key_padding_mask)
- output = self.decoder(tgt, memory, tgt_mask=tgt_mask, memory_mask=memory_mask,
- tgt_key_padding_mask=tgt_key_padding_mask,
- memory_key_padding_mask=memory_key_padding_mask)
- return output
- @staticmethod
- def generate_square_subsequent_mask(sz: int, device='cpu') -> Tensor:
- r"""Generate a square mask for the sequence. The masked positions are filled with float('-inf').
- Unmasked positions are filled with float(0.0).
- """
- return torch.triu(torch.full((sz, sz), float('-inf'), device=device), diagonal=1)
- def _reset_parameters(self):
- r"""Initiate parameters in the transformer model."""
- for p in self.parameters():
- if p.dim() > 1:
- xavier_uniform_(p)
- class TransformerEncoder(Module):
- r"""TransformerEncoder is a stack of N encoder layers. Users can build the
- BERT(https://arxiv.org/abs/1810.04805) model with corresponding parameters.
- Args:
- encoder_layer: an instance of the TransformerEncoderLayer() class (required).
- num_layers: the number of sub-encoder-layers in the encoder (required).
- norm: the layer normalization component (optional).
- enable_nested_tensor: if True, input will automatically convert to nested tensor
- (and convert back on output). This will improve the overall performance of
- TransformerEncoder when padding rate is high. Default: ``True`` (enabled).
- Examples::
- >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
- >>> transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=6)
- >>> src = torch.rand(10, 32, 512)
- >>> out = transformer_encoder(src)
- """
- __constants__ = ['norm']
- def __init__(self, encoder_layer, num_layers, norm=None, enable_nested_tensor=True, mask_check=True):
- super().__init__()
- torch._C._log_api_usage_once(f"torch.nn.modules.{self.__class__.__name__}")
- self.layers = _get_clones(encoder_layer, num_layers)
- self.num_layers = num_layers
- self.norm = norm
- self.enable_nested_tensor = enable_nested_tensor
- self.mask_check = mask_check
- def forward(
- self,
- src: Tensor,
- mask: Optional[Tensor] = None,
- src_key_padding_mask: Optional[Tensor] = None,
- is_causal: Optional[bool] = None) -> Tensor:
- r"""Pass the input through the encoder layers in turn.
- Args:
- src: the sequence to the encoder (required).
- mask: the mask for the src sequence (optional).
- is_causal: If specified, applies a causal mask as mask (optional)
- and ignores attn_mask for computing scaled dot product attention.
- Default: ``False``.
- src_key_padding_mask: the mask for the src keys per batch (optional).
- Shape:
- see the docs in Transformer class.
- """
- src_key_padding_mask = F._canonical_mask(
- mask=src_key_padding_mask,
- mask_name="src_key_padding_mask",
- other_type=F._none_or_dtype(mask),
- other_name="mask",
- target_type=src.dtype
- )
- output = src
- convert_to_nested = False
- first_layer = self.layers[0]
- src_key_padding_mask_for_layers = src_key_padding_mask
- why_not_sparsity_fast_path = ''
- str_first_layer = "self.layers[0]"
- if not isinstance(first_layer, torch.nn.TransformerEncoderLayer):
- why_not_sparsity_fast_path = f"{str_first_layer} was not TransformerEncoderLayer"
- elif first_layer.norm_first :
- why_not_sparsity_fast_path = f"{str_first_layer}.norm_first was True"
- elif first_layer.training:
- why_not_sparsity_fast_path = f"{str_first_layer} was in training mode"
- elif not first_layer.self_attn.batch_first:
- why_not_sparsity_fast_path = f" {str_first_layer}.self_attn.batch_first was not True"
- elif not first_layer.self_attn._qkv_same_embed_dim:
- why_not_sparsity_fast_path = f"{str_first_layer}.self_attn._qkv_same_embed_dim was not True"
- elif not first_layer.activation_relu_or_gelu:
- why_not_sparsity_fast_path = f" {str_first_layer}.activation_relu_or_gelu was not True"
- elif not (first_layer.norm1.eps == first_layer.norm2.eps) :
- why_not_sparsity_fast_path = f"{str_first_layer}.norm1.eps was not equal to {str_first_layer}.norm2.eps"
- elif not src.dim() == 3:
- why_not_sparsity_fast_path = f"input not batched; expected src.dim() of 3 but got {src.dim()}"
- elif not self.enable_nested_tensor:
- why_not_sparsity_fast_path = "enable_nested_tensor was not True"
- elif src_key_padding_mask is None:
- why_not_sparsity_fast_path = "src_key_padding_mask was None"
- elif (((not hasattr(self, "mask_check")) or self.mask_check)
- and not torch._nested_tensor_from_mask_left_aligned(src, src_key_padding_mask.logical_not())):
- why_not_sparsity_fast_path = "mask_check enabled, and src and src_key_padding_mask was not left aligned"
- elif output.is_nested:
- why_not_sparsity_fast_path = "NestedTensor input is not supported"
- elif mask is not None:
- why_not_sparsity_fast_path = "src_key_padding_mask and mask were both supplied"
- elif first_layer.self_attn.num_heads % 2 == 1:
- why_not_sparsity_fast_path = "num_head is odd"
- elif torch.is_autocast_enabled():
- why_not_sparsity_fast_path = "autocast is enabled"
- if not why_not_sparsity_fast_path:
- tensor_args = (
- src,
- first_layer.self_attn.in_proj_weight,
- first_layer.self_attn.in_proj_bias,
- first_layer.self_attn.out_proj.weight,
- first_layer.self_attn.out_proj.bias,
- first_layer.norm1.weight,
- first_layer.norm1.bias,
- first_layer.norm2.weight,
- first_layer.norm2.bias,
- first_layer.linear1.weight,
- first_layer.linear1.bias,
- first_layer.linear2.weight,
- first_layer.linear2.bias,
- )
- if torch.overrides.has_torch_function(tensor_args):
- why_not_sparsity_fast_path = "some Tensor argument has_torch_function"
- elif not (src.is_cuda or 'cpu' in str(src.device)):
- why_not_sparsity_fast_path = "src is neither CUDA nor CPU"
- elif torch.is_grad_enabled() and any(x.requires_grad for x in tensor_args):
- why_not_sparsity_fast_path = ("grad is enabled and at least one of query or the "
- "input/output projection weights or biases requires_grad")
- if (not why_not_sparsity_fast_path) and (src_key_padding_mask is not None):
- convert_to_nested = True
- output = torch._nested_tensor_from_mask(output, src_key_padding_mask.logical_not(), mask_check=False)
- src_key_padding_mask_for_layers = None
- # Prevent type refinement
- make_causal = (is_causal is True)
- if is_causal is None:
- if mask is not None:
- sz = mask.size(0)
- causal_comparison = torch.triu(
- torch.ones(sz, sz, device=mask.device) * float('-inf'), diagonal=1
- ).to(mask.dtype)
- if torch.equal(mask, causal_comparison):
- make_causal = True
- is_causal = make_causal
- for mod in self.layers:
- output = mod(output, src_mask=mask, is_causal=is_causal, src_key_padding_mask=src_key_padding_mask_for_layers)
- if convert_to_nested:
- output = output.to_padded_tensor(0.)
- if self.norm is not None:
- output = self.norm(output)
- return output
- class TransformerDecoder(Module):
- r"""TransformerDecoder is a stack of N decoder layers
- Args:
- decoder_layer: an instance of the TransformerDecoderLayer() class (required).
- num_layers: the number of sub-decoder-layers in the decoder (required).
- norm: the layer normalization component (optional).
- Examples::
- >>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8)
- >>> transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=6)
- >>> memory = torch.rand(10, 32, 512)
- >>> tgt = torch.rand(20, 32, 512)
- >>> out = transformer_decoder(tgt, memory)
- """
- __constants__ = ['norm']
- def __init__(self, decoder_layer, num_layers, norm=None):
- super().__init__()
- torch._C._log_api_usage_once(f"torch.nn.modules.{self.__class__.__name__}")
- self.layers = _get_clones(decoder_layer, num_layers)
- self.num_layers = num_layers
- self.norm = norm
- def forward(self, tgt: Tensor, memory: Tensor, tgt_mask: Optional[Tensor] = None,
- memory_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None,
- memory_key_padding_mask: Optional[Tensor] = None) -> Tensor:
- r"""Pass the inputs (and mask) through the decoder layer in turn.
- Args:
- tgt: the sequence to the decoder (required).
- memory: the sequence from the last layer of the encoder (required).
- tgt_mask: the mask for the tgt sequence (optional).
- memory_mask: the mask for the memory sequence (optional).
- tgt_key_padding_mask: the mask for the tgt keys per batch (optional).
- memory_key_padding_mask: the mask for the memory keys per batch (optional).
- Shape:
- see the docs in Transformer class.
- """
- output = tgt
- for mod in self.layers:
- output = mod(output, memory, tgt_mask=tgt_mask,
- memory_mask=memory_mask,
- tgt_key_padding_mask=tgt_key_padding_mask,
- memory_key_padding_mask=memory_key_padding_mask)
- if self.norm is not None:
- output = self.norm(output)
- return output
- class TransformerEncoderLayer(Module):
- r"""TransformerEncoderLayer is made up of self-attn and feedforward network.
- This standard encoder layer is based on the paper "Attention Is All You Need".
- Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
- Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
- Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
- in a different way during application.
- Args:
- d_model: the number of expected features in the input (required).
- nhead: the number of heads in the multiheadattention models (required).
- dim_feedforward: the dimension of the feedforward network model (default=2048).
- dropout: the dropout value (default=0.1).
- activation: the activation function of the intermediate layer, can be a string
- ("relu" or "gelu") or a unary callable. Default: relu
- layer_norm_eps: the eps value in layer normalization components (default=1e-5).
- batch_first: If ``True``, then the input and output tensors are provided
- as (batch, seq, feature). Default: ``False`` (seq, batch, feature).
- norm_first: if ``True``, layer norm is done prior to attention and feedforward
- operations, respectively. Otherwise it's done after. Default: ``False`` (after).
- Examples::
- >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
- >>> src = torch.rand(10, 32, 512)
- >>> out = encoder_layer(src)
- Alternatively, when ``batch_first`` is ``True``:
- >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8, batch_first=True)
- >>> src = torch.rand(32, 10, 512)
- >>> out = encoder_layer(src)
- Fast path:
- forward() will use a special optimized implementation described in
- `FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness`_ if all of the following
- conditions are met:
- - Either autograd is disabled (using ``torch.inference_mode`` or ``torch.no_grad``) or no tensor
- argument ``requires_grad``
- - training is disabled (using ``.eval()``)
- - batch_first is ``True`` and the input is batched (i.e., ``src.dim() == 3``)
- - activation is one of: ``"relu"``, ``"gelu"``, ``torch.functional.relu``, or ``torch.functional.gelu``
- - at most one of ``src_mask`` and ``src_key_padding_mask`` is passed
- - if src is a `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_, neither ``src_mask``
- nor ``src_key_padding_mask`` is passed
- - the two ``LayerNorm`` instances have a consistent ``eps`` value (this will naturally be the case
- unless the caller has manually modified one without modifying the other)
- If the optimized implementation is in use, a
- `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_ can be
- passed for ``src`` to represent padding more efficiently than using a padding
- mask. In this case, a `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_ will be
- returned, and an additional speedup proportional to the fraction of the input that
- is padding can be expected.
- .. _`FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness`:
- https://arxiv.org/abs/2205.14135
- """
- __constants__ = ['batch_first', 'norm_first']
- def __init__(self, d_model: int, nhead: int, dim_feedforward: int = 2048, dropout: float = 0.1,
- activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
- layer_norm_eps: float = 1e-5, batch_first: bool = False, norm_first: bool = False,
- device=None, dtype=None) -> None:
- factory_kwargs = {'device': device, 'dtype': dtype}
- super().__init__()
- self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first,
- **factory_kwargs)
- # Implementation of Feedforward model
- self.linear1 = Linear(d_model, dim_feedforward, **factory_kwargs)
- self.dropout = Dropout(dropout)
- self.linear2 = Linear(dim_feedforward, d_model, **factory_kwargs)
- self.norm_first = norm_first
- self.norm1 = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
- self.norm2 = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
- self.dropout1 = Dropout(dropout)
- self.dropout2 = Dropout(dropout)
- # Legacy string support for activation function.
- if isinstance(activation, str):
- activation = _get_activation_fn(activation)
- # We can't test self.activation in forward() in TorchScript,
- # so stash some information about it instead.
- if activation is F.relu or isinstance(activation, torch.nn.ReLU):
- self.activation_relu_or_gelu = 1
- elif activation is F.gelu or isinstance(activation, torch.nn.GELU):
- self.activation_relu_or_gelu = 2
- else:
- self.activation_relu_or_gelu = 0
- self.activation = activation
- def __setstate__(self, state):
- super().__setstate__(state)
- if not hasattr(self, 'activation'):
- self.activation = F.relu
- def forward(
- self,
- src: Tensor,
- src_mask: Optional[Tensor] = None,
- src_key_padding_mask: Optional[Tensor] = None,
- is_causal: bool = False) -> Tensor:
- r"""Pass the input through the encoder layer.
- Args:
- src: the sequence to the encoder layer (required).
- src_mask: the mask for the src sequence (optional).
- is_causal: If specified, applies a causal mask as src_mask.
- Default: ``False``.
- src_key_padding_mask: the mask for the src keys per batch (optional).
- Shape:
- see the docs in Transformer class.
- """
- src_key_padding_mask = F._canonical_mask(
- mask=src_key_padding_mask,
- mask_name="src_key_padding_mask",
- other_type=F._none_or_dtype(src_mask),
- other_name="src_mask",
- target_type=src.dtype
- )
- # see Fig. 1 of https://arxiv.org/pdf/2002.04745v1.pdf
- why_not_sparsity_fast_path = ''
- if not src.dim() == 3:
- why_not_sparsity_fast_path = f"input not batched; expected src.dim() of 3 but got {src.dim()}"
- elif self.training:
- why_not_sparsity_fast_path = "training is enabled"
- elif not self.self_attn.batch_first :
- why_not_sparsity_fast_path = "self_attn.batch_first was not True"
- elif not self.self_attn._qkv_same_embed_dim :
- why_not_sparsity_fast_path = "self_attn._qkv_same_embed_dim was not True"
- elif not self.activation_relu_or_gelu:
- why_not_sparsity_fast_path = "activation_relu_or_gelu was not True"
- elif not (self.norm1.eps == self.norm2.eps):
- why_not_sparsity_fast_path = "norm1.eps is not equal to norm2.eps"
- elif src.is_nested and (src_key_padding_mask is not None or src_mask is not None):
- why_not_sparsity_fast_path = "neither src_key_padding_mask nor src_mask are not supported with NestedTensor input"
- elif self.self_attn.num_heads % 2 == 1:
- why_not_sparsity_fast_path = "num_head is odd"
- elif torch.is_autocast_enabled():
- why_not_sparsity_fast_path = "autocast is enabled"
- if not why_not_sparsity_fast_path:
- tensor_args = (
- src,
- self.self_attn.in_proj_weight,
- self.self_attn.in_proj_bias,
- self.self_attn.out_proj.weight,
- self.self_attn.out_proj.bias,
- self.norm1.weight,
- self.norm1.bias,
- self.norm2.weight,
- self.norm2.bias,
- self.linear1.weight,
- self.linear1.bias,
- self.linear2.weight,
- self.linear2.bias,
- )
- # We have to use list comprehensions below because TorchScript does not support
- # generator expressions.
- if torch.overrides.has_torch_function(tensor_args):
- why_not_sparsity_fast_path = "some Tensor argument has_torch_function"
- elif not all((x.is_cuda or 'cpu' in str(x.device)) for x in tensor_args):
- why_not_sparsity_fast_path = "some Tensor argument is neither CUDA nor CPU"
- elif torch.is_grad_enabled() and any(x.requires_grad for x in tensor_args):
- why_not_sparsity_fast_path = ("grad is enabled and at least one of query or the "
- "input/output projection weights or biases requires_grad")
- if not why_not_sparsity_fast_path:
- merged_mask, mask_type = self.self_attn.merge_masks(src_mask, src_key_padding_mask, src)
- return torch._transformer_encoder_layer_fwd(
- src,
- self.self_attn.embed_dim,
- self.self_attn.num_heads,
- self.self_attn.in_proj_weight,
- self.self_attn.in_proj_bias,
- self.self_attn.out_proj.weight,
- self.self_attn.out_proj.bias,
- self.activation_relu_or_gelu == 2,
- self.norm_first,
- self.norm1.eps,
- self.norm1.weight,
- self.norm1.bias,
- self.norm2.weight,
- self.norm2.bias,
- self.linear1.weight,
- self.linear1.bias,
- self.linear2.weight,
- self.linear2.bias,
- merged_mask,
- mask_type,
- )
- x = src
- if self.norm_first:
- x = x + self._sa_block(self.norm1(x), src_mask, src_key_padding_mask)
- x = x + self._ff_block(self.norm2(x))
- else:
- x = self.norm1(x + self._sa_block(x, src_mask, src_key_padding_mask))
- x = self.norm2(x + self._ff_block(x))
- return x
- # self-attention block
- def _sa_block(self, x: Tensor,
- attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor]) -> Tensor:
- x = self.self_attn(x, x, x,
- attn_mask=attn_mask,
- key_padding_mask=key_padding_mask,
- need_weights=False)[0]
- return self.dropout1(x)
- # feed forward block
- def _ff_block(self, x: Tensor) -> Tensor:
- x = self.linear2(self.dropout(self.activation(self.linear1(x))))
- return self.dropout2(x)
- class TransformerDecoderLayer(Module):
- r"""TransformerDecoderLayer is made up of self-attn, multi-head-attn and feedforward network.
- This standard decoder layer is based on the paper "Attention Is All You Need".
- Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
- Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
- Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
- in a different way during application.
- Args:
- d_model: the number of expected features in the input (required).
- nhead: the number of heads in the multiheadattention models (required).
- dim_feedforward: the dimension of the feedforward network model (default=2048).
- dropout: the dropout value (default=0.1).
- activation: the activation function of the intermediate layer, can be a string
- ("relu" or "gelu") or a unary callable. Default: relu
- layer_norm_eps: the eps value in layer normalization components (default=1e-5).
- batch_first: If ``True``, then the input and output tensors are provided
- as (batch, seq, feature). Default: ``False`` (seq, batch, feature).
- norm_first: if ``True``, layer norm is done prior to self attention, multihead
- attention and feedforward operations, respectively. Otherwise it's done after.
- Default: ``False`` (after).
- Examples::
- >>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8)
- >>> memory = torch.rand(10, 32, 512)
- >>> tgt = torch.rand(20, 32, 512)
- >>> out = decoder_layer(tgt, memory)
- Alternatively, when ``batch_first`` is ``True``:
- >>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8, batch_first=True)
- >>> memory = torch.rand(32, 10, 512)
- >>> tgt = torch.rand(32, 20, 512)
- >>> out = decoder_layer(tgt, memory)
- """
- __constants__ = ['batch_first', 'norm_first']
- def __init__(self, d_model: int, nhead: int, dim_feedforward: int = 2048, dropout: float = 0.1,
- activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
- layer_norm_eps: float = 1e-5, batch_first: bool = False, norm_first: bool = False,
- device=None, dtype=None) -> None:
- factory_kwargs = {'device': device, 'dtype': dtype}
- super().__init__()
- self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first,
- **factory_kwargs)
- self.multihead_attn = MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first,
- **factory_kwargs)
- # Implementation of Feedforward model
- self.linear1 = Linear(d_model, dim_feedforward, **factory_kwargs)
- self.dropout = Dropout(dropout)
- self.linear2 = Linear(dim_feedforward, d_model, **factory_kwargs)
- self.norm_first = norm_first
- self.norm1 = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
- self.norm2 = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
- self.norm3 = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
- self.dropout1 = Dropout(dropout)
- self.dropout2 = Dropout(dropout)
- self.dropout3 = Dropout(dropout)
- # Legacy string support for activation function.
- if isinstance(activation, str):
- self.activation = _get_activation_fn(activation)
- else:
- self.activation = activation
- def __setstate__(self, state):
- if 'activation' not in state:
- state['activation'] = F.relu
- super().__setstate__(state)
- def forward(
- self,
- tgt: Tensor,
- memory: Tensor,
- tgt_mask: Optional[Tensor] = None,
- memory_mask: Optional[Tensor] = None,
- tgt_key_padding_mask: Optional[Tensor] = None,
- memory_key_padding_mask: Optional[Tensor] = None,
- tgt_is_causal: bool = False,
- memory_is_causal: bool = False,
- ) -> Tensor:
- r"""Pass the inputs (and mask) through the decoder layer.
- Args:
- tgt: the sequence to the decoder layer (required).
- memory: the sequence from the last layer of the encoder (required).
- tgt_mask: the mask for the tgt sequence (optional).
- memory_mask: the mask for the memory sequence (optional).
- tgt_key_padding_mask: the mask for the tgt keys per batch (optional).
- memory_key_padding_mask: the mask for the memory keys per batch (optional).
- tgt_is_causal: If specified, applies a causal mask as tgt mask.
- Mutually exclusive with providing tgt_mask. Default: ``False``.
- memory_is_causal: If specified, applies a causal mask as tgt mask.
- Mutually exclusive with providing memory_mask. Default: ``False``.
- Shape:
- see the docs in Transformer class.
- """
- # see Fig. 1 of https://arxiv.org/pdf/2002.04745v1.pdf
- x = tgt
- if self.norm_first:
- x = x + self._sa_block(self.norm1(x), tgt_mask, tgt_key_padding_mask, tgt_is_causal)
- x = x + self._mha_block(self.norm2(x), memory, memory_mask, memory_key_padding_mask, memory_is_causal)
- x = x + self._ff_block(self.norm3(x))
- else:
- x = self.norm1(x + self._sa_block(x, tgt_mask, tgt_key_padding_mask, tgt_is_causal))
- x = self.norm2(x + self._mha_block(x, memory, memory_mask, memory_key_padding_mask, memory_is_causal))
- x = self.norm3(x + self._ff_block(x))
- return x
- # self-attention block
- def _sa_block(self, x: Tensor,
- attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor], is_causal: bool = False) -> Tensor:
- x = self.self_attn(x, x, x,
- attn_mask=attn_mask,
- key_padding_mask=key_padding_mask,
- is_causal=is_causal,
- need_weights=False)[0]
- return self.dropout1(x)
- # multihead attention block
- def _mha_block(self, x: Tensor, mem: Tensor,
- attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor], is_causal: bool = False) -> Tensor:
- x = self.multihead_attn(x, mem, mem,
- attn_mask=attn_mask,
- key_padding_mask=key_padding_mask,
- is_causal=is_causal,
- need_weights=False)[0]
- return self.dropout2(x)
- # feed forward block
- def _ff_block(self, x: Tensor) -> Tensor:
- x = self.linear2(self.dropout(self.activation(self.linear1(x))))
- return self.dropout3(x)
- def _get_clones(module, N):
- # FIXME: copy.deepcopy() is not defined on nn.module
- return ModuleList([copy.deepcopy(module) for i in range(N)])
- def _get_activation_fn(activation: str) -> Callable[[Tensor], Tensor]:
- if activation == "relu":
- return F.relu
- elif activation == "gelu":
- return F.gelu
- raise RuntimeError("activation should be relu/gelu, not {}".format(activation))
|