123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897 |
- import math
- from dataclasses import dataclass
- from functools import partial
- from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple
- import torch
- import torch.fx
- import torch.nn as nn
- from ...ops import MLP, StochasticDepth
- from ...transforms._presets import VideoClassification
- from ...utils import _log_api_usage_once
- from .._api import register_model, Weights, WeightsEnum
- from .._meta import _KINETICS400_CATEGORIES
- from .._utils import _ovewrite_named_param, handle_legacy_interface
- __all__ = [
- "MViT",
- "MViT_V1_B_Weights",
- "mvit_v1_b",
- "MViT_V2_S_Weights",
- "mvit_v2_s",
- ]
- @dataclass
- class MSBlockConfig:
- num_heads: int
- input_channels: int
- output_channels: int
- kernel_q: List[int]
- kernel_kv: List[int]
- stride_q: List[int]
- stride_kv: List[int]
- def _prod(s: Sequence[int]) -> int:
- product = 1
- for v in s:
- product *= v
- return product
- def _unsqueeze(x: torch.Tensor, target_dim: int, expand_dim: int) -> Tuple[torch.Tensor, int]:
- tensor_dim = x.dim()
- if tensor_dim == target_dim - 1:
- x = x.unsqueeze(expand_dim)
- elif tensor_dim != target_dim:
- raise ValueError(f"Unsupported input dimension {x.shape}")
- return x, tensor_dim
- def _squeeze(x: torch.Tensor, target_dim: int, expand_dim: int, tensor_dim: int) -> torch.Tensor:
- if tensor_dim == target_dim - 1:
- x = x.squeeze(expand_dim)
- return x
- torch.fx.wrap("_unsqueeze")
- torch.fx.wrap("_squeeze")
- class Pool(nn.Module):
- def __init__(
- self,
- pool: nn.Module,
- norm: Optional[nn.Module],
- activation: Optional[nn.Module] = None,
- norm_before_pool: bool = False,
- ) -> None:
- super().__init__()
- self.pool = pool
- layers = []
- if norm is not None:
- layers.append(norm)
- if activation is not None:
- layers.append(activation)
- self.norm_act = nn.Sequential(*layers) if layers else None
- self.norm_before_pool = norm_before_pool
- def forward(self, x: torch.Tensor, thw: Tuple[int, int, int]) -> Tuple[torch.Tensor, Tuple[int, int, int]]:
- x, tensor_dim = _unsqueeze(x, 4, 1)
- # Separate the class token and reshape the input
- class_token, x = torch.tensor_split(x, indices=(1,), dim=2)
- x = x.transpose(2, 3)
- B, N, C = x.shape[:3]
- x = x.reshape((B * N, C) + thw).contiguous()
- # normalizing prior pooling is useful when we use BN which can be absorbed to speed up inference
- if self.norm_before_pool and self.norm_act is not None:
- x = self.norm_act(x)
- # apply the pool on the input and add back the token
- x = self.pool(x)
- T, H, W = x.shape[2:]
- x = x.reshape(B, N, C, -1).transpose(2, 3)
- x = torch.cat((class_token, x), dim=2)
- if not self.norm_before_pool and self.norm_act is not None:
- x = self.norm_act(x)
- x = _squeeze(x, 4, 1, tensor_dim)
- return x, (T, H, W)
- def _interpolate(embedding: torch.Tensor, d: int) -> torch.Tensor:
- if embedding.shape[0] == d:
- return embedding
- return (
- nn.functional.interpolate(
- embedding.permute(1, 0).unsqueeze(0),
- size=d,
- mode="linear",
- )
- .squeeze(0)
- .permute(1, 0)
- )
- def _add_rel_pos(
- attn: torch.Tensor,
- q: torch.Tensor,
- q_thw: Tuple[int, int, int],
- k_thw: Tuple[int, int, int],
- rel_pos_h: torch.Tensor,
- rel_pos_w: torch.Tensor,
- rel_pos_t: torch.Tensor,
- ) -> torch.Tensor:
- # Modified code from: https://github.com/facebookresearch/SlowFast/commit/1aebd71a2efad823d52b827a3deaf15a56cf4932
- q_t, q_h, q_w = q_thw
- k_t, k_h, k_w = k_thw
- dh = int(2 * max(q_h, k_h) - 1)
- dw = int(2 * max(q_w, k_w) - 1)
- dt = int(2 * max(q_t, k_t) - 1)
- # Scale up rel pos if shapes for q and k are different.
- q_h_ratio = max(k_h / q_h, 1.0)
- k_h_ratio = max(q_h / k_h, 1.0)
- dist_h = torch.arange(q_h)[:, None] * q_h_ratio - (torch.arange(k_h)[None, :] + (1.0 - k_h)) * k_h_ratio
- q_w_ratio = max(k_w / q_w, 1.0)
- k_w_ratio = max(q_w / k_w, 1.0)
- dist_w = torch.arange(q_w)[:, None] * q_w_ratio - (torch.arange(k_w)[None, :] + (1.0 - k_w)) * k_w_ratio
- q_t_ratio = max(k_t / q_t, 1.0)
- k_t_ratio = max(q_t / k_t, 1.0)
- dist_t = torch.arange(q_t)[:, None] * q_t_ratio - (torch.arange(k_t)[None, :] + (1.0 - k_t)) * k_t_ratio
- # Interpolate rel pos if needed.
- rel_pos_h = _interpolate(rel_pos_h, dh)
- rel_pos_w = _interpolate(rel_pos_w, dw)
- rel_pos_t = _interpolate(rel_pos_t, dt)
- Rh = rel_pos_h[dist_h.long()]
- Rw = rel_pos_w[dist_w.long()]
- Rt = rel_pos_t[dist_t.long()]
- B, n_head, _, dim = q.shape
- r_q = q[:, :, 1:].reshape(B, n_head, q_t, q_h, q_w, dim)
- rel_h_q = torch.einsum("bythwc,hkc->bythwk", r_q, Rh) # [B, H, q_t, qh, qw, k_h]
- rel_w_q = torch.einsum("bythwc,wkc->bythwk", r_q, Rw) # [B, H, q_t, qh, qw, k_w]
- # [B, H, q_t, q_h, q_w, dim] -> [q_t, B, H, q_h, q_w, dim] -> [q_t, B*H*q_h*q_w, dim]
- r_q = r_q.permute(2, 0, 1, 3, 4, 5).reshape(q_t, B * n_head * q_h * q_w, dim)
- # [q_t, B*H*q_h*q_w, dim] * [q_t, dim, k_t] = [q_t, B*H*q_h*q_w, k_t] -> [B*H*q_h*q_w, q_t, k_t]
- rel_q_t = torch.matmul(r_q, Rt.transpose(1, 2)).transpose(0, 1)
- # [B*H*q_h*q_w, q_t, k_t] -> [B, H, q_t, q_h, q_w, k_t]
- rel_q_t = rel_q_t.view(B, n_head, q_h, q_w, q_t, k_t).permute(0, 1, 4, 2, 3, 5)
- # Combine rel pos.
- rel_pos = (
- rel_h_q[:, :, :, :, :, None, :, None]
- + rel_w_q[:, :, :, :, :, None, None, :]
- + rel_q_t[:, :, :, :, :, :, None, None]
- ).reshape(B, n_head, q_t * q_h * q_w, k_t * k_h * k_w)
- # Add it to attention
- attn[:, :, 1:, 1:] += rel_pos
- return attn
- def _add_shortcut(x: torch.Tensor, shortcut: torch.Tensor, residual_with_cls_embed: bool):
- if residual_with_cls_embed:
- x.add_(shortcut)
- else:
- x[:, :, 1:, :] += shortcut[:, :, 1:, :]
- return x
- torch.fx.wrap("_add_rel_pos")
- torch.fx.wrap("_add_shortcut")
- class MultiscaleAttention(nn.Module):
- def __init__(
- self,
- input_size: List[int],
- embed_dim: int,
- output_dim: int,
- num_heads: int,
- kernel_q: List[int],
- kernel_kv: List[int],
- stride_q: List[int],
- stride_kv: List[int],
- residual_pool: bool,
- residual_with_cls_embed: bool,
- rel_pos_embed: bool,
- dropout: float = 0.0,
- norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
- ) -> None:
- super().__init__()
- self.embed_dim = embed_dim
- self.output_dim = output_dim
- self.num_heads = num_heads
- self.head_dim = output_dim // num_heads
- self.scaler = 1.0 / math.sqrt(self.head_dim)
- self.residual_pool = residual_pool
- self.residual_with_cls_embed = residual_with_cls_embed
- self.qkv = nn.Linear(embed_dim, 3 * output_dim)
- layers: List[nn.Module] = [nn.Linear(output_dim, output_dim)]
- if dropout > 0.0:
- layers.append(nn.Dropout(dropout, inplace=True))
- self.project = nn.Sequential(*layers)
- self.pool_q: Optional[nn.Module] = None
- if _prod(kernel_q) > 1 or _prod(stride_q) > 1:
- padding_q = [int(q // 2) for q in kernel_q]
- self.pool_q = Pool(
- nn.Conv3d(
- self.head_dim,
- self.head_dim,
- kernel_q, # type: ignore[arg-type]
- stride=stride_q, # type: ignore[arg-type]
- padding=padding_q, # type: ignore[arg-type]
- groups=self.head_dim,
- bias=False,
- ),
- norm_layer(self.head_dim),
- )
- self.pool_k: Optional[nn.Module] = None
- self.pool_v: Optional[nn.Module] = None
- if _prod(kernel_kv) > 1 or _prod(stride_kv) > 1:
- padding_kv = [int(kv // 2) for kv in kernel_kv]
- self.pool_k = Pool(
- nn.Conv3d(
- self.head_dim,
- self.head_dim,
- kernel_kv, # type: ignore[arg-type]
- stride=stride_kv, # type: ignore[arg-type]
- padding=padding_kv, # type: ignore[arg-type]
- groups=self.head_dim,
- bias=False,
- ),
- norm_layer(self.head_dim),
- )
- self.pool_v = Pool(
- nn.Conv3d(
- self.head_dim,
- self.head_dim,
- kernel_kv, # type: ignore[arg-type]
- stride=stride_kv, # type: ignore[arg-type]
- padding=padding_kv, # type: ignore[arg-type]
- groups=self.head_dim,
- bias=False,
- ),
- norm_layer(self.head_dim),
- )
- self.rel_pos_h: Optional[nn.Parameter] = None
- self.rel_pos_w: Optional[nn.Parameter] = None
- self.rel_pos_t: Optional[nn.Parameter] = None
- if rel_pos_embed:
- size = max(input_size[1:])
- q_size = size // stride_q[1] if len(stride_q) > 0 else size
- kv_size = size // stride_kv[1] if len(stride_kv) > 0 else size
- spatial_dim = 2 * max(q_size, kv_size) - 1
- temporal_dim = 2 * input_size[0] - 1
- self.rel_pos_h = nn.Parameter(torch.zeros(spatial_dim, self.head_dim))
- self.rel_pos_w = nn.Parameter(torch.zeros(spatial_dim, self.head_dim))
- self.rel_pos_t = nn.Parameter(torch.zeros(temporal_dim, self.head_dim))
- nn.init.trunc_normal_(self.rel_pos_h, std=0.02)
- nn.init.trunc_normal_(self.rel_pos_w, std=0.02)
- nn.init.trunc_normal_(self.rel_pos_t, std=0.02)
- def forward(self, x: torch.Tensor, thw: Tuple[int, int, int]) -> Tuple[torch.Tensor, Tuple[int, int, int]]:
- B, N, C = x.shape
- q, k, v = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).transpose(1, 3).unbind(dim=2)
- if self.pool_k is not None:
- k, k_thw = self.pool_k(k, thw)
- else:
- k_thw = thw
- if self.pool_v is not None:
- v = self.pool_v(v, thw)[0]
- if self.pool_q is not None:
- q, thw = self.pool_q(q, thw)
- attn = torch.matmul(self.scaler * q, k.transpose(2, 3))
- if self.rel_pos_h is not None and self.rel_pos_w is not None and self.rel_pos_t is not None:
- attn = _add_rel_pos(
- attn,
- q,
- thw,
- k_thw,
- self.rel_pos_h,
- self.rel_pos_w,
- self.rel_pos_t,
- )
- attn = attn.softmax(dim=-1)
- x = torch.matmul(attn, v)
- if self.residual_pool:
- _add_shortcut(x, q, self.residual_with_cls_embed)
- x = x.transpose(1, 2).reshape(B, -1, self.output_dim)
- x = self.project(x)
- return x, thw
- class MultiscaleBlock(nn.Module):
- def __init__(
- self,
- input_size: List[int],
- cnf: MSBlockConfig,
- residual_pool: bool,
- residual_with_cls_embed: bool,
- rel_pos_embed: bool,
- proj_after_attn: bool,
- dropout: float = 0.0,
- stochastic_depth_prob: float = 0.0,
- norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
- ) -> None:
- super().__init__()
- self.proj_after_attn = proj_after_attn
- self.pool_skip: Optional[nn.Module] = None
- if _prod(cnf.stride_q) > 1:
- kernel_skip = [s + 1 if s > 1 else s for s in cnf.stride_q]
- padding_skip = [int(k // 2) for k in kernel_skip]
- self.pool_skip = Pool(
- nn.MaxPool3d(kernel_skip, stride=cnf.stride_q, padding=padding_skip), None # type: ignore[arg-type]
- )
- attn_dim = cnf.output_channels if proj_after_attn else cnf.input_channels
- self.norm1 = norm_layer(cnf.input_channels)
- self.norm2 = norm_layer(attn_dim)
- self.needs_transposal = isinstance(self.norm1, nn.BatchNorm1d)
- self.attn = MultiscaleAttention(
- input_size,
- cnf.input_channels,
- attn_dim,
- cnf.num_heads,
- kernel_q=cnf.kernel_q,
- kernel_kv=cnf.kernel_kv,
- stride_q=cnf.stride_q,
- stride_kv=cnf.stride_kv,
- rel_pos_embed=rel_pos_embed,
- residual_pool=residual_pool,
- residual_with_cls_embed=residual_with_cls_embed,
- dropout=dropout,
- norm_layer=norm_layer,
- )
- self.mlp = MLP(
- attn_dim,
- [4 * attn_dim, cnf.output_channels],
- activation_layer=nn.GELU,
- dropout=dropout,
- inplace=None,
- )
- self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row")
- self.project: Optional[nn.Module] = None
- if cnf.input_channels != cnf.output_channels:
- self.project = nn.Linear(cnf.input_channels, cnf.output_channels)
- def forward(self, x: torch.Tensor, thw: Tuple[int, int, int]) -> Tuple[torch.Tensor, Tuple[int, int, int]]:
- x_norm1 = self.norm1(x.transpose(1, 2)).transpose(1, 2) if self.needs_transposal else self.norm1(x)
- x_attn, thw_new = self.attn(x_norm1, thw)
- x = x if self.project is None or not self.proj_after_attn else self.project(x_norm1)
- x_skip = x if self.pool_skip is None else self.pool_skip(x, thw)[0]
- x = x_skip + self.stochastic_depth(x_attn)
- x_norm2 = self.norm2(x.transpose(1, 2)).transpose(1, 2) if self.needs_transposal else self.norm2(x)
- x_proj = x if self.project is None or self.proj_after_attn else self.project(x_norm2)
- return x_proj + self.stochastic_depth(self.mlp(x_norm2)), thw_new
- class PositionalEncoding(nn.Module):
- def __init__(self, embed_size: int, spatial_size: Tuple[int, int], temporal_size: int, rel_pos_embed: bool) -> None:
- super().__init__()
- self.spatial_size = spatial_size
- self.temporal_size = temporal_size
- self.class_token = nn.Parameter(torch.zeros(embed_size))
- self.spatial_pos: Optional[nn.Parameter] = None
- self.temporal_pos: Optional[nn.Parameter] = None
- self.class_pos: Optional[nn.Parameter] = None
- if not rel_pos_embed:
- self.spatial_pos = nn.Parameter(torch.zeros(self.spatial_size[0] * self.spatial_size[1], embed_size))
- self.temporal_pos = nn.Parameter(torch.zeros(self.temporal_size, embed_size))
- self.class_pos = nn.Parameter(torch.zeros(embed_size))
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- class_token = self.class_token.expand(x.size(0), -1).unsqueeze(1)
- x = torch.cat((class_token, x), dim=1)
- if self.spatial_pos is not None and self.temporal_pos is not None and self.class_pos is not None:
- hw_size, embed_size = self.spatial_pos.shape
- pos_embedding = torch.repeat_interleave(self.temporal_pos, hw_size, dim=0)
- pos_embedding.add_(self.spatial_pos.unsqueeze(0).expand(self.temporal_size, -1, -1).reshape(-1, embed_size))
- pos_embedding = torch.cat((self.class_pos.unsqueeze(0), pos_embedding), dim=0).unsqueeze(0)
- x.add_(pos_embedding)
- return x
- class MViT(nn.Module):
- def __init__(
- self,
- spatial_size: Tuple[int, int],
- temporal_size: int,
- block_setting: Sequence[MSBlockConfig],
- residual_pool: bool,
- residual_with_cls_embed: bool,
- rel_pos_embed: bool,
- proj_after_attn: bool,
- dropout: float = 0.5,
- attention_dropout: float = 0.0,
- stochastic_depth_prob: float = 0.0,
- num_classes: int = 400,
- block: Optional[Callable[..., nn.Module]] = None,
- norm_layer: Optional[Callable[..., nn.Module]] = None,
- patch_embed_kernel: Tuple[int, int, int] = (3, 7, 7),
- patch_embed_stride: Tuple[int, int, int] = (2, 4, 4),
- patch_embed_padding: Tuple[int, int, int] = (1, 3, 3),
- ) -> None:
- """
- MViT main class.
- Args:
- spatial_size (tuple of ints): The spacial size of the input as ``(H, W)``.
- temporal_size (int): The temporal size ``T`` of the input.
- block_setting (sequence of MSBlockConfig): The Network structure.
- residual_pool (bool): If True, use MViTv2 pooling residual connection.
- residual_with_cls_embed (bool): If True, the addition on the residual connection will include
- the class embedding.
- rel_pos_embed (bool): If True, use MViTv2's relative positional embeddings.
- proj_after_attn (bool): If True, apply the projection after the attention.
- dropout (float): Dropout rate. Default: 0.0.
- attention_dropout (float): Attention dropout rate. Default: 0.0.
- stochastic_depth_prob: (float): Stochastic depth rate. Default: 0.0.
- num_classes (int): The number of classes.
- block (callable, optional): Module specifying the layer which consists of the attention and mlp.
- norm_layer (callable, optional): Module specifying the normalization layer to use.
- patch_embed_kernel (tuple of ints): The kernel of the convolution that patchifies the input.
- patch_embed_stride (tuple of ints): The stride of the convolution that patchifies the input.
- patch_embed_padding (tuple of ints): The padding of the convolution that patchifies the input.
- """
- super().__init__()
- # This implementation employs a different parameterization scheme than the one used at PyTorch Video:
- # https://github.com/facebookresearch/pytorchvideo/blob/718d0a4/pytorchvideo/models/vision_transformers.py
- # We remove any experimental configuration that didn't make it to the final variants of the models. To represent
- # the configuration of the architecture we use the simplified form suggested at Table 1 of the paper.
- _log_api_usage_once(self)
- total_stage_blocks = len(block_setting)
- if total_stage_blocks == 0:
- raise ValueError("The configuration parameter can't be empty.")
- if block is None:
- block = MultiscaleBlock
- if norm_layer is None:
- norm_layer = partial(nn.LayerNorm, eps=1e-6)
- # Patch Embedding module
- self.conv_proj = nn.Conv3d(
- in_channels=3,
- out_channels=block_setting[0].input_channels,
- kernel_size=patch_embed_kernel,
- stride=patch_embed_stride,
- padding=patch_embed_padding,
- )
- input_size = [size // stride for size, stride in zip((temporal_size,) + spatial_size, self.conv_proj.stride)]
- # Spatio-Temporal Class Positional Encoding
- self.pos_encoding = PositionalEncoding(
- embed_size=block_setting[0].input_channels,
- spatial_size=(input_size[1], input_size[2]),
- temporal_size=input_size[0],
- rel_pos_embed=rel_pos_embed,
- )
- # Encoder module
- self.blocks = nn.ModuleList()
- for stage_block_id, cnf in enumerate(block_setting):
- # adjust stochastic depth probability based on the depth of the stage block
- sd_prob = stochastic_depth_prob * stage_block_id / (total_stage_blocks - 1.0)
- self.blocks.append(
- block(
- input_size=input_size,
- cnf=cnf,
- residual_pool=residual_pool,
- residual_with_cls_embed=residual_with_cls_embed,
- rel_pos_embed=rel_pos_embed,
- proj_after_attn=proj_after_attn,
- dropout=attention_dropout,
- stochastic_depth_prob=sd_prob,
- norm_layer=norm_layer,
- )
- )
- if len(cnf.stride_q) > 0:
- input_size = [size // stride for size, stride in zip(input_size, cnf.stride_q)]
- self.norm = norm_layer(block_setting[-1].output_channels)
- # Classifier module
- self.head = nn.Sequential(
- nn.Dropout(dropout, inplace=True),
- nn.Linear(block_setting[-1].output_channels, num_classes),
- )
- for m in self.modules():
- if isinstance(m, nn.Linear):
- nn.init.trunc_normal_(m.weight, std=0.02)
- if isinstance(m, nn.Linear) and m.bias is not None:
- nn.init.constant_(m.bias, 0.0)
- elif isinstance(m, nn.LayerNorm):
- if m.weight is not None:
- nn.init.constant_(m.weight, 1.0)
- if m.bias is not None:
- nn.init.constant_(m.bias, 0.0)
- elif isinstance(m, PositionalEncoding):
- for weights in m.parameters():
- nn.init.trunc_normal_(weights, std=0.02)
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- # Convert if necessary (B, C, H, W) -> (B, C, 1, H, W)
- x = _unsqueeze(x, 5, 2)[0]
- # patchify and reshape: (B, C, T, H, W) -> (B, embed_channels[0], T', H', W') -> (B, THW', embed_channels[0])
- x = self.conv_proj(x)
- x = x.flatten(2).transpose(1, 2)
- # add positional encoding
- x = self.pos_encoding(x)
- # pass patches through the encoder
- thw = (self.pos_encoding.temporal_size,) + self.pos_encoding.spatial_size
- for block in self.blocks:
- x, thw = block(x, thw)
- x = self.norm(x)
- # classifier "token" as used by standard language architectures
- x = x[:, 0]
- x = self.head(x)
- return x
- def _mvit(
- block_setting: List[MSBlockConfig],
- stochastic_depth_prob: float,
- weights: Optional[WeightsEnum],
- progress: bool,
- **kwargs: Any,
- ) -> MViT:
- if weights is not None:
- _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
- assert weights.meta["min_size"][0] == weights.meta["min_size"][1]
- _ovewrite_named_param(kwargs, "spatial_size", weights.meta["min_size"])
- _ovewrite_named_param(kwargs, "temporal_size", weights.meta["min_temporal_size"])
- spatial_size = kwargs.pop("spatial_size", (224, 224))
- temporal_size = kwargs.pop("temporal_size", 16)
- model = MViT(
- spatial_size=spatial_size,
- temporal_size=temporal_size,
- block_setting=block_setting,
- residual_pool=kwargs.pop("residual_pool", False),
- residual_with_cls_embed=kwargs.pop("residual_with_cls_embed", True),
- rel_pos_embed=kwargs.pop("rel_pos_embed", False),
- proj_after_attn=kwargs.pop("proj_after_attn", False),
- stochastic_depth_prob=stochastic_depth_prob,
- **kwargs,
- )
- if weights is not None:
- model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
- return model
- class MViT_V1_B_Weights(WeightsEnum):
- KINETICS400_V1 = Weights(
- url="https://download.pytorch.org/models/mvit_v1_b-dbeb1030.pth",
- transforms=partial(
- VideoClassification,
- crop_size=(224, 224),
- resize_size=(256,),
- mean=(0.45, 0.45, 0.45),
- std=(0.225, 0.225, 0.225),
- ),
- meta={
- "min_size": (224, 224),
- "min_temporal_size": 16,
- "categories": _KINETICS400_CATEGORIES,
- "recipe": "https://github.com/facebookresearch/pytorchvideo/blob/main/docs/source/model_zoo.md",
- "_docs": (
- "The weights were ported from the paper. The accuracies are estimated on video-level "
- "with parameters `frame_rate=7.5`, `clips_per_video=5`, and `clip_len=16`"
- ),
- "num_params": 36610672,
- "_metrics": {
- "Kinetics-400": {
- "acc@1": 78.477,
- "acc@5": 93.582,
- }
- },
- "_ops": 70.599,
- "_file_size": 139.764,
- },
- )
- DEFAULT = KINETICS400_V1
- class MViT_V2_S_Weights(WeightsEnum):
- KINETICS400_V1 = Weights(
- url="https://download.pytorch.org/models/mvit_v2_s-ae3be167.pth",
- transforms=partial(
- VideoClassification,
- crop_size=(224, 224),
- resize_size=(256,),
- mean=(0.45, 0.45, 0.45),
- std=(0.225, 0.225, 0.225),
- ),
- meta={
- "min_size": (224, 224),
- "min_temporal_size": 16,
- "categories": _KINETICS400_CATEGORIES,
- "recipe": "https://github.com/facebookresearch/SlowFast/blob/main/MODEL_ZOO.md",
- "_docs": (
- "The weights were ported from the paper. The accuracies are estimated on video-level "
- "with parameters `frame_rate=7.5`, `clips_per_video=5`, and `clip_len=16`"
- ),
- "num_params": 34537744,
- "_metrics": {
- "Kinetics-400": {
- "acc@1": 80.757,
- "acc@5": 94.665,
- }
- },
- "_ops": 64.224,
- "_file_size": 131.884,
- },
- )
- DEFAULT = KINETICS400_V1
- @register_model()
- @handle_legacy_interface(weights=("pretrained", MViT_V1_B_Weights.KINETICS400_V1))
- def mvit_v1_b(*, weights: Optional[MViT_V1_B_Weights] = None, progress: bool = True, **kwargs: Any) -> MViT:
- """
- Constructs a base MViTV1 architecture from
- `Multiscale Vision Transformers <https://arxiv.org/abs/2104.11227>`__.
- .. betastatus:: video module
- Args:
- weights (:class:`~torchvision.models.video.MViT_V1_B_Weights`, optional): The
- pretrained weights to use. See
- :class:`~torchvision.models.video.MViT_V1_B_Weights` below for
- more details, and possible values. By default, no pre-trained
- weights are used.
- progress (bool, optional): If True, displays a progress bar of the
- download to stderr. Default is True.
- **kwargs: parameters passed to the ``torchvision.models.video.MViT``
- base class. Please refer to the `source code
- <https://github.com/pytorch/vision/blob/main/torchvision/models/video/mvit.py>`_
- for more details about this class.
- .. autoclass:: torchvision.models.video.MViT_V1_B_Weights
- :members:
- """
- weights = MViT_V1_B_Weights.verify(weights)
- config: Dict[str, List] = {
- "num_heads": [1, 2, 2, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 8, 8],
- "input_channels": [96, 192, 192, 384, 384, 384, 384, 384, 384, 384, 384, 384, 384, 384, 768, 768],
- "output_channels": [192, 192, 384, 384, 384, 384, 384, 384, 384, 384, 384, 384, 384, 768, 768, 768],
- "kernel_q": [[], [3, 3, 3], [], [3, 3, 3], [], [], [], [], [], [], [], [], [], [], [3, 3, 3], []],
- "kernel_kv": [
- [3, 3, 3],
- [3, 3, 3],
- [3, 3, 3],
- [3, 3, 3],
- [3, 3, 3],
- [3, 3, 3],
- [3, 3, 3],
- [3, 3, 3],
- [3, 3, 3],
- [3, 3, 3],
- [3, 3, 3],
- [3, 3, 3],
- [3, 3, 3],
- [3, 3, 3],
- [3, 3, 3],
- [3, 3, 3],
- ],
- "stride_q": [[], [1, 2, 2], [], [1, 2, 2], [], [], [], [], [], [], [], [], [], [], [1, 2, 2], []],
- "stride_kv": [
- [1, 8, 8],
- [1, 4, 4],
- [1, 4, 4],
- [1, 2, 2],
- [1, 2, 2],
- [1, 2, 2],
- [1, 2, 2],
- [1, 2, 2],
- [1, 2, 2],
- [1, 2, 2],
- [1, 2, 2],
- [1, 2, 2],
- [1, 2, 2],
- [1, 2, 2],
- [1, 1, 1],
- [1, 1, 1],
- ],
- }
- block_setting = []
- for i in range(len(config["num_heads"])):
- block_setting.append(
- MSBlockConfig(
- num_heads=config["num_heads"][i],
- input_channels=config["input_channels"][i],
- output_channels=config["output_channels"][i],
- kernel_q=config["kernel_q"][i],
- kernel_kv=config["kernel_kv"][i],
- stride_q=config["stride_q"][i],
- stride_kv=config["stride_kv"][i],
- )
- )
- return _mvit(
- spatial_size=(224, 224),
- temporal_size=16,
- block_setting=block_setting,
- residual_pool=False,
- residual_with_cls_embed=False,
- stochastic_depth_prob=kwargs.pop("stochastic_depth_prob", 0.2),
- weights=weights,
- progress=progress,
- **kwargs,
- )
- @register_model()
- @handle_legacy_interface(weights=("pretrained", MViT_V2_S_Weights.KINETICS400_V1))
- def mvit_v2_s(*, weights: Optional[MViT_V2_S_Weights] = None, progress: bool = True, **kwargs: Any) -> MViT:
- """Constructs a small MViTV2 architecture from
- `Multiscale Vision Transformers <https://arxiv.org/abs/2104.11227>`__ and
- `MViTv2: Improved Multiscale Vision Transformers for Classification
- and Detection <https://arxiv.org/abs/2112.01526>`__.
- .. betastatus:: video module
- Args:
- weights (:class:`~torchvision.models.video.MViT_V2_S_Weights`, optional): The
- pretrained weights to use. See
- :class:`~torchvision.models.video.MViT_V2_S_Weights` below for
- more details, and possible values. By default, no pre-trained
- weights are used.
- progress (bool, optional): If True, displays a progress bar of the
- download to stderr. Default is True.
- **kwargs: parameters passed to the ``torchvision.models.video.MViT``
- base class. Please refer to the `source code
- <https://github.com/pytorch/vision/blob/main/torchvision/models/video/mvit.py>`_
- for more details about this class.
- .. autoclass:: torchvision.models.video.MViT_V2_S_Weights
- :members:
- """
- weights = MViT_V2_S_Weights.verify(weights)
- config: Dict[str, List] = {
- "num_heads": [1, 2, 2, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 8, 8],
- "input_channels": [96, 96, 192, 192, 384, 384, 384, 384, 384, 384, 384, 384, 384, 384, 384, 768],
- "output_channels": [96, 192, 192, 384, 384, 384, 384, 384, 384, 384, 384, 384, 384, 384, 768, 768],
- "kernel_q": [
- [3, 3, 3],
- [3, 3, 3],
- [3, 3, 3],
- [3, 3, 3],
- [3, 3, 3],
- [3, 3, 3],
- [3, 3, 3],
- [3, 3, 3],
- [3, 3, 3],
- [3, 3, 3],
- [3, 3, 3],
- [3, 3, 3],
- [3, 3, 3],
- [3, 3, 3],
- [3, 3, 3],
- [3, 3, 3],
- ],
- "kernel_kv": [
- [3, 3, 3],
- [3, 3, 3],
- [3, 3, 3],
- [3, 3, 3],
- [3, 3, 3],
- [3, 3, 3],
- [3, 3, 3],
- [3, 3, 3],
- [3, 3, 3],
- [3, 3, 3],
- [3, 3, 3],
- [3, 3, 3],
- [3, 3, 3],
- [3, 3, 3],
- [3, 3, 3],
- [3, 3, 3],
- ],
- "stride_q": [
- [1, 1, 1],
- [1, 2, 2],
- [1, 1, 1],
- [1, 2, 2],
- [1, 1, 1],
- [1, 1, 1],
- [1, 1, 1],
- [1, 1, 1],
- [1, 1, 1],
- [1, 1, 1],
- [1, 1, 1],
- [1, 1, 1],
- [1, 1, 1],
- [1, 1, 1],
- [1, 2, 2],
- [1, 1, 1],
- ],
- "stride_kv": [
- [1, 8, 8],
- [1, 4, 4],
- [1, 4, 4],
- [1, 2, 2],
- [1, 2, 2],
- [1, 2, 2],
- [1, 2, 2],
- [1, 2, 2],
- [1, 2, 2],
- [1, 2, 2],
- [1, 2, 2],
- [1, 2, 2],
- [1, 2, 2],
- [1, 2, 2],
- [1, 1, 1],
- [1, 1, 1],
- ],
- }
- block_setting = []
- for i in range(len(config["num_heads"])):
- block_setting.append(
- MSBlockConfig(
- num_heads=config["num_heads"][i],
- input_channels=config["input_channels"][i],
- output_channels=config["output_channels"][i],
- kernel_q=config["kernel_q"][i],
- kernel_kv=config["kernel_kv"][i],
- stride_q=config["stride_q"][i],
- stride_kv=config["stride_kv"][i],
- )
- )
- return _mvit(
- spatial_size=(224, 224),
- temporal_size=16,
- block_setting=block_setting,
- residual_pool=True,
- residual_with_cls_embed=False,
- rel_pos_embed=True,
- proj_after_attn=True,
- stochastic_depth_prob=kwargs.pop("stochastic_depth_prob", 0.2),
- weights=weights,
- progress=progress,
- **kwargs,
- )
|