swin_transformer.py 38 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033
  1. import math
  2. from functools import partial
  3. from typing import Any, Callable, List, Optional
  4. import torch
  5. import torch.nn.functional as F
  6. from torch import nn, Tensor
  7. from ..ops.misc import MLP, Permute
  8. from ..ops.stochastic_depth import StochasticDepth
  9. from ..transforms._presets import ImageClassification, InterpolationMode
  10. from ..utils import _log_api_usage_once
  11. from ._api import register_model, Weights, WeightsEnum
  12. from ._meta import _IMAGENET_CATEGORIES
  13. from ._utils import _ovewrite_named_param, handle_legacy_interface
  14. __all__ = [
  15. "SwinTransformer",
  16. "Swin_T_Weights",
  17. "Swin_S_Weights",
  18. "Swin_B_Weights",
  19. "Swin_V2_T_Weights",
  20. "Swin_V2_S_Weights",
  21. "Swin_V2_B_Weights",
  22. "swin_t",
  23. "swin_s",
  24. "swin_b",
  25. "swin_v2_t",
  26. "swin_v2_s",
  27. "swin_v2_b",
  28. ]
  29. def _patch_merging_pad(x: torch.Tensor) -> torch.Tensor:
  30. H, W, _ = x.shape[-3:]
  31. x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
  32. x0 = x[..., 0::2, 0::2, :] # ... H/2 W/2 C
  33. x1 = x[..., 1::2, 0::2, :] # ... H/2 W/2 C
  34. x2 = x[..., 0::2, 1::2, :] # ... H/2 W/2 C
  35. x3 = x[..., 1::2, 1::2, :] # ... H/2 W/2 C
  36. x = torch.cat([x0, x1, x2, x3], -1) # ... H/2 W/2 4*C
  37. return x
  38. torch.fx.wrap("_patch_merging_pad")
  39. def _get_relative_position_bias(
  40. relative_position_bias_table: torch.Tensor, relative_position_index: torch.Tensor, window_size: List[int]
  41. ) -> torch.Tensor:
  42. N = window_size[0] * window_size[1]
  43. relative_position_bias = relative_position_bias_table[relative_position_index] # type: ignore[index]
  44. relative_position_bias = relative_position_bias.view(N, N, -1)
  45. relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous().unsqueeze(0)
  46. return relative_position_bias
  47. torch.fx.wrap("_get_relative_position_bias")
  48. class PatchMerging(nn.Module):
  49. """Patch Merging Layer.
  50. Args:
  51. dim (int): Number of input channels.
  52. norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
  53. """
  54. def __init__(self, dim: int, norm_layer: Callable[..., nn.Module] = nn.LayerNorm):
  55. super().__init__()
  56. _log_api_usage_once(self)
  57. self.dim = dim
  58. self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
  59. self.norm = norm_layer(4 * dim)
  60. def forward(self, x: Tensor):
  61. """
  62. Args:
  63. x (Tensor): input tensor with expected layout of [..., H, W, C]
  64. Returns:
  65. Tensor with layout of [..., H/2, W/2, 2*C]
  66. """
  67. x = _patch_merging_pad(x)
  68. x = self.norm(x)
  69. x = self.reduction(x) # ... H/2 W/2 2*C
  70. return x
  71. class PatchMergingV2(nn.Module):
  72. """Patch Merging Layer for Swin Transformer V2.
  73. Args:
  74. dim (int): Number of input channels.
  75. norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
  76. """
  77. def __init__(self, dim: int, norm_layer: Callable[..., nn.Module] = nn.LayerNorm):
  78. super().__init__()
  79. _log_api_usage_once(self)
  80. self.dim = dim
  81. self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
  82. self.norm = norm_layer(2 * dim) # difference
  83. def forward(self, x: Tensor):
  84. """
  85. Args:
  86. x (Tensor): input tensor with expected layout of [..., H, W, C]
  87. Returns:
  88. Tensor with layout of [..., H/2, W/2, 2*C]
  89. """
  90. x = _patch_merging_pad(x)
  91. x = self.reduction(x) # ... H/2 W/2 2*C
  92. x = self.norm(x)
  93. return x
  94. def shifted_window_attention(
  95. input: Tensor,
  96. qkv_weight: Tensor,
  97. proj_weight: Tensor,
  98. relative_position_bias: Tensor,
  99. window_size: List[int],
  100. num_heads: int,
  101. shift_size: List[int],
  102. attention_dropout: float = 0.0,
  103. dropout: float = 0.0,
  104. qkv_bias: Optional[Tensor] = None,
  105. proj_bias: Optional[Tensor] = None,
  106. logit_scale: Optional[torch.Tensor] = None,
  107. training: bool = True,
  108. ) -> Tensor:
  109. """
  110. Window based multi-head self attention (W-MSA) module with relative position bias.
  111. It supports both of shifted and non-shifted window.
  112. Args:
  113. input (Tensor[N, H, W, C]): The input tensor or 4-dimensions.
  114. qkv_weight (Tensor[in_dim, out_dim]): The weight tensor of query, key, value.
  115. proj_weight (Tensor[out_dim, out_dim]): The weight tensor of projection.
  116. relative_position_bias (Tensor): The learned relative position bias added to attention.
  117. window_size (List[int]): Window size.
  118. num_heads (int): Number of attention heads.
  119. shift_size (List[int]): Shift size for shifted window attention.
  120. attention_dropout (float): Dropout ratio of attention weight. Default: 0.0.
  121. dropout (float): Dropout ratio of output. Default: 0.0.
  122. qkv_bias (Tensor[out_dim], optional): The bias tensor of query, key, value. Default: None.
  123. proj_bias (Tensor[out_dim], optional): The bias tensor of projection. Default: None.
  124. logit_scale (Tensor[out_dim], optional): Logit scale of cosine attention for Swin Transformer V2. Default: None.
  125. training (bool, optional): Training flag used by the dropout parameters. Default: True.
  126. Returns:
  127. Tensor[N, H, W, C]: The output tensor after shifted window attention.
  128. """
  129. B, H, W, C = input.shape
  130. # pad feature maps to multiples of window size
  131. pad_r = (window_size[1] - W % window_size[1]) % window_size[1]
  132. pad_b = (window_size[0] - H % window_size[0]) % window_size[0]
  133. x = F.pad(input, (0, 0, 0, pad_r, 0, pad_b))
  134. _, pad_H, pad_W, _ = x.shape
  135. shift_size = shift_size.copy()
  136. # If window size is larger than feature size, there is no need to shift window
  137. if window_size[0] >= pad_H:
  138. shift_size[0] = 0
  139. if window_size[1] >= pad_W:
  140. shift_size[1] = 0
  141. # cyclic shift
  142. if sum(shift_size) > 0:
  143. x = torch.roll(x, shifts=(-shift_size[0], -shift_size[1]), dims=(1, 2))
  144. # partition windows
  145. num_windows = (pad_H // window_size[0]) * (pad_W // window_size[1])
  146. x = x.view(B, pad_H // window_size[0], window_size[0], pad_W // window_size[1], window_size[1], C)
  147. x = x.permute(0, 1, 3, 2, 4, 5).reshape(B * num_windows, window_size[0] * window_size[1], C) # B*nW, Ws*Ws, C
  148. # multi-head attention
  149. if logit_scale is not None and qkv_bias is not None:
  150. qkv_bias = qkv_bias.clone()
  151. length = qkv_bias.numel() // 3
  152. qkv_bias[length : 2 * length].zero_()
  153. qkv = F.linear(x, qkv_weight, qkv_bias)
  154. qkv = qkv.reshape(x.size(0), x.size(1), 3, num_heads, C // num_heads).permute(2, 0, 3, 1, 4)
  155. q, k, v = qkv[0], qkv[1], qkv[2]
  156. if logit_scale is not None:
  157. # cosine attention
  158. attn = F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1)
  159. logit_scale = torch.clamp(logit_scale, max=math.log(100.0)).exp()
  160. attn = attn * logit_scale
  161. else:
  162. q = q * (C // num_heads) ** -0.5
  163. attn = q.matmul(k.transpose(-2, -1))
  164. # add relative position bias
  165. attn = attn + relative_position_bias
  166. if sum(shift_size) > 0:
  167. # generate attention mask
  168. attn_mask = x.new_zeros((pad_H, pad_W))
  169. h_slices = ((0, -window_size[0]), (-window_size[0], -shift_size[0]), (-shift_size[0], None))
  170. w_slices = ((0, -window_size[1]), (-window_size[1], -shift_size[1]), (-shift_size[1], None))
  171. count = 0
  172. for h in h_slices:
  173. for w in w_slices:
  174. attn_mask[h[0] : h[1], w[0] : w[1]] = count
  175. count += 1
  176. attn_mask = attn_mask.view(pad_H // window_size[0], window_size[0], pad_W // window_size[1], window_size[1])
  177. attn_mask = attn_mask.permute(0, 2, 1, 3).reshape(num_windows, window_size[0] * window_size[1])
  178. attn_mask = attn_mask.unsqueeze(1) - attn_mask.unsqueeze(2)
  179. attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
  180. attn = attn.view(x.size(0) // num_windows, num_windows, num_heads, x.size(1), x.size(1))
  181. attn = attn + attn_mask.unsqueeze(1).unsqueeze(0)
  182. attn = attn.view(-1, num_heads, x.size(1), x.size(1))
  183. attn = F.softmax(attn, dim=-1)
  184. attn = F.dropout(attn, p=attention_dropout, training=training)
  185. x = attn.matmul(v).transpose(1, 2).reshape(x.size(0), x.size(1), C)
  186. x = F.linear(x, proj_weight, proj_bias)
  187. x = F.dropout(x, p=dropout, training=training)
  188. # reverse windows
  189. x = x.view(B, pad_H // window_size[0], pad_W // window_size[1], window_size[0], window_size[1], C)
  190. x = x.permute(0, 1, 3, 2, 4, 5).reshape(B, pad_H, pad_W, C)
  191. # reverse cyclic shift
  192. if sum(shift_size) > 0:
  193. x = torch.roll(x, shifts=(shift_size[0], shift_size[1]), dims=(1, 2))
  194. # unpad features
  195. x = x[:, :H, :W, :].contiguous()
  196. return x
  197. torch.fx.wrap("shifted_window_attention")
  198. class ShiftedWindowAttention(nn.Module):
  199. """
  200. See :func:`shifted_window_attention`.
  201. """
  202. def __init__(
  203. self,
  204. dim: int,
  205. window_size: List[int],
  206. shift_size: List[int],
  207. num_heads: int,
  208. qkv_bias: bool = True,
  209. proj_bias: bool = True,
  210. attention_dropout: float = 0.0,
  211. dropout: float = 0.0,
  212. ):
  213. super().__init__()
  214. if len(window_size) != 2 or len(shift_size) != 2:
  215. raise ValueError("window_size and shift_size must be of length 2")
  216. self.window_size = window_size
  217. self.shift_size = shift_size
  218. self.num_heads = num_heads
  219. self.attention_dropout = attention_dropout
  220. self.dropout = dropout
  221. self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
  222. self.proj = nn.Linear(dim, dim, bias=proj_bias)
  223. self.define_relative_position_bias_table()
  224. self.define_relative_position_index()
  225. def define_relative_position_bias_table(self):
  226. # define a parameter table of relative position bias
  227. self.relative_position_bias_table = nn.Parameter(
  228. torch.zeros((2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1), self.num_heads)
  229. ) # 2*Wh-1 * 2*Ww-1, nH
  230. nn.init.trunc_normal_(self.relative_position_bias_table, std=0.02)
  231. def define_relative_position_index(self):
  232. # get pair-wise relative position index for each token inside the window
  233. coords_h = torch.arange(self.window_size[0])
  234. coords_w = torch.arange(self.window_size[1])
  235. coords = torch.stack(torch.meshgrid(coords_h, coords_w, indexing="ij")) # 2, Wh, Ww
  236. coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
  237. relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
  238. relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
  239. relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
  240. relative_coords[:, :, 1] += self.window_size[1] - 1
  241. relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
  242. relative_position_index = relative_coords.sum(-1).flatten() # Wh*Ww*Wh*Ww
  243. self.register_buffer("relative_position_index", relative_position_index)
  244. def get_relative_position_bias(self) -> torch.Tensor:
  245. return _get_relative_position_bias(
  246. self.relative_position_bias_table, self.relative_position_index, self.window_size # type: ignore[arg-type]
  247. )
  248. def forward(self, x: Tensor) -> Tensor:
  249. """
  250. Args:
  251. x (Tensor): Tensor with layout of [B, H, W, C]
  252. Returns:
  253. Tensor with same layout as input, i.e. [B, H, W, C]
  254. """
  255. relative_position_bias = self.get_relative_position_bias()
  256. return shifted_window_attention(
  257. x,
  258. self.qkv.weight,
  259. self.proj.weight,
  260. relative_position_bias,
  261. self.window_size,
  262. self.num_heads,
  263. shift_size=self.shift_size,
  264. attention_dropout=self.attention_dropout,
  265. dropout=self.dropout,
  266. qkv_bias=self.qkv.bias,
  267. proj_bias=self.proj.bias,
  268. training=self.training,
  269. )
  270. class ShiftedWindowAttentionV2(ShiftedWindowAttention):
  271. """
  272. See :func:`shifted_window_attention_v2`.
  273. """
  274. def __init__(
  275. self,
  276. dim: int,
  277. window_size: List[int],
  278. shift_size: List[int],
  279. num_heads: int,
  280. qkv_bias: bool = True,
  281. proj_bias: bool = True,
  282. attention_dropout: float = 0.0,
  283. dropout: float = 0.0,
  284. ):
  285. super().__init__(
  286. dim,
  287. window_size,
  288. shift_size,
  289. num_heads,
  290. qkv_bias=qkv_bias,
  291. proj_bias=proj_bias,
  292. attention_dropout=attention_dropout,
  293. dropout=dropout,
  294. )
  295. self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))))
  296. # mlp to generate continuous relative position bias
  297. self.cpb_mlp = nn.Sequential(
  298. nn.Linear(2, 512, bias=True), nn.ReLU(inplace=True), nn.Linear(512, num_heads, bias=False)
  299. )
  300. if qkv_bias:
  301. length = self.qkv.bias.numel() // 3
  302. self.qkv.bias[length : 2 * length].data.zero_()
  303. def define_relative_position_bias_table(self):
  304. # get relative_coords_table
  305. relative_coords_h = torch.arange(-(self.window_size[0] - 1), self.window_size[0], dtype=torch.float32)
  306. relative_coords_w = torch.arange(-(self.window_size[1] - 1), self.window_size[1], dtype=torch.float32)
  307. relative_coords_table = torch.stack(torch.meshgrid([relative_coords_h, relative_coords_w], indexing="ij"))
  308. relative_coords_table = relative_coords_table.permute(1, 2, 0).contiguous().unsqueeze(0) # 1, 2*Wh-1, 2*Ww-1, 2
  309. relative_coords_table[:, :, :, 0] /= self.window_size[0] - 1
  310. relative_coords_table[:, :, :, 1] /= self.window_size[1] - 1
  311. relative_coords_table *= 8 # normalize to -8, 8
  312. relative_coords_table = (
  313. torch.sign(relative_coords_table) * torch.log2(torch.abs(relative_coords_table) + 1.0) / 3.0
  314. )
  315. self.register_buffer("relative_coords_table", relative_coords_table)
  316. def get_relative_position_bias(self) -> torch.Tensor:
  317. relative_position_bias = _get_relative_position_bias(
  318. self.cpb_mlp(self.relative_coords_table).view(-1, self.num_heads),
  319. self.relative_position_index, # type: ignore[arg-type]
  320. self.window_size,
  321. )
  322. relative_position_bias = 16 * torch.sigmoid(relative_position_bias)
  323. return relative_position_bias
  324. def forward(self, x: Tensor):
  325. """
  326. Args:
  327. x (Tensor): Tensor with layout of [B, H, W, C]
  328. Returns:
  329. Tensor with same layout as input, i.e. [B, H, W, C]
  330. """
  331. relative_position_bias = self.get_relative_position_bias()
  332. return shifted_window_attention(
  333. x,
  334. self.qkv.weight,
  335. self.proj.weight,
  336. relative_position_bias,
  337. self.window_size,
  338. self.num_heads,
  339. shift_size=self.shift_size,
  340. attention_dropout=self.attention_dropout,
  341. dropout=self.dropout,
  342. qkv_bias=self.qkv.bias,
  343. proj_bias=self.proj.bias,
  344. logit_scale=self.logit_scale,
  345. training=self.training,
  346. )
  347. class SwinTransformerBlock(nn.Module):
  348. """
  349. Swin Transformer Block.
  350. Args:
  351. dim (int): Number of input channels.
  352. num_heads (int): Number of attention heads.
  353. window_size (List[int]): Window size.
  354. shift_size (List[int]): Shift size for shifted window attention.
  355. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0.
  356. dropout (float): Dropout rate. Default: 0.0.
  357. attention_dropout (float): Attention dropout rate. Default: 0.0.
  358. stochastic_depth_prob: (float): Stochastic depth rate. Default: 0.0.
  359. norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
  360. attn_layer (nn.Module): Attention layer. Default: ShiftedWindowAttention
  361. """
  362. def __init__(
  363. self,
  364. dim: int,
  365. num_heads: int,
  366. window_size: List[int],
  367. shift_size: List[int],
  368. mlp_ratio: float = 4.0,
  369. dropout: float = 0.0,
  370. attention_dropout: float = 0.0,
  371. stochastic_depth_prob: float = 0.0,
  372. norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
  373. attn_layer: Callable[..., nn.Module] = ShiftedWindowAttention,
  374. ):
  375. super().__init__()
  376. _log_api_usage_once(self)
  377. self.norm1 = norm_layer(dim)
  378. self.attn = attn_layer(
  379. dim,
  380. window_size,
  381. shift_size,
  382. num_heads,
  383. attention_dropout=attention_dropout,
  384. dropout=dropout,
  385. )
  386. self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row")
  387. self.norm2 = norm_layer(dim)
  388. self.mlp = MLP(dim, [int(dim * mlp_ratio), dim], activation_layer=nn.GELU, inplace=None, dropout=dropout)
  389. for m in self.mlp.modules():
  390. if isinstance(m, nn.Linear):
  391. nn.init.xavier_uniform_(m.weight)
  392. if m.bias is not None:
  393. nn.init.normal_(m.bias, std=1e-6)
  394. def forward(self, x: Tensor):
  395. x = x + self.stochastic_depth(self.attn(self.norm1(x)))
  396. x = x + self.stochastic_depth(self.mlp(self.norm2(x)))
  397. return x
  398. class SwinTransformerBlockV2(SwinTransformerBlock):
  399. """
  400. Swin Transformer V2 Block.
  401. Args:
  402. dim (int): Number of input channels.
  403. num_heads (int): Number of attention heads.
  404. window_size (List[int]): Window size.
  405. shift_size (List[int]): Shift size for shifted window attention.
  406. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0.
  407. dropout (float): Dropout rate. Default: 0.0.
  408. attention_dropout (float): Attention dropout rate. Default: 0.0.
  409. stochastic_depth_prob: (float): Stochastic depth rate. Default: 0.0.
  410. norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
  411. attn_layer (nn.Module): Attention layer. Default: ShiftedWindowAttentionV2.
  412. """
  413. def __init__(
  414. self,
  415. dim: int,
  416. num_heads: int,
  417. window_size: List[int],
  418. shift_size: List[int],
  419. mlp_ratio: float = 4.0,
  420. dropout: float = 0.0,
  421. attention_dropout: float = 0.0,
  422. stochastic_depth_prob: float = 0.0,
  423. norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
  424. attn_layer: Callable[..., nn.Module] = ShiftedWindowAttentionV2,
  425. ):
  426. super().__init__(
  427. dim,
  428. num_heads,
  429. window_size,
  430. shift_size,
  431. mlp_ratio=mlp_ratio,
  432. dropout=dropout,
  433. attention_dropout=attention_dropout,
  434. stochastic_depth_prob=stochastic_depth_prob,
  435. norm_layer=norm_layer,
  436. attn_layer=attn_layer,
  437. )
  438. def forward(self, x: Tensor):
  439. # Here is the difference, we apply norm after the attention in V2.
  440. # In V1 we applied norm before the attention.
  441. x = x + self.stochastic_depth(self.norm1(self.attn(x)))
  442. x = x + self.stochastic_depth(self.norm2(self.mlp(x)))
  443. return x
  444. class SwinTransformer(nn.Module):
  445. """
  446. Implements Swin Transformer from the `"Swin Transformer: Hierarchical Vision Transformer using
  447. Shifted Windows" <https://arxiv.org/abs/2103.14030>`_ paper.
  448. Args:
  449. patch_size (List[int]): Patch size.
  450. embed_dim (int): Patch embedding dimension.
  451. depths (List(int)): Depth of each Swin Transformer layer.
  452. num_heads (List(int)): Number of attention heads in different layers.
  453. window_size (List[int]): Window size.
  454. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0.
  455. dropout (float): Dropout rate. Default: 0.0.
  456. attention_dropout (float): Attention dropout rate. Default: 0.0.
  457. stochastic_depth_prob (float): Stochastic depth rate. Default: 0.1.
  458. num_classes (int): Number of classes for classification head. Default: 1000.
  459. block (nn.Module, optional): SwinTransformer Block. Default: None.
  460. norm_layer (nn.Module, optional): Normalization layer. Default: None.
  461. downsample_layer (nn.Module): Downsample layer (patch merging). Default: PatchMerging.
  462. """
  463. def __init__(
  464. self,
  465. patch_size: List[int],
  466. embed_dim: int,
  467. depths: List[int],
  468. num_heads: List[int],
  469. window_size: List[int],
  470. mlp_ratio: float = 4.0,
  471. dropout: float = 0.0,
  472. attention_dropout: float = 0.0,
  473. stochastic_depth_prob: float = 0.1,
  474. num_classes: int = 1000,
  475. norm_layer: Optional[Callable[..., nn.Module]] = None,
  476. block: Optional[Callable[..., nn.Module]] = None,
  477. downsample_layer: Callable[..., nn.Module] = PatchMerging,
  478. ):
  479. super().__init__()
  480. _log_api_usage_once(self)
  481. self.num_classes = num_classes
  482. if block is None:
  483. block = SwinTransformerBlock
  484. if norm_layer is None:
  485. norm_layer = partial(nn.LayerNorm, eps=1e-5)
  486. layers: List[nn.Module] = []
  487. # split image into non-overlapping patches
  488. layers.append(
  489. nn.Sequential(
  490. nn.Conv2d(
  491. 3, embed_dim, kernel_size=(patch_size[0], patch_size[1]), stride=(patch_size[0], patch_size[1])
  492. ),
  493. Permute([0, 2, 3, 1]),
  494. norm_layer(embed_dim),
  495. )
  496. )
  497. total_stage_blocks = sum(depths)
  498. stage_block_id = 0
  499. # build SwinTransformer blocks
  500. for i_stage in range(len(depths)):
  501. stage: List[nn.Module] = []
  502. dim = embed_dim * 2**i_stage
  503. for i_layer in range(depths[i_stage]):
  504. # adjust stochastic depth probability based on the depth of the stage block
  505. sd_prob = stochastic_depth_prob * float(stage_block_id) / (total_stage_blocks - 1)
  506. stage.append(
  507. block(
  508. dim,
  509. num_heads[i_stage],
  510. window_size=window_size,
  511. shift_size=[0 if i_layer % 2 == 0 else w // 2 for w in window_size],
  512. mlp_ratio=mlp_ratio,
  513. dropout=dropout,
  514. attention_dropout=attention_dropout,
  515. stochastic_depth_prob=sd_prob,
  516. norm_layer=norm_layer,
  517. )
  518. )
  519. stage_block_id += 1
  520. layers.append(nn.Sequential(*stage))
  521. # add patch merging layer
  522. if i_stage < (len(depths) - 1):
  523. layers.append(downsample_layer(dim, norm_layer))
  524. self.features = nn.Sequential(*layers)
  525. num_features = embed_dim * 2 ** (len(depths) - 1)
  526. self.norm = norm_layer(num_features)
  527. self.permute = Permute([0, 3, 1, 2]) # B H W C -> B C H W
  528. self.avgpool = nn.AdaptiveAvgPool2d(1)
  529. self.flatten = nn.Flatten(1)
  530. self.head = nn.Linear(num_features, num_classes)
  531. for m in self.modules():
  532. if isinstance(m, nn.Linear):
  533. nn.init.trunc_normal_(m.weight, std=0.02)
  534. if m.bias is not None:
  535. nn.init.zeros_(m.bias)
  536. def forward(self, x):
  537. x = self.features(x)
  538. x = self.norm(x)
  539. x = self.permute(x)
  540. x = self.avgpool(x)
  541. x = self.flatten(x)
  542. x = self.head(x)
  543. return x
  544. def _swin_transformer(
  545. patch_size: List[int],
  546. embed_dim: int,
  547. depths: List[int],
  548. num_heads: List[int],
  549. window_size: List[int],
  550. stochastic_depth_prob: float,
  551. weights: Optional[WeightsEnum],
  552. progress: bool,
  553. **kwargs: Any,
  554. ) -> SwinTransformer:
  555. if weights is not None:
  556. _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
  557. model = SwinTransformer(
  558. patch_size=patch_size,
  559. embed_dim=embed_dim,
  560. depths=depths,
  561. num_heads=num_heads,
  562. window_size=window_size,
  563. stochastic_depth_prob=stochastic_depth_prob,
  564. **kwargs,
  565. )
  566. if weights is not None:
  567. model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
  568. return model
  569. _COMMON_META = {
  570. "categories": _IMAGENET_CATEGORIES,
  571. }
  572. class Swin_T_Weights(WeightsEnum):
  573. IMAGENET1K_V1 = Weights(
  574. url="https://download.pytorch.org/models/swin_t-704ceda3.pth",
  575. transforms=partial(
  576. ImageClassification, crop_size=224, resize_size=232, interpolation=InterpolationMode.BICUBIC
  577. ),
  578. meta={
  579. **_COMMON_META,
  580. "num_params": 28288354,
  581. "min_size": (224, 224),
  582. "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#swintransformer",
  583. "_metrics": {
  584. "ImageNet-1K": {
  585. "acc@1": 81.474,
  586. "acc@5": 95.776,
  587. }
  588. },
  589. "_ops": 4.491,
  590. "_file_size": 108.19,
  591. "_docs": """These weights reproduce closely the results of the paper using a similar training recipe.""",
  592. },
  593. )
  594. DEFAULT = IMAGENET1K_V1
  595. class Swin_S_Weights(WeightsEnum):
  596. IMAGENET1K_V1 = Weights(
  597. url="https://download.pytorch.org/models/swin_s-5e29d889.pth",
  598. transforms=partial(
  599. ImageClassification, crop_size=224, resize_size=246, interpolation=InterpolationMode.BICUBIC
  600. ),
  601. meta={
  602. **_COMMON_META,
  603. "num_params": 49606258,
  604. "min_size": (224, 224),
  605. "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#swintransformer",
  606. "_metrics": {
  607. "ImageNet-1K": {
  608. "acc@1": 83.196,
  609. "acc@5": 96.360,
  610. }
  611. },
  612. "_ops": 8.741,
  613. "_file_size": 189.786,
  614. "_docs": """These weights reproduce closely the results of the paper using a similar training recipe.""",
  615. },
  616. )
  617. DEFAULT = IMAGENET1K_V1
  618. class Swin_B_Weights(WeightsEnum):
  619. IMAGENET1K_V1 = Weights(
  620. url="https://download.pytorch.org/models/swin_b-68c6b09e.pth",
  621. transforms=partial(
  622. ImageClassification, crop_size=224, resize_size=238, interpolation=InterpolationMode.BICUBIC
  623. ),
  624. meta={
  625. **_COMMON_META,
  626. "num_params": 87768224,
  627. "min_size": (224, 224),
  628. "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#swintransformer",
  629. "_metrics": {
  630. "ImageNet-1K": {
  631. "acc@1": 83.582,
  632. "acc@5": 96.640,
  633. }
  634. },
  635. "_ops": 15.431,
  636. "_file_size": 335.364,
  637. "_docs": """These weights reproduce closely the results of the paper using a similar training recipe.""",
  638. },
  639. )
  640. DEFAULT = IMAGENET1K_V1
  641. class Swin_V2_T_Weights(WeightsEnum):
  642. IMAGENET1K_V1 = Weights(
  643. url="https://download.pytorch.org/models/swin_v2_t-b137f0e2.pth",
  644. transforms=partial(
  645. ImageClassification, crop_size=256, resize_size=260, interpolation=InterpolationMode.BICUBIC
  646. ),
  647. meta={
  648. **_COMMON_META,
  649. "num_params": 28351570,
  650. "min_size": (256, 256),
  651. "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#swintransformer-v2",
  652. "_metrics": {
  653. "ImageNet-1K": {
  654. "acc@1": 82.072,
  655. "acc@5": 96.132,
  656. }
  657. },
  658. "_ops": 5.94,
  659. "_file_size": 108.626,
  660. "_docs": """These weights reproduce closely the results of the paper using a similar training recipe.""",
  661. },
  662. )
  663. DEFAULT = IMAGENET1K_V1
  664. class Swin_V2_S_Weights(WeightsEnum):
  665. IMAGENET1K_V1 = Weights(
  666. url="https://download.pytorch.org/models/swin_v2_s-637d8ceb.pth",
  667. transforms=partial(
  668. ImageClassification, crop_size=256, resize_size=260, interpolation=InterpolationMode.BICUBIC
  669. ),
  670. meta={
  671. **_COMMON_META,
  672. "num_params": 49737442,
  673. "min_size": (256, 256),
  674. "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#swintransformer-v2",
  675. "_metrics": {
  676. "ImageNet-1K": {
  677. "acc@1": 83.712,
  678. "acc@5": 96.816,
  679. }
  680. },
  681. "_ops": 11.546,
  682. "_file_size": 190.675,
  683. "_docs": """These weights reproduce closely the results of the paper using a similar training recipe.""",
  684. },
  685. )
  686. DEFAULT = IMAGENET1K_V1
  687. class Swin_V2_B_Weights(WeightsEnum):
  688. IMAGENET1K_V1 = Weights(
  689. url="https://download.pytorch.org/models/swin_v2_b-781e5279.pth",
  690. transforms=partial(
  691. ImageClassification, crop_size=256, resize_size=272, interpolation=InterpolationMode.BICUBIC
  692. ),
  693. meta={
  694. **_COMMON_META,
  695. "num_params": 87930848,
  696. "min_size": (256, 256),
  697. "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#swintransformer-v2",
  698. "_metrics": {
  699. "ImageNet-1K": {
  700. "acc@1": 84.112,
  701. "acc@5": 96.864,
  702. }
  703. },
  704. "_ops": 20.325,
  705. "_file_size": 336.372,
  706. "_docs": """These weights reproduce closely the results of the paper using a similar training recipe.""",
  707. },
  708. )
  709. DEFAULT = IMAGENET1K_V1
  710. @register_model()
  711. @handle_legacy_interface(weights=("pretrained", Swin_T_Weights.IMAGENET1K_V1))
  712. def swin_t(*, weights: Optional[Swin_T_Weights] = None, progress: bool = True, **kwargs: Any) -> SwinTransformer:
  713. """
  714. Constructs a swin_tiny architecture from
  715. `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows <https://arxiv.org/abs/2103.14030>`_.
  716. Args:
  717. weights (:class:`~torchvision.models.Swin_T_Weights`, optional): The
  718. pretrained weights to use. See
  719. :class:`~torchvision.models.Swin_T_Weights` below for
  720. more details, and possible values. By default, no pre-trained
  721. weights are used.
  722. progress (bool, optional): If True, displays a progress bar of the
  723. download to stderr. Default is True.
  724. **kwargs: parameters passed to the ``torchvision.models.swin_transformer.SwinTransformer``
  725. base class. Please refer to the `source code
  726. <https://github.com/pytorch/vision/blob/main/torchvision/models/swin_transformer.py>`_
  727. for more details about this class.
  728. .. autoclass:: torchvision.models.Swin_T_Weights
  729. :members:
  730. """
  731. weights = Swin_T_Weights.verify(weights)
  732. return _swin_transformer(
  733. patch_size=[4, 4],
  734. embed_dim=96,
  735. depths=[2, 2, 6, 2],
  736. num_heads=[3, 6, 12, 24],
  737. window_size=[7, 7],
  738. stochastic_depth_prob=0.2,
  739. weights=weights,
  740. progress=progress,
  741. **kwargs,
  742. )
  743. @register_model()
  744. @handle_legacy_interface(weights=("pretrained", Swin_S_Weights.IMAGENET1K_V1))
  745. def swin_s(*, weights: Optional[Swin_S_Weights] = None, progress: bool = True, **kwargs: Any) -> SwinTransformer:
  746. """
  747. Constructs a swin_small architecture from
  748. `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows <https://arxiv.org/abs/2103.14030>`_.
  749. Args:
  750. weights (:class:`~torchvision.models.Swin_S_Weights`, optional): The
  751. pretrained weights to use. See
  752. :class:`~torchvision.models.Swin_S_Weights` below for
  753. more details, and possible values. By default, no pre-trained
  754. weights are used.
  755. progress (bool, optional): If True, displays a progress bar of the
  756. download to stderr. Default is True.
  757. **kwargs: parameters passed to the ``torchvision.models.swin_transformer.SwinTransformer``
  758. base class. Please refer to the `source code
  759. <https://github.com/pytorch/vision/blob/main/torchvision/models/swin_transformer.py>`_
  760. for more details about this class.
  761. .. autoclass:: torchvision.models.Swin_S_Weights
  762. :members:
  763. """
  764. weights = Swin_S_Weights.verify(weights)
  765. return _swin_transformer(
  766. patch_size=[4, 4],
  767. embed_dim=96,
  768. depths=[2, 2, 18, 2],
  769. num_heads=[3, 6, 12, 24],
  770. window_size=[7, 7],
  771. stochastic_depth_prob=0.3,
  772. weights=weights,
  773. progress=progress,
  774. **kwargs,
  775. )
  776. @register_model()
  777. @handle_legacy_interface(weights=("pretrained", Swin_B_Weights.IMAGENET1K_V1))
  778. def swin_b(*, weights: Optional[Swin_B_Weights] = None, progress: bool = True, **kwargs: Any) -> SwinTransformer:
  779. """
  780. Constructs a swin_base architecture from
  781. `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows <https://arxiv.org/abs/2103.14030>`_.
  782. Args:
  783. weights (:class:`~torchvision.models.Swin_B_Weights`, optional): The
  784. pretrained weights to use. See
  785. :class:`~torchvision.models.Swin_B_Weights` below for
  786. more details, and possible values. By default, no pre-trained
  787. weights are used.
  788. progress (bool, optional): If True, displays a progress bar of the
  789. download to stderr. Default is True.
  790. **kwargs: parameters passed to the ``torchvision.models.swin_transformer.SwinTransformer``
  791. base class. Please refer to the `source code
  792. <https://github.com/pytorch/vision/blob/main/torchvision/models/swin_transformer.py>`_
  793. for more details about this class.
  794. .. autoclass:: torchvision.models.Swin_B_Weights
  795. :members:
  796. """
  797. weights = Swin_B_Weights.verify(weights)
  798. return _swin_transformer(
  799. patch_size=[4, 4],
  800. embed_dim=128,
  801. depths=[2, 2, 18, 2],
  802. num_heads=[4, 8, 16, 32],
  803. window_size=[7, 7],
  804. stochastic_depth_prob=0.5,
  805. weights=weights,
  806. progress=progress,
  807. **kwargs,
  808. )
  809. @register_model()
  810. @handle_legacy_interface(weights=("pretrained", Swin_V2_T_Weights.IMAGENET1K_V1))
  811. def swin_v2_t(*, weights: Optional[Swin_V2_T_Weights] = None, progress: bool = True, **kwargs: Any) -> SwinTransformer:
  812. """
  813. Constructs a swin_v2_tiny architecture from
  814. `Swin Transformer V2: Scaling Up Capacity and Resolution <https://arxiv.org/abs/2111.09883>`_.
  815. Args:
  816. weights (:class:`~torchvision.models.Swin_V2_T_Weights`, optional): The
  817. pretrained weights to use. See
  818. :class:`~torchvision.models.Swin_V2_T_Weights` below for
  819. more details, and possible values. By default, no pre-trained
  820. weights are used.
  821. progress (bool, optional): If True, displays a progress bar of the
  822. download to stderr. Default is True.
  823. **kwargs: parameters passed to the ``torchvision.models.swin_transformer.SwinTransformer``
  824. base class. Please refer to the `source code
  825. <https://github.com/pytorch/vision/blob/main/torchvision/models/swin_transformer.py>`_
  826. for more details about this class.
  827. .. autoclass:: torchvision.models.Swin_V2_T_Weights
  828. :members:
  829. """
  830. weights = Swin_V2_T_Weights.verify(weights)
  831. return _swin_transformer(
  832. patch_size=[4, 4],
  833. embed_dim=96,
  834. depths=[2, 2, 6, 2],
  835. num_heads=[3, 6, 12, 24],
  836. window_size=[8, 8],
  837. stochastic_depth_prob=0.2,
  838. weights=weights,
  839. progress=progress,
  840. block=SwinTransformerBlockV2,
  841. downsample_layer=PatchMergingV2,
  842. **kwargs,
  843. )
  844. @register_model()
  845. @handle_legacy_interface(weights=("pretrained", Swin_V2_S_Weights.IMAGENET1K_V1))
  846. def swin_v2_s(*, weights: Optional[Swin_V2_S_Weights] = None, progress: bool = True, **kwargs: Any) -> SwinTransformer:
  847. """
  848. Constructs a swin_v2_small architecture from
  849. `Swin Transformer V2: Scaling Up Capacity and Resolution <https://arxiv.org/abs/2111.09883>`_.
  850. Args:
  851. weights (:class:`~torchvision.models.Swin_V2_S_Weights`, optional): The
  852. pretrained weights to use. See
  853. :class:`~torchvision.models.Swin_V2_S_Weights` below for
  854. more details, and possible values. By default, no pre-trained
  855. weights are used.
  856. progress (bool, optional): If True, displays a progress bar of the
  857. download to stderr. Default is True.
  858. **kwargs: parameters passed to the ``torchvision.models.swin_transformer.SwinTransformer``
  859. base class. Please refer to the `source code
  860. <https://github.com/pytorch/vision/blob/main/torchvision/models/swin_transformer.py>`_
  861. for more details about this class.
  862. .. autoclass:: torchvision.models.Swin_V2_S_Weights
  863. :members:
  864. """
  865. weights = Swin_V2_S_Weights.verify(weights)
  866. return _swin_transformer(
  867. patch_size=[4, 4],
  868. embed_dim=96,
  869. depths=[2, 2, 18, 2],
  870. num_heads=[3, 6, 12, 24],
  871. window_size=[8, 8],
  872. stochastic_depth_prob=0.3,
  873. weights=weights,
  874. progress=progress,
  875. block=SwinTransformerBlockV2,
  876. downsample_layer=PatchMergingV2,
  877. **kwargs,
  878. )
  879. @register_model()
  880. @handle_legacy_interface(weights=("pretrained", Swin_V2_B_Weights.IMAGENET1K_V1))
  881. def swin_v2_b(*, weights: Optional[Swin_V2_B_Weights] = None, progress: bool = True, **kwargs: Any) -> SwinTransformer:
  882. """
  883. Constructs a swin_v2_base architecture from
  884. `Swin Transformer V2: Scaling Up Capacity and Resolution <https://arxiv.org/abs/2111.09883>`_.
  885. Args:
  886. weights (:class:`~torchvision.models.Swin_V2_B_Weights`, optional): The
  887. pretrained weights to use. See
  888. :class:`~torchvision.models.Swin_V2_B_Weights` below for
  889. more details, and possible values. By default, no pre-trained
  890. weights are used.
  891. progress (bool, optional): If True, displays a progress bar of the
  892. download to stderr. Default is True.
  893. **kwargs: parameters passed to the ``torchvision.models.swin_transformer.SwinTransformer``
  894. base class. Please refer to the `source code
  895. <https://github.com/pytorch/vision/blob/main/torchvision/models/swin_transformer.py>`_
  896. for more details about this class.
  897. .. autoclass:: torchvision.models.Swin_V2_B_Weights
  898. :members:
  899. """
  900. weights = Swin_V2_B_Weights.verify(weights)
  901. return _swin_transformer(
  902. patch_size=[4, 4],
  903. embed_dim=128,
  904. depths=[2, 2, 18, 2],
  905. num_heads=[4, 8, 16, 32],
  906. window_size=[8, 8],
  907. stochastic_depth_prob=0.5,
  908. weights=weights,
  909. progress=progress,
  910. block=SwinTransformerBlockV2,
  911. downsample_layer=PatchMergingV2,
  912. **kwargs,
  913. )