transformer.py 36 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741
  1. import copy
  2. from typing import Optional, Any, Union, Callable
  3. import torch
  4. from torch import Tensor
  5. from .. import functional as F
  6. from .module import Module
  7. from .activation import MultiheadAttention
  8. from .container import ModuleList
  9. from ..init import xavier_uniform_
  10. from .dropout import Dropout
  11. from .linear import Linear
  12. from .normalization import LayerNorm
  13. __all__ = ['Transformer', 'TransformerEncoder', 'TransformerDecoder', 'TransformerEncoderLayer', 'TransformerDecoderLayer']
  14. class Transformer(Module):
  15. r"""A transformer model. User is able to modify the attributes as needed. The architecture
  16. is based on the paper "Attention Is All You Need". Ashish Vaswani, Noam Shazeer,
  17. Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Lukasz Kaiser, and
  18. Illia Polosukhin. 2017. Attention is all you need. In Advances in Neural Information
  19. Processing Systems, pages 6000-6010.
  20. Args:
  21. d_model: the number of expected features in the encoder/decoder inputs (default=512).
  22. nhead: the number of heads in the multiheadattention models (default=8).
  23. num_encoder_layers: the number of sub-encoder-layers in the encoder (default=6).
  24. num_decoder_layers: the number of sub-decoder-layers in the decoder (default=6).
  25. dim_feedforward: the dimension of the feedforward network model (default=2048).
  26. dropout: the dropout value (default=0.1).
  27. activation: the activation function of encoder/decoder intermediate layer, can be a string
  28. ("relu" or "gelu") or a unary callable. Default: relu
  29. custom_encoder: custom encoder (default=None).
  30. custom_decoder: custom decoder (default=None).
  31. layer_norm_eps: the eps value in layer normalization components (default=1e-5).
  32. batch_first: If ``True``, then the input and output tensors are provided
  33. as (batch, seq, feature). Default: ``False`` (seq, batch, feature).
  34. norm_first: if ``True``, encoder and decoder layers will perform LayerNorms before
  35. other attention and feedforward operations, otherwise after. Default: ``False`` (after).
  36. Examples::
  37. >>> transformer_model = nn.Transformer(nhead=16, num_encoder_layers=12)
  38. >>> src = torch.rand((10, 32, 512))
  39. >>> tgt = torch.rand((20, 32, 512))
  40. >>> out = transformer_model(src, tgt)
  41. Note: A full example to apply nn.Transformer module for the word language model is available in
  42. https://github.com/pytorch/examples/tree/master/word_language_model
  43. """
  44. def __init__(self, d_model: int = 512, nhead: int = 8, num_encoder_layers: int = 6,
  45. num_decoder_layers: int = 6, dim_feedforward: int = 2048, dropout: float = 0.1,
  46. activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
  47. custom_encoder: Optional[Any] = None, custom_decoder: Optional[Any] = None,
  48. layer_norm_eps: float = 1e-5, batch_first: bool = False, norm_first: bool = False,
  49. device=None, dtype=None) -> None:
  50. factory_kwargs = {'device': device, 'dtype': dtype}
  51. super().__init__()
  52. torch._C._log_api_usage_once(f"torch.nn.modules.{self.__class__.__name__}")
  53. if custom_encoder is not None:
  54. self.encoder = custom_encoder
  55. else:
  56. encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout,
  57. activation, layer_norm_eps, batch_first, norm_first,
  58. **factory_kwargs)
  59. encoder_norm = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
  60. self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
  61. if custom_decoder is not None:
  62. self.decoder = custom_decoder
  63. else:
  64. decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward, dropout,
  65. activation, layer_norm_eps, batch_first, norm_first,
  66. **factory_kwargs)
  67. decoder_norm = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
  68. self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm)
  69. self._reset_parameters()
  70. self.d_model = d_model
  71. self.nhead = nhead
  72. self.batch_first = batch_first
  73. def forward(self, src: Tensor, tgt: Tensor, src_mask: Optional[Tensor] = None, tgt_mask: Optional[Tensor] = None,
  74. memory_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None,
  75. tgt_key_padding_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None) -> Tensor:
  76. r"""Take in and process masked source/target sequences.
  77. Args:
  78. src: the sequence to the encoder (required).
  79. tgt: the sequence to the decoder (required).
  80. src_mask: the additive mask for the src sequence (optional).
  81. tgt_mask: the additive mask for the tgt sequence (optional).
  82. memory_mask: the additive mask for the encoder output (optional).
  83. src_key_padding_mask: the Tensor mask for src keys per batch (optional).
  84. tgt_key_padding_mask: the Tensor mask for tgt keys per batch (optional).
  85. memory_key_padding_mask: the Tensor mask for memory keys per batch (optional).
  86. Shape:
  87. - src: :math:`(S, E)` for unbatched input, :math:`(S, N, E)` if `batch_first=False` or
  88. `(N, S, E)` if `batch_first=True`.
  89. - tgt: :math:`(T, E)` for unbatched input, :math:`(T, N, E)` if `batch_first=False` or
  90. `(N, T, E)` if `batch_first=True`.
  91. - src_mask: :math:`(S, S)` or :math:`(N\cdot\text{num\_heads}, S, S)`.
  92. - tgt_mask: :math:`(T, T)` or :math:`(N\cdot\text{num\_heads}, T, T)`.
  93. - memory_mask: :math:`(T, S)`.
  94. - src_key_padding_mask: :math:`(S)` for unbatched input otherwise :math:`(N, S)`.
  95. - tgt_key_padding_mask: :math:`(T)` for unbatched input otherwise :math:`(N, T)`.
  96. - memory_key_padding_mask: :math:`(S)` for unbatched input otherwise :math:`(N, S)`.
  97. Note: [src/tgt/memory]_mask ensures that position i is allowed to attend the unmasked
  98. positions. If a BoolTensor is provided, positions with ``True``
  99. are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
  100. is provided, it will be added to the attention weight.
  101. [src/tgt/memory]_key_padding_mask provides specified elements in the key to be ignored by
  102. the attention. If a BoolTensor is provided, the positions with the
  103. value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
  104. - output: :math:`(T, E)` for unbatched input, :math:`(T, N, E)` if `batch_first=False` or
  105. `(N, T, E)` if `batch_first=True`.
  106. Note: Due to the multi-head attention architecture in the transformer model,
  107. the output sequence length of a transformer is same as the input sequence
  108. (i.e. target) length of the decoder.
  109. where S is the source sequence length, T is the target sequence length, N is the
  110. batch size, E is the feature number
  111. Examples:
  112. >>> # xdoctest: +SKIP
  113. >>> output = transformer_model(src, tgt, src_mask=src_mask, tgt_mask=tgt_mask)
  114. """
  115. is_batched = src.dim() == 3
  116. if not self.batch_first and src.size(1) != tgt.size(1) and is_batched:
  117. raise RuntimeError("the batch number of src and tgt must be equal")
  118. elif self.batch_first and src.size(0) != tgt.size(0) and is_batched:
  119. raise RuntimeError("the batch number of src and tgt must be equal")
  120. if src.size(-1) != self.d_model or tgt.size(-1) != self.d_model:
  121. raise RuntimeError("the feature number of src and tgt must be equal to d_model")
  122. memory = self.encoder(src, mask=src_mask, src_key_padding_mask=src_key_padding_mask)
  123. output = self.decoder(tgt, memory, tgt_mask=tgt_mask, memory_mask=memory_mask,
  124. tgt_key_padding_mask=tgt_key_padding_mask,
  125. memory_key_padding_mask=memory_key_padding_mask)
  126. return output
  127. @staticmethod
  128. def generate_square_subsequent_mask(sz: int, device='cpu') -> Tensor:
  129. r"""Generate a square mask for the sequence. The masked positions are filled with float('-inf').
  130. Unmasked positions are filled with float(0.0).
  131. """
  132. return torch.triu(torch.full((sz, sz), float('-inf'), device=device), diagonal=1)
  133. def _reset_parameters(self):
  134. r"""Initiate parameters in the transformer model."""
  135. for p in self.parameters():
  136. if p.dim() > 1:
  137. xavier_uniform_(p)
  138. class TransformerEncoder(Module):
  139. r"""TransformerEncoder is a stack of N encoder layers. Users can build the
  140. BERT(https://arxiv.org/abs/1810.04805) model with corresponding parameters.
  141. Args:
  142. encoder_layer: an instance of the TransformerEncoderLayer() class (required).
  143. num_layers: the number of sub-encoder-layers in the encoder (required).
  144. norm: the layer normalization component (optional).
  145. enable_nested_tensor: if True, input will automatically convert to nested tensor
  146. (and convert back on output). This will improve the overall performance of
  147. TransformerEncoder when padding rate is high. Default: ``True`` (enabled).
  148. Examples::
  149. >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
  150. >>> transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=6)
  151. >>> src = torch.rand(10, 32, 512)
  152. >>> out = transformer_encoder(src)
  153. """
  154. __constants__ = ['norm']
  155. def __init__(self, encoder_layer, num_layers, norm=None, enable_nested_tensor=True, mask_check=True):
  156. super().__init__()
  157. torch._C._log_api_usage_once(f"torch.nn.modules.{self.__class__.__name__}")
  158. self.layers = _get_clones(encoder_layer, num_layers)
  159. self.num_layers = num_layers
  160. self.norm = norm
  161. self.enable_nested_tensor = enable_nested_tensor
  162. self.mask_check = mask_check
  163. def forward(
  164. self,
  165. src: Tensor,
  166. mask: Optional[Tensor] = None,
  167. src_key_padding_mask: Optional[Tensor] = None,
  168. is_causal: Optional[bool] = None) -> Tensor:
  169. r"""Pass the input through the encoder layers in turn.
  170. Args:
  171. src: the sequence to the encoder (required).
  172. mask: the mask for the src sequence (optional).
  173. is_causal: If specified, applies a causal mask as mask (optional)
  174. and ignores attn_mask for computing scaled dot product attention.
  175. Default: ``False``.
  176. src_key_padding_mask: the mask for the src keys per batch (optional).
  177. Shape:
  178. see the docs in Transformer class.
  179. """
  180. src_key_padding_mask = F._canonical_mask(
  181. mask=src_key_padding_mask,
  182. mask_name="src_key_padding_mask",
  183. other_type=F._none_or_dtype(mask),
  184. other_name="mask",
  185. target_type=src.dtype
  186. )
  187. output = src
  188. convert_to_nested = False
  189. first_layer = self.layers[0]
  190. src_key_padding_mask_for_layers = src_key_padding_mask
  191. why_not_sparsity_fast_path = ''
  192. str_first_layer = "self.layers[0]"
  193. if not isinstance(first_layer, torch.nn.TransformerEncoderLayer):
  194. why_not_sparsity_fast_path = f"{str_first_layer} was not TransformerEncoderLayer"
  195. elif first_layer.norm_first :
  196. why_not_sparsity_fast_path = f"{str_first_layer}.norm_first was True"
  197. elif first_layer.training:
  198. why_not_sparsity_fast_path = f"{str_first_layer} was in training mode"
  199. elif not first_layer.self_attn.batch_first:
  200. why_not_sparsity_fast_path = f" {str_first_layer}.self_attn.batch_first was not True"
  201. elif not first_layer.self_attn._qkv_same_embed_dim:
  202. why_not_sparsity_fast_path = f"{str_first_layer}.self_attn._qkv_same_embed_dim was not True"
  203. elif not first_layer.activation_relu_or_gelu:
  204. why_not_sparsity_fast_path = f" {str_first_layer}.activation_relu_or_gelu was not True"
  205. elif not (first_layer.norm1.eps == first_layer.norm2.eps) :
  206. why_not_sparsity_fast_path = f"{str_first_layer}.norm1.eps was not equal to {str_first_layer}.norm2.eps"
  207. elif not src.dim() == 3:
  208. why_not_sparsity_fast_path = f"input not batched; expected src.dim() of 3 but got {src.dim()}"
  209. elif not self.enable_nested_tensor:
  210. why_not_sparsity_fast_path = "enable_nested_tensor was not True"
  211. elif src_key_padding_mask is None:
  212. why_not_sparsity_fast_path = "src_key_padding_mask was None"
  213. elif (((not hasattr(self, "mask_check")) or self.mask_check)
  214. and not torch._nested_tensor_from_mask_left_aligned(src, src_key_padding_mask.logical_not())):
  215. why_not_sparsity_fast_path = "mask_check enabled, and src and src_key_padding_mask was not left aligned"
  216. elif output.is_nested:
  217. why_not_sparsity_fast_path = "NestedTensor input is not supported"
  218. elif mask is not None:
  219. why_not_sparsity_fast_path = "src_key_padding_mask and mask were both supplied"
  220. elif first_layer.self_attn.num_heads % 2 == 1:
  221. why_not_sparsity_fast_path = "num_head is odd"
  222. elif torch.is_autocast_enabled():
  223. why_not_sparsity_fast_path = "autocast is enabled"
  224. if not why_not_sparsity_fast_path:
  225. tensor_args = (
  226. src,
  227. first_layer.self_attn.in_proj_weight,
  228. first_layer.self_attn.in_proj_bias,
  229. first_layer.self_attn.out_proj.weight,
  230. first_layer.self_attn.out_proj.bias,
  231. first_layer.norm1.weight,
  232. first_layer.norm1.bias,
  233. first_layer.norm2.weight,
  234. first_layer.norm2.bias,
  235. first_layer.linear1.weight,
  236. first_layer.linear1.bias,
  237. first_layer.linear2.weight,
  238. first_layer.linear2.bias,
  239. )
  240. if torch.overrides.has_torch_function(tensor_args):
  241. why_not_sparsity_fast_path = "some Tensor argument has_torch_function"
  242. elif not (src.is_cuda or 'cpu' in str(src.device)):
  243. why_not_sparsity_fast_path = "src is neither CUDA nor CPU"
  244. elif torch.is_grad_enabled() and any(x.requires_grad for x in tensor_args):
  245. why_not_sparsity_fast_path = ("grad is enabled and at least one of query or the "
  246. "input/output projection weights or biases requires_grad")
  247. if (not why_not_sparsity_fast_path) and (src_key_padding_mask is not None):
  248. convert_to_nested = True
  249. output = torch._nested_tensor_from_mask(output, src_key_padding_mask.logical_not(), mask_check=False)
  250. src_key_padding_mask_for_layers = None
  251. # Prevent type refinement
  252. make_causal = (is_causal is True)
  253. if is_causal is None:
  254. if mask is not None:
  255. sz = mask.size(0)
  256. causal_comparison = torch.triu(
  257. torch.ones(sz, sz, device=mask.device) * float('-inf'), diagonal=1
  258. ).to(mask.dtype)
  259. if torch.equal(mask, causal_comparison):
  260. make_causal = True
  261. is_causal = make_causal
  262. for mod in self.layers:
  263. output = mod(output, src_mask=mask, is_causal=is_causal, src_key_padding_mask=src_key_padding_mask_for_layers)
  264. if convert_to_nested:
  265. output = output.to_padded_tensor(0.)
  266. if self.norm is not None:
  267. output = self.norm(output)
  268. return output
  269. class TransformerDecoder(Module):
  270. r"""TransformerDecoder is a stack of N decoder layers
  271. Args:
  272. decoder_layer: an instance of the TransformerDecoderLayer() class (required).
  273. num_layers: the number of sub-decoder-layers in the decoder (required).
  274. norm: the layer normalization component (optional).
  275. Examples::
  276. >>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8)
  277. >>> transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=6)
  278. >>> memory = torch.rand(10, 32, 512)
  279. >>> tgt = torch.rand(20, 32, 512)
  280. >>> out = transformer_decoder(tgt, memory)
  281. """
  282. __constants__ = ['norm']
  283. def __init__(self, decoder_layer, num_layers, norm=None):
  284. super().__init__()
  285. torch._C._log_api_usage_once(f"torch.nn.modules.{self.__class__.__name__}")
  286. self.layers = _get_clones(decoder_layer, num_layers)
  287. self.num_layers = num_layers
  288. self.norm = norm
  289. def forward(self, tgt: Tensor, memory: Tensor, tgt_mask: Optional[Tensor] = None,
  290. memory_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None,
  291. memory_key_padding_mask: Optional[Tensor] = None) -> Tensor:
  292. r"""Pass the inputs (and mask) through the decoder layer in turn.
  293. Args:
  294. tgt: the sequence to the decoder (required).
  295. memory: the sequence from the last layer of the encoder (required).
  296. tgt_mask: the mask for the tgt sequence (optional).
  297. memory_mask: the mask for the memory sequence (optional).
  298. tgt_key_padding_mask: the mask for the tgt keys per batch (optional).
  299. memory_key_padding_mask: the mask for the memory keys per batch (optional).
  300. Shape:
  301. see the docs in Transformer class.
  302. """
  303. output = tgt
  304. for mod in self.layers:
  305. output = mod(output, memory, tgt_mask=tgt_mask,
  306. memory_mask=memory_mask,
  307. tgt_key_padding_mask=tgt_key_padding_mask,
  308. memory_key_padding_mask=memory_key_padding_mask)
  309. if self.norm is not None:
  310. output = self.norm(output)
  311. return output
  312. class TransformerEncoderLayer(Module):
  313. r"""TransformerEncoderLayer is made up of self-attn and feedforward network.
  314. This standard encoder layer is based on the paper "Attention Is All You Need".
  315. Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
  316. Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
  317. Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
  318. in a different way during application.
  319. Args:
  320. d_model: the number of expected features in the input (required).
  321. nhead: the number of heads in the multiheadattention models (required).
  322. dim_feedforward: the dimension of the feedforward network model (default=2048).
  323. dropout: the dropout value (default=0.1).
  324. activation: the activation function of the intermediate layer, can be a string
  325. ("relu" or "gelu") or a unary callable. Default: relu
  326. layer_norm_eps: the eps value in layer normalization components (default=1e-5).
  327. batch_first: If ``True``, then the input and output tensors are provided
  328. as (batch, seq, feature). Default: ``False`` (seq, batch, feature).
  329. norm_first: if ``True``, layer norm is done prior to attention and feedforward
  330. operations, respectively. Otherwise it's done after. Default: ``False`` (after).
  331. Examples::
  332. >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
  333. >>> src = torch.rand(10, 32, 512)
  334. >>> out = encoder_layer(src)
  335. Alternatively, when ``batch_first`` is ``True``:
  336. >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8, batch_first=True)
  337. >>> src = torch.rand(32, 10, 512)
  338. >>> out = encoder_layer(src)
  339. Fast path:
  340. forward() will use a special optimized implementation described in
  341. `FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness`_ if all of the following
  342. conditions are met:
  343. - Either autograd is disabled (using ``torch.inference_mode`` or ``torch.no_grad``) or no tensor
  344. argument ``requires_grad``
  345. - training is disabled (using ``.eval()``)
  346. - batch_first is ``True`` and the input is batched (i.e., ``src.dim() == 3``)
  347. - activation is one of: ``"relu"``, ``"gelu"``, ``torch.functional.relu``, or ``torch.functional.gelu``
  348. - at most one of ``src_mask`` and ``src_key_padding_mask`` is passed
  349. - if src is a `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_, neither ``src_mask``
  350. nor ``src_key_padding_mask`` is passed
  351. - the two ``LayerNorm`` instances have a consistent ``eps`` value (this will naturally be the case
  352. unless the caller has manually modified one without modifying the other)
  353. If the optimized implementation is in use, a
  354. `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_ can be
  355. passed for ``src`` to represent padding more efficiently than using a padding
  356. mask. In this case, a `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_ will be
  357. returned, and an additional speedup proportional to the fraction of the input that
  358. is padding can be expected.
  359. .. _`FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness`:
  360. https://arxiv.org/abs/2205.14135
  361. """
  362. __constants__ = ['batch_first', 'norm_first']
  363. def __init__(self, d_model: int, nhead: int, dim_feedforward: int = 2048, dropout: float = 0.1,
  364. activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
  365. layer_norm_eps: float = 1e-5, batch_first: bool = False, norm_first: bool = False,
  366. device=None, dtype=None) -> None:
  367. factory_kwargs = {'device': device, 'dtype': dtype}
  368. super().__init__()
  369. self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first,
  370. **factory_kwargs)
  371. # Implementation of Feedforward model
  372. self.linear1 = Linear(d_model, dim_feedforward, **factory_kwargs)
  373. self.dropout = Dropout(dropout)
  374. self.linear2 = Linear(dim_feedforward, d_model, **factory_kwargs)
  375. self.norm_first = norm_first
  376. self.norm1 = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
  377. self.norm2 = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
  378. self.dropout1 = Dropout(dropout)
  379. self.dropout2 = Dropout(dropout)
  380. # Legacy string support for activation function.
  381. if isinstance(activation, str):
  382. activation = _get_activation_fn(activation)
  383. # We can't test self.activation in forward() in TorchScript,
  384. # so stash some information about it instead.
  385. if activation is F.relu or isinstance(activation, torch.nn.ReLU):
  386. self.activation_relu_or_gelu = 1
  387. elif activation is F.gelu or isinstance(activation, torch.nn.GELU):
  388. self.activation_relu_or_gelu = 2
  389. else:
  390. self.activation_relu_or_gelu = 0
  391. self.activation = activation
  392. def __setstate__(self, state):
  393. super().__setstate__(state)
  394. if not hasattr(self, 'activation'):
  395. self.activation = F.relu
  396. def forward(
  397. self,
  398. src: Tensor,
  399. src_mask: Optional[Tensor] = None,
  400. src_key_padding_mask: Optional[Tensor] = None,
  401. is_causal: bool = False) -> Tensor:
  402. r"""Pass the input through the encoder layer.
  403. Args:
  404. src: the sequence to the encoder layer (required).
  405. src_mask: the mask for the src sequence (optional).
  406. is_causal: If specified, applies a causal mask as src_mask.
  407. Default: ``False``.
  408. src_key_padding_mask: the mask for the src keys per batch (optional).
  409. Shape:
  410. see the docs in Transformer class.
  411. """
  412. src_key_padding_mask = F._canonical_mask(
  413. mask=src_key_padding_mask,
  414. mask_name="src_key_padding_mask",
  415. other_type=F._none_or_dtype(src_mask),
  416. other_name="src_mask",
  417. target_type=src.dtype
  418. )
  419. # see Fig. 1 of https://arxiv.org/pdf/2002.04745v1.pdf
  420. why_not_sparsity_fast_path = ''
  421. if not src.dim() == 3:
  422. why_not_sparsity_fast_path = f"input not batched; expected src.dim() of 3 but got {src.dim()}"
  423. elif self.training:
  424. why_not_sparsity_fast_path = "training is enabled"
  425. elif not self.self_attn.batch_first :
  426. why_not_sparsity_fast_path = "self_attn.batch_first was not True"
  427. elif not self.self_attn._qkv_same_embed_dim :
  428. why_not_sparsity_fast_path = "self_attn._qkv_same_embed_dim was not True"
  429. elif not self.activation_relu_or_gelu:
  430. why_not_sparsity_fast_path = "activation_relu_or_gelu was not True"
  431. elif not (self.norm1.eps == self.norm2.eps):
  432. why_not_sparsity_fast_path = "norm1.eps is not equal to norm2.eps"
  433. elif src.is_nested and (src_key_padding_mask is not None or src_mask is not None):
  434. why_not_sparsity_fast_path = "neither src_key_padding_mask nor src_mask are not supported with NestedTensor input"
  435. elif self.self_attn.num_heads % 2 == 1:
  436. why_not_sparsity_fast_path = "num_head is odd"
  437. elif torch.is_autocast_enabled():
  438. why_not_sparsity_fast_path = "autocast is enabled"
  439. if not why_not_sparsity_fast_path:
  440. tensor_args = (
  441. src,
  442. self.self_attn.in_proj_weight,
  443. self.self_attn.in_proj_bias,
  444. self.self_attn.out_proj.weight,
  445. self.self_attn.out_proj.bias,
  446. self.norm1.weight,
  447. self.norm1.bias,
  448. self.norm2.weight,
  449. self.norm2.bias,
  450. self.linear1.weight,
  451. self.linear1.bias,
  452. self.linear2.weight,
  453. self.linear2.bias,
  454. )
  455. # We have to use list comprehensions below because TorchScript does not support
  456. # generator expressions.
  457. if torch.overrides.has_torch_function(tensor_args):
  458. why_not_sparsity_fast_path = "some Tensor argument has_torch_function"
  459. elif not all((x.is_cuda or 'cpu' in str(x.device)) for x in tensor_args):
  460. why_not_sparsity_fast_path = "some Tensor argument is neither CUDA nor CPU"
  461. elif torch.is_grad_enabled() and any(x.requires_grad for x in tensor_args):
  462. why_not_sparsity_fast_path = ("grad is enabled and at least one of query or the "
  463. "input/output projection weights or biases requires_grad")
  464. if not why_not_sparsity_fast_path:
  465. merged_mask, mask_type = self.self_attn.merge_masks(src_mask, src_key_padding_mask, src)
  466. return torch._transformer_encoder_layer_fwd(
  467. src,
  468. self.self_attn.embed_dim,
  469. self.self_attn.num_heads,
  470. self.self_attn.in_proj_weight,
  471. self.self_attn.in_proj_bias,
  472. self.self_attn.out_proj.weight,
  473. self.self_attn.out_proj.bias,
  474. self.activation_relu_or_gelu == 2,
  475. self.norm_first,
  476. self.norm1.eps,
  477. self.norm1.weight,
  478. self.norm1.bias,
  479. self.norm2.weight,
  480. self.norm2.bias,
  481. self.linear1.weight,
  482. self.linear1.bias,
  483. self.linear2.weight,
  484. self.linear2.bias,
  485. merged_mask,
  486. mask_type,
  487. )
  488. x = src
  489. if self.norm_first:
  490. x = x + self._sa_block(self.norm1(x), src_mask, src_key_padding_mask)
  491. x = x + self._ff_block(self.norm2(x))
  492. else:
  493. x = self.norm1(x + self._sa_block(x, src_mask, src_key_padding_mask))
  494. x = self.norm2(x + self._ff_block(x))
  495. return x
  496. # self-attention block
  497. def _sa_block(self, x: Tensor,
  498. attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor]) -> Tensor:
  499. x = self.self_attn(x, x, x,
  500. attn_mask=attn_mask,
  501. key_padding_mask=key_padding_mask,
  502. need_weights=False)[0]
  503. return self.dropout1(x)
  504. # feed forward block
  505. def _ff_block(self, x: Tensor) -> Tensor:
  506. x = self.linear2(self.dropout(self.activation(self.linear1(x))))
  507. return self.dropout2(x)
  508. class TransformerDecoderLayer(Module):
  509. r"""TransformerDecoderLayer is made up of self-attn, multi-head-attn and feedforward network.
  510. This standard decoder layer is based on the paper "Attention Is All You Need".
  511. Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
  512. Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
  513. Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
  514. in a different way during application.
  515. Args:
  516. d_model: the number of expected features in the input (required).
  517. nhead: the number of heads in the multiheadattention models (required).
  518. dim_feedforward: the dimension of the feedforward network model (default=2048).
  519. dropout: the dropout value (default=0.1).
  520. activation: the activation function of the intermediate layer, can be a string
  521. ("relu" or "gelu") or a unary callable. Default: relu
  522. layer_norm_eps: the eps value in layer normalization components (default=1e-5).
  523. batch_first: If ``True``, then the input and output tensors are provided
  524. as (batch, seq, feature). Default: ``False`` (seq, batch, feature).
  525. norm_first: if ``True``, layer norm is done prior to self attention, multihead
  526. attention and feedforward operations, respectively. Otherwise it's done after.
  527. Default: ``False`` (after).
  528. Examples::
  529. >>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8)
  530. >>> memory = torch.rand(10, 32, 512)
  531. >>> tgt = torch.rand(20, 32, 512)
  532. >>> out = decoder_layer(tgt, memory)
  533. Alternatively, when ``batch_first`` is ``True``:
  534. >>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8, batch_first=True)
  535. >>> memory = torch.rand(32, 10, 512)
  536. >>> tgt = torch.rand(32, 20, 512)
  537. >>> out = decoder_layer(tgt, memory)
  538. """
  539. __constants__ = ['batch_first', 'norm_first']
  540. def __init__(self, d_model: int, nhead: int, dim_feedforward: int = 2048, dropout: float = 0.1,
  541. activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
  542. layer_norm_eps: float = 1e-5, batch_first: bool = False, norm_first: bool = False,
  543. device=None, dtype=None) -> None:
  544. factory_kwargs = {'device': device, 'dtype': dtype}
  545. super().__init__()
  546. self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first,
  547. **factory_kwargs)
  548. self.multihead_attn = MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first,
  549. **factory_kwargs)
  550. # Implementation of Feedforward model
  551. self.linear1 = Linear(d_model, dim_feedforward, **factory_kwargs)
  552. self.dropout = Dropout(dropout)
  553. self.linear2 = Linear(dim_feedforward, d_model, **factory_kwargs)
  554. self.norm_first = norm_first
  555. self.norm1 = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
  556. self.norm2 = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
  557. self.norm3 = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
  558. self.dropout1 = Dropout(dropout)
  559. self.dropout2 = Dropout(dropout)
  560. self.dropout3 = Dropout(dropout)
  561. # Legacy string support for activation function.
  562. if isinstance(activation, str):
  563. self.activation = _get_activation_fn(activation)
  564. else:
  565. self.activation = activation
  566. def __setstate__(self, state):
  567. if 'activation' not in state:
  568. state['activation'] = F.relu
  569. super().__setstate__(state)
  570. def forward(
  571. self,
  572. tgt: Tensor,
  573. memory: Tensor,
  574. tgt_mask: Optional[Tensor] = None,
  575. memory_mask: Optional[Tensor] = None,
  576. tgt_key_padding_mask: Optional[Tensor] = None,
  577. memory_key_padding_mask: Optional[Tensor] = None,
  578. tgt_is_causal: bool = False,
  579. memory_is_causal: bool = False,
  580. ) -> Tensor:
  581. r"""Pass the inputs (and mask) through the decoder layer.
  582. Args:
  583. tgt: the sequence to the decoder layer (required).
  584. memory: the sequence from the last layer of the encoder (required).
  585. tgt_mask: the mask for the tgt sequence (optional).
  586. memory_mask: the mask for the memory sequence (optional).
  587. tgt_key_padding_mask: the mask for the tgt keys per batch (optional).
  588. memory_key_padding_mask: the mask for the memory keys per batch (optional).
  589. tgt_is_causal: If specified, applies a causal mask as tgt mask.
  590. Mutually exclusive with providing tgt_mask. Default: ``False``.
  591. memory_is_causal: If specified, applies a causal mask as tgt mask.
  592. Mutually exclusive with providing memory_mask. Default: ``False``.
  593. Shape:
  594. see the docs in Transformer class.
  595. """
  596. # see Fig. 1 of https://arxiv.org/pdf/2002.04745v1.pdf
  597. x = tgt
  598. if self.norm_first:
  599. x = x + self._sa_block(self.norm1(x), tgt_mask, tgt_key_padding_mask, tgt_is_causal)
  600. x = x + self._mha_block(self.norm2(x), memory, memory_mask, memory_key_padding_mask, memory_is_causal)
  601. x = x + self._ff_block(self.norm3(x))
  602. else:
  603. x = self.norm1(x + self._sa_block(x, tgt_mask, tgt_key_padding_mask, tgt_is_causal))
  604. x = self.norm2(x + self._mha_block(x, memory, memory_mask, memory_key_padding_mask, memory_is_causal))
  605. x = self.norm3(x + self._ff_block(x))
  606. return x
  607. # self-attention block
  608. def _sa_block(self, x: Tensor,
  609. attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor], is_causal: bool = False) -> Tensor:
  610. x = self.self_attn(x, x, x,
  611. attn_mask=attn_mask,
  612. key_padding_mask=key_padding_mask,
  613. is_causal=is_causal,
  614. need_weights=False)[0]
  615. return self.dropout1(x)
  616. # multihead attention block
  617. def _mha_block(self, x: Tensor, mem: Tensor,
  618. attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor], is_causal: bool = False) -> Tensor:
  619. x = self.multihead_attn(x, mem, mem,
  620. attn_mask=attn_mask,
  621. key_padding_mask=key_padding_mask,
  622. is_causal=is_causal,
  623. need_weights=False)[0]
  624. return self.dropout2(x)
  625. # feed forward block
  626. def _ff_block(self, x: Tensor) -> Tensor:
  627. x = self.linear2(self.dropout(self.activation(self.linear1(x))))
  628. return self.dropout3(x)
  629. def _get_clones(module, N):
  630. # FIXME: copy.deepcopy() is not defined on nn.module
  631. return ModuleList([copy.deepcopy(module) for i in range(N)])
  632. def _get_activation_fn(activation: str) -> Callable[[Tensor], Tensor]:
  633. if activation == "relu":
  634. return F.relu
  635. elif activation == "gelu":
  636. return F.gelu
  637. raise RuntimeError("activation should be relu/gelu, not {}".format(activation))