mvit.py 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897
  1. import math
  2. from dataclasses import dataclass
  3. from functools import partial
  4. from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple
  5. import torch
  6. import torch.fx
  7. import torch.nn as nn
  8. from ...ops import MLP, StochasticDepth
  9. from ...transforms._presets import VideoClassification
  10. from ...utils import _log_api_usage_once
  11. from .._api import register_model, Weights, WeightsEnum
  12. from .._meta import _KINETICS400_CATEGORIES
  13. from .._utils import _ovewrite_named_param, handle_legacy_interface
  14. __all__ = [
  15. "MViT",
  16. "MViT_V1_B_Weights",
  17. "mvit_v1_b",
  18. "MViT_V2_S_Weights",
  19. "mvit_v2_s",
  20. ]
  21. @dataclass
  22. class MSBlockConfig:
  23. num_heads: int
  24. input_channels: int
  25. output_channels: int
  26. kernel_q: List[int]
  27. kernel_kv: List[int]
  28. stride_q: List[int]
  29. stride_kv: List[int]
  30. def _prod(s: Sequence[int]) -> int:
  31. product = 1
  32. for v in s:
  33. product *= v
  34. return product
  35. def _unsqueeze(x: torch.Tensor, target_dim: int, expand_dim: int) -> Tuple[torch.Tensor, int]:
  36. tensor_dim = x.dim()
  37. if tensor_dim == target_dim - 1:
  38. x = x.unsqueeze(expand_dim)
  39. elif tensor_dim != target_dim:
  40. raise ValueError(f"Unsupported input dimension {x.shape}")
  41. return x, tensor_dim
  42. def _squeeze(x: torch.Tensor, target_dim: int, expand_dim: int, tensor_dim: int) -> torch.Tensor:
  43. if tensor_dim == target_dim - 1:
  44. x = x.squeeze(expand_dim)
  45. return x
  46. torch.fx.wrap("_unsqueeze")
  47. torch.fx.wrap("_squeeze")
  48. class Pool(nn.Module):
  49. def __init__(
  50. self,
  51. pool: nn.Module,
  52. norm: Optional[nn.Module],
  53. activation: Optional[nn.Module] = None,
  54. norm_before_pool: bool = False,
  55. ) -> None:
  56. super().__init__()
  57. self.pool = pool
  58. layers = []
  59. if norm is not None:
  60. layers.append(norm)
  61. if activation is not None:
  62. layers.append(activation)
  63. self.norm_act = nn.Sequential(*layers) if layers else None
  64. self.norm_before_pool = norm_before_pool
  65. def forward(self, x: torch.Tensor, thw: Tuple[int, int, int]) -> Tuple[torch.Tensor, Tuple[int, int, int]]:
  66. x, tensor_dim = _unsqueeze(x, 4, 1)
  67. # Separate the class token and reshape the input
  68. class_token, x = torch.tensor_split(x, indices=(1,), dim=2)
  69. x = x.transpose(2, 3)
  70. B, N, C = x.shape[:3]
  71. x = x.reshape((B * N, C) + thw).contiguous()
  72. # normalizing prior pooling is useful when we use BN which can be absorbed to speed up inference
  73. if self.norm_before_pool and self.norm_act is not None:
  74. x = self.norm_act(x)
  75. # apply the pool on the input and add back the token
  76. x = self.pool(x)
  77. T, H, W = x.shape[2:]
  78. x = x.reshape(B, N, C, -1).transpose(2, 3)
  79. x = torch.cat((class_token, x), dim=2)
  80. if not self.norm_before_pool and self.norm_act is not None:
  81. x = self.norm_act(x)
  82. x = _squeeze(x, 4, 1, tensor_dim)
  83. return x, (T, H, W)
  84. def _interpolate(embedding: torch.Tensor, d: int) -> torch.Tensor:
  85. if embedding.shape[0] == d:
  86. return embedding
  87. return (
  88. nn.functional.interpolate(
  89. embedding.permute(1, 0).unsqueeze(0),
  90. size=d,
  91. mode="linear",
  92. )
  93. .squeeze(0)
  94. .permute(1, 0)
  95. )
  96. def _add_rel_pos(
  97. attn: torch.Tensor,
  98. q: torch.Tensor,
  99. q_thw: Tuple[int, int, int],
  100. k_thw: Tuple[int, int, int],
  101. rel_pos_h: torch.Tensor,
  102. rel_pos_w: torch.Tensor,
  103. rel_pos_t: torch.Tensor,
  104. ) -> torch.Tensor:
  105. # Modified code from: https://github.com/facebookresearch/SlowFast/commit/1aebd71a2efad823d52b827a3deaf15a56cf4932
  106. q_t, q_h, q_w = q_thw
  107. k_t, k_h, k_w = k_thw
  108. dh = int(2 * max(q_h, k_h) - 1)
  109. dw = int(2 * max(q_w, k_w) - 1)
  110. dt = int(2 * max(q_t, k_t) - 1)
  111. # Scale up rel pos if shapes for q and k are different.
  112. q_h_ratio = max(k_h / q_h, 1.0)
  113. k_h_ratio = max(q_h / k_h, 1.0)
  114. dist_h = torch.arange(q_h)[:, None] * q_h_ratio - (torch.arange(k_h)[None, :] + (1.0 - k_h)) * k_h_ratio
  115. q_w_ratio = max(k_w / q_w, 1.0)
  116. k_w_ratio = max(q_w / k_w, 1.0)
  117. dist_w = torch.arange(q_w)[:, None] * q_w_ratio - (torch.arange(k_w)[None, :] + (1.0 - k_w)) * k_w_ratio
  118. q_t_ratio = max(k_t / q_t, 1.0)
  119. k_t_ratio = max(q_t / k_t, 1.0)
  120. dist_t = torch.arange(q_t)[:, None] * q_t_ratio - (torch.arange(k_t)[None, :] + (1.0 - k_t)) * k_t_ratio
  121. # Interpolate rel pos if needed.
  122. rel_pos_h = _interpolate(rel_pos_h, dh)
  123. rel_pos_w = _interpolate(rel_pos_w, dw)
  124. rel_pos_t = _interpolate(rel_pos_t, dt)
  125. Rh = rel_pos_h[dist_h.long()]
  126. Rw = rel_pos_w[dist_w.long()]
  127. Rt = rel_pos_t[dist_t.long()]
  128. B, n_head, _, dim = q.shape
  129. r_q = q[:, :, 1:].reshape(B, n_head, q_t, q_h, q_w, dim)
  130. rel_h_q = torch.einsum("bythwc,hkc->bythwk", r_q, Rh) # [B, H, q_t, qh, qw, k_h]
  131. rel_w_q = torch.einsum("bythwc,wkc->bythwk", r_q, Rw) # [B, H, q_t, qh, qw, k_w]
  132. # [B, H, q_t, q_h, q_w, dim] -> [q_t, B, H, q_h, q_w, dim] -> [q_t, B*H*q_h*q_w, dim]
  133. r_q = r_q.permute(2, 0, 1, 3, 4, 5).reshape(q_t, B * n_head * q_h * q_w, dim)
  134. # [q_t, B*H*q_h*q_w, dim] * [q_t, dim, k_t] = [q_t, B*H*q_h*q_w, k_t] -> [B*H*q_h*q_w, q_t, k_t]
  135. rel_q_t = torch.matmul(r_q, Rt.transpose(1, 2)).transpose(0, 1)
  136. # [B*H*q_h*q_w, q_t, k_t] -> [B, H, q_t, q_h, q_w, k_t]
  137. rel_q_t = rel_q_t.view(B, n_head, q_h, q_w, q_t, k_t).permute(0, 1, 4, 2, 3, 5)
  138. # Combine rel pos.
  139. rel_pos = (
  140. rel_h_q[:, :, :, :, :, None, :, None]
  141. + rel_w_q[:, :, :, :, :, None, None, :]
  142. + rel_q_t[:, :, :, :, :, :, None, None]
  143. ).reshape(B, n_head, q_t * q_h * q_w, k_t * k_h * k_w)
  144. # Add it to attention
  145. attn[:, :, 1:, 1:] += rel_pos
  146. return attn
  147. def _add_shortcut(x: torch.Tensor, shortcut: torch.Tensor, residual_with_cls_embed: bool):
  148. if residual_with_cls_embed:
  149. x.add_(shortcut)
  150. else:
  151. x[:, :, 1:, :] += shortcut[:, :, 1:, :]
  152. return x
  153. torch.fx.wrap("_add_rel_pos")
  154. torch.fx.wrap("_add_shortcut")
  155. class MultiscaleAttention(nn.Module):
  156. def __init__(
  157. self,
  158. input_size: List[int],
  159. embed_dim: int,
  160. output_dim: int,
  161. num_heads: int,
  162. kernel_q: List[int],
  163. kernel_kv: List[int],
  164. stride_q: List[int],
  165. stride_kv: List[int],
  166. residual_pool: bool,
  167. residual_with_cls_embed: bool,
  168. rel_pos_embed: bool,
  169. dropout: float = 0.0,
  170. norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
  171. ) -> None:
  172. super().__init__()
  173. self.embed_dim = embed_dim
  174. self.output_dim = output_dim
  175. self.num_heads = num_heads
  176. self.head_dim = output_dim // num_heads
  177. self.scaler = 1.0 / math.sqrt(self.head_dim)
  178. self.residual_pool = residual_pool
  179. self.residual_with_cls_embed = residual_with_cls_embed
  180. self.qkv = nn.Linear(embed_dim, 3 * output_dim)
  181. layers: List[nn.Module] = [nn.Linear(output_dim, output_dim)]
  182. if dropout > 0.0:
  183. layers.append(nn.Dropout(dropout, inplace=True))
  184. self.project = nn.Sequential(*layers)
  185. self.pool_q: Optional[nn.Module] = None
  186. if _prod(kernel_q) > 1 or _prod(stride_q) > 1:
  187. padding_q = [int(q // 2) for q in kernel_q]
  188. self.pool_q = Pool(
  189. nn.Conv3d(
  190. self.head_dim,
  191. self.head_dim,
  192. kernel_q, # type: ignore[arg-type]
  193. stride=stride_q, # type: ignore[arg-type]
  194. padding=padding_q, # type: ignore[arg-type]
  195. groups=self.head_dim,
  196. bias=False,
  197. ),
  198. norm_layer(self.head_dim),
  199. )
  200. self.pool_k: Optional[nn.Module] = None
  201. self.pool_v: Optional[nn.Module] = None
  202. if _prod(kernel_kv) > 1 or _prod(stride_kv) > 1:
  203. padding_kv = [int(kv // 2) for kv in kernel_kv]
  204. self.pool_k = Pool(
  205. nn.Conv3d(
  206. self.head_dim,
  207. self.head_dim,
  208. kernel_kv, # type: ignore[arg-type]
  209. stride=stride_kv, # type: ignore[arg-type]
  210. padding=padding_kv, # type: ignore[arg-type]
  211. groups=self.head_dim,
  212. bias=False,
  213. ),
  214. norm_layer(self.head_dim),
  215. )
  216. self.pool_v = Pool(
  217. nn.Conv3d(
  218. self.head_dim,
  219. self.head_dim,
  220. kernel_kv, # type: ignore[arg-type]
  221. stride=stride_kv, # type: ignore[arg-type]
  222. padding=padding_kv, # type: ignore[arg-type]
  223. groups=self.head_dim,
  224. bias=False,
  225. ),
  226. norm_layer(self.head_dim),
  227. )
  228. self.rel_pos_h: Optional[nn.Parameter] = None
  229. self.rel_pos_w: Optional[nn.Parameter] = None
  230. self.rel_pos_t: Optional[nn.Parameter] = None
  231. if rel_pos_embed:
  232. size = max(input_size[1:])
  233. q_size = size // stride_q[1] if len(stride_q) > 0 else size
  234. kv_size = size // stride_kv[1] if len(stride_kv) > 0 else size
  235. spatial_dim = 2 * max(q_size, kv_size) - 1
  236. temporal_dim = 2 * input_size[0] - 1
  237. self.rel_pos_h = nn.Parameter(torch.zeros(spatial_dim, self.head_dim))
  238. self.rel_pos_w = nn.Parameter(torch.zeros(spatial_dim, self.head_dim))
  239. self.rel_pos_t = nn.Parameter(torch.zeros(temporal_dim, self.head_dim))
  240. nn.init.trunc_normal_(self.rel_pos_h, std=0.02)
  241. nn.init.trunc_normal_(self.rel_pos_w, std=0.02)
  242. nn.init.trunc_normal_(self.rel_pos_t, std=0.02)
  243. def forward(self, x: torch.Tensor, thw: Tuple[int, int, int]) -> Tuple[torch.Tensor, Tuple[int, int, int]]:
  244. B, N, C = x.shape
  245. q, k, v = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).transpose(1, 3).unbind(dim=2)
  246. if self.pool_k is not None:
  247. k, k_thw = self.pool_k(k, thw)
  248. else:
  249. k_thw = thw
  250. if self.pool_v is not None:
  251. v = self.pool_v(v, thw)[0]
  252. if self.pool_q is not None:
  253. q, thw = self.pool_q(q, thw)
  254. attn = torch.matmul(self.scaler * q, k.transpose(2, 3))
  255. if self.rel_pos_h is not None and self.rel_pos_w is not None and self.rel_pos_t is not None:
  256. attn = _add_rel_pos(
  257. attn,
  258. q,
  259. thw,
  260. k_thw,
  261. self.rel_pos_h,
  262. self.rel_pos_w,
  263. self.rel_pos_t,
  264. )
  265. attn = attn.softmax(dim=-1)
  266. x = torch.matmul(attn, v)
  267. if self.residual_pool:
  268. _add_shortcut(x, q, self.residual_with_cls_embed)
  269. x = x.transpose(1, 2).reshape(B, -1, self.output_dim)
  270. x = self.project(x)
  271. return x, thw
  272. class MultiscaleBlock(nn.Module):
  273. def __init__(
  274. self,
  275. input_size: List[int],
  276. cnf: MSBlockConfig,
  277. residual_pool: bool,
  278. residual_with_cls_embed: bool,
  279. rel_pos_embed: bool,
  280. proj_after_attn: bool,
  281. dropout: float = 0.0,
  282. stochastic_depth_prob: float = 0.0,
  283. norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
  284. ) -> None:
  285. super().__init__()
  286. self.proj_after_attn = proj_after_attn
  287. self.pool_skip: Optional[nn.Module] = None
  288. if _prod(cnf.stride_q) > 1:
  289. kernel_skip = [s + 1 if s > 1 else s for s in cnf.stride_q]
  290. padding_skip = [int(k // 2) for k in kernel_skip]
  291. self.pool_skip = Pool(
  292. nn.MaxPool3d(kernel_skip, stride=cnf.stride_q, padding=padding_skip), None # type: ignore[arg-type]
  293. )
  294. attn_dim = cnf.output_channels if proj_after_attn else cnf.input_channels
  295. self.norm1 = norm_layer(cnf.input_channels)
  296. self.norm2 = norm_layer(attn_dim)
  297. self.needs_transposal = isinstance(self.norm1, nn.BatchNorm1d)
  298. self.attn = MultiscaleAttention(
  299. input_size,
  300. cnf.input_channels,
  301. attn_dim,
  302. cnf.num_heads,
  303. kernel_q=cnf.kernel_q,
  304. kernel_kv=cnf.kernel_kv,
  305. stride_q=cnf.stride_q,
  306. stride_kv=cnf.stride_kv,
  307. rel_pos_embed=rel_pos_embed,
  308. residual_pool=residual_pool,
  309. residual_with_cls_embed=residual_with_cls_embed,
  310. dropout=dropout,
  311. norm_layer=norm_layer,
  312. )
  313. self.mlp = MLP(
  314. attn_dim,
  315. [4 * attn_dim, cnf.output_channels],
  316. activation_layer=nn.GELU,
  317. dropout=dropout,
  318. inplace=None,
  319. )
  320. self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row")
  321. self.project: Optional[nn.Module] = None
  322. if cnf.input_channels != cnf.output_channels:
  323. self.project = nn.Linear(cnf.input_channels, cnf.output_channels)
  324. def forward(self, x: torch.Tensor, thw: Tuple[int, int, int]) -> Tuple[torch.Tensor, Tuple[int, int, int]]:
  325. x_norm1 = self.norm1(x.transpose(1, 2)).transpose(1, 2) if self.needs_transposal else self.norm1(x)
  326. x_attn, thw_new = self.attn(x_norm1, thw)
  327. x = x if self.project is None or not self.proj_after_attn else self.project(x_norm1)
  328. x_skip = x if self.pool_skip is None else self.pool_skip(x, thw)[0]
  329. x = x_skip + self.stochastic_depth(x_attn)
  330. x_norm2 = self.norm2(x.transpose(1, 2)).transpose(1, 2) if self.needs_transposal else self.norm2(x)
  331. x_proj = x if self.project is None or self.proj_after_attn else self.project(x_norm2)
  332. return x_proj + self.stochastic_depth(self.mlp(x_norm2)), thw_new
  333. class PositionalEncoding(nn.Module):
  334. def __init__(self, embed_size: int, spatial_size: Tuple[int, int], temporal_size: int, rel_pos_embed: bool) -> None:
  335. super().__init__()
  336. self.spatial_size = spatial_size
  337. self.temporal_size = temporal_size
  338. self.class_token = nn.Parameter(torch.zeros(embed_size))
  339. self.spatial_pos: Optional[nn.Parameter] = None
  340. self.temporal_pos: Optional[nn.Parameter] = None
  341. self.class_pos: Optional[nn.Parameter] = None
  342. if not rel_pos_embed:
  343. self.spatial_pos = nn.Parameter(torch.zeros(self.spatial_size[0] * self.spatial_size[1], embed_size))
  344. self.temporal_pos = nn.Parameter(torch.zeros(self.temporal_size, embed_size))
  345. self.class_pos = nn.Parameter(torch.zeros(embed_size))
  346. def forward(self, x: torch.Tensor) -> torch.Tensor:
  347. class_token = self.class_token.expand(x.size(0), -1).unsqueeze(1)
  348. x = torch.cat((class_token, x), dim=1)
  349. if self.spatial_pos is not None and self.temporal_pos is not None and self.class_pos is not None:
  350. hw_size, embed_size = self.spatial_pos.shape
  351. pos_embedding = torch.repeat_interleave(self.temporal_pos, hw_size, dim=0)
  352. pos_embedding.add_(self.spatial_pos.unsqueeze(0).expand(self.temporal_size, -1, -1).reshape(-1, embed_size))
  353. pos_embedding = torch.cat((self.class_pos.unsqueeze(0), pos_embedding), dim=0).unsqueeze(0)
  354. x.add_(pos_embedding)
  355. return x
  356. class MViT(nn.Module):
  357. def __init__(
  358. self,
  359. spatial_size: Tuple[int, int],
  360. temporal_size: int,
  361. block_setting: Sequence[MSBlockConfig],
  362. residual_pool: bool,
  363. residual_with_cls_embed: bool,
  364. rel_pos_embed: bool,
  365. proj_after_attn: bool,
  366. dropout: float = 0.5,
  367. attention_dropout: float = 0.0,
  368. stochastic_depth_prob: float = 0.0,
  369. num_classes: int = 400,
  370. block: Optional[Callable[..., nn.Module]] = None,
  371. norm_layer: Optional[Callable[..., nn.Module]] = None,
  372. patch_embed_kernel: Tuple[int, int, int] = (3, 7, 7),
  373. patch_embed_stride: Tuple[int, int, int] = (2, 4, 4),
  374. patch_embed_padding: Tuple[int, int, int] = (1, 3, 3),
  375. ) -> None:
  376. """
  377. MViT main class.
  378. Args:
  379. spatial_size (tuple of ints): The spacial size of the input as ``(H, W)``.
  380. temporal_size (int): The temporal size ``T`` of the input.
  381. block_setting (sequence of MSBlockConfig): The Network structure.
  382. residual_pool (bool): If True, use MViTv2 pooling residual connection.
  383. residual_with_cls_embed (bool): If True, the addition on the residual connection will include
  384. the class embedding.
  385. rel_pos_embed (bool): If True, use MViTv2's relative positional embeddings.
  386. proj_after_attn (bool): If True, apply the projection after the attention.
  387. dropout (float): Dropout rate. Default: 0.0.
  388. attention_dropout (float): Attention dropout rate. Default: 0.0.
  389. stochastic_depth_prob: (float): Stochastic depth rate. Default: 0.0.
  390. num_classes (int): The number of classes.
  391. block (callable, optional): Module specifying the layer which consists of the attention and mlp.
  392. norm_layer (callable, optional): Module specifying the normalization layer to use.
  393. patch_embed_kernel (tuple of ints): The kernel of the convolution that patchifies the input.
  394. patch_embed_stride (tuple of ints): The stride of the convolution that patchifies the input.
  395. patch_embed_padding (tuple of ints): The padding of the convolution that patchifies the input.
  396. """
  397. super().__init__()
  398. # This implementation employs a different parameterization scheme than the one used at PyTorch Video:
  399. # https://github.com/facebookresearch/pytorchvideo/blob/718d0a4/pytorchvideo/models/vision_transformers.py
  400. # We remove any experimental configuration that didn't make it to the final variants of the models. To represent
  401. # the configuration of the architecture we use the simplified form suggested at Table 1 of the paper.
  402. _log_api_usage_once(self)
  403. total_stage_blocks = len(block_setting)
  404. if total_stage_blocks == 0:
  405. raise ValueError("The configuration parameter can't be empty.")
  406. if block is None:
  407. block = MultiscaleBlock
  408. if norm_layer is None:
  409. norm_layer = partial(nn.LayerNorm, eps=1e-6)
  410. # Patch Embedding module
  411. self.conv_proj = nn.Conv3d(
  412. in_channels=3,
  413. out_channels=block_setting[0].input_channels,
  414. kernel_size=patch_embed_kernel,
  415. stride=patch_embed_stride,
  416. padding=patch_embed_padding,
  417. )
  418. input_size = [size // stride for size, stride in zip((temporal_size,) + spatial_size, self.conv_proj.stride)]
  419. # Spatio-Temporal Class Positional Encoding
  420. self.pos_encoding = PositionalEncoding(
  421. embed_size=block_setting[0].input_channels,
  422. spatial_size=(input_size[1], input_size[2]),
  423. temporal_size=input_size[0],
  424. rel_pos_embed=rel_pos_embed,
  425. )
  426. # Encoder module
  427. self.blocks = nn.ModuleList()
  428. for stage_block_id, cnf in enumerate(block_setting):
  429. # adjust stochastic depth probability based on the depth of the stage block
  430. sd_prob = stochastic_depth_prob * stage_block_id / (total_stage_blocks - 1.0)
  431. self.blocks.append(
  432. block(
  433. input_size=input_size,
  434. cnf=cnf,
  435. residual_pool=residual_pool,
  436. residual_with_cls_embed=residual_with_cls_embed,
  437. rel_pos_embed=rel_pos_embed,
  438. proj_after_attn=proj_after_attn,
  439. dropout=attention_dropout,
  440. stochastic_depth_prob=sd_prob,
  441. norm_layer=norm_layer,
  442. )
  443. )
  444. if len(cnf.stride_q) > 0:
  445. input_size = [size // stride for size, stride in zip(input_size, cnf.stride_q)]
  446. self.norm = norm_layer(block_setting[-1].output_channels)
  447. # Classifier module
  448. self.head = nn.Sequential(
  449. nn.Dropout(dropout, inplace=True),
  450. nn.Linear(block_setting[-1].output_channels, num_classes),
  451. )
  452. for m in self.modules():
  453. if isinstance(m, nn.Linear):
  454. nn.init.trunc_normal_(m.weight, std=0.02)
  455. if isinstance(m, nn.Linear) and m.bias is not None:
  456. nn.init.constant_(m.bias, 0.0)
  457. elif isinstance(m, nn.LayerNorm):
  458. if m.weight is not None:
  459. nn.init.constant_(m.weight, 1.0)
  460. if m.bias is not None:
  461. nn.init.constant_(m.bias, 0.0)
  462. elif isinstance(m, PositionalEncoding):
  463. for weights in m.parameters():
  464. nn.init.trunc_normal_(weights, std=0.02)
  465. def forward(self, x: torch.Tensor) -> torch.Tensor:
  466. # Convert if necessary (B, C, H, W) -> (B, C, 1, H, W)
  467. x = _unsqueeze(x, 5, 2)[0]
  468. # patchify and reshape: (B, C, T, H, W) -> (B, embed_channels[0], T', H', W') -> (B, THW', embed_channels[0])
  469. x = self.conv_proj(x)
  470. x = x.flatten(2).transpose(1, 2)
  471. # add positional encoding
  472. x = self.pos_encoding(x)
  473. # pass patches through the encoder
  474. thw = (self.pos_encoding.temporal_size,) + self.pos_encoding.spatial_size
  475. for block in self.blocks:
  476. x, thw = block(x, thw)
  477. x = self.norm(x)
  478. # classifier "token" as used by standard language architectures
  479. x = x[:, 0]
  480. x = self.head(x)
  481. return x
  482. def _mvit(
  483. block_setting: List[MSBlockConfig],
  484. stochastic_depth_prob: float,
  485. weights: Optional[WeightsEnum],
  486. progress: bool,
  487. **kwargs: Any,
  488. ) -> MViT:
  489. if weights is not None:
  490. _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
  491. assert weights.meta["min_size"][0] == weights.meta["min_size"][1]
  492. _ovewrite_named_param(kwargs, "spatial_size", weights.meta["min_size"])
  493. _ovewrite_named_param(kwargs, "temporal_size", weights.meta["min_temporal_size"])
  494. spatial_size = kwargs.pop("spatial_size", (224, 224))
  495. temporal_size = kwargs.pop("temporal_size", 16)
  496. model = MViT(
  497. spatial_size=spatial_size,
  498. temporal_size=temporal_size,
  499. block_setting=block_setting,
  500. residual_pool=kwargs.pop("residual_pool", False),
  501. residual_with_cls_embed=kwargs.pop("residual_with_cls_embed", True),
  502. rel_pos_embed=kwargs.pop("rel_pos_embed", False),
  503. proj_after_attn=kwargs.pop("proj_after_attn", False),
  504. stochastic_depth_prob=stochastic_depth_prob,
  505. **kwargs,
  506. )
  507. if weights is not None:
  508. model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
  509. return model
  510. class MViT_V1_B_Weights(WeightsEnum):
  511. KINETICS400_V1 = Weights(
  512. url="https://download.pytorch.org/models/mvit_v1_b-dbeb1030.pth",
  513. transforms=partial(
  514. VideoClassification,
  515. crop_size=(224, 224),
  516. resize_size=(256,),
  517. mean=(0.45, 0.45, 0.45),
  518. std=(0.225, 0.225, 0.225),
  519. ),
  520. meta={
  521. "min_size": (224, 224),
  522. "min_temporal_size": 16,
  523. "categories": _KINETICS400_CATEGORIES,
  524. "recipe": "https://github.com/facebookresearch/pytorchvideo/blob/main/docs/source/model_zoo.md",
  525. "_docs": (
  526. "The weights were ported from the paper. The accuracies are estimated on video-level "
  527. "with parameters `frame_rate=7.5`, `clips_per_video=5`, and `clip_len=16`"
  528. ),
  529. "num_params": 36610672,
  530. "_metrics": {
  531. "Kinetics-400": {
  532. "acc@1": 78.477,
  533. "acc@5": 93.582,
  534. }
  535. },
  536. "_ops": 70.599,
  537. "_file_size": 139.764,
  538. },
  539. )
  540. DEFAULT = KINETICS400_V1
  541. class MViT_V2_S_Weights(WeightsEnum):
  542. KINETICS400_V1 = Weights(
  543. url="https://download.pytorch.org/models/mvit_v2_s-ae3be167.pth",
  544. transforms=partial(
  545. VideoClassification,
  546. crop_size=(224, 224),
  547. resize_size=(256,),
  548. mean=(0.45, 0.45, 0.45),
  549. std=(0.225, 0.225, 0.225),
  550. ),
  551. meta={
  552. "min_size": (224, 224),
  553. "min_temporal_size": 16,
  554. "categories": _KINETICS400_CATEGORIES,
  555. "recipe": "https://github.com/facebookresearch/SlowFast/blob/main/MODEL_ZOO.md",
  556. "_docs": (
  557. "The weights were ported from the paper. The accuracies are estimated on video-level "
  558. "with parameters `frame_rate=7.5`, `clips_per_video=5`, and `clip_len=16`"
  559. ),
  560. "num_params": 34537744,
  561. "_metrics": {
  562. "Kinetics-400": {
  563. "acc@1": 80.757,
  564. "acc@5": 94.665,
  565. }
  566. },
  567. "_ops": 64.224,
  568. "_file_size": 131.884,
  569. },
  570. )
  571. DEFAULT = KINETICS400_V1
  572. @register_model()
  573. @handle_legacy_interface(weights=("pretrained", MViT_V1_B_Weights.KINETICS400_V1))
  574. def mvit_v1_b(*, weights: Optional[MViT_V1_B_Weights] = None, progress: bool = True, **kwargs: Any) -> MViT:
  575. """
  576. Constructs a base MViTV1 architecture from
  577. `Multiscale Vision Transformers <https://arxiv.org/abs/2104.11227>`__.
  578. .. betastatus:: video module
  579. Args:
  580. weights (:class:`~torchvision.models.video.MViT_V1_B_Weights`, optional): The
  581. pretrained weights to use. See
  582. :class:`~torchvision.models.video.MViT_V1_B_Weights` below for
  583. more details, and possible values. By default, no pre-trained
  584. weights are used.
  585. progress (bool, optional): If True, displays a progress bar of the
  586. download to stderr. Default is True.
  587. **kwargs: parameters passed to the ``torchvision.models.video.MViT``
  588. base class. Please refer to the `source code
  589. <https://github.com/pytorch/vision/blob/main/torchvision/models/video/mvit.py>`_
  590. for more details about this class.
  591. .. autoclass:: torchvision.models.video.MViT_V1_B_Weights
  592. :members:
  593. """
  594. weights = MViT_V1_B_Weights.verify(weights)
  595. config: Dict[str, List] = {
  596. "num_heads": [1, 2, 2, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 8, 8],
  597. "input_channels": [96, 192, 192, 384, 384, 384, 384, 384, 384, 384, 384, 384, 384, 384, 768, 768],
  598. "output_channels": [192, 192, 384, 384, 384, 384, 384, 384, 384, 384, 384, 384, 384, 768, 768, 768],
  599. "kernel_q": [[], [3, 3, 3], [], [3, 3, 3], [], [], [], [], [], [], [], [], [], [], [3, 3, 3], []],
  600. "kernel_kv": [
  601. [3, 3, 3],
  602. [3, 3, 3],
  603. [3, 3, 3],
  604. [3, 3, 3],
  605. [3, 3, 3],
  606. [3, 3, 3],
  607. [3, 3, 3],
  608. [3, 3, 3],
  609. [3, 3, 3],
  610. [3, 3, 3],
  611. [3, 3, 3],
  612. [3, 3, 3],
  613. [3, 3, 3],
  614. [3, 3, 3],
  615. [3, 3, 3],
  616. [3, 3, 3],
  617. ],
  618. "stride_q": [[], [1, 2, 2], [], [1, 2, 2], [], [], [], [], [], [], [], [], [], [], [1, 2, 2], []],
  619. "stride_kv": [
  620. [1, 8, 8],
  621. [1, 4, 4],
  622. [1, 4, 4],
  623. [1, 2, 2],
  624. [1, 2, 2],
  625. [1, 2, 2],
  626. [1, 2, 2],
  627. [1, 2, 2],
  628. [1, 2, 2],
  629. [1, 2, 2],
  630. [1, 2, 2],
  631. [1, 2, 2],
  632. [1, 2, 2],
  633. [1, 2, 2],
  634. [1, 1, 1],
  635. [1, 1, 1],
  636. ],
  637. }
  638. block_setting = []
  639. for i in range(len(config["num_heads"])):
  640. block_setting.append(
  641. MSBlockConfig(
  642. num_heads=config["num_heads"][i],
  643. input_channels=config["input_channels"][i],
  644. output_channels=config["output_channels"][i],
  645. kernel_q=config["kernel_q"][i],
  646. kernel_kv=config["kernel_kv"][i],
  647. stride_q=config["stride_q"][i],
  648. stride_kv=config["stride_kv"][i],
  649. )
  650. )
  651. return _mvit(
  652. spatial_size=(224, 224),
  653. temporal_size=16,
  654. block_setting=block_setting,
  655. residual_pool=False,
  656. residual_with_cls_embed=False,
  657. stochastic_depth_prob=kwargs.pop("stochastic_depth_prob", 0.2),
  658. weights=weights,
  659. progress=progress,
  660. **kwargs,
  661. )
  662. @register_model()
  663. @handle_legacy_interface(weights=("pretrained", MViT_V2_S_Weights.KINETICS400_V1))
  664. def mvit_v2_s(*, weights: Optional[MViT_V2_S_Weights] = None, progress: bool = True, **kwargs: Any) -> MViT:
  665. """Constructs a small MViTV2 architecture from
  666. `Multiscale Vision Transformers <https://arxiv.org/abs/2104.11227>`__ and
  667. `MViTv2: Improved Multiscale Vision Transformers for Classification
  668. and Detection <https://arxiv.org/abs/2112.01526>`__.
  669. .. betastatus:: video module
  670. Args:
  671. weights (:class:`~torchvision.models.video.MViT_V2_S_Weights`, optional): The
  672. pretrained weights to use. See
  673. :class:`~torchvision.models.video.MViT_V2_S_Weights` below for
  674. more details, and possible values. By default, no pre-trained
  675. weights are used.
  676. progress (bool, optional): If True, displays a progress bar of the
  677. download to stderr. Default is True.
  678. **kwargs: parameters passed to the ``torchvision.models.video.MViT``
  679. base class. Please refer to the `source code
  680. <https://github.com/pytorch/vision/blob/main/torchvision/models/video/mvit.py>`_
  681. for more details about this class.
  682. .. autoclass:: torchvision.models.video.MViT_V2_S_Weights
  683. :members:
  684. """
  685. weights = MViT_V2_S_Weights.verify(weights)
  686. config: Dict[str, List] = {
  687. "num_heads": [1, 2, 2, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 8, 8],
  688. "input_channels": [96, 96, 192, 192, 384, 384, 384, 384, 384, 384, 384, 384, 384, 384, 384, 768],
  689. "output_channels": [96, 192, 192, 384, 384, 384, 384, 384, 384, 384, 384, 384, 384, 384, 768, 768],
  690. "kernel_q": [
  691. [3, 3, 3],
  692. [3, 3, 3],
  693. [3, 3, 3],
  694. [3, 3, 3],
  695. [3, 3, 3],
  696. [3, 3, 3],
  697. [3, 3, 3],
  698. [3, 3, 3],
  699. [3, 3, 3],
  700. [3, 3, 3],
  701. [3, 3, 3],
  702. [3, 3, 3],
  703. [3, 3, 3],
  704. [3, 3, 3],
  705. [3, 3, 3],
  706. [3, 3, 3],
  707. ],
  708. "kernel_kv": [
  709. [3, 3, 3],
  710. [3, 3, 3],
  711. [3, 3, 3],
  712. [3, 3, 3],
  713. [3, 3, 3],
  714. [3, 3, 3],
  715. [3, 3, 3],
  716. [3, 3, 3],
  717. [3, 3, 3],
  718. [3, 3, 3],
  719. [3, 3, 3],
  720. [3, 3, 3],
  721. [3, 3, 3],
  722. [3, 3, 3],
  723. [3, 3, 3],
  724. [3, 3, 3],
  725. ],
  726. "stride_q": [
  727. [1, 1, 1],
  728. [1, 2, 2],
  729. [1, 1, 1],
  730. [1, 2, 2],
  731. [1, 1, 1],
  732. [1, 1, 1],
  733. [1, 1, 1],
  734. [1, 1, 1],
  735. [1, 1, 1],
  736. [1, 1, 1],
  737. [1, 1, 1],
  738. [1, 1, 1],
  739. [1, 1, 1],
  740. [1, 1, 1],
  741. [1, 2, 2],
  742. [1, 1, 1],
  743. ],
  744. "stride_kv": [
  745. [1, 8, 8],
  746. [1, 4, 4],
  747. [1, 4, 4],
  748. [1, 2, 2],
  749. [1, 2, 2],
  750. [1, 2, 2],
  751. [1, 2, 2],
  752. [1, 2, 2],
  753. [1, 2, 2],
  754. [1, 2, 2],
  755. [1, 2, 2],
  756. [1, 2, 2],
  757. [1, 2, 2],
  758. [1, 2, 2],
  759. [1, 1, 1],
  760. [1, 1, 1],
  761. ],
  762. }
  763. block_setting = []
  764. for i in range(len(config["num_heads"])):
  765. block_setting.append(
  766. MSBlockConfig(
  767. num_heads=config["num_heads"][i],
  768. input_channels=config["input_channels"][i],
  769. output_channels=config["output_channels"][i],
  770. kernel_q=config["kernel_q"][i],
  771. kernel_kv=config["kernel_kv"][i],
  772. stride_q=config["stride_q"][i],
  773. stride_kv=config["stride_kv"][i],
  774. )
  775. )
  776. return _mvit(
  777. spatial_size=(224, 224),
  778. temporal_size=16,
  779. block_setting=block_setting,
  780. residual_pool=True,
  781. residual_with_cls_embed=False,
  782. rel_pos_embed=True,
  783. proj_after_attn=True,
  784. stochastic_depth_prob=kwargs.pop("stochastic_depth_prob", 0.2),
  785. weights=weights,
  786. progress=progress,
  787. **kwargs,
  788. )