encoders.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567
  1. # Ultralytics YOLO 🚀, AGPL-3.0 license
  2. from typing import Any, Optional, Tuple, Type
  3. import numpy as np
  4. import torch
  5. import torch.nn as nn
  6. import torch.nn.functional as F
  7. from ultralytics.nn.modules import LayerNorm2d, MLPBlock
  8. # This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa
  9. class ImageEncoderViT(nn.Module):
  10. def __init__(
  11. self,
  12. img_size: int = 1024,
  13. patch_size: int = 16,
  14. in_chans: int = 3,
  15. embed_dim: int = 768,
  16. depth: int = 12,
  17. num_heads: int = 12,
  18. mlp_ratio: float = 4.0,
  19. out_chans: int = 256,
  20. qkv_bias: bool = True,
  21. norm_layer: Type[nn.Module] = nn.LayerNorm,
  22. act_layer: Type[nn.Module] = nn.GELU,
  23. use_abs_pos: bool = True,
  24. use_rel_pos: bool = False,
  25. rel_pos_zero_init: bool = True,
  26. window_size: int = 0,
  27. global_attn_indexes: Tuple[int, ...] = (),
  28. ) -> None:
  29. """
  30. Args:
  31. img_size (int): Input image size.
  32. patch_size (int): Patch size.
  33. in_chans (int): Number of input image channels.
  34. embed_dim (int): Patch embedding dimension.
  35. depth (int): Depth of ViT.
  36. num_heads (int): Number of attention heads in each ViT block.
  37. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
  38. qkv_bias (bool): If True, add a learnable bias to query, key, value.
  39. norm_layer (nn.Module): Normalization layer.
  40. act_layer (nn.Module): Activation layer.
  41. use_abs_pos (bool): If True, use absolute positional embeddings.
  42. use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
  43. rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
  44. window_size (int): Window size for window attention blocks.
  45. global_attn_indexes (list): Indexes for blocks using global attention.
  46. """
  47. super().__init__()
  48. self.img_size = img_size
  49. self.patch_embed = PatchEmbed(
  50. kernel_size=(patch_size, patch_size),
  51. stride=(patch_size, patch_size),
  52. in_chans=in_chans,
  53. embed_dim=embed_dim,
  54. )
  55. self.pos_embed: Optional[nn.Parameter] = None
  56. if use_abs_pos:
  57. # Initialize absolute positional embedding with pretrain image size.
  58. self.pos_embed = nn.Parameter(torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim))
  59. self.blocks = nn.ModuleList()
  60. for i in range(depth):
  61. block = Block(
  62. dim=embed_dim,
  63. num_heads=num_heads,
  64. mlp_ratio=mlp_ratio,
  65. qkv_bias=qkv_bias,
  66. norm_layer=norm_layer,
  67. act_layer=act_layer,
  68. use_rel_pos=use_rel_pos,
  69. rel_pos_zero_init=rel_pos_zero_init,
  70. window_size=window_size if i not in global_attn_indexes else 0,
  71. input_size=(img_size // patch_size, img_size // patch_size),
  72. )
  73. self.blocks.append(block)
  74. self.neck = nn.Sequential(
  75. nn.Conv2d(
  76. embed_dim,
  77. out_chans,
  78. kernel_size=1,
  79. bias=False,
  80. ),
  81. LayerNorm2d(out_chans),
  82. nn.Conv2d(
  83. out_chans,
  84. out_chans,
  85. kernel_size=3,
  86. padding=1,
  87. bias=False,
  88. ),
  89. LayerNorm2d(out_chans),
  90. )
  91. def forward(self, x: torch.Tensor) -> torch.Tensor:
  92. x = self.patch_embed(x)
  93. if self.pos_embed is not None:
  94. x = x + self.pos_embed
  95. for blk in self.blocks:
  96. x = blk(x)
  97. return self.neck(x.permute(0, 3, 1, 2))
  98. class PromptEncoder(nn.Module):
  99. def __init__(
  100. self,
  101. embed_dim: int,
  102. image_embedding_size: Tuple[int, int],
  103. input_image_size: Tuple[int, int],
  104. mask_in_chans: int,
  105. activation: Type[nn.Module] = nn.GELU,
  106. ) -> None:
  107. """
  108. Encodes prompts for input to SAM's mask decoder.
  109. Args:
  110. embed_dim (int): The prompts' embedding dimension
  111. image_embedding_size (tuple(int, int)): The spatial size of the
  112. image embedding, as (H, W).
  113. input_image_size (int): The padded size of the image as input
  114. to the image encoder, as (H, W).
  115. mask_in_chans (int): The number of hidden channels used for
  116. encoding input masks.
  117. activation (nn.Module): The activation to use when encoding
  118. input masks.
  119. """
  120. super().__init__()
  121. self.embed_dim = embed_dim
  122. self.input_image_size = input_image_size
  123. self.image_embedding_size = image_embedding_size
  124. self.pe_layer = PositionEmbeddingRandom(embed_dim // 2)
  125. self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners
  126. point_embeddings = [nn.Embedding(1, embed_dim) for _ in range(self.num_point_embeddings)]
  127. self.point_embeddings = nn.ModuleList(point_embeddings)
  128. self.not_a_point_embed = nn.Embedding(1, embed_dim)
  129. self.mask_input_size = (4 * image_embedding_size[0], 4 * image_embedding_size[1])
  130. self.mask_downscaling = nn.Sequential(
  131. nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2),
  132. LayerNorm2d(mask_in_chans // 4),
  133. activation(),
  134. nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2),
  135. LayerNorm2d(mask_in_chans),
  136. activation(),
  137. nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1),
  138. )
  139. self.no_mask_embed = nn.Embedding(1, embed_dim)
  140. def get_dense_pe(self) -> torch.Tensor:
  141. """
  142. Returns the positional encoding used to encode point prompts,
  143. applied to a dense set of points the shape of the image encoding.
  144. Returns:
  145. torch.Tensor: Positional encoding with shape 1x(embed_dim)x(embedding_h)x(embedding_w)
  146. """
  147. return self.pe_layer(self.image_embedding_size).unsqueeze(0)
  148. def _embed_points(
  149. self,
  150. points: torch.Tensor,
  151. labels: torch.Tensor,
  152. pad: bool,
  153. ) -> torch.Tensor:
  154. """Embeds point prompts."""
  155. points = points + 0.5 # Shift to center of pixel
  156. if pad:
  157. padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device)
  158. padding_label = -torch.ones((labels.shape[0], 1), device=labels.device)
  159. points = torch.cat([points, padding_point], dim=1)
  160. labels = torch.cat([labels, padding_label], dim=1)
  161. point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size)
  162. point_embedding[labels == -1] = 0.0
  163. point_embedding[labels == -1] += self.not_a_point_embed.weight
  164. point_embedding[labels == 0] += self.point_embeddings[0].weight
  165. point_embedding[labels == 1] += self.point_embeddings[1].weight
  166. return point_embedding
  167. def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
  168. """Embeds box prompts."""
  169. boxes = boxes + 0.5 # Shift to center of pixel
  170. coords = boxes.reshape(-1, 2, 2)
  171. corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size)
  172. corner_embedding[:, 0, :] += self.point_embeddings[2].weight
  173. corner_embedding[:, 1, :] += self.point_embeddings[3].weight
  174. return corner_embedding
  175. def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor:
  176. """Embeds mask inputs."""
  177. return self.mask_downscaling(masks)
  178. def _get_batch_size(
  179. self,
  180. points: Optional[Tuple[torch.Tensor, torch.Tensor]],
  181. boxes: Optional[torch.Tensor],
  182. masks: Optional[torch.Tensor],
  183. ) -> int:
  184. """
  185. Gets the batch size of the output given the batch size of the input prompts.
  186. """
  187. if points is not None:
  188. return points[0].shape[0]
  189. elif boxes is not None:
  190. return boxes.shape[0]
  191. elif masks is not None:
  192. return masks.shape[0]
  193. else:
  194. return 1
  195. def _get_device(self) -> torch.device:
  196. return self.point_embeddings[0].weight.device
  197. def forward(
  198. self,
  199. points: Optional[Tuple[torch.Tensor, torch.Tensor]],
  200. boxes: Optional[torch.Tensor],
  201. masks: Optional[torch.Tensor],
  202. ) -> Tuple[torch.Tensor, torch.Tensor]:
  203. """
  204. Embeds different types of prompts, returning both sparse and dense embeddings.
  205. Args:
  206. points (tuple(torch.Tensor, torch.Tensor), None): point coordinates and labels to embed.
  207. boxes (torch.Tensor, None): boxes to embed
  208. masks (torch.Tensor, None): masks to embed
  209. Returns:
  210. torch.Tensor: sparse embeddings for the points and boxes, with shape BxNx(embed_dim), where N is determined
  211. by the number of input points and boxes.
  212. torch.Tensor: dense embeddings for the masks, in the shape Bx(embed_dim)x(embed_H)x(embed_W)
  213. """
  214. bs = self._get_batch_size(points, boxes, masks)
  215. sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device())
  216. if points is not None:
  217. coords, labels = points
  218. point_embeddings = self._embed_points(coords, labels, pad=(boxes is None))
  219. sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1)
  220. if boxes is not None:
  221. box_embeddings = self._embed_boxes(boxes)
  222. sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1)
  223. if masks is not None:
  224. dense_embeddings = self._embed_masks(masks)
  225. else:
  226. dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1,
  227. 1).expand(bs, -1, self.image_embedding_size[0],
  228. self.image_embedding_size[1])
  229. return sparse_embeddings, dense_embeddings
  230. class PositionEmbeddingRandom(nn.Module):
  231. """
  232. Positional encoding using random spatial frequencies.
  233. """
  234. def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None:
  235. super().__init__()
  236. if scale is None or scale <= 0.0:
  237. scale = 1.0
  238. self.register_buffer(
  239. 'positional_encoding_gaussian_matrix',
  240. scale * torch.randn((2, num_pos_feats)),
  241. )
  242. def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:
  243. """Positionally encode points that are normalized to [0,1]."""
  244. # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
  245. coords = 2 * coords - 1
  246. coords = coords @ self.positional_encoding_gaussian_matrix
  247. coords = 2 * np.pi * coords
  248. # outputs d_1 x ... x d_n x C shape
  249. return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)
  250. def forward(self, size: Tuple[int, int]) -> torch.Tensor:
  251. """Generate positional encoding for a grid of the specified size."""
  252. h, w = size
  253. device: Any = self.positional_encoding_gaussian_matrix.device
  254. grid = torch.ones((h, w), device=device, dtype=torch.float32)
  255. y_embed = grid.cumsum(dim=0) - 0.5
  256. x_embed = grid.cumsum(dim=1) - 0.5
  257. y_embed = y_embed / h
  258. x_embed = x_embed / w
  259. pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1))
  260. return pe.permute(2, 0, 1) # C x H x W
  261. def forward_with_coords(self, coords_input: torch.Tensor, image_size: Tuple[int, int]) -> torch.Tensor:
  262. """Positionally encode points that are not normalized to [0,1]."""
  263. coords = coords_input.clone()
  264. coords[:, :, 0] = coords[:, :, 0] / image_size[1]
  265. coords[:, :, 1] = coords[:, :, 1] / image_size[0]
  266. return self._pe_encoding(coords.to(torch.float)) # B x N x C
  267. class Block(nn.Module):
  268. """Transformer blocks with support of window attention and residual propagation blocks"""
  269. def __init__(
  270. self,
  271. dim: int,
  272. num_heads: int,
  273. mlp_ratio: float = 4.0,
  274. qkv_bias: bool = True,
  275. norm_layer: Type[nn.Module] = nn.LayerNorm,
  276. act_layer: Type[nn.Module] = nn.GELU,
  277. use_rel_pos: bool = False,
  278. rel_pos_zero_init: bool = True,
  279. window_size: int = 0,
  280. input_size: Optional[Tuple[int, int]] = None,
  281. ) -> None:
  282. """
  283. Args:
  284. dim (int): Number of input channels.
  285. num_heads (int): Number of attention heads in each ViT block.
  286. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
  287. qkv_bias (bool): If True, add a learnable bias to query, key, value.
  288. norm_layer (nn.Module): Normalization layer.
  289. act_layer (nn.Module): Activation layer.
  290. use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
  291. rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
  292. window_size (int): Window size for window attention blocks. If it equals 0, then
  293. use global attention.
  294. input_size (tuple(int, int), None): Input resolution for calculating the relative
  295. positional parameter size.
  296. """
  297. super().__init__()
  298. self.norm1 = norm_layer(dim)
  299. self.attn = Attention(
  300. dim,
  301. num_heads=num_heads,
  302. qkv_bias=qkv_bias,
  303. use_rel_pos=use_rel_pos,
  304. rel_pos_zero_init=rel_pos_zero_init,
  305. input_size=input_size if window_size == 0 else (window_size, window_size),
  306. )
  307. self.norm2 = norm_layer(dim)
  308. self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer)
  309. self.window_size = window_size
  310. def forward(self, x: torch.Tensor) -> torch.Tensor:
  311. shortcut = x
  312. x = self.norm1(x)
  313. # Window partition
  314. if self.window_size > 0:
  315. H, W = x.shape[1], x.shape[2]
  316. x, pad_hw = window_partition(x, self.window_size)
  317. x = self.attn(x)
  318. # Reverse window partition
  319. if self.window_size > 0:
  320. x = window_unpartition(x, self.window_size, pad_hw, (H, W))
  321. x = shortcut + x
  322. return x + self.mlp(self.norm2(x))
  323. class Attention(nn.Module):
  324. """Multi-head Attention block with relative position embeddings."""
  325. def __init__(
  326. self,
  327. dim: int,
  328. num_heads: int = 8,
  329. qkv_bias: bool = True,
  330. use_rel_pos: bool = False,
  331. rel_pos_zero_init: bool = True,
  332. input_size: Optional[Tuple[int, int]] = None,
  333. ) -> None:
  334. """
  335. Args:
  336. dim (int): Number of input channels.
  337. num_heads (int): Number of attention heads.
  338. qkv_bias (bool): If True, add a learnable bias to query, key, value.
  339. rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
  340. input_size (tuple(int, int), None): Input resolution for calculating the relative
  341. positional parameter size.
  342. """
  343. super().__init__()
  344. self.num_heads = num_heads
  345. head_dim = dim // num_heads
  346. self.scale = head_dim ** -0.5
  347. self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
  348. self.proj = nn.Linear(dim, dim)
  349. self.use_rel_pos = use_rel_pos
  350. if self.use_rel_pos:
  351. assert (input_size is not None), 'Input size must be provided if using relative positional encoding.'
  352. # initialize relative positional embeddings
  353. self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim))
  354. self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))
  355. def forward(self, x: torch.Tensor) -> torch.Tensor:
  356. B, H, W, _ = x.shape
  357. # qkv with shape (3, B, nHead, H * W, C)
  358. qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
  359. # q, k, v with shape (B * nHead, H * W, C)
  360. q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0)
  361. attn = (q * self.scale) @ k.transpose(-2, -1)
  362. if self.use_rel_pos:
  363. attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W))
  364. attn = attn.softmax(dim=-1)
  365. x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1)
  366. return self.proj(x)
  367. def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]:
  368. """
  369. Partition into non-overlapping windows with padding if needed.
  370. Args:
  371. x (tensor): input tokens with [B, H, W, C].
  372. window_size (int): window size.
  373. Returns:
  374. windows: windows after partition with [B * num_windows, window_size, window_size, C].
  375. (Hp, Wp): padded height and width before partition
  376. """
  377. B, H, W, C = x.shape
  378. pad_h = (window_size - H % window_size) % window_size
  379. pad_w = (window_size - W % window_size) % window_size
  380. if pad_h > 0 or pad_w > 0:
  381. x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
  382. Hp, Wp = H + pad_h, W + pad_w
  383. x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
  384. windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
  385. return windows, (Hp, Wp)
  386. def window_unpartition(windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int],
  387. hw: Tuple[int, int]) -> torch.Tensor:
  388. """
  389. Window unpartition into original sequences and removing padding.
  390. Args:
  391. windows (tensor): input tokens with [B * num_windows, window_size, window_size, C].
  392. window_size (int): window size.
  393. pad_hw (Tuple): padded height and width (Hp, Wp).
  394. hw (Tuple): original height and width (H, W) before padding.
  395. Returns:
  396. x: unpartitioned sequences with [B, H, W, C].
  397. """
  398. Hp, Wp = pad_hw
  399. H, W = hw
  400. B = windows.shape[0] // (Hp * Wp // window_size // window_size)
  401. x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1)
  402. x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
  403. if Hp > H or Wp > W:
  404. x = x[:, :H, :W, :].contiguous()
  405. return x
  406. def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
  407. """
  408. Get relative positional embeddings according to the relative positions of
  409. query and key sizes.
  410. Args:
  411. q_size (int): size of query q.
  412. k_size (int): size of key k.
  413. rel_pos (Tensor): relative position embeddings (L, C).
  414. Returns:
  415. Extracted positional embeddings according to relative positions.
  416. """
  417. max_rel_dist = int(2 * max(q_size, k_size) - 1)
  418. # Interpolate rel pos if needed.
  419. if rel_pos.shape[0] != max_rel_dist:
  420. # Interpolate rel pos.
  421. rel_pos_resized = F.interpolate(
  422. rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
  423. size=max_rel_dist,
  424. mode='linear',
  425. )
  426. rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
  427. else:
  428. rel_pos_resized = rel_pos
  429. # Scale the coords with short length if shapes for q and k are different.
  430. q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
  431. k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
  432. relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
  433. return rel_pos_resized[relative_coords.long()]
  434. def add_decomposed_rel_pos(
  435. attn: torch.Tensor,
  436. q: torch.Tensor,
  437. rel_pos_h: torch.Tensor,
  438. rel_pos_w: torch.Tensor,
  439. q_size: Tuple[int, int],
  440. k_size: Tuple[int, int],
  441. ) -> torch.Tensor:
  442. """
  443. Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
  444. https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950
  445. Args:
  446. attn (Tensor): attention map.
  447. q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C).
  448. rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis.
  449. rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis.
  450. q_size (Tuple): spatial sequence size of query q with (q_h, q_w).
  451. k_size (Tuple): spatial sequence size of key k with (k_h, k_w).
  452. Returns:
  453. attn (Tensor): attention map with added relative positional embeddings.
  454. """
  455. q_h, q_w = q_size
  456. k_h, k_w = k_size
  457. Rh = get_rel_pos(q_h, k_h, rel_pos_h)
  458. Rw = get_rel_pos(q_w, k_w, rel_pos_w)
  459. B, _, dim = q.shape
  460. r_q = q.reshape(B, q_h, q_w, dim)
  461. rel_h = torch.einsum('bhwc,hkc->bhwk', r_q, Rh)
  462. rel_w = torch.einsum('bhwc,wkc->bhwk', r_q, Rw)
  463. attn = (attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]).view(
  464. B, q_h * q_w, k_h * k_w)
  465. return attn
  466. class PatchEmbed(nn.Module):
  467. """
  468. Image to Patch Embedding.
  469. """
  470. def __init__(
  471. self,
  472. kernel_size: Tuple[int, int] = (16, 16),
  473. stride: Tuple[int, int] = (16, 16),
  474. padding: Tuple[int, int] = (0, 0),
  475. in_chans: int = 3,
  476. embed_dim: int = 768,
  477. ) -> None:
  478. """
  479. Args:
  480. kernel_size (Tuple): kernel size of the projection layer.
  481. stride (Tuple): stride of the projection layer.
  482. padding (Tuple): padding size of the projection layer.
  483. in_chans (int): Number of input image channels.
  484. embed_dim (int): Patch embedding dimension.
  485. """
  486. super().__init__()
  487. self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding)
  488. def forward(self, x: torch.Tensor) -> torch.Tensor:
  489. return self.proj(x).permute(0, 2, 3, 1) # B C H W -> B H W C