maxvit.py 31 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832
  1. import math
  2. from collections import OrderedDict
  3. from functools import partial
  4. from typing import Any, Callable, List, Optional, Sequence, Tuple
  5. import numpy as np
  6. import torch
  7. import torch.nn.functional as F
  8. from torch import nn, Tensor
  9. from torchvision.models._api import register_model, Weights, WeightsEnum
  10. from torchvision.models._meta import _IMAGENET_CATEGORIES
  11. from torchvision.models._utils import _ovewrite_named_param, handle_legacy_interface
  12. from torchvision.ops.misc import Conv2dNormActivation, SqueezeExcitation
  13. from torchvision.ops.stochastic_depth import StochasticDepth
  14. from torchvision.transforms._presets import ImageClassification, InterpolationMode
  15. from torchvision.utils import _log_api_usage_once
  16. __all__ = [
  17. "MaxVit",
  18. "MaxVit_T_Weights",
  19. "maxvit_t",
  20. ]
  21. def _get_conv_output_shape(input_size: Tuple[int, int], kernel_size: int, stride: int, padding: int) -> Tuple[int, int]:
  22. return (
  23. (input_size[0] - kernel_size + 2 * padding) // stride + 1,
  24. (input_size[1] - kernel_size + 2 * padding) // stride + 1,
  25. )
  26. def _make_block_input_shapes(input_size: Tuple[int, int], n_blocks: int) -> List[Tuple[int, int]]:
  27. """Util function to check that the input size is correct for a MaxVit configuration."""
  28. shapes = []
  29. block_input_shape = _get_conv_output_shape(input_size, 3, 2, 1)
  30. for _ in range(n_blocks):
  31. block_input_shape = _get_conv_output_shape(block_input_shape, 3, 2, 1)
  32. shapes.append(block_input_shape)
  33. return shapes
  34. def _get_relative_position_index(height: int, width: int) -> torch.Tensor:
  35. coords = torch.stack(torch.meshgrid([torch.arange(height), torch.arange(width)]))
  36. coords_flat = torch.flatten(coords, 1)
  37. relative_coords = coords_flat[:, :, None] - coords_flat[:, None, :]
  38. relative_coords = relative_coords.permute(1, 2, 0).contiguous()
  39. relative_coords[:, :, 0] += height - 1
  40. relative_coords[:, :, 1] += width - 1
  41. relative_coords[:, :, 0] *= 2 * width - 1
  42. return relative_coords.sum(-1)
  43. class MBConv(nn.Module):
  44. """MBConv: Mobile Inverted Residual Bottleneck.
  45. Args:
  46. in_channels (int): Number of input channels.
  47. out_channels (int): Number of output channels.
  48. expansion_ratio (float): Expansion ratio in the bottleneck.
  49. squeeze_ratio (float): Squeeze ratio in the SE Layer.
  50. stride (int): Stride of the depthwise convolution.
  51. activation_layer (Callable[..., nn.Module]): Activation function.
  52. norm_layer (Callable[..., nn.Module]): Normalization function.
  53. p_stochastic_dropout (float): Probability of stochastic depth.
  54. """
  55. def __init__(
  56. self,
  57. in_channels: int,
  58. out_channels: int,
  59. expansion_ratio: float,
  60. squeeze_ratio: float,
  61. stride: int,
  62. activation_layer: Callable[..., nn.Module],
  63. norm_layer: Callable[..., nn.Module],
  64. p_stochastic_dropout: float = 0.0,
  65. ) -> None:
  66. super().__init__()
  67. proj: Sequence[nn.Module]
  68. self.proj: nn.Module
  69. should_proj = stride != 1 or in_channels != out_channels
  70. if should_proj:
  71. proj = [nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=True)]
  72. if stride == 2:
  73. proj = [nn.AvgPool2d(kernel_size=3, stride=stride, padding=1)] + proj # type: ignore
  74. self.proj = nn.Sequential(*proj)
  75. else:
  76. self.proj = nn.Identity() # type: ignore
  77. mid_channels = int(out_channels * expansion_ratio)
  78. sqz_channels = int(out_channels * squeeze_ratio)
  79. if p_stochastic_dropout:
  80. self.stochastic_depth = StochasticDepth(p_stochastic_dropout, mode="row") # type: ignore
  81. else:
  82. self.stochastic_depth = nn.Identity() # type: ignore
  83. _layers = OrderedDict()
  84. _layers["pre_norm"] = norm_layer(in_channels)
  85. _layers["conv_a"] = Conv2dNormActivation(
  86. in_channels,
  87. mid_channels,
  88. kernel_size=1,
  89. stride=1,
  90. padding=0,
  91. activation_layer=activation_layer,
  92. norm_layer=norm_layer,
  93. inplace=None,
  94. )
  95. _layers["conv_b"] = Conv2dNormActivation(
  96. mid_channels,
  97. mid_channels,
  98. kernel_size=3,
  99. stride=stride,
  100. padding=1,
  101. activation_layer=activation_layer,
  102. norm_layer=norm_layer,
  103. groups=mid_channels,
  104. inplace=None,
  105. )
  106. _layers["squeeze_excitation"] = SqueezeExcitation(mid_channels, sqz_channels, activation=nn.SiLU)
  107. _layers["conv_c"] = nn.Conv2d(in_channels=mid_channels, out_channels=out_channels, kernel_size=1, bias=True)
  108. self.layers = nn.Sequential(_layers)
  109. def forward(self, x: Tensor) -> Tensor:
  110. """
  111. Args:
  112. x (Tensor): Input tensor with expected layout of [B, C, H, W].
  113. Returns:
  114. Tensor: Output tensor with expected layout of [B, C, H / stride, W / stride].
  115. """
  116. res = self.proj(x)
  117. x = self.stochastic_depth(self.layers(x))
  118. return res + x
  119. class RelativePositionalMultiHeadAttention(nn.Module):
  120. """Relative Positional Multi-Head Attention.
  121. Args:
  122. feat_dim (int): Number of input features.
  123. head_dim (int): Number of features per head.
  124. max_seq_len (int): Maximum sequence length.
  125. """
  126. def __init__(
  127. self,
  128. feat_dim: int,
  129. head_dim: int,
  130. max_seq_len: int,
  131. ) -> None:
  132. super().__init__()
  133. if feat_dim % head_dim != 0:
  134. raise ValueError(f"feat_dim: {feat_dim} must be divisible by head_dim: {head_dim}")
  135. self.n_heads = feat_dim // head_dim
  136. self.head_dim = head_dim
  137. self.size = int(math.sqrt(max_seq_len))
  138. self.max_seq_len = max_seq_len
  139. self.to_qkv = nn.Linear(feat_dim, self.n_heads * self.head_dim * 3)
  140. self.scale_factor = feat_dim**-0.5
  141. self.merge = nn.Linear(self.head_dim * self.n_heads, feat_dim)
  142. self.relative_position_bias_table = nn.parameter.Parameter(
  143. torch.empty(((2 * self.size - 1) * (2 * self.size - 1), self.n_heads), dtype=torch.float32),
  144. )
  145. self.register_buffer("relative_position_index", _get_relative_position_index(self.size, self.size))
  146. # initialize with truncated normal the bias
  147. torch.nn.init.trunc_normal_(self.relative_position_bias_table, std=0.02)
  148. def get_relative_positional_bias(self) -> torch.Tensor:
  149. bias_index = self.relative_position_index.view(-1) # type: ignore
  150. relative_bias = self.relative_position_bias_table[bias_index].view(self.max_seq_len, self.max_seq_len, -1) # type: ignore
  151. relative_bias = relative_bias.permute(2, 0, 1).contiguous()
  152. return relative_bias.unsqueeze(0)
  153. def forward(self, x: Tensor) -> Tensor:
  154. """
  155. Args:
  156. x (Tensor): Input tensor with expected layout of [B, G, P, D].
  157. Returns:
  158. Tensor: Output tensor with expected layout of [B, G, P, D].
  159. """
  160. B, G, P, D = x.shape
  161. H, DH = self.n_heads, self.head_dim
  162. qkv = self.to_qkv(x)
  163. q, k, v = torch.chunk(qkv, 3, dim=-1)
  164. q = q.reshape(B, G, P, H, DH).permute(0, 1, 3, 2, 4)
  165. k = k.reshape(B, G, P, H, DH).permute(0, 1, 3, 2, 4)
  166. v = v.reshape(B, G, P, H, DH).permute(0, 1, 3, 2, 4)
  167. k = k * self.scale_factor
  168. dot_prod = torch.einsum("B G H I D, B G H J D -> B G H I J", q, k)
  169. pos_bias = self.get_relative_positional_bias()
  170. dot_prod = F.softmax(dot_prod + pos_bias, dim=-1)
  171. out = torch.einsum("B G H I J, B G H J D -> B G H I D", dot_prod, v)
  172. out = out.permute(0, 1, 3, 2, 4).reshape(B, G, P, D)
  173. out = self.merge(out)
  174. return out
  175. class SwapAxes(nn.Module):
  176. """Permute the axes of a tensor."""
  177. def __init__(self, a: int, b: int) -> None:
  178. super().__init__()
  179. self.a = a
  180. self.b = b
  181. def forward(self, x: torch.Tensor) -> torch.Tensor:
  182. res = torch.swapaxes(x, self.a, self.b)
  183. return res
  184. class WindowPartition(nn.Module):
  185. """
  186. Partition the input tensor into non-overlapping windows.
  187. """
  188. def __init__(self) -> None:
  189. super().__init__()
  190. def forward(self, x: Tensor, p: int) -> Tensor:
  191. """
  192. Args:
  193. x (Tensor): Input tensor with expected layout of [B, C, H, W].
  194. p (int): Number of partitions.
  195. Returns:
  196. Tensor: Output tensor with expected layout of [B, H/P, W/P, P*P, C].
  197. """
  198. B, C, H, W = x.shape
  199. P = p
  200. # chunk up H and W dimensions
  201. x = x.reshape(B, C, H // P, P, W // P, P)
  202. x = x.permute(0, 2, 4, 3, 5, 1)
  203. # colapse P * P dimension
  204. x = x.reshape(B, (H // P) * (W // P), P * P, C)
  205. return x
  206. class WindowDepartition(nn.Module):
  207. """
  208. Departition the input tensor of non-overlapping windows into a feature volume of layout [B, C, H, W].
  209. """
  210. def __init__(self) -> None:
  211. super().__init__()
  212. def forward(self, x: Tensor, p: int, h_partitions: int, w_partitions: int) -> Tensor:
  213. """
  214. Args:
  215. x (Tensor): Input tensor with expected layout of [B, (H/P * W/P), P*P, C].
  216. p (int): Number of partitions.
  217. h_partitions (int): Number of vertical partitions.
  218. w_partitions (int): Number of horizontal partitions.
  219. Returns:
  220. Tensor: Output tensor with expected layout of [B, C, H, W].
  221. """
  222. B, G, PP, C = x.shape
  223. P = p
  224. HP, WP = h_partitions, w_partitions
  225. # split P * P dimension into 2 P tile dimensionsa
  226. x = x.reshape(B, HP, WP, P, P, C)
  227. # permute into B, C, HP, P, WP, P
  228. x = x.permute(0, 5, 1, 3, 2, 4)
  229. # reshape into B, C, H, W
  230. x = x.reshape(B, C, HP * P, WP * P)
  231. return x
  232. class PartitionAttentionLayer(nn.Module):
  233. """
  234. Layer for partitioning the input tensor into non-overlapping windows and applying attention to each window.
  235. Args:
  236. in_channels (int): Number of input channels.
  237. head_dim (int): Dimension of each attention head.
  238. partition_size (int): Size of the partitions.
  239. partition_type (str): Type of partitioning to use. Can be either "grid" or "window".
  240. grid_size (Tuple[int, int]): Size of the grid to partition the input tensor into.
  241. mlp_ratio (int): Ratio of the feature size expansion in the MLP layer.
  242. activation_layer (Callable[..., nn.Module]): Activation function to use.
  243. norm_layer (Callable[..., nn.Module]): Normalization function to use.
  244. attention_dropout (float): Dropout probability for the attention layer.
  245. mlp_dropout (float): Dropout probability for the MLP layer.
  246. p_stochastic_dropout (float): Probability of dropping out a partition.
  247. """
  248. def __init__(
  249. self,
  250. in_channels: int,
  251. head_dim: int,
  252. # partitioning parameters
  253. partition_size: int,
  254. partition_type: str,
  255. # grid size needs to be known at initialization time
  256. # because we need to know hamy relative offsets there are in the grid
  257. grid_size: Tuple[int, int],
  258. mlp_ratio: int,
  259. activation_layer: Callable[..., nn.Module],
  260. norm_layer: Callable[..., nn.Module],
  261. attention_dropout: float,
  262. mlp_dropout: float,
  263. p_stochastic_dropout: float,
  264. ) -> None:
  265. super().__init__()
  266. self.n_heads = in_channels // head_dim
  267. self.head_dim = head_dim
  268. self.n_partitions = grid_size[0] // partition_size
  269. self.partition_type = partition_type
  270. self.grid_size = grid_size
  271. if partition_type not in ["grid", "window"]:
  272. raise ValueError("partition_type must be either 'grid' or 'window'")
  273. if partition_type == "window":
  274. self.p, self.g = partition_size, self.n_partitions
  275. else:
  276. self.p, self.g = self.n_partitions, partition_size
  277. self.partition_op = WindowPartition()
  278. self.departition_op = WindowDepartition()
  279. self.partition_swap = SwapAxes(-2, -3) if partition_type == "grid" else nn.Identity()
  280. self.departition_swap = SwapAxes(-2, -3) if partition_type == "grid" else nn.Identity()
  281. self.attn_layer = nn.Sequential(
  282. norm_layer(in_channels),
  283. # it's always going to be partition_size ** 2 because
  284. # of the axis swap in the case of grid partitioning
  285. RelativePositionalMultiHeadAttention(in_channels, head_dim, partition_size**2),
  286. nn.Dropout(attention_dropout),
  287. )
  288. # pre-normalization similar to transformer layers
  289. self.mlp_layer = nn.Sequential(
  290. nn.LayerNorm(in_channels),
  291. nn.Linear(in_channels, in_channels * mlp_ratio),
  292. activation_layer(),
  293. nn.Linear(in_channels * mlp_ratio, in_channels),
  294. nn.Dropout(mlp_dropout),
  295. )
  296. # layer scale factors
  297. self.stochastic_dropout = StochasticDepth(p_stochastic_dropout, mode="row")
  298. def forward(self, x: Tensor) -> Tensor:
  299. """
  300. Args:
  301. x (Tensor): Input tensor with expected layout of [B, C, H, W].
  302. Returns:
  303. Tensor: Output tensor with expected layout of [B, C, H, W].
  304. """
  305. # Undefined behavior if H or W are not divisible by p
  306. # https://github.com/google-research/maxvit/blob/da76cf0d8a6ec668cc31b399c4126186da7da944/maxvit/models/maxvit.py#L766
  307. gh, gw = self.grid_size[0] // self.p, self.grid_size[1] // self.p
  308. torch._assert(
  309. self.grid_size[0] % self.p == 0 and self.grid_size[1] % self.p == 0,
  310. "Grid size must be divisible by partition size. Got grid size of {} and partition size of {}".format(
  311. self.grid_size, self.p
  312. ),
  313. )
  314. x = self.partition_op(x, self.p)
  315. x = self.partition_swap(x)
  316. x = x + self.stochastic_dropout(self.attn_layer(x))
  317. x = x + self.stochastic_dropout(self.mlp_layer(x))
  318. x = self.departition_swap(x)
  319. x = self.departition_op(x, self.p, gh, gw)
  320. return x
  321. class MaxVitLayer(nn.Module):
  322. """
  323. MaxVit layer consisting of a MBConv layer followed by a PartitionAttentionLayer with `window` and a PartitionAttentionLayer with `grid`.
  324. Args:
  325. in_channels (int): Number of input channels.
  326. out_channels (int): Number of output channels.
  327. expansion_ratio (float): Expansion ratio in the bottleneck.
  328. squeeze_ratio (float): Squeeze ratio in the SE Layer.
  329. stride (int): Stride of the depthwise convolution.
  330. activation_layer (Callable[..., nn.Module]): Activation function.
  331. norm_layer (Callable[..., nn.Module]): Normalization function.
  332. head_dim (int): Dimension of the attention heads.
  333. mlp_ratio (int): Ratio of the MLP layer.
  334. mlp_dropout (float): Dropout probability for the MLP layer.
  335. attention_dropout (float): Dropout probability for the attention layer.
  336. p_stochastic_dropout (float): Probability of stochastic depth.
  337. partition_size (int): Size of the partitions.
  338. grid_size (Tuple[int, int]): Size of the input feature grid.
  339. """
  340. def __init__(
  341. self,
  342. # conv parameters
  343. in_channels: int,
  344. out_channels: int,
  345. squeeze_ratio: float,
  346. expansion_ratio: float,
  347. stride: int,
  348. # conv + transformer parameters
  349. norm_layer: Callable[..., nn.Module],
  350. activation_layer: Callable[..., nn.Module],
  351. # transformer parameters
  352. head_dim: int,
  353. mlp_ratio: int,
  354. mlp_dropout: float,
  355. attention_dropout: float,
  356. p_stochastic_dropout: float,
  357. # partitioning parameters
  358. partition_size: int,
  359. grid_size: Tuple[int, int],
  360. ) -> None:
  361. super().__init__()
  362. layers: OrderedDict = OrderedDict()
  363. # convolutional layer
  364. layers["MBconv"] = MBConv(
  365. in_channels=in_channels,
  366. out_channels=out_channels,
  367. expansion_ratio=expansion_ratio,
  368. squeeze_ratio=squeeze_ratio,
  369. stride=stride,
  370. activation_layer=activation_layer,
  371. norm_layer=norm_layer,
  372. p_stochastic_dropout=p_stochastic_dropout,
  373. )
  374. # attention layers, block -> grid
  375. layers["window_attention"] = PartitionAttentionLayer(
  376. in_channels=out_channels,
  377. head_dim=head_dim,
  378. partition_size=partition_size,
  379. partition_type="window",
  380. grid_size=grid_size,
  381. mlp_ratio=mlp_ratio,
  382. activation_layer=activation_layer,
  383. norm_layer=nn.LayerNorm,
  384. attention_dropout=attention_dropout,
  385. mlp_dropout=mlp_dropout,
  386. p_stochastic_dropout=p_stochastic_dropout,
  387. )
  388. layers["grid_attention"] = PartitionAttentionLayer(
  389. in_channels=out_channels,
  390. head_dim=head_dim,
  391. partition_size=partition_size,
  392. partition_type="grid",
  393. grid_size=grid_size,
  394. mlp_ratio=mlp_ratio,
  395. activation_layer=activation_layer,
  396. norm_layer=nn.LayerNorm,
  397. attention_dropout=attention_dropout,
  398. mlp_dropout=mlp_dropout,
  399. p_stochastic_dropout=p_stochastic_dropout,
  400. )
  401. self.layers = nn.Sequential(layers)
  402. def forward(self, x: Tensor) -> Tensor:
  403. """
  404. Args:
  405. x (Tensor): Input tensor of shape (B, C, H, W).
  406. Returns:
  407. Tensor: Output tensor of shape (B, C, H, W).
  408. """
  409. x = self.layers(x)
  410. return x
  411. class MaxVitBlock(nn.Module):
  412. """
  413. A MaxVit block consisting of `n_layers` MaxVit layers.
  414. Args:
  415. in_channels (int): Number of input channels.
  416. out_channels (int): Number of output channels.
  417. expansion_ratio (float): Expansion ratio in the bottleneck.
  418. squeeze_ratio (float): Squeeze ratio in the SE Layer.
  419. activation_layer (Callable[..., nn.Module]): Activation function.
  420. norm_layer (Callable[..., nn.Module]): Normalization function.
  421. head_dim (int): Dimension of the attention heads.
  422. mlp_ratio (int): Ratio of the MLP layer.
  423. mlp_dropout (float): Dropout probability for the MLP layer.
  424. attention_dropout (float): Dropout probability for the attention layer.
  425. p_stochastic_dropout (float): Probability of stochastic depth.
  426. partition_size (int): Size of the partitions.
  427. input_grid_size (Tuple[int, int]): Size of the input feature grid.
  428. n_layers (int): Number of layers in the block.
  429. p_stochastic (List[float]): List of probabilities for stochastic depth for each layer.
  430. """
  431. def __init__(
  432. self,
  433. # conv parameters
  434. in_channels: int,
  435. out_channels: int,
  436. squeeze_ratio: float,
  437. expansion_ratio: float,
  438. # conv + transformer parameters
  439. norm_layer: Callable[..., nn.Module],
  440. activation_layer: Callable[..., nn.Module],
  441. # transformer parameters
  442. head_dim: int,
  443. mlp_ratio: int,
  444. mlp_dropout: float,
  445. attention_dropout: float,
  446. # partitioning parameters
  447. partition_size: int,
  448. input_grid_size: Tuple[int, int],
  449. # number of layers
  450. n_layers: int,
  451. p_stochastic: List[float],
  452. ) -> None:
  453. super().__init__()
  454. if not len(p_stochastic) == n_layers:
  455. raise ValueError(f"p_stochastic must have length n_layers={n_layers}, got p_stochastic={p_stochastic}.")
  456. self.layers = nn.ModuleList()
  457. # account for the first stride of the first layer
  458. self.grid_size = _get_conv_output_shape(input_grid_size, kernel_size=3, stride=2, padding=1)
  459. for idx, p in enumerate(p_stochastic):
  460. stride = 2 if idx == 0 else 1
  461. self.layers += [
  462. MaxVitLayer(
  463. in_channels=in_channels if idx == 0 else out_channels,
  464. out_channels=out_channels,
  465. squeeze_ratio=squeeze_ratio,
  466. expansion_ratio=expansion_ratio,
  467. stride=stride,
  468. norm_layer=norm_layer,
  469. activation_layer=activation_layer,
  470. head_dim=head_dim,
  471. mlp_ratio=mlp_ratio,
  472. mlp_dropout=mlp_dropout,
  473. attention_dropout=attention_dropout,
  474. partition_size=partition_size,
  475. grid_size=self.grid_size,
  476. p_stochastic_dropout=p,
  477. ),
  478. ]
  479. def forward(self, x: Tensor) -> Tensor:
  480. """
  481. Args:
  482. x (Tensor): Input tensor of shape (B, C, H, W).
  483. Returns:
  484. Tensor: Output tensor of shape (B, C, H, W).
  485. """
  486. for layer in self.layers:
  487. x = layer(x)
  488. return x
  489. class MaxVit(nn.Module):
  490. """
  491. Implements MaxVit Transformer from the `MaxViT: Multi-Axis Vision Transformer <https://arxiv.org/abs/2204.01697>`_ paper.
  492. Args:
  493. input_size (Tuple[int, int]): Size of the input image.
  494. stem_channels (int): Number of channels in the stem.
  495. partition_size (int): Size of the partitions.
  496. block_channels (List[int]): Number of channels in each block.
  497. block_layers (List[int]): Number of layers in each block.
  498. stochastic_depth_prob (float): Probability of stochastic depth. Expands to a list of probabilities for each layer that scales linearly to the specified value.
  499. squeeze_ratio (float): Squeeze ratio in the SE Layer. Default: 0.25.
  500. expansion_ratio (float): Expansion ratio in the MBConv bottleneck. Default: 4.
  501. norm_layer (Callable[..., nn.Module]): Normalization function. Default: None (setting to None will produce a `BatchNorm2d(eps=1e-3, momentum=0.99)`).
  502. activation_layer (Callable[..., nn.Module]): Activation function Default: nn.GELU.
  503. head_dim (int): Dimension of the attention heads.
  504. mlp_ratio (int): Expansion ratio of the MLP layer. Default: 4.
  505. mlp_dropout (float): Dropout probability for the MLP layer. Default: 0.0.
  506. attention_dropout (float): Dropout probability for the attention layer. Default: 0.0.
  507. num_classes (int): Number of classes. Default: 1000.
  508. """
  509. def __init__(
  510. self,
  511. # input size parameters
  512. input_size: Tuple[int, int],
  513. # stem and task parameters
  514. stem_channels: int,
  515. # partitioning parameters
  516. partition_size: int,
  517. # block parameters
  518. block_channels: List[int],
  519. block_layers: List[int],
  520. # attention head dimensions
  521. head_dim: int,
  522. stochastic_depth_prob: float,
  523. # conv + transformer parameters
  524. # norm_layer is applied only to the conv layers
  525. # activation_layer is applied both to conv and transformer layers
  526. norm_layer: Optional[Callable[..., nn.Module]] = None,
  527. activation_layer: Callable[..., nn.Module] = nn.GELU,
  528. # conv parameters
  529. squeeze_ratio: float = 0.25,
  530. expansion_ratio: float = 4,
  531. # transformer parameters
  532. mlp_ratio: int = 4,
  533. mlp_dropout: float = 0.0,
  534. attention_dropout: float = 0.0,
  535. # task parameters
  536. num_classes: int = 1000,
  537. ) -> None:
  538. super().__init__()
  539. _log_api_usage_once(self)
  540. input_channels = 3
  541. # https://github.com/google-research/maxvit/blob/da76cf0d8a6ec668cc31b399c4126186da7da944/maxvit/models/maxvit.py#L1029-L1030
  542. # for the exact parameters used in batchnorm
  543. if norm_layer is None:
  544. norm_layer = partial(nn.BatchNorm2d, eps=1e-3, momentum=0.99)
  545. # Make sure input size will be divisible by the partition size in all blocks
  546. # Undefined behavior if H or W are not divisible by p
  547. # https://github.com/google-research/maxvit/blob/da76cf0d8a6ec668cc31b399c4126186da7da944/maxvit/models/maxvit.py#L766
  548. block_input_sizes = _make_block_input_shapes(input_size, len(block_channels))
  549. for idx, block_input_size in enumerate(block_input_sizes):
  550. if block_input_size[0] % partition_size != 0 or block_input_size[1] % partition_size != 0:
  551. raise ValueError(
  552. f"Input size {block_input_size} of block {idx} is not divisible by partition size {partition_size}. "
  553. f"Consider changing the partition size or the input size.\n"
  554. f"Current configuration yields the following block input sizes: {block_input_sizes}."
  555. )
  556. # stem
  557. self.stem = nn.Sequential(
  558. Conv2dNormActivation(
  559. input_channels,
  560. stem_channels,
  561. 3,
  562. stride=2,
  563. norm_layer=norm_layer,
  564. activation_layer=activation_layer,
  565. bias=False,
  566. inplace=None,
  567. ),
  568. Conv2dNormActivation(
  569. stem_channels, stem_channels, 3, stride=1, norm_layer=None, activation_layer=None, bias=True
  570. ),
  571. )
  572. # account for stem stride
  573. input_size = _get_conv_output_shape(input_size, kernel_size=3, stride=2, padding=1)
  574. self.partition_size = partition_size
  575. # blocks
  576. self.blocks = nn.ModuleList()
  577. in_channels = [stem_channels] + block_channels[:-1]
  578. out_channels = block_channels
  579. # precompute the stochastich depth probabilities from 0 to stochastic_depth_prob
  580. # since we have N blocks with L layers, we will have N * L probabilities uniformly distributed
  581. # over the range [0, stochastic_depth_prob]
  582. p_stochastic = np.linspace(0, stochastic_depth_prob, sum(block_layers)).tolist()
  583. p_idx = 0
  584. for in_channel, out_channel, num_layers in zip(in_channels, out_channels, block_layers):
  585. self.blocks.append(
  586. MaxVitBlock(
  587. in_channels=in_channel,
  588. out_channels=out_channel,
  589. squeeze_ratio=squeeze_ratio,
  590. expansion_ratio=expansion_ratio,
  591. norm_layer=norm_layer,
  592. activation_layer=activation_layer,
  593. head_dim=head_dim,
  594. mlp_ratio=mlp_ratio,
  595. mlp_dropout=mlp_dropout,
  596. attention_dropout=attention_dropout,
  597. partition_size=partition_size,
  598. input_grid_size=input_size,
  599. n_layers=num_layers,
  600. p_stochastic=p_stochastic[p_idx : p_idx + num_layers],
  601. ),
  602. )
  603. input_size = self.blocks[-1].grid_size
  604. p_idx += num_layers
  605. # see https://github.com/google-research/maxvit/blob/da76cf0d8a6ec668cc31b399c4126186da7da944/maxvit/models/maxvit.py#L1137-L1158
  606. # for why there is Linear -> Tanh -> Linear
  607. self.classifier = nn.Sequential(
  608. nn.AdaptiveAvgPool2d(1),
  609. nn.Flatten(),
  610. nn.LayerNorm(block_channels[-1]),
  611. nn.Linear(block_channels[-1], block_channels[-1]),
  612. nn.Tanh(),
  613. nn.Linear(block_channels[-1], num_classes, bias=False),
  614. )
  615. self._init_weights()
  616. def forward(self, x: Tensor) -> Tensor:
  617. x = self.stem(x)
  618. for block in self.blocks:
  619. x = block(x)
  620. x = self.classifier(x)
  621. return x
  622. def _init_weights(self):
  623. for m in self.modules():
  624. if isinstance(m, nn.Conv2d):
  625. nn.init.normal_(m.weight, std=0.02)
  626. if m.bias is not None:
  627. nn.init.zeros_(m.bias)
  628. elif isinstance(m, nn.BatchNorm2d):
  629. nn.init.constant_(m.weight, 1)
  630. nn.init.constant_(m.bias, 0)
  631. elif isinstance(m, nn.Linear):
  632. nn.init.normal_(m.weight, std=0.02)
  633. if m.bias is not None:
  634. nn.init.zeros_(m.bias)
  635. def _maxvit(
  636. # stem parameters
  637. stem_channels: int,
  638. # block parameters
  639. block_channels: List[int],
  640. block_layers: List[int],
  641. stochastic_depth_prob: float,
  642. # partitioning parameters
  643. partition_size: int,
  644. # transformer parameters
  645. head_dim: int,
  646. # Weights API
  647. weights: Optional[WeightsEnum] = None,
  648. progress: bool = False,
  649. # kwargs,
  650. **kwargs: Any,
  651. ) -> MaxVit:
  652. if weights is not None:
  653. _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
  654. assert weights.meta["min_size"][0] == weights.meta["min_size"][1]
  655. _ovewrite_named_param(kwargs, "input_size", weights.meta["min_size"])
  656. input_size = kwargs.pop("input_size", (224, 224))
  657. model = MaxVit(
  658. stem_channels=stem_channels,
  659. block_channels=block_channels,
  660. block_layers=block_layers,
  661. stochastic_depth_prob=stochastic_depth_prob,
  662. head_dim=head_dim,
  663. partition_size=partition_size,
  664. input_size=input_size,
  665. **kwargs,
  666. )
  667. if weights is not None:
  668. model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
  669. return model
  670. class MaxVit_T_Weights(WeightsEnum):
  671. IMAGENET1K_V1 = Weights(
  672. # URL empty until official release
  673. url="https://download.pytorch.org/models/maxvit_t-bc5ab103.pth",
  674. transforms=partial(
  675. ImageClassification, crop_size=224, resize_size=224, interpolation=InterpolationMode.BICUBIC
  676. ),
  677. meta={
  678. "categories": _IMAGENET_CATEGORIES,
  679. "num_params": 30919624,
  680. "min_size": (224, 224),
  681. "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#maxvit",
  682. "_metrics": {
  683. "ImageNet-1K": {
  684. "acc@1": 83.700,
  685. "acc@5": 96.722,
  686. }
  687. },
  688. "_ops": 5.558,
  689. "_file_size": 118.769,
  690. "_docs": """These weights reproduce closely the results of the paper using a similar training recipe.""",
  691. },
  692. )
  693. DEFAULT = IMAGENET1K_V1
  694. @register_model()
  695. @handle_legacy_interface(weights=("pretrained", MaxVit_T_Weights.IMAGENET1K_V1))
  696. def maxvit_t(*, weights: Optional[MaxVit_T_Weights] = None, progress: bool = True, **kwargs: Any) -> MaxVit:
  697. """
  698. Constructs a maxvit_t architecture from
  699. `MaxViT: Multi-Axis Vision Transformer <https://arxiv.org/abs/2204.01697>`_.
  700. Args:
  701. weights (:class:`~torchvision.models.MaxVit_T_Weights`, optional): The
  702. pretrained weights to use. See
  703. :class:`~torchvision.models.MaxVit_T_Weights` below for
  704. more details, and possible values. By default, no pre-trained
  705. weights are used.
  706. progress (bool, optional): If True, displays a progress bar of the
  707. download to stderr. Default is True.
  708. **kwargs: parameters passed to the ``torchvision.models.maxvit.MaxVit``
  709. base class. Please refer to the `source code
  710. <https://github.com/pytorch/vision/blob/main/torchvision/models/maxvit.py>`_
  711. for more details about this class.
  712. .. autoclass:: torchvision.models.MaxVit_T_Weights
  713. :members:
  714. """
  715. weights = MaxVit_T_Weights.verify(weights)
  716. return _maxvit(
  717. stem_channels=64,
  718. block_channels=[64, 128, 256, 512],
  719. block_layers=[2, 2, 5, 2],
  720. head_dim=32,
  721. stochastic_depth_prob=0.2,
  722. partition_size=7,
  723. weights=weights,
  724. progress=progress,
  725. **kwargs,
  726. )