transformer.py 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230
  1. # Ultralytics YOLO 🚀, AGPL-3.0 license
  2. import math
  3. from typing import Tuple, Type
  4. import torch
  5. from torch import Tensor, nn
  6. from ultralytics.nn.modules import MLPBlock
  7. class TwoWayTransformer(nn.Module):
  8. def __init__(
  9. self,
  10. depth: int,
  11. embedding_dim: int,
  12. num_heads: int,
  13. mlp_dim: int,
  14. activation: Type[nn.Module] = nn.ReLU,
  15. attention_downsample_rate: int = 2,
  16. ) -> None:
  17. """
  18. A transformer decoder that attends to an input image using
  19. queries whose positional embedding is supplied.
  20. Args:
  21. depth (int): number of layers in the transformer
  22. embedding_dim (int): the channel dimension for the input embeddings
  23. num_heads (int): the number of heads for multihead attention. Must
  24. divide embedding_dim
  25. mlp_dim (int): the channel dimension internal to the MLP block
  26. activation (nn.Module): the activation to use in the MLP block
  27. """
  28. super().__init__()
  29. self.depth = depth
  30. self.embedding_dim = embedding_dim
  31. self.num_heads = num_heads
  32. self.mlp_dim = mlp_dim
  33. self.layers = nn.ModuleList()
  34. for i in range(depth):
  35. self.layers.append(
  36. TwoWayAttentionBlock(
  37. embedding_dim=embedding_dim,
  38. num_heads=num_heads,
  39. mlp_dim=mlp_dim,
  40. activation=activation,
  41. attention_downsample_rate=attention_downsample_rate,
  42. skip_first_layer_pe=(i == 0),
  43. ))
  44. self.final_attn_token_to_image = Attention(embedding_dim, num_heads, downsample_rate=attention_downsample_rate)
  45. self.norm_final_attn = nn.LayerNorm(embedding_dim)
  46. def forward(
  47. self,
  48. image_embedding: Tensor,
  49. image_pe: Tensor,
  50. point_embedding: Tensor,
  51. ) -> Tuple[Tensor, Tensor]:
  52. """
  53. Args:
  54. image_embedding (torch.Tensor): image to attend to. Should be shape B x embedding_dim x h x w for any h and w.
  55. image_pe (torch.Tensor): the positional encoding to add to the image. Must have same shape as image_embedding.
  56. point_embedding (torch.Tensor): the embedding to add to the query points.
  57. Must have shape B x N_points x embedding_dim for any N_points.
  58. Returns:
  59. (torch.Tensor): the processed point_embedding
  60. (torch.Tensor): the processed image_embedding
  61. """
  62. # BxCxHxW -> BxHWxC == B x N_image_tokens x C
  63. bs, c, h, w = image_embedding.shape
  64. image_embedding = image_embedding.flatten(2).permute(0, 2, 1)
  65. image_pe = image_pe.flatten(2).permute(0, 2, 1)
  66. # Prepare queries
  67. queries = point_embedding
  68. keys = image_embedding
  69. # Apply transformer blocks and final layernorm
  70. for layer in self.layers:
  71. queries, keys = layer(
  72. queries=queries,
  73. keys=keys,
  74. query_pe=point_embedding,
  75. key_pe=image_pe,
  76. )
  77. # Apply the final attention layer from the points to the image
  78. q = queries + point_embedding
  79. k = keys + image_pe
  80. attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys)
  81. queries = queries + attn_out
  82. queries = self.norm_final_attn(queries)
  83. return queries, keys
  84. class TwoWayAttentionBlock(nn.Module):
  85. def __init__(
  86. self,
  87. embedding_dim: int,
  88. num_heads: int,
  89. mlp_dim: int = 2048,
  90. activation: Type[nn.Module] = nn.ReLU,
  91. attention_downsample_rate: int = 2,
  92. skip_first_layer_pe: bool = False,
  93. ) -> None:
  94. """
  95. A transformer block with four layers: (1) self-attention of sparse inputs, (2) cross attention of sparse
  96. inputs to dense inputs, (3) mlp block on sparse inputs, and (4) cross attention of dense inputs to sparse
  97. inputs.
  98. Args:
  99. embedding_dim (int): the channel dimension of the embeddings
  100. num_heads (int): the number of heads in the attention layers
  101. mlp_dim (int): the hidden dimension of the mlp block
  102. activation (nn.Module): the activation of the mlp block
  103. skip_first_layer_pe (bool): skip the PE on the first layer
  104. """
  105. super().__init__()
  106. self.self_attn = Attention(embedding_dim, num_heads)
  107. self.norm1 = nn.LayerNorm(embedding_dim)
  108. self.cross_attn_token_to_image = Attention(embedding_dim, num_heads, downsample_rate=attention_downsample_rate)
  109. self.norm2 = nn.LayerNorm(embedding_dim)
  110. self.mlp = MLPBlock(embedding_dim, mlp_dim, activation)
  111. self.norm3 = nn.LayerNorm(embedding_dim)
  112. self.norm4 = nn.LayerNorm(embedding_dim)
  113. self.cross_attn_image_to_token = Attention(embedding_dim, num_heads, downsample_rate=attention_downsample_rate)
  114. self.skip_first_layer_pe = skip_first_layer_pe
  115. def forward(self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor) -> Tuple[Tensor, Tensor]:
  116. """Apply self-attention and cross-attention to queries and keys and return the processed embeddings."""
  117. # Self attention block
  118. if self.skip_first_layer_pe:
  119. queries = self.self_attn(q=queries, k=queries, v=queries)
  120. else:
  121. q = queries + query_pe
  122. attn_out = self.self_attn(q=q, k=q, v=queries)
  123. queries = queries + attn_out
  124. queries = self.norm1(queries)
  125. # Cross attention block, tokens attending to image embedding
  126. q = queries + query_pe
  127. k = keys + key_pe
  128. attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys)
  129. queries = queries + attn_out
  130. queries = self.norm2(queries)
  131. # MLP block
  132. mlp_out = self.mlp(queries)
  133. queries = queries + mlp_out
  134. queries = self.norm3(queries)
  135. # Cross attention block, image embedding attending to tokens
  136. q = queries + query_pe
  137. k = keys + key_pe
  138. attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries)
  139. keys = keys + attn_out
  140. keys = self.norm4(keys)
  141. return queries, keys
  142. class Attention(nn.Module):
  143. """
  144. An attention layer that allows for downscaling the size of the embedding after projection to queries, keys, and
  145. values.
  146. """
  147. def __init__(
  148. self,
  149. embedding_dim: int,
  150. num_heads: int,
  151. downsample_rate: int = 1,
  152. ) -> None:
  153. super().__init__()
  154. self.embedding_dim = embedding_dim
  155. self.internal_dim = embedding_dim // downsample_rate
  156. self.num_heads = num_heads
  157. assert self.internal_dim % num_heads == 0, 'num_heads must divide embedding_dim.'
  158. self.q_proj = nn.Linear(embedding_dim, self.internal_dim)
  159. self.k_proj = nn.Linear(embedding_dim, self.internal_dim)
  160. self.v_proj = nn.Linear(embedding_dim, self.internal_dim)
  161. self.out_proj = nn.Linear(self.internal_dim, embedding_dim)
  162. def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor:
  163. """Separate the input tensor into the specified number of attention heads."""
  164. b, n, c = x.shape
  165. x = x.reshape(b, n, num_heads, c // num_heads)
  166. return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head
  167. def _recombine_heads(self, x: Tensor) -> Tensor:
  168. """Recombine the separated attention heads into a single tensor."""
  169. b, n_heads, n_tokens, c_per_head = x.shape
  170. x = x.transpose(1, 2)
  171. return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C
  172. def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
  173. """Compute the attention output given the input query, key, and value tensors."""
  174. # Input projections
  175. q = self.q_proj(q)
  176. k = self.k_proj(k)
  177. v = self.v_proj(v)
  178. # Separate into heads
  179. q = self._separate_heads(q, self.num_heads)
  180. k = self._separate_heads(k, self.num_heads)
  181. v = self._separate_heads(v, self.num_heads)
  182. # Attention
  183. _, _, _, c_per_head = q.shape
  184. attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens
  185. attn = attn / math.sqrt(c_per_head)
  186. attn = torch.softmax(attn, dim=-1)
  187. # Get output
  188. out = attn @ v
  189. out = self._recombine_heads(out)
  190. return self.out_proj(out)