vision_transformer.py 31 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864
  1. import math
  2. from collections import OrderedDict
  3. from functools import partial
  4. from typing import Any, Callable, Dict, List, NamedTuple, Optional
  5. import torch
  6. import torch.nn as nn
  7. from ..ops.misc import Conv2dNormActivation, MLP
  8. from ..transforms._presets import ImageClassification, InterpolationMode
  9. from ..utils import _log_api_usage_once
  10. from ._api import register_model, Weights, WeightsEnum
  11. from ._meta import _IMAGENET_CATEGORIES
  12. from ._utils import _ovewrite_named_param, handle_legacy_interface
  13. __all__ = [
  14. "VisionTransformer",
  15. "ViT_B_16_Weights",
  16. "ViT_B_32_Weights",
  17. "ViT_L_16_Weights",
  18. "ViT_L_32_Weights",
  19. "ViT_H_14_Weights",
  20. "vit_b_16",
  21. "vit_b_32",
  22. "vit_l_16",
  23. "vit_l_32",
  24. "vit_h_14",
  25. ]
  26. class ConvStemConfig(NamedTuple):
  27. out_channels: int
  28. kernel_size: int
  29. stride: int
  30. norm_layer: Callable[..., nn.Module] = nn.BatchNorm2d
  31. activation_layer: Callable[..., nn.Module] = nn.ReLU
  32. class MLPBlock(MLP):
  33. """Transformer MLP block."""
  34. _version = 2
  35. def __init__(self, in_dim: int, mlp_dim: int, dropout: float):
  36. super().__init__(in_dim, [mlp_dim, in_dim], activation_layer=nn.GELU, inplace=None, dropout=dropout)
  37. for m in self.modules():
  38. if isinstance(m, nn.Linear):
  39. nn.init.xavier_uniform_(m.weight)
  40. if m.bias is not None:
  41. nn.init.normal_(m.bias, std=1e-6)
  42. def _load_from_state_dict(
  43. self,
  44. state_dict,
  45. prefix,
  46. local_metadata,
  47. strict,
  48. missing_keys,
  49. unexpected_keys,
  50. error_msgs,
  51. ):
  52. version = local_metadata.get("version", None)
  53. if version is None or version < 2:
  54. # Replacing legacy MLPBlock with MLP. See https://github.com/pytorch/vision/pull/6053
  55. for i in range(2):
  56. for type in ["weight", "bias"]:
  57. old_key = f"{prefix}linear_{i+1}.{type}"
  58. new_key = f"{prefix}{3*i}.{type}"
  59. if old_key in state_dict:
  60. state_dict[new_key] = state_dict.pop(old_key)
  61. super()._load_from_state_dict(
  62. state_dict,
  63. prefix,
  64. local_metadata,
  65. strict,
  66. missing_keys,
  67. unexpected_keys,
  68. error_msgs,
  69. )
  70. class EncoderBlock(nn.Module):
  71. """Transformer encoder block."""
  72. def __init__(
  73. self,
  74. num_heads: int,
  75. hidden_dim: int,
  76. mlp_dim: int,
  77. dropout: float,
  78. attention_dropout: float,
  79. norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6),
  80. ):
  81. super().__init__()
  82. self.num_heads = num_heads
  83. # Attention block
  84. self.ln_1 = norm_layer(hidden_dim)
  85. self.self_attention = nn.MultiheadAttention(hidden_dim, num_heads, dropout=attention_dropout, batch_first=True)
  86. self.dropout = nn.Dropout(dropout)
  87. # MLP block
  88. self.ln_2 = norm_layer(hidden_dim)
  89. self.mlp = MLPBlock(hidden_dim, mlp_dim, dropout)
  90. def forward(self, input: torch.Tensor):
  91. torch._assert(input.dim() == 3, f"Expected (batch_size, seq_length, hidden_dim) got {input.shape}")
  92. x = self.ln_1(input)
  93. x, _ = self.self_attention(x, x, x, need_weights=False)
  94. x = self.dropout(x)
  95. x = x + input
  96. y = self.ln_2(x)
  97. y = self.mlp(y)
  98. return x + y
  99. class Encoder(nn.Module):
  100. """Transformer Model Encoder for sequence to sequence translation."""
  101. def __init__(
  102. self,
  103. seq_length: int,
  104. num_layers: int,
  105. num_heads: int,
  106. hidden_dim: int,
  107. mlp_dim: int,
  108. dropout: float,
  109. attention_dropout: float,
  110. norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6),
  111. ):
  112. super().__init__()
  113. # Note that batch_size is on the first dim because
  114. # we have batch_first=True in nn.MultiAttention() by default
  115. self.pos_embedding = nn.Parameter(torch.empty(1, seq_length, hidden_dim).normal_(std=0.02)) # from BERT
  116. self.dropout = nn.Dropout(dropout)
  117. layers: OrderedDict[str, nn.Module] = OrderedDict()
  118. for i in range(num_layers):
  119. layers[f"encoder_layer_{i}"] = EncoderBlock(
  120. num_heads,
  121. hidden_dim,
  122. mlp_dim,
  123. dropout,
  124. attention_dropout,
  125. norm_layer,
  126. )
  127. self.layers = nn.Sequential(layers)
  128. self.ln = norm_layer(hidden_dim)
  129. def forward(self, input: torch.Tensor):
  130. torch._assert(input.dim() == 3, f"Expected (batch_size, seq_length, hidden_dim) got {input.shape}")
  131. input = input + self.pos_embedding
  132. return self.ln(self.layers(self.dropout(input)))
  133. class VisionTransformer(nn.Module):
  134. """Vision Transformer as per https://arxiv.org/abs/2010.11929."""
  135. def __init__(
  136. self,
  137. image_size: int,
  138. patch_size: int,
  139. num_layers: int,
  140. num_heads: int,
  141. hidden_dim: int,
  142. mlp_dim: int,
  143. dropout: float = 0.0,
  144. attention_dropout: float = 0.0,
  145. num_classes: int = 1000,
  146. representation_size: Optional[int] = None,
  147. norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6),
  148. conv_stem_configs: Optional[List[ConvStemConfig]] = None,
  149. ):
  150. super().__init__()
  151. _log_api_usage_once(self)
  152. torch._assert(image_size % patch_size == 0, "Input shape indivisible by patch size!")
  153. self.image_size = image_size
  154. self.patch_size = patch_size
  155. self.hidden_dim = hidden_dim
  156. self.mlp_dim = mlp_dim
  157. self.attention_dropout = attention_dropout
  158. self.dropout = dropout
  159. self.num_classes = num_classes
  160. self.representation_size = representation_size
  161. self.norm_layer = norm_layer
  162. if conv_stem_configs is not None:
  163. # As per https://arxiv.org/abs/2106.14881
  164. seq_proj = nn.Sequential()
  165. prev_channels = 3
  166. for i, conv_stem_layer_config in enumerate(conv_stem_configs):
  167. seq_proj.add_module(
  168. f"conv_bn_relu_{i}",
  169. Conv2dNormActivation(
  170. in_channels=prev_channels,
  171. out_channels=conv_stem_layer_config.out_channels,
  172. kernel_size=conv_stem_layer_config.kernel_size,
  173. stride=conv_stem_layer_config.stride,
  174. norm_layer=conv_stem_layer_config.norm_layer,
  175. activation_layer=conv_stem_layer_config.activation_layer,
  176. ),
  177. )
  178. prev_channels = conv_stem_layer_config.out_channels
  179. seq_proj.add_module(
  180. "conv_last", nn.Conv2d(in_channels=prev_channels, out_channels=hidden_dim, kernel_size=1)
  181. )
  182. self.conv_proj: nn.Module = seq_proj
  183. else:
  184. self.conv_proj = nn.Conv2d(
  185. in_channels=3, out_channels=hidden_dim, kernel_size=patch_size, stride=patch_size
  186. )
  187. seq_length = (image_size // patch_size) ** 2
  188. # Add a class token
  189. self.class_token = nn.Parameter(torch.zeros(1, 1, hidden_dim))
  190. seq_length += 1
  191. self.encoder = Encoder(
  192. seq_length,
  193. num_layers,
  194. num_heads,
  195. hidden_dim,
  196. mlp_dim,
  197. dropout,
  198. attention_dropout,
  199. norm_layer,
  200. )
  201. self.seq_length = seq_length
  202. heads_layers: OrderedDict[str, nn.Module] = OrderedDict()
  203. if representation_size is None:
  204. heads_layers["head"] = nn.Linear(hidden_dim, num_classes)
  205. else:
  206. heads_layers["pre_logits"] = nn.Linear(hidden_dim, representation_size)
  207. heads_layers["act"] = nn.Tanh()
  208. heads_layers["head"] = nn.Linear(representation_size, num_classes)
  209. self.heads = nn.Sequential(heads_layers)
  210. if isinstance(self.conv_proj, nn.Conv2d):
  211. # Init the patchify stem
  212. fan_in = self.conv_proj.in_channels * self.conv_proj.kernel_size[0] * self.conv_proj.kernel_size[1]
  213. nn.init.trunc_normal_(self.conv_proj.weight, std=math.sqrt(1 / fan_in))
  214. if self.conv_proj.bias is not None:
  215. nn.init.zeros_(self.conv_proj.bias)
  216. elif self.conv_proj.conv_last is not None and isinstance(self.conv_proj.conv_last, nn.Conv2d):
  217. # Init the last 1x1 conv of the conv stem
  218. nn.init.normal_(
  219. self.conv_proj.conv_last.weight, mean=0.0, std=math.sqrt(2.0 / self.conv_proj.conv_last.out_channels)
  220. )
  221. if self.conv_proj.conv_last.bias is not None:
  222. nn.init.zeros_(self.conv_proj.conv_last.bias)
  223. if hasattr(self.heads, "pre_logits") and isinstance(self.heads.pre_logits, nn.Linear):
  224. fan_in = self.heads.pre_logits.in_features
  225. nn.init.trunc_normal_(self.heads.pre_logits.weight, std=math.sqrt(1 / fan_in))
  226. nn.init.zeros_(self.heads.pre_logits.bias)
  227. if isinstance(self.heads.head, nn.Linear):
  228. nn.init.zeros_(self.heads.head.weight)
  229. nn.init.zeros_(self.heads.head.bias)
  230. def _process_input(self, x: torch.Tensor) -> torch.Tensor:
  231. n, c, h, w = x.shape
  232. p = self.patch_size
  233. torch._assert(h == self.image_size, f"Wrong image height! Expected {self.image_size} but got {h}!")
  234. torch._assert(w == self.image_size, f"Wrong image width! Expected {self.image_size} but got {w}!")
  235. n_h = h // p
  236. n_w = w // p
  237. # (n, c, h, w) -> (n, hidden_dim, n_h, n_w)
  238. x = self.conv_proj(x)
  239. # (n, hidden_dim, n_h, n_w) -> (n, hidden_dim, (n_h * n_w))
  240. x = x.reshape(n, self.hidden_dim, n_h * n_w)
  241. # (n, hidden_dim, (n_h * n_w)) -> (n, (n_h * n_w), hidden_dim)
  242. # The self attention layer expects inputs in the format (N, S, E)
  243. # where S is the source sequence length, N is the batch size, E is the
  244. # embedding dimension
  245. x = x.permute(0, 2, 1)
  246. return x
  247. def forward(self, x: torch.Tensor):
  248. # Reshape and permute the input tensor
  249. x = self._process_input(x)
  250. n = x.shape[0]
  251. # Expand the class token to the full batch
  252. batch_class_token = self.class_token.expand(n, -1, -1)
  253. x = torch.cat([batch_class_token, x], dim=1)
  254. x = self.encoder(x)
  255. # Classifier "token" as used by standard language architectures
  256. x = x[:, 0]
  257. x = self.heads(x)
  258. return x
  259. def _vision_transformer(
  260. patch_size: int,
  261. num_layers: int,
  262. num_heads: int,
  263. hidden_dim: int,
  264. mlp_dim: int,
  265. weights: Optional[WeightsEnum],
  266. progress: bool,
  267. **kwargs: Any,
  268. ) -> VisionTransformer:
  269. if weights is not None:
  270. _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
  271. assert weights.meta["min_size"][0] == weights.meta["min_size"][1]
  272. _ovewrite_named_param(kwargs, "image_size", weights.meta["min_size"][0])
  273. image_size = kwargs.pop("image_size", 224)
  274. model = VisionTransformer(
  275. image_size=image_size,
  276. patch_size=patch_size,
  277. num_layers=num_layers,
  278. num_heads=num_heads,
  279. hidden_dim=hidden_dim,
  280. mlp_dim=mlp_dim,
  281. **kwargs,
  282. )
  283. if weights:
  284. model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
  285. return model
  286. _COMMON_META: Dict[str, Any] = {
  287. "categories": _IMAGENET_CATEGORIES,
  288. }
  289. _COMMON_SWAG_META = {
  290. **_COMMON_META,
  291. "recipe": "https://github.com/facebookresearch/SWAG",
  292. "license": "https://github.com/facebookresearch/SWAG/blob/main/LICENSE",
  293. }
  294. class ViT_B_16_Weights(WeightsEnum):
  295. IMAGENET1K_V1 = Weights(
  296. url="https://download.pytorch.org/models/vit_b_16-c867db91.pth",
  297. transforms=partial(ImageClassification, crop_size=224),
  298. meta={
  299. **_COMMON_META,
  300. "num_params": 86567656,
  301. "min_size": (224, 224),
  302. "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#vit_b_16",
  303. "_metrics": {
  304. "ImageNet-1K": {
  305. "acc@1": 81.072,
  306. "acc@5": 95.318,
  307. }
  308. },
  309. "_ops": 17.564,
  310. "_file_size": 330.285,
  311. "_docs": """
  312. These weights were trained from scratch by using a modified version of `DeIT
  313. <https://arxiv.org/abs/2012.12877>`_'s training recipe.
  314. """,
  315. },
  316. )
  317. IMAGENET1K_SWAG_E2E_V1 = Weights(
  318. url="https://download.pytorch.org/models/vit_b_16_swag-9ac1b537.pth",
  319. transforms=partial(
  320. ImageClassification,
  321. crop_size=384,
  322. resize_size=384,
  323. interpolation=InterpolationMode.BICUBIC,
  324. ),
  325. meta={
  326. **_COMMON_SWAG_META,
  327. "num_params": 86859496,
  328. "min_size": (384, 384),
  329. "_metrics": {
  330. "ImageNet-1K": {
  331. "acc@1": 85.304,
  332. "acc@5": 97.650,
  333. }
  334. },
  335. "_ops": 55.484,
  336. "_file_size": 331.398,
  337. "_docs": """
  338. These weights are learnt via transfer learning by end-to-end fine-tuning the original
  339. `SWAG <https://arxiv.org/abs/2201.08371>`_ weights on ImageNet-1K data.
  340. """,
  341. },
  342. )
  343. IMAGENET1K_SWAG_LINEAR_V1 = Weights(
  344. url="https://download.pytorch.org/models/vit_b_16_lc_swag-4e70ced5.pth",
  345. transforms=partial(
  346. ImageClassification,
  347. crop_size=224,
  348. resize_size=224,
  349. interpolation=InterpolationMode.BICUBIC,
  350. ),
  351. meta={
  352. **_COMMON_SWAG_META,
  353. "recipe": "https://github.com/pytorch/vision/pull/5793",
  354. "num_params": 86567656,
  355. "min_size": (224, 224),
  356. "_metrics": {
  357. "ImageNet-1K": {
  358. "acc@1": 81.886,
  359. "acc@5": 96.180,
  360. }
  361. },
  362. "_ops": 17.564,
  363. "_file_size": 330.285,
  364. "_docs": """
  365. These weights are composed of the original frozen `SWAG <https://arxiv.org/abs/2201.08371>`_ trunk
  366. weights and a linear classifier learnt on top of them trained on ImageNet-1K data.
  367. """,
  368. },
  369. )
  370. DEFAULT = IMAGENET1K_V1
  371. class ViT_B_32_Weights(WeightsEnum):
  372. IMAGENET1K_V1 = Weights(
  373. url="https://download.pytorch.org/models/vit_b_32-d86f8d99.pth",
  374. transforms=partial(ImageClassification, crop_size=224),
  375. meta={
  376. **_COMMON_META,
  377. "num_params": 88224232,
  378. "min_size": (224, 224),
  379. "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#vit_b_32",
  380. "_metrics": {
  381. "ImageNet-1K": {
  382. "acc@1": 75.912,
  383. "acc@5": 92.466,
  384. }
  385. },
  386. "_ops": 4.409,
  387. "_file_size": 336.604,
  388. "_docs": """
  389. These weights were trained from scratch by using a modified version of `DeIT
  390. <https://arxiv.org/abs/2012.12877>`_'s training recipe.
  391. """,
  392. },
  393. )
  394. DEFAULT = IMAGENET1K_V1
  395. class ViT_L_16_Weights(WeightsEnum):
  396. IMAGENET1K_V1 = Weights(
  397. url="https://download.pytorch.org/models/vit_l_16-852ce7e3.pth",
  398. transforms=partial(ImageClassification, crop_size=224, resize_size=242),
  399. meta={
  400. **_COMMON_META,
  401. "num_params": 304326632,
  402. "min_size": (224, 224),
  403. "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#vit_l_16",
  404. "_metrics": {
  405. "ImageNet-1K": {
  406. "acc@1": 79.662,
  407. "acc@5": 94.638,
  408. }
  409. },
  410. "_ops": 61.555,
  411. "_file_size": 1161.023,
  412. "_docs": """
  413. These weights were trained from scratch by using a modified version of TorchVision's
  414. `new training recipe
  415. <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
  416. """,
  417. },
  418. )
  419. IMAGENET1K_SWAG_E2E_V1 = Weights(
  420. url="https://download.pytorch.org/models/vit_l_16_swag-4f3808c9.pth",
  421. transforms=partial(
  422. ImageClassification,
  423. crop_size=512,
  424. resize_size=512,
  425. interpolation=InterpolationMode.BICUBIC,
  426. ),
  427. meta={
  428. **_COMMON_SWAG_META,
  429. "num_params": 305174504,
  430. "min_size": (512, 512),
  431. "_metrics": {
  432. "ImageNet-1K": {
  433. "acc@1": 88.064,
  434. "acc@5": 98.512,
  435. }
  436. },
  437. "_ops": 361.986,
  438. "_file_size": 1164.258,
  439. "_docs": """
  440. These weights are learnt via transfer learning by end-to-end fine-tuning the original
  441. `SWAG <https://arxiv.org/abs/2201.08371>`_ weights on ImageNet-1K data.
  442. """,
  443. },
  444. )
  445. IMAGENET1K_SWAG_LINEAR_V1 = Weights(
  446. url="https://download.pytorch.org/models/vit_l_16_lc_swag-4d563306.pth",
  447. transforms=partial(
  448. ImageClassification,
  449. crop_size=224,
  450. resize_size=224,
  451. interpolation=InterpolationMode.BICUBIC,
  452. ),
  453. meta={
  454. **_COMMON_SWAG_META,
  455. "recipe": "https://github.com/pytorch/vision/pull/5793",
  456. "num_params": 304326632,
  457. "min_size": (224, 224),
  458. "_metrics": {
  459. "ImageNet-1K": {
  460. "acc@1": 85.146,
  461. "acc@5": 97.422,
  462. }
  463. },
  464. "_ops": 61.555,
  465. "_file_size": 1161.023,
  466. "_docs": """
  467. These weights are composed of the original frozen `SWAG <https://arxiv.org/abs/2201.08371>`_ trunk
  468. weights and a linear classifier learnt on top of them trained on ImageNet-1K data.
  469. """,
  470. },
  471. )
  472. DEFAULT = IMAGENET1K_V1
  473. class ViT_L_32_Weights(WeightsEnum):
  474. IMAGENET1K_V1 = Weights(
  475. url="https://download.pytorch.org/models/vit_l_32-c7638314.pth",
  476. transforms=partial(ImageClassification, crop_size=224),
  477. meta={
  478. **_COMMON_META,
  479. "num_params": 306535400,
  480. "min_size": (224, 224),
  481. "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#vit_l_32",
  482. "_metrics": {
  483. "ImageNet-1K": {
  484. "acc@1": 76.972,
  485. "acc@5": 93.07,
  486. }
  487. },
  488. "_ops": 15.378,
  489. "_file_size": 1169.449,
  490. "_docs": """
  491. These weights were trained from scratch by using a modified version of `DeIT
  492. <https://arxiv.org/abs/2012.12877>`_'s training recipe.
  493. """,
  494. },
  495. )
  496. DEFAULT = IMAGENET1K_V1
  497. class ViT_H_14_Weights(WeightsEnum):
  498. IMAGENET1K_SWAG_E2E_V1 = Weights(
  499. url="https://download.pytorch.org/models/vit_h_14_swag-80465313.pth",
  500. transforms=partial(
  501. ImageClassification,
  502. crop_size=518,
  503. resize_size=518,
  504. interpolation=InterpolationMode.BICUBIC,
  505. ),
  506. meta={
  507. **_COMMON_SWAG_META,
  508. "num_params": 633470440,
  509. "min_size": (518, 518),
  510. "_metrics": {
  511. "ImageNet-1K": {
  512. "acc@1": 88.552,
  513. "acc@5": 98.694,
  514. }
  515. },
  516. "_ops": 1016.717,
  517. "_file_size": 2416.643,
  518. "_docs": """
  519. These weights are learnt via transfer learning by end-to-end fine-tuning the original
  520. `SWAG <https://arxiv.org/abs/2201.08371>`_ weights on ImageNet-1K data.
  521. """,
  522. },
  523. )
  524. IMAGENET1K_SWAG_LINEAR_V1 = Weights(
  525. url="https://download.pytorch.org/models/vit_h_14_lc_swag-c1eb923e.pth",
  526. transforms=partial(
  527. ImageClassification,
  528. crop_size=224,
  529. resize_size=224,
  530. interpolation=InterpolationMode.BICUBIC,
  531. ),
  532. meta={
  533. **_COMMON_SWAG_META,
  534. "recipe": "https://github.com/pytorch/vision/pull/5793",
  535. "num_params": 632045800,
  536. "min_size": (224, 224),
  537. "_metrics": {
  538. "ImageNet-1K": {
  539. "acc@1": 85.708,
  540. "acc@5": 97.730,
  541. }
  542. },
  543. "_ops": 167.295,
  544. "_file_size": 2411.209,
  545. "_docs": """
  546. These weights are composed of the original frozen `SWAG <https://arxiv.org/abs/2201.08371>`_ trunk
  547. weights and a linear classifier learnt on top of them trained on ImageNet-1K data.
  548. """,
  549. },
  550. )
  551. DEFAULT = IMAGENET1K_SWAG_E2E_V1
  552. @register_model()
  553. @handle_legacy_interface(weights=("pretrained", ViT_B_16_Weights.IMAGENET1K_V1))
  554. def vit_b_16(*, weights: Optional[ViT_B_16_Weights] = None, progress: bool = True, **kwargs: Any) -> VisionTransformer:
  555. """
  556. Constructs a vit_b_16 architecture from
  557. `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>`_.
  558. Args:
  559. weights (:class:`~torchvision.models.ViT_B_16_Weights`, optional): The pretrained
  560. weights to use. See :class:`~torchvision.models.ViT_B_16_Weights`
  561. below for more details and possible values. By default, no pre-trained weights are used.
  562. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
  563. **kwargs: parameters passed to the ``torchvision.models.vision_transformer.VisionTransformer``
  564. base class. Please refer to the `source code
  565. <https://github.com/pytorch/vision/blob/main/torchvision/models/vision_transformer.py>`_
  566. for more details about this class.
  567. .. autoclass:: torchvision.models.ViT_B_16_Weights
  568. :members:
  569. """
  570. weights = ViT_B_16_Weights.verify(weights)
  571. return _vision_transformer(
  572. patch_size=16,
  573. num_layers=12,
  574. num_heads=12,
  575. hidden_dim=768,
  576. mlp_dim=3072,
  577. weights=weights,
  578. progress=progress,
  579. **kwargs,
  580. )
  581. @register_model()
  582. @handle_legacy_interface(weights=("pretrained", ViT_B_32_Weights.IMAGENET1K_V1))
  583. def vit_b_32(*, weights: Optional[ViT_B_32_Weights] = None, progress: bool = True, **kwargs: Any) -> VisionTransformer:
  584. """
  585. Constructs a vit_b_32 architecture from
  586. `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>`_.
  587. Args:
  588. weights (:class:`~torchvision.models.ViT_B_32_Weights`, optional): The pretrained
  589. weights to use. See :class:`~torchvision.models.ViT_B_32_Weights`
  590. below for more details and possible values. By default, no pre-trained weights are used.
  591. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
  592. **kwargs: parameters passed to the ``torchvision.models.vision_transformer.VisionTransformer``
  593. base class. Please refer to the `source code
  594. <https://github.com/pytorch/vision/blob/main/torchvision/models/vision_transformer.py>`_
  595. for more details about this class.
  596. .. autoclass:: torchvision.models.ViT_B_32_Weights
  597. :members:
  598. """
  599. weights = ViT_B_32_Weights.verify(weights)
  600. return _vision_transformer(
  601. patch_size=32,
  602. num_layers=12,
  603. num_heads=12,
  604. hidden_dim=768,
  605. mlp_dim=3072,
  606. weights=weights,
  607. progress=progress,
  608. **kwargs,
  609. )
  610. @register_model()
  611. @handle_legacy_interface(weights=("pretrained", ViT_L_16_Weights.IMAGENET1K_V1))
  612. def vit_l_16(*, weights: Optional[ViT_L_16_Weights] = None, progress: bool = True, **kwargs: Any) -> VisionTransformer:
  613. """
  614. Constructs a vit_l_16 architecture from
  615. `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>`_.
  616. Args:
  617. weights (:class:`~torchvision.models.ViT_L_16_Weights`, optional): The pretrained
  618. weights to use. See :class:`~torchvision.models.ViT_L_16_Weights`
  619. below for more details and possible values. By default, no pre-trained weights are used.
  620. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
  621. **kwargs: parameters passed to the ``torchvision.models.vision_transformer.VisionTransformer``
  622. base class. Please refer to the `source code
  623. <https://github.com/pytorch/vision/blob/main/torchvision/models/vision_transformer.py>`_
  624. for more details about this class.
  625. .. autoclass:: torchvision.models.ViT_L_16_Weights
  626. :members:
  627. """
  628. weights = ViT_L_16_Weights.verify(weights)
  629. return _vision_transformer(
  630. patch_size=16,
  631. num_layers=24,
  632. num_heads=16,
  633. hidden_dim=1024,
  634. mlp_dim=4096,
  635. weights=weights,
  636. progress=progress,
  637. **kwargs,
  638. )
  639. @register_model()
  640. @handle_legacy_interface(weights=("pretrained", ViT_L_32_Weights.IMAGENET1K_V1))
  641. def vit_l_32(*, weights: Optional[ViT_L_32_Weights] = None, progress: bool = True, **kwargs: Any) -> VisionTransformer:
  642. """
  643. Constructs a vit_l_32 architecture from
  644. `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>`_.
  645. Args:
  646. weights (:class:`~torchvision.models.ViT_L_32_Weights`, optional): The pretrained
  647. weights to use. See :class:`~torchvision.models.ViT_L_32_Weights`
  648. below for more details and possible values. By default, no pre-trained weights are used.
  649. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
  650. **kwargs: parameters passed to the ``torchvision.models.vision_transformer.VisionTransformer``
  651. base class. Please refer to the `source code
  652. <https://github.com/pytorch/vision/blob/main/torchvision/models/vision_transformer.py>`_
  653. for more details about this class.
  654. .. autoclass:: torchvision.models.ViT_L_32_Weights
  655. :members:
  656. """
  657. weights = ViT_L_32_Weights.verify(weights)
  658. return _vision_transformer(
  659. patch_size=32,
  660. num_layers=24,
  661. num_heads=16,
  662. hidden_dim=1024,
  663. mlp_dim=4096,
  664. weights=weights,
  665. progress=progress,
  666. **kwargs,
  667. )
  668. @register_model()
  669. @handle_legacy_interface(weights=("pretrained", None))
  670. def vit_h_14(*, weights: Optional[ViT_H_14_Weights] = None, progress: bool = True, **kwargs: Any) -> VisionTransformer:
  671. """
  672. Constructs a vit_h_14 architecture from
  673. `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>`_.
  674. Args:
  675. weights (:class:`~torchvision.models.ViT_H_14_Weights`, optional): The pretrained
  676. weights to use. See :class:`~torchvision.models.ViT_H_14_Weights`
  677. below for more details and possible values. By default, no pre-trained weights are used.
  678. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
  679. **kwargs: parameters passed to the ``torchvision.models.vision_transformer.VisionTransformer``
  680. base class. Please refer to the `source code
  681. <https://github.com/pytorch/vision/blob/main/torchvision/models/vision_transformer.py>`_
  682. for more details about this class.
  683. .. autoclass:: torchvision.models.ViT_H_14_Weights
  684. :members:
  685. """
  686. weights = ViT_H_14_Weights.verify(weights)
  687. return _vision_transformer(
  688. patch_size=14,
  689. num_layers=32,
  690. num_heads=16,
  691. hidden_dim=1280,
  692. mlp_dim=5120,
  693. weights=weights,
  694. progress=progress,
  695. **kwargs,
  696. )
  697. def interpolate_embeddings(
  698. image_size: int,
  699. patch_size: int,
  700. model_state: "OrderedDict[str, torch.Tensor]",
  701. interpolation_mode: str = "bicubic",
  702. reset_heads: bool = False,
  703. ) -> "OrderedDict[str, torch.Tensor]":
  704. """This function helps interpolate positional embeddings during checkpoint loading,
  705. especially when you want to apply a pre-trained model on images with different resolution.
  706. Args:
  707. image_size (int): Image size of the new model.
  708. patch_size (int): Patch size of the new model.
  709. model_state (OrderedDict[str, torch.Tensor]): State dict of the pre-trained model.
  710. interpolation_mode (str): The algorithm used for upsampling. Default: bicubic.
  711. reset_heads (bool): If true, not copying the state of heads. Default: False.
  712. Returns:
  713. OrderedDict[str, torch.Tensor]: A state dict which can be loaded into the new model.
  714. """
  715. # Shape of pos_embedding is (1, seq_length, hidden_dim)
  716. pos_embedding = model_state["encoder.pos_embedding"]
  717. n, seq_length, hidden_dim = pos_embedding.shape
  718. if n != 1:
  719. raise ValueError(f"Unexpected position embedding shape: {pos_embedding.shape}")
  720. new_seq_length = (image_size // patch_size) ** 2 + 1
  721. # Need to interpolate the weights for the position embedding.
  722. # We do this by reshaping the positions embeddings to a 2d grid, performing
  723. # an interpolation in the (h, w) space and then reshaping back to a 1d grid.
  724. if new_seq_length != seq_length:
  725. # The class token embedding shouldn't be interpolated, so we split it up.
  726. seq_length -= 1
  727. new_seq_length -= 1
  728. pos_embedding_token = pos_embedding[:, :1, :]
  729. pos_embedding_img = pos_embedding[:, 1:, :]
  730. # (1, seq_length, hidden_dim) -> (1, hidden_dim, seq_length)
  731. pos_embedding_img = pos_embedding_img.permute(0, 2, 1)
  732. seq_length_1d = int(math.sqrt(seq_length))
  733. if seq_length_1d * seq_length_1d != seq_length:
  734. raise ValueError(
  735. f"seq_length is not a perfect square! Instead got seq_length_1d * seq_length_1d = {seq_length_1d * seq_length_1d } and seq_length = {seq_length}"
  736. )
  737. # (1, hidden_dim, seq_length) -> (1, hidden_dim, seq_l_1d, seq_l_1d)
  738. pos_embedding_img = pos_embedding_img.reshape(1, hidden_dim, seq_length_1d, seq_length_1d)
  739. new_seq_length_1d = image_size // patch_size
  740. # Perform interpolation.
  741. # (1, hidden_dim, seq_l_1d, seq_l_1d) -> (1, hidden_dim, new_seq_l_1d, new_seq_l_1d)
  742. new_pos_embedding_img = nn.functional.interpolate(
  743. pos_embedding_img,
  744. size=new_seq_length_1d,
  745. mode=interpolation_mode,
  746. align_corners=True,
  747. )
  748. # (1, hidden_dim, new_seq_l_1d, new_seq_l_1d) -> (1, hidden_dim, new_seq_length)
  749. new_pos_embedding_img = new_pos_embedding_img.reshape(1, hidden_dim, new_seq_length)
  750. # (1, hidden_dim, new_seq_length) -> (1, new_seq_length, hidden_dim)
  751. new_pos_embedding_img = new_pos_embedding_img.permute(0, 2, 1)
  752. new_pos_embedding = torch.cat([pos_embedding_token, new_pos_embedding_img], dim=1)
  753. model_state["encoder.pos_embedding"] = new_pos_embedding
  754. if reset_heads:
  755. model_state_copy: "OrderedDict[str, torch.Tensor]" = OrderedDict()
  756. for k, v in model_state.items():
  757. if not k.startswith("heads"):
  758. model_state_copy[k] = v
  759. model_state = model_state_copy
  760. return model_state