transformer.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383
  1. # Ultralytics YOLO 🚀, AGPL-3.0 license
  2. """
  3. Transformer modules
  4. """
  5. import math
  6. import torch
  7. import torch.nn as nn
  8. import torch.nn.functional as F
  9. from torch.nn.init import constant_, xavier_uniform_
  10. from .conv import Conv
  11. from .utils import _get_clones, inverse_sigmoid, multi_scale_deformable_attn_pytorch
  12. __all__ = ('TransformerEncoderLayer', 'TransformerLayer', 'TransformerBlock', 'MLPBlock', 'LayerNorm2d', 'AIFI',
  13. 'DeformableTransformerDecoder', 'DeformableTransformerDecoderLayer', 'MSDeformAttn', 'MLP')
  14. class TransformerEncoderLayer(nn.Module):
  15. """Transformer Encoder."""
  16. def __init__(self, c1, cm=2048, num_heads=8, dropout=0.0, act=nn.GELU(), normalize_before=False):
  17. super().__init__()
  18. from ...utils.torch_utils import TORCH_1_9
  19. if not TORCH_1_9:
  20. raise ModuleNotFoundError(
  21. 'TransformerEncoderLayer() requires torch>=1.9 to use nn.MultiheadAttention(batch_first=True).')
  22. self.ma = nn.MultiheadAttention(c1, num_heads, dropout=dropout, batch_first=True)
  23. # Implementation of Feedforward model
  24. self.fc1 = nn.Linear(c1, cm)
  25. self.fc2 = nn.Linear(cm, c1)
  26. self.norm1 = nn.LayerNorm(c1)
  27. self.norm2 = nn.LayerNorm(c1)
  28. self.dropout = nn.Dropout(dropout)
  29. self.dropout1 = nn.Dropout(dropout)
  30. self.dropout2 = nn.Dropout(dropout)
  31. self.act = act
  32. self.normalize_before = normalize_before
  33. def with_pos_embed(self, tensor, pos=None):
  34. """Add position embeddings if given."""
  35. return tensor if pos is None else tensor + pos
  36. def forward_post(self, src, src_mask=None, src_key_padding_mask=None, pos=None):
  37. q = k = self.with_pos_embed(src, pos)
  38. src2 = self.ma(q, k, value=src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0]
  39. src = src + self.dropout1(src2)
  40. src = self.norm1(src)
  41. src2 = self.fc2(self.dropout(self.act(self.fc1(src))))
  42. src = src + self.dropout2(src2)
  43. src = self.norm2(src)
  44. return src
  45. def forward_pre(self, src, src_mask=None, src_key_padding_mask=None, pos=None):
  46. src2 = self.norm1(src)
  47. q = k = self.with_pos_embed(src2, pos)
  48. src2 = self.ma(q, k, value=src2, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0]
  49. src = src + self.dropout1(src2)
  50. src2 = self.norm2(src)
  51. src2 = self.fc2(self.dropout(self.act(self.fc1(src2))))
  52. src = src + self.dropout2(src2)
  53. return src
  54. def forward(self, src, src_mask=None, src_key_padding_mask=None, pos=None):
  55. """Forward propagates the input through the encoder module."""
  56. if self.normalize_before:
  57. return self.forward_pre(src, src_mask, src_key_padding_mask, pos)
  58. return self.forward_post(src, src_mask, src_key_padding_mask, pos)
  59. class AIFI(TransformerEncoderLayer):
  60. def __init__(self, c1, cm=2048, num_heads=8, dropout=0, act=nn.GELU(), normalize_before=False):
  61. super().__init__(c1, cm, num_heads, dropout, act, normalize_before)
  62. def forward(self, x):
  63. c, h, w = x.shape[1:]
  64. pos_embed = self.build_2d_sincos_position_embedding(w, h, c)
  65. # flatten [B, C, H, W] to [B, HxW, C]
  66. x = super().forward(x.flatten(2).permute(0, 2, 1), pos=pos_embed.to(device=x.device, dtype=x.dtype))
  67. return x.permute(0, 2, 1).view([-1, c, h, w]).contiguous()
  68. @staticmethod
  69. def build_2d_sincos_position_embedding(w, h, embed_dim=256, temperature=10000.):
  70. grid_w = torch.arange(int(w), dtype=torch.float32)
  71. grid_h = torch.arange(int(h), dtype=torch.float32)
  72. grid_w, grid_h = torch.meshgrid(grid_w, grid_h, indexing='ij')
  73. assert embed_dim % 4 == 0, \
  74. 'Embed dimension must be divisible by 4 for 2D sin-cos position embedding'
  75. pos_dim = embed_dim // 4
  76. omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim
  77. omega = 1. / (temperature ** omega)
  78. out_w = grid_w.flatten()[..., None] @ omega[None]
  79. out_h = grid_h.flatten()[..., None] @ omega[None]
  80. return torch.cat([torch.sin(out_w), torch.cos(out_w), torch.sin(out_h), torch.cos(out_h)], 1)[None]
  81. class TransformerLayer(nn.Module):
  82. """Transformer layer https://arxiv.org/abs/2010.11929 (LayerNorm layers removed for better performance)."""
  83. def __init__(self, c, num_heads):
  84. """Initializes a self-attention mechanism using linear transformations and multi-head attention."""
  85. super().__init__()
  86. self.q = nn.Linear(c, c, bias=False)
  87. self.k = nn.Linear(c, c, bias=False)
  88. self.v = nn.Linear(c, c, bias=False)
  89. self.ma = nn.MultiheadAttention(embed_dim=c, num_heads=num_heads)
  90. self.fc1 = nn.Linear(c, c, bias=False)
  91. self.fc2 = nn.Linear(c, c, bias=False)
  92. def forward(self, x):
  93. """Apply a transformer block to the input x and return the output."""
  94. x = self.ma(self.q(x), self.k(x), self.v(x))[0] + x
  95. x = self.fc2(self.fc1(x)) + x
  96. return x
  97. class TransformerBlock(nn.Module):
  98. """Vision Transformer https://arxiv.org/abs/2010.11929."""
  99. def __init__(self, c1, c2, num_heads, num_layers):
  100. """Initialize a Transformer module with position embedding and specified number of heads and layers."""
  101. super().__init__()
  102. self.conv = None
  103. if c1 != c2:
  104. self.conv = Conv(c1, c2)
  105. self.linear = nn.Linear(c2, c2) # learnable position embedding
  106. self.tr = nn.Sequential(*(TransformerLayer(c2, num_heads) for _ in range(num_layers)))
  107. self.c2 = c2
  108. def forward(self, x):
  109. """Forward propagates the input through the bottleneck module."""
  110. if self.conv is not None:
  111. x = self.conv(x)
  112. b, _, w, h = x.shape
  113. p = x.flatten(2).permute(2, 0, 1)
  114. return self.tr(p + self.linear(p)).permute(1, 2, 0).reshape(b, self.c2, w, h)
  115. class MLPBlock(nn.Module):
  116. def __init__(self, embedding_dim, mlp_dim, act=nn.GELU):
  117. super().__init__()
  118. self.lin1 = nn.Linear(embedding_dim, mlp_dim)
  119. self.lin2 = nn.Linear(mlp_dim, embedding_dim)
  120. self.act = act()
  121. def forward(self, x: torch.Tensor) -> torch.Tensor:
  122. return self.lin2(self.act(self.lin1(x)))
  123. class MLP(nn.Module):
  124. """ Very simple multi-layer perceptron (also called FFN)"""
  125. def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
  126. super().__init__()
  127. self.num_layers = num_layers
  128. h = [hidden_dim] * (num_layers - 1)
  129. self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
  130. def forward(self, x):
  131. for i, layer in enumerate(self.layers):
  132. x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
  133. return x
  134. class LayerNorm2d(nn.Module):
  135. """
  136. LayerNorm2d module from https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py
  137. https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119
  138. """
  139. def __init__(self, num_channels, eps=1e-6):
  140. super().__init__()
  141. self.weight = nn.Parameter(torch.ones(num_channels))
  142. self.bias = nn.Parameter(torch.zeros(num_channels))
  143. self.eps = eps
  144. def forward(self, x):
  145. u = x.mean(1, keepdim=True)
  146. s = (x - u).pow(2).mean(1, keepdim=True)
  147. x = (x - u) / torch.sqrt(s + self.eps)
  148. x = self.weight[:, None, None] * x + self.bias[:, None, None]
  149. return x
  150. class MSDeformAttn(nn.Module):
  151. """
  152. Original Multi-Scale Deformable Attention Module.
  153. https://github.com/fundamentalvision/Deformable-DETR/blob/main/models/ops/modules/ms_deform_attn.py
  154. """
  155. def __init__(self, d_model=256, n_levels=4, n_heads=8, n_points=4):
  156. super().__init__()
  157. if d_model % n_heads != 0:
  158. raise ValueError(f'd_model must be divisible by n_heads, but got {d_model} and {n_heads}')
  159. _d_per_head = d_model // n_heads
  160. # you'd better set _d_per_head to a power of 2 which is more efficient in our CUDA implementation
  161. assert _d_per_head * n_heads == d_model, '`d_model` must be divisible by `n_heads`'
  162. self.im2col_step = 64
  163. self.d_model = d_model
  164. self.n_levels = n_levels
  165. self.n_heads = n_heads
  166. self.n_points = n_points
  167. self.sampling_offsets = nn.Linear(d_model, n_heads * n_levels * n_points * 2)
  168. self.attention_weights = nn.Linear(d_model, n_heads * n_levels * n_points)
  169. self.value_proj = nn.Linear(d_model, d_model)
  170. self.output_proj = nn.Linear(d_model, d_model)
  171. self._reset_parameters()
  172. def _reset_parameters(self):
  173. constant_(self.sampling_offsets.weight.data, 0.)
  174. thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads)
  175. grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
  176. grid_init = (grid_init / grid_init.abs().max(-1, keepdim=True)[0]).view(self.n_heads, 1, 1, 2).repeat(
  177. 1, self.n_levels, self.n_points, 1)
  178. for i in range(self.n_points):
  179. grid_init[:, :, i, :] *= i + 1
  180. with torch.no_grad():
  181. self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1))
  182. constant_(self.attention_weights.weight.data, 0.)
  183. constant_(self.attention_weights.bias.data, 0.)
  184. xavier_uniform_(self.value_proj.weight.data)
  185. constant_(self.value_proj.bias.data, 0.)
  186. xavier_uniform_(self.output_proj.weight.data)
  187. constant_(self.output_proj.bias.data, 0.)
  188. def forward(self, query, refer_bbox, value, value_shapes, value_mask=None):
  189. """
  190. https://github.com/PaddlePaddle/PaddleDetection/blob/develop/ppdet/modeling/transformers/deformable_transformer.py
  191. Args:
  192. query (torch.Tensor): [bs, query_length, C]
  193. refer_bbox (torch.Tensor): [bs, query_length, n_levels, 2], range in [0, 1], top-left (0,0),
  194. bottom-right (1, 1), including padding area
  195. value (torch.Tensor): [bs, value_length, C]
  196. value_shapes (List): [n_levels, 2], [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})]
  197. value_mask (Tensor): [bs, value_length], True for non-padding elements, False for padding elements
  198. Returns:
  199. output (Tensor): [bs, Length_{query}, C]
  200. """
  201. bs, len_q = query.shape[:2]
  202. len_v = value.shape[1]
  203. assert sum(s[0] * s[1] for s in value_shapes) == len_v
  204. value = self.value_proj(value)
  205. if value_mask is not None:
  206. value = value.masked_fill(value_mask[..., None], float(0))
  207. value = value.view(bs, len_v, self.n_heads, self.d_model // self.n_heads)
  208. sampling_offsets = self.sampling_offsets(query).view(bs, len_q, self.n_heads, self.n_levels, self.n_points, 2)
  209. attention_weights = self.attention_weights(query).view(bs, len_q, self.n_heads, self.n_levels * self.n_points)
  210. attention_weights = F.softmax(attention_weights, -1).view(bs, len_q, self.n_heads, self.n_levels, self.n_points)
  211. # N, Len_q, n_heads, n_levels, n_points, 2
  212. num_points = refer_bbox.shape[-1]
  213. if num_points == 2:
  214. offset_normalizer = torch.as_tensor(value_shapes, dtype=query.dtype, device=query.device).flip(-1)
  215. add = sampling_offsets / offset_normalizer[None, None, None, :, None, :]
  216. sampling_locations = refer_bbox[:, :, None, :, None, :] + add
  217. elif num_points == 4:
  218. add = sampling_offsets / self.n_points * refer_bbox[:, :, None, :, None, 2:] * 0.5
  219. sampling_locations = refer_bbox[:, :, None, :, None, :2] + add
  220. else:
  221. raise ValueError(f'Last dim of reference_points must be 2 or 4, but got {num_points}.')
  222. output = multi_scale_deformable_attn_pytorch(value, value_shapes, sampling_locations, attention_weights)
  223. output = self.output_proj(output)
  224. return output
  225. class DeformableTransformerDecoderLayer(nn.Module):
  226. """
  227. https://github.com/PaddlePaddle/PaddleDetection/blob/develop/ppdet/modeling/transformers/deformable_transformer.py
  228. https://github.com/fundamentalvision/Deformable-DETR/blob/main/models/deformable_transformer.py
  229. """
  230. def __init__(self, d_model=256, n_heads=8, d_ffn=1024, dropout=0., act=nn.ReLU(), n_levels=4, n_points=4):
  231. super().__init__()
  232. # self attention
  233. self.self_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout)
  234. self.dropout1 = nn.Dropout(dropout)
  235. self.norm1 = nn.LayerNorm(d_model)
  236. # cross attention
  237. self.cross_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points)
  238. self.dropout2 = nn.Dropout(dropout)
  239. self.norm2 = nn.LayerNorm(d_model)
  240. # ffn
  241. self.linear1 = nn.Linear(d_model, d_ffn)
  242. self.act = act
  243. self.dropout3 = nn.Dropout(dropout)
  244. self.linear2 = nn.Linear(d_ffn, d_model)
  245. self.dropout4 = nn.Dropout(dropout)
  246. self.norm3 = nn.LayerNorm(d_model)
  247. @staticmethod
  248. def with_pos_embed(tensor, pos):
  249. return tensor if pos is None else tensor + pos
  250. def forward_ffn(self, tgt):
  251. tgt2 = self.linear2(self.dropout3(self.act(self.linear1(tgt))))
  252. tgt = tgt + self.dropout4(tgt2)
  253. tgt = self.norm3(tgt)
  254. return tgt
  255. def forward(self, embed, refer_bbox, feats, shapes, padding_mask=None, attn_mask=None, query_pos=None):
  256. # self attention
  257. q = k = self.with_pos_embed(embed, query_pos)
  258. tgt = self.self_attn(q.transpose(0, 1), k.transpose(0, 1), embed.transpose(0, 1),
  259. attn_mask=attn_mask)[0].transpose(0, 1)
  260. embed = embed + self.dropout1(tgt)
  261. embed = self.norm1(embed)
  262. # cross attention
  263. tgt = self.cross_attn(self.with_pos_embed(embed, query_pos), refer_bbox.unsqueeze(2), feats, shapes,
  264. padding_mask)
  265. embed = embed + self.dropout2(tgt)
  266. embed = self.norm2(embed)
  267. # ffn
  268. embed = self.forward_ffn(embed)
  269. return embed
  270. class DeformableTransformerDecoder(nn.Module):
  271. """
  272. https://github.com/PaddlePaddle/PaddleDetection/blob/develop/ppdet/modeling/transformers/deformable_transformer.py
  273. """
  274. def __init__(self, hidden_dim, decoder_layer, num_layers, eval_idx=-1):
  275. super().__init__()
  276. self.layers = _get_clones(decoder_layer, num_layers)
  277. self.num_layers = num_layers
  278. self.hidden_dim = hidden_dim
  279. self.eval_idx = eval_idx if eval_idx >= 0 else num_layers + eval_idx
  280. def forward(
  281. self,
  282. embed, # decoder embeddings
  283. refer_bbox, # anchor
  284. feats, # image features
  285. shapes, # feature shapes
  286. bbox_head,
  287. score_head,
  288. pos_mlp,
  289. attn_mask=None,
  290. padding_mask=None):
  291. output = embed
  292. dec_bboxes = []
  293. dec_cls = []
  294. last_refined_bbox = None
  295. refer_bbox = refer_bbox.sigmoid()
  296. for i, layer in enumerate(self.layers):
  297. output = layer(output, refer_bbox, feats, shapes, padding_mask, attn_mask, pos_mlp(refer_bbox))
  298. # refine bboxes, (bs, num_queries+num_denoising, 4)
  299. refined_bbox = torch.sigmoid(bbox_head[i](output) + inverse_sigmoid(refer_bbox))
  300. if self.training:
  301. dec_cls.append(score_head[i](output))
  302. if i == 0:
  303. dec_bboxes.append(refined_bbox)
  304. else:
  305. dec_bboxes.append(torch.sigmoid(bbox_head[i](output) + inverse_sigmoid(last_refined_bbox)))
  306. elif i == self.eval_idx:
  307. dec_cls.append(score_head[i](output))
  308. dec_bboxes.append(refined_bbox)
  309. break
  310. last_refined_bbox = refined_bbox
  311. refer_bbox = refined_bbox.detach() if self.training else refined_bbox
  312. return torch.stack(dec_bboxes), torch.stack(dec_cls)