swin_transformer.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743
  1. # Modified from 2d Swin Transformers in torchvision:
  2. # https://github.com/pytorch/vision/blob/main/torchvision/models/swin_transformer.py
  3. from functools import partial
  4. from typing import Any, Callable, List, Optional, Tuple
  5. import torch
  6. import torch.nn.functional as F
  7. from torch import nn, Tensor
  8. from ...transforms._presets import VideoClassification
  9. from ...utils import _log_api_usage_once
  10. from .._api import register_model, Weights, WeightsEnum
  11. from .._meta import _KINETICS400_CATEGORIES
  12. from .._utils import _ovewrite_named_param, handle_legacy_interface
  13. from ..swin_transformer import PatchMerging, SwinTransformerBlock
  14. __all__ = [
  15. "SwinTransformer3d",
  16. "Swin3D_T_Weights",
  17. "Swin3D_S_Weights",
  18. "Swin3D_B_Weights",
  19. "swin3d_t",
  20. "swin3d_s",
  21. "swin3d_b",
  22. ]
  23. def _get_window_and_shift_size(
  24. shift_size: List[int], size_dhw: List[int], window_size: List[int]
  25. ) -> Tuple[List[int], List[int]]:
  26. for i in range(3):
  27. if size_dhw[i] <= window_size[i]:
  28. # In this case, window_size will adapt to the input size, and no need to shift
  29. window_size[i] = size_dhw[i]
  30. shift_size[i] = 0
  31. return window_size, shift_size
  32. torch.fx.wrap("_get_window_and_shift_size")
  33. def _get_relative_position_bias(
  34. relative_position_bias_table: torch.Tensor, relative_position_index: torch.Tensor, window_size: List[int]
  35. ) -> Tensor:
  36. window_vol = window_size[0] * window_size[1] * window_size[2]
  37. # In 3d case we flatten the relative_position_bias
  38. relative_position_bias = relative_position_bias_table[
  39. relative_position_index[:window_vol, :window_vol].flatten() # type: ignore[index]
  40. ]
  41. relative_position_bias = relative_position_bias.view(window_vol, window_vol, -1)
  42. relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous().unsqueeze(0)
  43. return relative_position_bias
  44. torch.fx.wrap("_get_relative_position_bias")
  45. def _compute_pad_size_3d(size_dhw: Tuple[int, int, int], patch_size: Tuple[int, int, int]) -> Tuple[int, int, int]:
  46. pad_size = [(patch_size[i] - size_dhw[i] % patch_size[i]) % patch_size[i] for i in range(3)]
  47. return pad_size[0], pad_size[1], pad_size[2]
  48. torch.fx.wrap("_compute_pad_size_3d")
  49. def _compute_attention_mask_3d(
  50. x: Tensor,
  51. size_dhw: Tuple[int, int, int],
  52. window_size: Tuple[int, int, int],
  53. shift_size: Tuple[int, int, int],
  54. ) -> Tensor:
  55. # generate attention mask
  56. attn_mask = x.new_zeros(*size_dhw)
  57. num_windows = (size_dhw[0] // window_size[0]) * (size_dhw[1] // window_size[1]) * (size_dhw[2] // window_size[2])
  58. slices = [
  59. (
  60. (0, -window_size[i]),
  61. (-window_size[i], -shift_size[i]),
  62. (-shift_size[i], None),
  63. )
  64. for i in range(3)
  65. ]
  66. count = 0
  67. for d in slices[0]:
  68. for h in slices[1]:
  69. for w in slices[2]:
  70. attn_mask[d[0] : d[1], h[0] : h[1], w[0] : w[1]] = count
  71. count += 1
  72. # Partition window on attn_mask
  73. attn_mask = attn_mask.view(
  74. size_dhw[0] // window_size[0],
  75. window_size[0],
  76. size_dhw[1] // window_size[1],
  77. window_size[1],
  78. size_dhw[2] // window_size[2],
  79. window_size[2],
  80. )
  81. attn_mask = attn_mask.permute(0, 2, 4, 1, 3, 5).reshape(
  82. num_windows, window_size[0] * window_size[1] * window_size[2]
  83. )
  84. attn_mask = attn_mask.unsqueeze(1) - attn_mask.unsqueeze(2)
  85. attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
  86. return attn_mask
  87. torch.fx.wrap("_compute_attention_mask_3d")
  88. def shifted_window_attention_3d(
  89. input: Tensor,
  90. qkv_weight: Tensor,
  91. proj_weight: Tensor,
  92. relative_position_bias: Tensor,
  93. window_size: List[int],
  94. num_heads: int,
  95. shift_size: List[int],
  96. attention_dropout: float = 0.0,
  97. dropout: float = 0.0,
  98. qkv_bias: Optional[Tensor] = None,
  99. proj_bias: Optional[Tensor] = None,
  100. training: bool = True,
  101. ) -> Tensor:
  102. """
  103. Window based multi-head self attention (W-MSA) module with relative position bias.
  104. It supports both of shifted and non-shifted window.
  105. Args:
  106. input (Tensor[B, T, H, W, C]): The input tensor, 5-dimensions.
  107. qkv_weight (Tensor[in_dim, out_dim]): The weight tensor of query, key, value.
  108. proj_weight (Tensor[out_dim, out_dim]): The weight tensor of projection.
  109. relative_position_bias (Tensor): The learned relative position bias added to attention.
  110. window_size (List[int]): 3-dimensions window size, T, H, W .
  111. num_heads (int): Number of attention heads.
  112. shift_size (List[int]): Shift size for shifted window attention (T, H, W).
  113. attention_dropout (float): Dropout ratio of attention weight. Default: 0.0.
  114. dropout (float): Dropout ratio of output. Default: 0.0.
  115. qkv_bias (Tensor[out_dim], optional): The bias tensor of query, key, value. Default: None.
  116. proj_bias (Tensor[out_dim], optional): The bias tensor of projection. Default: None.
  117. training (bool, optional): Training flag used by the dropout parameters. Default: True.
  118. Returns:
  119. Tensor[B, T, H, W, C]: The output tensor after shifted window attention.
  120. """
  121. b, t, h, w, c = input.shape
  122. # pad feature maps to multiples of window size
  123. pad_size = _compute_pad_size_3d((t, h, w), (window_size[0], window_size[1], window_size[2]))
  124. x = F.pad(input, (0, 0, 0, pad_size[2], 0, pad_size[1], 0, pad_size[0]))
  125. _, tp, hp, wp, _ = x.shape
  126. padded_size = (tp, hp, wp)
  127. # cyclic shift
  128. if sum(shift_size) > 0:
  129. x = torch.roll(x, shifts=(-shift_size[0], -shift_size[1], -shift_size[2]), dims=(1, 2, 3))
  130. # partition windows
  131. num_windows = (
  132. (padded_size[0] // window_size[0]) * (padded_size[1] // window_size[1]) * (padded_size[2] // window_size[2])
  133. )
  134. x = x.view(
  135. b,
  136. padded_size[0] // window_size[0],
  137. window_size[0],
  138. padded_size[1] // window_size[1],
  139. window_size[1],
  140. padded_size[2] // window_size[2],
  141. window_size[2],
  142. c,
  143. )
  144. x = x.permute(0, 1, 3, 5, 2, 4, 6, 7).reshape(
  145. b * num_windows, window_size[0] * window_size[1] * window_size[2], c
  146. ) # B*nW, Wd*Wh*Ww, C
  147. # multi-head attention
  148. qkv = F.linear(x, qkv_weight, qkv_bias)
  149. qkv = qkv.reshape(x.size(0), x.size(1), 3, num_heads, c // num_heads).permute(2, 0, 3, 1, 4)
  150. q, k, v = qkv[0], qkv[1], qkv[2]
  151. q = q * (c // num_heads) ** -0.5
  152. attn = q.matmul(k.transpose(-2, -1))
  153. # add relative position bias
  154. attn = attn + relative_position_bias
  155. if sum(shift_size) > 0:
  156. # generate attention mask to handle shifted windows with varying size
  157. attn_mask = _compute_attention_mask_3d(
  158. x,
  159. (padded_size[0], padded_size[1], padded_size[2]),
  160. (window_size[0], window_size[1], window_size[2]),
  161. (shift_size[0], shift_size[1], shift_size[2]),
  162. )
  163. attn = attn.view(x.size(0) // num_windows, num_windows, num_heads, x.size(1), x.size(1))
  164. attn = attn + attn_mask.unsqueeze(1).unsqueeze(0)
  165. attn = attn.view(-1, num_heads, x.size(1), x.size(1))
  166. attn = F.softmax(attn, dim=-1)
  167. attn = F.dropout(attn, p=attention_dropout, training=training)
  168. x = attn.matmul(v).transpose(1, 2).reshape(x.size(0), x.size(1), c)
  169. x = F.linear(x, proj_weight, proj_bias)
  170. x = F.dropout(x, p=dropout, training=training)
  171. # reverse windows
  172. x = x.view(
  173. b,
  174. padded_size[0] // window_size[0],
  175. padded_size[1] // window_size[1],
  176. padded_size[2] // window_size[2],
  177. window_size[0],
  178. window_size[1],
  179. window_size[2],
  180. c,
  181. )
  182. x = x.permute(0, 1, 4, 2, 5, 3, 6, 7).reshape(b, tp, hp, wp, c)
  183. # reverse cyclic shift
  184. if sum(shift_size) > 0:
  185. x = torch.roll(x, shifts=(shift_size[0], shift_size[1], shift_size[2]), dims=(1, 2, 3))
  186. # unpad features
  187. x = x[:, :t, :h, :w, :].contiguous()
  188. return x
  189. torch.fx.wrap("shifted_window_attention_3d")
  190. class ShiftedWindowAttention3d(nn.Module):
  191. """
  192. See :func:`shifted_window_attention_3d`.
  193. """
  194. def __init__(
  195. self,
  196. dim: int,
  197. window_size: List[int],
  198. shift_size: List[int],
  199. num_heads: int,
  200. qkv_bias: bool = True,
  201. proj_bias: bool = True,
  202. attention_dropout: float = 0.0,
  203. dropout: float = 0.0,
  204. ) -> None:
  205. super().__init__()
  206. if len(window_size) != 3 or len(shift_size) != 3:
  207. raise ValueError("window_size and shift_size must be of length 2")
  208. self.window_size = window_size # Wd, Wh, Ww
  209. self.shift_size = shift_size
  210. self.num_heads = num_heads
  211. self.attention_dropout = attention_dropout
  212. self.dropout = dropout
  213. self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
  214. self.proj = nn.Linear(dim, dim, bias=proj_bias)
  215. self.define_relative_position_bias_table()
  216. self.define_relative_position_index()
  217. def define_relative_position_bias_table(self) -> None:
  218. # define a parameter table of relative position bias
  219. self.relative_position_bias_table = nn.Parameter(
  220. torch.zeros(
  221. (2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1) * (2 * self.window_size[2] - 1),
  222. self.num_heads,
  223. )
  224. ) # 2*Wd-1 * 2*Wh-1 * 2*Ww-1, nH
  225. nn.init.trunc_normal_(self.relative_position_bias_table, std=0.02)
  226. def define_relative_position_index(self) -> None:
  227. # get pair-wise relative position index for each token inside the window
  228. coords_dhw = [torch.arange(self.window_size[i]) for i in range(3)]
  229. coords = torch.stack(
  230. torch.meshgrid(coords_dhw[0], coords_dhw[1], coords_dhw[2], indexing="ij")
  231. ) # 3, Wd, Wh, Ww
  232. coords_flatten = torch.flatten(coords, 1) # 3, Wd*Wh*Ww
  233. relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 3, Wd*Wh*Ww, Wd*Wh*Ww
  234. relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wd*Wh*Ww, Wd*Wh*Ww, 3
  235. relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
  236. relative_coords[:, :, 1] += self.window_size[1] - 1
  237. relative_coords[:, :, 2] += self.window_size[2] - 1
  238. relative_coords[:, :, 0] *= (2 * self.window_size[1] - 1) * (2 * self.window_size[2] - 1)
  239. relative_coords[:, :, 1] *= 2 * self.window_size[2] - 1
  240. # We don't flatten the relative_position_index here in 3d case.
  241. relative_position_index = relative_coords.sum(-1) # Wd*Wh*Ww, Wd*Wh*Ww
  242. self.register_buffer("relative_position_index", relative_position_index)
  243. def get_relative_position_bias(self, window_size: List[int]) -> torch.Tensor:
  244. return _get_relative_position_bias(self.relative_position_bias_table, self.relative_position_index, window_size) # type: ignore
  245. def forward(self, x: Tensor) -> Tensor:
  246. _, t, h, w, _ = x.shape
  247. size_dhw = [t, h, w]
  248. window_size, shift_size = self.window_size.copy(), self.shift_size.copy()
  249. # Handle case where window_size is larger than the input tensor
  250. window_size, shift_size = _get_window_and_shift_size(shift_size, size_dhw, window_size)
  251. relative_position_bias = self.get_relative_position_bias(window_size)
  252. return shifted_window_attention_3d(
  253. x,
  254. self.qkv.weight,
  255. self.proj.weight,
  256. relative_position_bias,
  257. window_size,
  258. self.num_heads,
  259. shift_size=shift_size,
  260. attention_dropout=self.attention_dropout,
  261. dropout=self.dropout,
  262. qkv_bias=self.qkv.bias,
  263. proj_bias=self.proj.bias,
  264. training=self.training,
  265. )
  266. # Modified from:
  267. # https://github.com/SwinTransformer/Video-Swin-Transformer/blob/master/mmaction/models/backbones/swin_transformer.py
  268. class PatchEmbed3d(nn.Module):
  269. """Video to Patch Embedding.
  270. Args:
  271. patch_size (List[int]): Patch token size.
  272. in_channels (int): Number of input channels. Default: 3
  273. embed_dim (int): Number of linear projection output channels. Default: 96.
  274. norm_layer (nn.Module, optional): Normalization layer. Default: None
  275. """
  276. def __init__(
  277. self,
  278. patch_size: List[int],
  279. in_channels: int = 3,
  280. embed_dim: int = 96,
  281. norm_layer: Optional[Callable[..., nn.Module]] = None,
  282. ) -> None:
  283. super().__init__()
  284. _log_api_usage_once(self)
  285. self.tuple_patch_size = (patch_size[0], patch_size[1], patch_size[2])
  286. self.proj = nn.Conv3d(
  287. in_channels,
  288. embed_dim,
  289. kernel_size=self.tuple_patch_size,
  290. stride=self.tuple_patch_size,
  291. )
  292. if norm_layer is not None:
  293. self.norm = norm_layer(embed_dim)
  294. else:
  295. self.norm = nn.Identity()
  296. def forward(self, x: Tensor) -> Tensor:
  297. """Forward function."""
  298. # padding
  299. _, _, t, h, w = x.size()
  300. pad_size = _compute_pad_size_3d((t, h, w), self.tuple_patch_size)
  301. x = F.pad(x, (0, pad_size[2], 0, pad_size[1], 0, pad_size[0]))
  302. x = self.proj(x) # B C T Wh Ww
  303. x = x.permute(0, 2, 3, 4, 1) # B T Wh Ww C
  304. if self.norm is not None:
  305. x = self.norm(x)
  306. return x
  307. class SwinTransformer3d(nn.Module):
  308. """
  309. Implements 3D Swin Transformer from the `"Video Swin Transformer" <https://arxiv.org/abs/2106.13230>`_ paper.
  310. Args:
  311. patch_size (List[int]): Patch size.
  312. embed_dim (int): Patch embedding dimension.
  313. depths (List(int)): Depth of each Swin Transformer layer.
  314. num_heads (List(int)): Number of attention heads in different layers.
  315. window_size (List[int]): Window size.
  316. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0.
  317. dropout (float): Dropout rate. Default: 0.0.
  318. attention_dropout (float): Attention dropout rate. Default: 0.0.
  319. stochastic_depth_prob (float): Stochastic depth rate. Default: 0.1.
  320. num_classes (int): Number of classes for classification head. Default: 400.
  321. norm_layer (nn.Module, optional): Normalization layer. Default: None.
  322. block (nn.Module, optional): SwinTransformer Block. Default: None.
  323. downsample_layer (nn.Module): Downsample layer (patch merging). Default: PatchMerging.
  324. patch_embed (nn.Module, optional): Patch Embedding layer. Default: None.
  325. """
  326. def __init__(
  327. self,
  328. patch_size: List[int],
  329. embed_dim: int,
  330. depths: List[int],
  331. num_heads: List[int],
  332. window_size: List[int],
  333. mlp_ratio: float = 4.0,
  334. dropout: float = 0.0,
  335. attention_dropout: float = 0.0,
  336. stochastic_depth_prob: float = 0.1,
  337. num_classes: int = 400,
  338. norm_layer: Optional[Callable[..., nn.Module]] = None,
  339. block: Optional[Callable[..., nn.Module]] = None,
  340. downsample_layer: Callable[..., nn.Module] = PatchMerging,
  341. patch_embed: Optional[Callable[..., nn.Module]] = None,
  342. ) -> None:
  343. super().__init__()
  344. _log_api_usage_once(self)
  345. self.num_classes = num_classes
  346. if block is None:
  347. block = partial(SwinTransformerBlock, attn_layer=ShiftedWindowAttention3d)
  348. if norm_layer is None:
  349. norm_layer = partial(nn.LayerNorm, eps=1e-5)
  350. if patch_embed is None:
  351. patch_embed = PatchEmbed3d
  352. # split image into non-overlapping patches
  353. self.patch_embed = patch_embed(patch_size=patch_size, embed_dim=embed_dim, norm_layer=norm_layer)
  354. self.pos_drop = nn.Dropout(p=dropout)
  355. layers: List[nn.Module] = []
  356. total_stage_blocks = sum(depths)
  357. stage_block_id = 0
  358. # build SwinTransformer blocks
  359. for i_stage in range(len(depths)):
  360. stage: List[nn.Module] = []
  361. dim = embed_dim * 2**i_stage
  362. for i_layer in range(depths[i_stage]):
  363. # adjust stochastic depth probability based on the depth of the stage block
  364. sd_prob = stochastic_depth_prob * float(stage_block_id) / (total_stage_blocks - 1)
  365. stage.append(
  366. block(
  367. dim,
  368. num_heads[i_stage],
  369. window_size=window_size,
  370. shift_size=[0 if i_layer % 2 == 0 else w // 2 for w in window_size],
  371. mlp_ratio=mlp_ratio,
  372. dropout=dropout,
  373. attention_dropout=attention_dropout,
  374. stochastic_depth_prob=sd_prob,
  375. norm_layer=norm_layer,
  376. attn_layer=ShiftedWindowAttention3d,
  377. )
  378. )
  379. stage_block_id += 1
  380. layers.append(nn.Sequential(*stage))
  381. # add patch merging layer
  382. if i_stage < (len(depths) - 1):
  383. layers.append(downsample_layer(dim, norm_layer))
  384. self.features = nn.Sequential(*layers)
  385. self.num_features = embed_dim * 2 ** (len(depths) - 1)
  386. self.norm = norm_layer(self.num_features)
  387. self.avgpool = nn.AdaptiveAvgPool3d(1)
  388. self.head = nn.Linear(self.num_features, num_classes)
  389. for m in self.modules():
  390. if isinstance(m, nn.Linear):
  391. nn.init.trunc_normal_(m.weight, std=0.02)
  392. if m.bias is not None:
  393. nn.init.zeros_(m.bias)
  394. def forward(self, x: Tensor) -> Tensor:
  395. # x: B C T H W
  396. x = self.patch_embed(x) # B _T _H _W C
  397. x = self.pos_drop(x)
  398. x = self.features(x) # B _T _H _W C
  399. x = self.norm(x)
  400. x = x.permute(0, 4, 1, 2, 3) # B, C, _T, _H, _W
  401. x = self.avgpool(x)
  402. x = torch.flatten(x, 1)
  403. x = self.head(x)
  404. return x
  405. def _swin_transformer3d(
  406. patch_size: List[int],
  407. embed_dim: int,
  408. depths: List[int],
  409. num_heads: List[int],
  410. window_size: List[int],
  411. stochastic_depth_prob: float,
  412. weights: Optional[WeightsEnum],
  413. progress: bool,
  414. **kwargs: Any,
  415. ) -> SwinTransformer3d:
  416. if weights is not None:
  417. _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
  418. model = SwinTransformer3d(
  419. patch_size=patch_size,
  420. embed_dim=embed_dim,
  421. depths=depths,
  422. num_heads=num_heads,
  423. window_size=window_size,
  424. stochastic_depth_prob=stochastic_depth_prob,
  425. **kwargs,
  426. )
  427. if weights is not None:
  428. model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
  429. return model
  430. _COMMON_META = {
  431. "categories": _KINETICS400_CATEGORIES,
  432. "min_size": (1, 1),
  433. "min_temporal_size": 1,
  434. }
  435. class Swin3D_T_Weights(WeightsEnum):
  436. KINETICS400_V1 = Weights(
  437. url="https://download.pytorch.org/models/swin3d_t-7615ae03.pth",
  438. transforms=partial(
  439. VideoClassification,
  440. crop_size=(224, 224),
  441. resize_size=(256,),
  442. mean=(0.4850, 0.4560, 0.4060),
  443. std=(0.2290, 0.2240, 0.2250),
  444. ),
  445. meta={
  446. **_COMMON_META,
  447. "recipe": "https://github.com/SwinTransformer/Video-Swin-Transformer#kinetics-400",
  448. "_docs": (
  449. "The weights were ported from the paper. The accuracies are estimated on video-level "
  450. "with parameters `frame_rate=15`, `clips_per_video=12`, and `clip_len=32`"
  451. ),
  452. "num_params": 28158070,
  453. "_metrics": {
  454. "Kinetics-400": {
  455. "acc@1": 77.715,
  456. "acc@5": 93.519,
  457. }
  458. },
  459. "_ops": 43.882,
  460. "_file_size": 121.543,
  461. },
  462. )
  463. DEFAULT = KINETICS400_V1
  464. class Swin3D_S_Weights(WeightsEnum):
  465. KINETICS400_V1 = Weights(
  466. url="https://download.pytorch.org/models/swin3d_s-da41c237.pth",
  467. transforms=partial(
  468. VideoClassification,
  469. crop_size=(224, 224),
  470. resize_size=(256,),
  471. mean=(0.4850, 0.4560, 0.4060),
  472. std=(0.2290, 0.2240, 0.2250),
  473. ),
  474. meta={
  475. **_COMMON_META,
  476. "recipe": "https://github.com/SwinTransformer/Video-Swin-Transformer#kinetics-400",
  477. "_docs": (
  478. "The weights were ported from the paper. The accuracies are estimated on video-level "
  479. "with parameters `frame_rate=15`, `clips_per_video=12`, and `clip_len=32`"
  480. ),
  481. "num_params": 49816678,
  482. "_metrics": {
  483. "Kinetics-400": {
  484. "acc@1": 79.521,
  485. "acc@5": 94.158,
  486. }
  487. },
  488. "_ops": 82.841,
  489. "_file_size": 218.288,
  490. },
  491. )
  492. DEFAULT = KINETICS400_V1
  493. class Swin3D_B_Weights(WeightsEnum):
  494. KINETICS400_V1 = Weights(
  495. url="https://download.pytorch.org/models/swin3d_b_1k-24f7c7c6.pth",
  496. transforms=partial(
  497. VideoClassification,
  498. crop_size=(224, 224),
  499. resize_size=(256,),
  500. mean=(0.4850, 0.4560, 0.4060),
  501. std=(0.2290, 0.2240, 0.2250),
  502. ),
  503. meta={
  504. **_COMMON_META,
  505. "recipe": "https://github.com/SwinTransformer/Video-Swin-Transformer#kinetics-400",
  506. "_docs": (
  507. "The weights were ported from the paper. The accuracies are estimated on video-level "
  508. "with parameters `frame_rate=15`, `clips_per_video=12`, and `clip_len=32`"
  509. ),
  510. "num_params": 88048984,
  511. "_metrics": {
  512. "Kinetics-400": {
  513. "acc@1": 79.427,
  514. "acc@5": 94.386,
  515. }
  516. },
  517. "_ops": 140.667,
  518. "_file_size": 364.134,
  519. },
  520. )
  521. KINETICS400_IMAGENET22K_V1 = Weights(
  522. url="https://download.pytorch.org/models/swin3d_b_22k-7c6ae6fa.pth",
  523. transforms=partial(
  524. VideoClassification,
  525. crop_size=(224, 224),
  526. resize_size=(256,),
  527. mean=(0.4850, 0.4560, 0.4060),
  528. std=(0.2290, 0.2240, 0.2250),
  529. ),
  530. meta={
  531. **_COMMON_META,
  532. "recipe": "https://github.com/SwinTransformer/Video-Swin-Transformer#kinetics-400",
  533. "_docs": (
  534. "The weights were ported from the paper. The accuracies are estimated on video-level "
  535. "with parameters `frame_rate=15`, `clips_per_video=12`, and `clip_len=32`"
  536. ),
  537. "num_params": 88048984,
  538. "_metrics": {
  539. "Kinetics-400": {
  540. "acc@1": 81.643,
  541. "acc@5": 95.574,
  542. }
  543. },
  544. "_ops": 140.667,
  545. "_file_size": 364.134,
  546. },
  547. )
  548. DEFAULT = KINETICS400_V1
  549. @register_model()
  550. @handle_legacy_interface(weights=("pretrained", Swin3D_T_Weights.KINETICS400_V1))
  551. def swin3d_t(*, weights: Optional[Swin3D_T_Weights] = None, progress: bool = True, **kwargs: Any) -> SwinTransformer3d:
  552. """
  553. Constructs a swin_tiny architecture from
  554. `Video Swin Transformer <https://arxiv.org/abs/2106.13230>`_.
  555. Args:
  556. weights (:class:`~torchvision.models.video.Swin3D_T_Weights`, optional): The
  557. pretrained weights to use. See
  558. :class:`~torchvision.models.video.Swin3D_T_Weights` below for
  559. more details, and possible values. By default, no pre-trained
  560. weights are used.
  561. progress (bool, optional): If True, displays a progress bar of the
  562. download to stderr. Default is True.
  563. **kwargs: parameters passed to the ``torchvision.models.video.swin_transformer.SwinTransformer``
  564. base class. Please refer to the `source code
  565. <https://github.com/pytorch/vision/blob/main/torchvision/models/video/swin_transformer.py>`_
  566. for more details about this class.
  567. .. autoclass:: torchvision.models.video.Swin3D_T_Weights
  568. :members:
  569. """
  570. weights = Swin3D_T_Weights.verify(weights)
  571. return _swin_transformer3d(
  572. patch_size=[2, 4, 4],
  573. embed_dim=96,
  574. depths=[2, 2, 6, 2],
  575. num_heads=[3, 6, 12, 24],
  576. window_size=[8, 7, 7],
  577. stochastic_depth_prob=0.1,
  578. weights=weights,
  579. progress=progress,
  580. **kwargs,
  581. )
  582. @register_model()
  583. @handle_legacy_interface(weights=("pretrained", Swin3D_S_Weights.KINETICS400_V1))
  584. def swin3d_s(*, weights: Optional[Swin3D_S_Weights] = None, progress: bool = True, **kwargs: Any) -> SwinTransformer3d:
  585. """
  586. Constructs a swin_small architecture from
  587. `Video Swin Transformer <https://arxiv.org/abs/2106.13230>`_.
  588. Args:
  589. weights (:class:`~torchvision.models.video.Swin3D_S_Weights`, optional): The
  590. pretrained weights to use. See
  591. :class:`~torchvision.models.video.Swin3D_S_Weights` below for
  592. more details, and possible values. By default, no pre-trained
  593. weights are used.
  594. progress (bool, optional): If True, displays a progress bar of the
  595. download to stderr. Default is True.
  596. **kwargs: parameters passed to the ``torchvision.models.video.swin_transformer.SwinTransformer``
  597. base class. Please refer to the `source code
  598. <https://github.com/pytorch/vision/blob/main/torchvision/models/video/swin_transformer.py>`_
  599. for more details about this class.
  600. .. autoclass:: torchvision.models.video.Swin3D_S_Weights
  601. :members:
  602. """
  603. weights = Swin3D_S_Weights.verify(weights)
  604. return _swin_transformer3d(
  605. patch_size=[2, 4, 4],
  606. embed_dim=96,
  607. depths=[2, 2, 18, 2],
  608. num_heads=[3, 6, 12, 24],
  609. window_size=[8, 7, 7],
  610. stochastic_depth_prob=0.1,
  611. weights=weights,
  612. progress=progress,
  613. **kwargs,
  614. )
  615. @register_model()
  616. @handle_legacy_interface(weights=("pretrained", Swin3D_B_Weights.KINETICS400_V1))
  617. def swin3d_b(*, weights: Optional[Swin3D_B_Weights] = None, progress: bool = True, **kwargs: Any) -> SwinTransformer3d:
  618. """
  619. Constructs a swin_base architecture from
  620. `Video Swin Transformer <https://arxiv.org/abs/2106.13230>`_.
  621. Args:
  622. weights (:class:`~torchvision.models.video.Swin3D_B_Weights`, optional): The
  623. pretrained weights to use. See
  624. :class:`~torchvision.models.video.Swin3D_B_Weights` below for
  625. more details, and possible values. By default, no pre-trained
  626. weights are used.
  627. progress (bool, optional): If True, displays a progress bar of the
  628. download to stderr. Default is True.
  629. **kwargs: parameters passed to the ``torchvision.models.video.swin_transformer.SwinTransformer``
  630. base class. Please refer to the `source code
  631. <https://github.com/pytorch/vision/blob/main/torchvision/models/video/swin_transformer.py>`_
  632. for more details about this class.
  633. .. autoclass:: torchvision.models.video.Swin3D_B_Weights
  634. :members:
  635. """
  636. weights = Swin3D_B_Weights.verify(weights)
  637. return _swin_transformer3d(
  638. patch_size=[2, 4, 4],
  639. embed_dim=128,
  640. depths=[2, 2, 18, 2],
  641. num_heads=[4, 8, 16, 32],
  642. window_size=[8, 7, 7],
  643. stochastic_depth_prob=0.1,
  644. weights=weights,
  645. progress=progress,
  646. **kwargs,
  647. )