resnet.py 38 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985
  1. from functools import partial
  2. from typing import Any, Callable, List, Optional, Type, Union
  3. import torch
  4. import torch.nn as nn
  5. from torch import Tensor
  6. from ..transforms._presets import ImageClassification
  7. from ..utils import _log_api_usage_once
  8. from ._api import register_model, Weights, WeightsEnum
  9. from ._meta import _IMAGENET_CATEGORIES
  10. from ._utils import _ovewrite_named_param, handle_legacy_interface
  11. __all__ = [
  12. "ResNet",
  13. "ResNet18_Weights",
  14. "ResNet34_Weights",
  15. "ResNet50_Weights",
  16. "ResNet101_Weights",
  17. "ResNet152_Weights",
  18. "ResNeXt50_32X4D_Weights",
  19. "ResNeXt101_32X8D_Weights",
  20. "ResNeXt101_64X4D_Weights",
  21. "Wide_ResNet50_2_Weights",
  22. "Wide_ResNet101_2_Weights",
  23. "resnet18",
  24. "resnet34",
  25. "resnet50",
  26. "resnet101",
  27. "resnet152",
  28. "resnext50_32x4d",
  29. "resnext101_32x8d",
  30. "resnext101_64x4d",
  31. "wide_resnet50_2",
  32. "wide_resnet101_2",
  33. ]
  34. def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d:
  35. """3x3 convolution with padding"""
  36. return nn.Conv2d(
  37. in_planes,
  38. out_planes,
  39. kernel_size=3,
  40. stride=stride,
  41. padding=dilation,
  42. groups=groups,
  43. bias=False,
  44. dilation=dilation,
  45. )
  46. def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
  47. """1x1 convolution"""
  48. return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
  49. class BasicBlock(nn.Module):
  50. expansion: int = 1
  51. def __init__(
  52. self,
  53. inplanes: int,
  54. planes: int,
  55. stride: int = 1,
  56. downsample: Optional[nn.Module] = None,
  57. groups: int = 1,
  58. base_width: int = 64,
  59. dilation: int = 1,
  60. norm_layer: Optional[Callable[..., nn.Module]] = None,
  61. ) -> None:
  62. super().__init__()
  63. if norm_layer is None:
  64. norm_layer = nn.BatchNorm2d
  65. if groups != 1 or base_width != 64:
  66. raise ValueError("BasicBlock only supports groups=1 and base_width=64")
  67. if dilation > 1:
  68. raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
  69. # Both self.conv1 and self.downsample layers downsample the input when stride != 1
  70. self.conv1 = conv3x3(inplanes, planes, stride)
  71. self.bn1 = norm_layer(planes)
  72. self.relu = nn.ReLU(inplace=True)
  73. self.conv2 = conv3x3(planes, planes)
  74. self.bn2 = norm_layer(planes)
  75. self.downsample = downsample
  76. self.stride = stride
  77. def forward(self, x: Tensor) -> Tensor:
  78. identity = x
  79. out = self.conv1(x)
  80. out = self.bn1(out)
  81. out = self.relu(out)
  82. out = self.conv2(out)
  83. out = self.bn2(out)
  84. if self.downsample is not None:
  85. identity = self.downsample(x)
  86. out += identity
  87. out = self.relu(out)
  88. return out
  89. class Bottleneck(nn.Module):
  90. # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
  91. # while original implementation places the stride at the first 1x1 convolution(self.conv1)
  92. # according to "Deep residual learning for image recognition" https://arxiv.org/abs/1512.03385.
  93. # This variant is also known as ResNet V1.5 and improves accuracy according to
  94. # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
  95. expansion: int = 4
  96. def __init__(
  97. self,
  98. inplanes: int,
  99. planes: int,
  100. stride: int = 1,
  101. downsample: Optional[nn.Module] = None,
  102. groups: int = 1,
  103. base_width: int = 64,
  104. dilation: int = 1,
  105. norm_layer: Optional[Callable[..., nn.Module]] = None,
  106. ) -> None:
  107. super().__init__()
  108. if norm_layer is None:
  109. norm_layer = nn.BatchNorm2d
  110. width = int(planes * (base_width / 64.0)) * groups
  111. # Both self.conv2 and self.downsample layers downsample the input when stride != 1
  112. self.conv1 = conv1x1(inplanes, width)
  113. self.bn1 = norm_layer(width)
  114. self.conv2 = conv3x3(width, width, stride, groups, dilation)
  115. self.bn2 = norm_layer(width)
  116. self.conv3 = conv1x1(width, planes * self.expansion)
  117. self.bn3 = norm_layer(planes * self.expansion)
  118. self.relu = nn.ReLU(inplace=True)
  119. self.downsample = downsample
  120. self.stride = stride
  121. def forward(self, x: Tensor) -> Tensor:
  122. identity = x
  123. out = self.conv1(x)
  124. out = self.bn1(out)
  125. out = self.relu(out)
  126. out = self.conv2(out)
  127. out = self.bn2(out)
  128. out = self.relu(out)
  129. out = self.conv3(out)
  130. out = self.bn3(out)
  131. if self.downsample is not None:
  132. identity = self.downsample(x)
  133. out += identity
  134. out = self.relu(out)
  135. return out
  136. class ResNet(nn.Module):
  137. def __init__(
  138. self,
  139. block: Type[Union[BasicBlock, Bottleneck]],
  140. layers: List[int],
  141. num_classes: int = 1000,
  142. zero_init_residual: bool = False,
  143. groups: int = 1,
  144. width_per_group: int = 64,
  145. replace_stride_with_dilation: Optional[List[bool]] = None,
  146. norm_layer: Optional[Callable[..., nn.Module]] = None,
  147. ) -> None:
  148. super().__init__()
  149. _log_api_usage_once(self)
  150. if norm_layer is None:
  151. norm_layer = nn.BatchNorm2d
  152. self._norm_layer = norm_layer
  153. self.inplanes = 64
  154. self.dilation = 1
  155. if replace_stride_with_dilation is None:
  156. # each element in the tuple indicates if we should replace
  157. # the 2x2 stride with a dilated convolution instead
  158. replace_stride_with_dilation = [False, False, False]
  159. if len(replace_stride_with_dilation) != 3:
  160. raise ValueError(
  161. "replace_stride_with_dilation should be None "
  162. f"or a 3-element tuple, got {replace_stride_with_dilation}"
  163. )
  164. self.groups = groups
  165. self.base_width = width_per_group
  166. self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False)
  167. self.bn1 = norm_layer(self.inplanes)
  168. self.relu = nn.ReLU(inplace=True)
  169. self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
  170. self.layer1 = self._make_layer(block, 64, layers[0])
  171. self.layer2 = self._make_layer(block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0])
  172. self.layer3 = self._make_layer(block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1])
  173. self.layer4 = self._make_layer(block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2])
  174. self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
  175. self.fc = nn.Linear(512 * block.expansion, num_classes)
  176. for m in self.modules():
  177. if isinstance(m, nn.Conv2d):
  178. nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
  179. elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
  180. nn.init.constant_(m.weight, 1)
  181. nn.init.constant_(m.bias, 0)
  182. # Zero-initialize the last BN in each residual branch,
  183. # so that the residual branch starts with zeros, and each residual block behaves like an identity.
  184. # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
  185. if zero_init_residual:
  186. for m in self.modules():
  187. if isinstance(m, Bottleneck) and m.bn3.weight is not None:
  188. nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type]
  189. elif isinstance(m, BasicBlock) and m.bn2.weight is not None:
  190. nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type]
  191. def _make_layer(
  192. self,
  193. block: Type[Union[BasicBlock, Bottleneck]],
  194. planes: int,
  195. blocks: int,
  196. stride: int = 1,
  197. dilate: bool = False,
  198. ) -> nn.Sequential:
  199. norm_layer = self._norm_layer
  200. downsample = None
  201. previous_dilation = self.dilation
  202. if dilate:
  203. self.dilation *= stride
  204. stride = 1
  205. if stride != 1 or self.inplanes != planes * block.expansion:
  206. downsample = nn.Sequential(
  207. conv1x1(self.inplanes, planes * block.expansion, stride),
  208. norm_layer(planes * block.expansion),
  209. )
  210. layers = []
  211. layers.append(
  212. block(
  213. self.inplanes, planes, stride, downsample, self.groups, self.base_width, previous_dilation, norm_layer
  214. )
  215. )
  216. self.inplanes = planes * block.expansion
  217. for _ in range(1, blocks):
  218. layers.append(
  219. block(
  220. self.inplanes,
  221. planes,
  222. groups=self.groups,
  223. base_width=self.base_width,
  224. dilation=self.dilation,
  225. norm_layer=norm_layer,
  226. )
  227. )
  228. return nn.Sequential(*layers)
  229. def _forward_impl(self, x: Tensor) -> Tensor:
  230. # See note [TorchScript super()]
  231. x = self.conv1(x)
  232. x = self.bn1(x)
  233. x = self.relu(x)
  234. x = self.maxpool(x)
  235. x = self.layer1(x)
  236. x = self.layer2(x)
  237. x = self.layer3(x)
  238. x = self.layer4(x)
  239. x = self.avgpool(x)
  240. x = torch.flatten(x, 1)
  241. x = self.fc(x)
  242. return x
  243. def forward(self, x: Tensor) -> Tensor:
  244. return self._forward_impl(x)
  245. def _resnet(
  246. block: Type[Union[BasicBlock, Bottleneck]],
  247. layers: List[int],
  248. weights: Optional[WeightsEnum],
  249. progress: bool,
  250. **kwargs: Any,
  251. ) -> ResNet:
  252. if weights is not None:
  253. _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
  254. model = ResNet(block, layers, **kwargs)
  255. if weights is not None:
  256. model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
  257. return model
  258. _COMMON_META = {
  259. "min_size": (1, 1),
  260. "categories": _IMAGENET_CATEGORIES,
  261. }
  262. class ResNet18_Weights(WeightsEnum):
  263. IMAGENET1K_V1 = Weights(
  264. url="https://download.pytorch.org/models/resnet18-f37072fd.pth",
  265. transforms=partial(ImageClassification, crop_size=224),
  266. meta={
  267. **_COMMON_META,
  268. "num_params": 11689512,
  269. "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnet",
  270. "_metrics": {
  271. "ImageNet-1K": {
  272. "acc@1": 69.758,
  273. "acc@5": 89.078,
  274. }
  275. },
  276. "_ops": 1.814,
  277. "_file_size": 44.661,
  278. "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
  279. },
  280. )
  281. DEFAULT = IMAGENET1K_V1
  282. class ResNet34_Weights(WeightsEnum):
  283. IMAGENET1K_V1 = Weights(
  284. url="https://download.pytorch.org/models/resnet34-b627a593.pth",
  285. transforms=partial(ImageClassification, crop_size=224),
  286. meta={
  287. **_COMMON_META,
  288. "num_params": 21797672,
  289. "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnet",
  290. "_metrics": {
  291. "ImageNet-1K": {
  292. "acc@1": 73.314,
  293. "acc@5": 91.420,
  294. }
  295. },
  296. "_ops": 3.664,
  297. "_file_size": 83.275,
  298. "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
  299. },
  300. )
  301. DEFAULT = IMAGENET1K_V1
  302. class ResNet50_Weights(WeightsEnum):
  303. IMAGENET1K_V1 = Weights(
  304. url="https://download.pytorch.org/models/resnet50-0676ba61.pth",
  305. transforms=partial(ImageClassification, crop_size=224),
  306. meta={
  307. **_COMMON_META,
  308. "num_params": 25557032,
  309. "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnet",
  310. "_metrics": {
  311. "ImageNet-1K": {
  312. "acc@1": 76.130,
  313. "acc@5": 92.862,
  314. }
  315. },
  316. "_ops": 4.089,
  317. "_file_size": 97.781,
  318. "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
  319. },
  320. )
  321. IMAGENET1K_V2 = Weights(
  322. url="https://download.pytorch.org/models/resnet50-11ad3fa6.pth",
  323. transforms=partial(ImageClassification, crop_size=224, resize_size=232),
  324. meta={
  325. **_COMMON_META,
  326. "num_params": 25557032,
  327. "recipe": "https://github.com/pytorch/vision/issues/3995#issuecomment-1013906621",
  328. "_metrics": {
  329. "ImageNet-1K": {
  330. "acc@1": 80.858,
  331. "acc@5": 95.434,
  332. }
  333. },
  334. "_ops": 4.089,
  335. "_file_size": 97.79,
  336. "_docs": """
  337. These weights improve upon the results of the original paper by using TorchVision's `new training recipe
  338. <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
  339. """,
  340. },
  341. )
  342. DEFAULT = IMAGENET1K_V2
  343. class ResNet101_Weights(WeightsEnum):
  344. IMAGENET1K_V1 = Weights(
  345. url="https://download.pytorch.org/models/resnet101-63fe2227.pth",
  346. transforms=partial(ImageClassification, crop_size=224),
  347. meta={
  348. **_COMMON_META,
  349. "num_params": 44549160,
  350. "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnet",
  351. "_metrics": {
  352. "ImageNet-1K": {
  353. "acc@1": 77.374,
  354. "acc@5": 93.546,
  355. }
  356. },
  357. "_ops": 7.801,
  358. "_file_size": 170.511,
  359. "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
  360. },
  361. )
  362. IMAGENET1K_V2 = Weights(
  363. url="https://download.pytorch.org/models/resnet101-cd907fc2.pth",
  364. transforms=partial(ImageClassification, crop_size=224, resize_size=232),
  365. meta={
  366. **_COMMON_META,
  367. "num_params": 44549160,
  368. "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe",
  369. "_metrics": {
  370. "ImageNet-1K": {
  371. "acc@1": 81.886,
  372. "acc@5": 95.780,
  373. }
  374. },
  375. "_ops": 7.801,
  376. "_file_size": 170.53,
  377. "_docs": """
  378. These weights improve upon the results of the original paper by using TorchVision's `new training recipe
  379. <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
  380. """,
  381. },
  382. )
  383. DEFAULT = IMAGENET1K_V2
  384. class ResNet152_Weights(WeightsEnum):
  385. IMAGENET1K_V1 = Weights(
  386. url="https://download.pytorch.org/models/resnet152-394f9c45.pth",
  387. transforms=partial(ImageClassification, crop_size=224),
  388. meta={
  389. **_COMMON_META,
  390. "num_params": 60192808,
  391. "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnet",
  392. "_metrics": {
  393. "ImageNet-1K": {
  394. "acc@1": 78.312,
  395. "acc@5": 94.046,
  396. }
  397. },
  398. "_ops": 11.514,
  399. "_file_size": 230.434,
  400. "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
  401. },
  402. )
  403. IMAGENET1K_V2 = Weights(
  404. url="https://download.pytorch.org/models/resnet152-f82ba261.pth",
  405. transforms=partial(ImageClassification, crop_size=224, resize_size=232),
  406. meta={
  407. **_COMMON_META,
  408. "num_params": 60192808,
  409. "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe",
  410. "_metrics": {
  411. "ImageNet-1K": {
  412. "acc@1": 82.284,
  413. "acc@5": 96.002,
  414. }
  415. },
  416. "_ops": 11.514,
  417. "_file_size": 230.474,
  418. "_docs": """
  419. These weights improve upon the results of the original paper by using TorchVision's `new training recipe
  420. <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
  421. """,
  422. },
  423. )
  424. DEFAULT = IMAGENET1K_V2
  425. class ResNeXt50_32X4D_Weights(WeightsEnum):
  426. IMAGENET1K_V1 = Weights(
  427. url="https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth",
  428. transforms=partial(ImageClassification, crop_size=224),
  429. meta={
  430. **_COMMON_META,
  431. "num_params": 25028904,
  432. "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnext",
  433. "_metrics": {
  434. "ImageNet-1K": {
  435. "acc@1": 77.618,
  436. "acc@5": 93.698,
  437. }
  438. },
  439. "_ops": 4.23,
  440. "_file_size": 95.789,
  441. "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
  442. },
  443. )
  444. IMAGENET1K_V2 = Weights(
  445. url="https://download.pytorch.org/models/resnext50_32x4d-1a0047aa.pth",
  446. transforms=partial(ImageClassification, crop_size=224, resize_size=232),
  447. meta={
  448. **_COMMON_META,
  449. "num_params": 25028904,
  450. "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe",
  451. "_metrics": {
  452. "ImageNet-1K": {
  453. "acc@1": 81.198,
  454. "acc@5": 95.340,
  455. }
  456. },
  457. "_ops": 4.23,
  458. "_file_size": 95.833,
  459. "_docs": """
  460. These weights improve upon the results of the original paper by using TorchVision's `new training recipe
  461. <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
  462. """,
  463. },
  464. )
  465. DEFAULT = IMAGENET1K_V2
  466. class ResNeXt101_32X8D_Weights(WeightsEnum):
  467. IMAGENET1K_V1 = Weights(
  468. url="https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth",
  469. transforms=partial(ImageClassification, crop_size=224),
  470. meta={
  471. **_COMMON_META,
  472. "num_params": 88791336,
  473. "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnext",
  474. "_metrics": {
  475. "ImageNet-1K": {
  476. "acc@1": 79.312,
  477. "acc@5": 94.526,
  478. }
  479. },
  480. "_ops": 16.414,
  481. "_file_size": 339.586,
  482. "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
  483. },
  484. )
  485. IMAGENET1K_V2 = Weights(
  486. url="https://download.pytorch.org/models/resnext101_32x8d-110c445d.pth",
  487. transforms=partial(ImageClassification, crop_size=224, resize_size=232),
  488. meta={
  489. **_COMMON_META,
  490. "num_params": 88791336,
  491. "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe-with-fixres",
  492. "_metrics": {
  493. "ImageNet-1K": {
  494. "acc@1": 82.834,
  495. "acc@5": 96.228,
  496. }
  497. },
  498. "_ops": 16.414,
  499. "_file_size": 339.673,
  500. "_docs": """
  501. These weights improve upon the results of the original paper by using TorchVision's `new training recipe
  502. <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
  503. """,
  504. },
  505. )
  506. DEFAULT = IMAGENET1K_V2
  507. class ResNeXt101_64X4D_Weights(WeightsEnum):
  508. IMAGENET1K_V1 = Weights(
  509. url="https://download.pytorch.org/models/resnext101_64x4d-173b62eb.pth",
  510. transforms=partial(ImageClassification, crop_size=224, resize_size=232),
  511. meta={
  512. **_COMMON_META,
  513. "num_params": 83455272,
  514. "recipe": "https://github.com/pytorch/vision/pull/5935",
  515. "_metrics": {
  516. "ImageNet-1K": {
  517. "acc@1": 83.246,
  518. "acc@5": 96.454,
  519. }
  520. },
  521. "_ops": 15.46,
  522. "_file_size": 319.318,
  523. "_docs": """
  524. These weights were trained from scratch by using TorchVision's `new training recipe
  525. <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
  526. """,
  527. },
  528. )
  529. DEFAULT = IMAGENET1K_V1
  530. class Wide_ResNet50_2_Weights(WeightsEnum):
  531. IMAGENET1K_V1 = Weights(
  532. url="https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth",
  533. transforms=partial(ImageClassification, crop_size=224),
  534. meta={
  535. **_COMMON_META,
  536. "num_params": 68883240,
  537. "recipe": "https://github.com/pytorch/vision/pull/912#issue-445437439",
  538. "_metrics": {
  539. "ImageNet-1K": {
  540. "acc@1": 78.468,
  541. "acc@5": 94.086,
  542. }
  543. },
  544. "_ops": 11.398,
  545. "_file_size": 131.82,
  546. "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
  547. },
  548. )
  549. IMAGENET1K_V2 = Weights(
  550. url="https://download.pytorch.org/models/wide_resnet50_2-9ba9bcbe.pth",
  551. transforms=partial(ImageClassification, crop_size=224, resize_size=232),
  552. meta={
  553. **_COMMON_META,
  554. "num_params": 68883240,
  555. "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe-with-fixres",
  556. "_metrics": {
  557. "ImageNet-1K": {
  558. "acc@1": 81.602,
  559. "acc@5": 95.758,
  560. }
  561. },
  562. "_ops": 11.398,
  563. "_file_size": 263.124,
  564. "_docs": """
  565. These weights improve upon the results of the original paper by using TorchVision's `new training recipe
  566. <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
  567. """,
  568. },
  569. )
  570. DEFAULT = IMAGENET1K_V2
  571. class Wide_ResNet101_2_Weights(WeightsEnum):
  572. IMAGENET1K_V1 = Weights(
  573. url="https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth",
  574. transforms=partial(ImageClassification, crop_size=224),
  575. meta={
  576. **_COMMON_META,
  577. "num_params": 126886696,
  578. "recipe": "https://github.com/pytorch/vision/pull/912#issue-445437439",
  579. "_metrics": {
  580. "ImageNet-1K": {
  581. "acc@1": 78.848,
  582. "acc@5": 94.284,
  583. }
  584. },
  585. "_ops": 22.753,
  586. "_file_size": 242.896,
  587. "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
  588. },
  589. )
  590. IMAGENET1K_V2 = Weights(
  591. url="https://download.pytorch.org/models/wide_resnet101_2-d733dc28.pth",
  592. transforms=partial(ImageClassification, crop_size=224, resize_size=232),
  593. meta={
  594. **_COMMON_META,
  595. "num_params": 126886696,
  596. "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe",
  597. "_metrics": {
  598. "ImageNet-1K": {
  599. "acc@1": 82.510,
  600. "acc@5": 96.020,
  601. }
  602. },
  603. "_ops": 22.753,
  604. "_file_size": 484.747,
  605. "_docs": """
  606. These weights improve upon the results of the original paper by using TorchVision's `new training recipe
  607. <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
  608. """,
  609. },
  610. )
  611. DEFAULT = IMAGENET1K_V2
  612. @register_model()
  613. @handle_legacy_interface(weights=("pretrained", ResNet18_Weights.IMAGENET1K_V1))
  614. def resnet18(*, weights: Optional[ResNet18_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
  615. """ResNet-18 from `Deep Residual Learning for Image Recognition <https://arxiv.org/abs/1512.03385>`__.
  616. Args:
  617. weights (:class:`~torchvision.models.ResNet18_Weights`, optional): The
  618. pretrained weights to use. See
  619. :class:`~torchvision.models.ResNet18_Weights` below for
  620. more details, and possible values. By default, no pre-trained
  621. weights are used.
  622. progress (bool, optional): If True, displays a progress bar of the
  623. download to stderr. Default is True.
  624. **kwargs: parameters passed to the ``torchvision.models.resnet.ResNet``
  625. base class. Please refer to the `source code
  626. <https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py>`_
  627. for more details about this class.
  628. .. autoclass:: torchvision.models.ResNet18_Weights
  629. :members:
  630. """
  631. weights = ResNet18_Weights.verify(weights)
  632. return _resnet(BasicBlock, [2, 2, 2, 2], weights, progress, **kwargs)
  633. @register_model()
  634. @handle_legacy_interface(weights=("pretrained", ResNet34_Weights.IMAGENET1K_V1))
  635. def resnet34(*, weights: Optional[ResNet34_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
  636. """ResNet-34 from `Deep Residual Learning for Image Recognition <https://arxiv.org/abs/1512.03385>`__.
  637. Args:
  638. weights (:class:`~torchvision.models.ResNet34_Weights`, optional): The
  639. pretrained weights to use. See
  640. :class:`~torchvision.models.ResNet34_Weights` below for
  641. more details, and possible values. By default, no pre-trained
  642. weights are used.
  643. progress (bool, optional): If True, displays a progress bar of the
  644. download to stderr. Default is True.
  645. **kwargs: parameters passed to the ``torchvision.models.resnet.ResNet``
  646. base class. Please refer to the `source code
  647. <https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py>`_
  648. for more details about this class.
  649. .. autoclass:: torchvision.models.ResNet34_Weights
  650. :members:
  651. """
  652. weights = ResNet34_Weights.verify(weights)
  653. return _resnet(BasicBlock, [3, 4, 6, 3], weights, progress, **kwargs)
  654. @register_model()
  655. @handle_legacy_interface(weights=("pretrained", ResNet50_Weights.IMAGENET1K_V1))
  656. def resnet50(*, weights: Optional[ResNet50_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
  657. """ResNet-50 from `Deep Residual Learning for Image Recognition <https://arxiv.org/abs/1512.03385>`__.
  658. .. note::
  659. The bottleneck of TorchVision places the stride for downsampling to the second 3x3
  660. convolution while the original paper places it to the first 1x1 convolution.
  661. This variant improves the accuracy and is known as `ResNet V1.5
  662. <https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch>`_.
  663. Args:
  664. weights (:class:`~torchvision.models.ResNet50_Weights`, optional): The
  665. pretrained weights to use. See
  666. :class:`~torchvision.models.ResNet50_Weights` below for
  667. more details, and possible values. By default, no pre-trained
  668. weights are used.
  669. progress (bool, optional): If True, displays a progress bar of the
  670. download to stderr. Default is True.
  671. **kwargs: parameters passed to the ``torchvision.models.resnet.ResNet``
  672. base class. Please refer to the `source code
  673. <https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py>`_
  674. for more details about this class.
  675. .. autoclass:: torchvision.models.ResNet50_Weights
  676. :members:
  677. """
  678. weights = ResNet50_Weights.verify(weights)
  679. return _resnet(Bottleneck, [3, 4, 6, 3], weights, progress, **kwargs)
  680. @register_model()
  681. @handle_legacy_interface(weights=("pretrained", ResNet101_Weights.IMAGENET1K_V1))
  682. def resnet101(*, weights: Optional[ResNet101_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
  683. """ResNet-101 from `Deep Residual Learning for Image Recognition <https://arxiv.org/abs/1512.03385>`__.
  684. .. note::
  685. The bottleneck of TorchVision places the stride for downsampling to the second 3x3
  686. convolution while the original paper places it to the first 1x1 convolution.
  687. This variant improves the accuracy and is known as `ResNet V1.5
  688. <https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch>`_.
  689. Args:
  690. weights (:class:`~torchvision.models.ResNet101_Weights`, optional): The
  691. pretrained weights to use. See
  692. :class:`~torchvision.models.ResNet101_Weights` below for
  693. more details, and possible values. By default, no pre-trained
  694. weights are used.
  695. progress (bool, optional): If True, displays a progress bar of the
  696. download to stderr. Default is True.
  697. **kwargs: parameters passed to the ``torchvision.models.resnet.ResNet``
  698. base class. Please refer to the `source code
  699. <https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py>`_
  700. for more details about this class.
  701. .. autoclass:: torchvision.models.ResNet101_Weights
  702. :members:
  703. """
  704. weights = ResNet101_Weights.verify(weights)
  705. return _resnet(Bottleneck, [3, 4, 23, 3], weights, progress, **kwargs)
  706. @register_model()
  707. @handle_legacy_interface(weights=("pretrained", ResNet152_Weights.IMAGENET1K_V1))
  708. def resnet152(*, weights: Optional[ResNet152_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
  709. """ResNet-152 from `Deep Residual Learning for Image Recognition <https://arxiv.org/abs/1512.03385>`__.
  710. .. note::
  711. The bottleneck of TorchVision places the stride for downsampling to the second 3x3
  712. convolution while the original paper places it to the first 1x1 convolution.
  713. This variant improves the accuracy and is known as `ResNet V1.5
  714. <https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch>`_.
  715. Args:
  716. weights (:class:`~torchvision.models.ResNet152_Weights`, optional): The
  717. pretrained weights to use. See
  718. :class:`~torchvision.models.ResNet152_Weights` below for
  719. more details, and possible values. By default, no pre-trained
  720. weights are used.
  721. progress (bool, optional): If True, displays a progress bar of the
  722. download to stderr. Default is True.
  723. **kwargs: parameters passed to the ``torchvision.models.resnet.ResNet``
  724. base class. Please refer to the `source code
  725. <https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py>`_
  726. for more details about this class.
  727. .. autoclass:: torchvision.models.ResNet152_Weights
  728. :members:
  729. """
  730. weights = ResNet152_Weights.verify(weights)
  731. return _resnet(Bottleneck, [3, 8, 36, 3], weights, progress, **kwargs)
  732. @register_model()
  733. @handle_legacy_interface(weights=("pretrained", ResNeXt50_32X4D_Weights.IMAGENET1K_V1))
  734. def resnext50_32x4d(
  735. *, weights: Optional[ResNeXt50_32X4D_Weights] = None, progress: bool = True, **kwargs: Any
  736. ) -> ResNet:
  737. """ResNeXt-50 32x4d model from
  738. `Aggregated Residual Transformation for Deep Neural Networks <https://arxiv.org/abs/1611.05431>`_.
  739. Args:
  740. weights (:class:`~torchvision.models.ResNeXt50_32X4D_Weights`, optional): The
  741. pretrained weights to use. See
  742. :class:`~torchvision.models.ResNext50_32X4D_Weights` below for
  743. more details, and possible values. By default, no pre-trained
  744. weights are used.
  745. progress (bool, optional): If True, displays a progress bar of the
  746. download to stderr. Default is True.
  747. **kwargs: parameters passed to the ``torchvision.models.resnet.ResNet``
  748. base class. Please refer to the `source code
  749. <https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py>`_
  750. for more details about this class.
  751. .. autoclass:: torchvision.models.ResNeXt50_32X4D_Weights
  752. :members:
  753. """
  754. weights = ResNeXt50_32X4D_Weights.verify(weights)
  755. _ovewrite_named_param(kwargs, "groups", 32)
  756. _ovewrite_named_param(kwargs, "width_per_group", 4)
  757. return _resnet(Bottleneck, [3, 4, 6, 3], weights, progress, **kwargs)
  758. @register_model()
  759. @handle_legacy_interface(weights=("pretrained", ResNeXt101_32X8D_Weights.IMAGENET1K_V1))
  760. def resnext101_32x8d(
  761. *, weights: Optional[ResNeXt101_32X8D_Weights] = None, progress: bool = True, **kwargs: Any
  762. ) -> ResNet:
  763. """ResNeXt-101 32x8d model from
  764. `Aggregated Residual Transformation for Deep Neural Networks <https://arxiv.org/abs/1611.05431>`_.
  765. Args:
  766. weights (:class:`~torchvision.models.ResNeXt101_32X8D_Weights`, optional): The
  767. pretrained weights to use. See
  768. :class:`~torchvision.models.ResNeXt101_32X8D_Weights` below for
  769. more details, and possible values. By default, no pre-trained
  770. weights are used.
  771. progress (bool, optional): If True, displays a progress bar of the
  772. download to stderr. Default is True.
  773. **kwargs: parameters passed to the ``torchvision.models.resnet.ResNet``
  774. base class. Please refer to the `source code
  775. <https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py>`_
  776. for more details about this class.
  777. .. autoclass:: torchvision.models.ResNeXt101_32X8D_Weights
  778. :members:
  779. """
  780. weights = ResNeXt101_32X8D_Weights.verify(weights)
  781. _ovewrite_named_param(kwargs, "groups", 32)
  782. _ovewrite_named_param(kwargs, "width_per_group", 8)
  783. return _resnet(Bottleneck, [3, 4, 23, 3], weights, progress, **kwargs)
  784. @register_model()
  785. @handle_legacy_interface(weights=("pretrained", ResNeXt101_64X4D_Weights.IMAGENET1K_V1))
  786. def resnext101_64x4d(
  787. *, weights: Optional[ResNeXt101_64X4D_Weights] = None, progress: bool = True, **kwargs: Any
  788. ) -> ResNet:
  789. """ResNeXt-101 64x4d model from
  790. `Aggregated Residual Transformation for Deep Neural Networks <https://arxiv.org/abs/1611.05431>`_.
  791. Args:
  792. weights (:class:`~torchvision.models.ResNeXt101_64X4D_Weights`, optional): The
  793. pretrained weights to use. See
  794. :class:`~torchvision.models.ResNeXt101_64X4D_Weights` below for
  795. more details, and possible values. By default, no pre-trained
  796. weights are used.
  797. progress (bool, optional): If True, displays a progress bar of the
  798. download to stderr. Default is True.
  799. **kwargs: parameters passed to the ``torchvision.models.resnet.ResNet``
  800. base class. Please refer to the `source code
  801. <https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py>`_
  802. for more details about this class.
  803. .. autoclass:: torchvision.models.ResNeXt101_64X4D_Weights
  804. :members:
  805. """
  806. weights = ResNeXt101_64X4D_Weights.verify(weights)
  807. _ovewrite_named_param(kwargs, "groups", 64)
  808. _ovewrite_named_param(kwargs, "width_per_group", 4)
  809. return _resnet(Bottleneck, [3, 4, 23, 3], weights, progress, **kwargs)
  810. @register_model()
  811. @handle_legacy_interface(weights=("pretrained", Wide_ResNet50_2_Weights.IMAGENET1K_V1))
  812. def wide_resnet50_2(
  813. *, weights: Optional[Wide_ResNet50_2_Weights] = None, progress: bool = True, **kwargs: Any
  814. ) -> ResNet:
  815. """Wide ResNet-50-2 model from
  816. `Wide Residual Networks <https://arxiv.org/abs/1605.07146>`_.
  817. The model is the same as ResNet except for the bottleneck number of channels
  818. which is twice larger in every block. The number of channels in outer 1x1
  819. convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
  820. channels, and in Wide ResNet-50-2 has 2048-1024-2048.
  821. Args:
  822. weights (:class:`~torchvision.models.Wide_ResNet50_2_Weights`, optional): The
  823. pretrained weights to use. See
  824. :class:`~torchvision.models.Wide_ResNet50_2_Weights` below for
  825. more details, and possible values. By default, no pre-trained
  826. weights are used.
  827. progress (bool, optional): If True, displays a progress bar of the
  828. download to stderr. Default is True.
  829. **kwargs: parameters passed to the ``torchvision.models.resnet.ResNet``
  830. base class. Please refer to the `source code
  831. <https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py>`_
  832. for more details about this class.
  833. .. autoclass:: torchvision.models.Wide_ResNet50_2_Weights
  834. :members:
  835. """
  836. weights = Wide_ResNet50_2_Weights.verify(weights)
  837. _ovewrite_named_param(kwargs, "width_per_group", 64 * 2)
  838. return _resnet(Bottleneck, [3, 4, 6, 3], weights, progress, **kwargs)
  839. @register_model()
  840. @handle_legacy_interface(weights=("pretrained", Wide_ResNet101_2_Weights.IMAGENET1K_V1))
  841. def wide_resnet101_2(
  842. *, weights: Optional[Wide_ResNet101_2_Weights] = None, progress: bool = True, **kwargs: Any
  843. ) -> ResNet:
  844. """Wide ResNet-101-2 model from
  845. `Wide Residual Networks <https://arxiv.org/abs/1605.07146>`_.
  846. The model is the same as ResNet except for the bottleneck number of channels
  847. which is twice larger in every block. The number of channels in outer 1x1
  848. convolutions is the same, e.g. last block in ResNet-101 has 2048-512-2048
  849. channels, and in Wide ResNet-101-2 has 2048-1024-2048.
  850. Args:
  851. weights (:class:`~torchvision.models.Wide_ResNet101_2_Weights`, optional): The
  852. pretrained weights to use. See
  853. :class:`~torchvision.models.Wide_ResNet101_2_Weights` below for
  854. more details, and possible values. By default, no pre-trained
  855. weights are used.
  856. progress (bool, optional): If True, displays a progress bar of the
  857. download to stderr. Default is True.
  858. **kwargs: parameters passed to the ``torchvision.models.resnet.ResNet``
  859. base class. Please refer to the `source code
  860. <https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py>`_
  861. for more details about this class.
  862. .. autoclass:: torchvision.models.Wide_ResNet101_2_Weights
  863. :members:
  864. """
  865. weights = Wide_ResNet101_2_Weights.verify(weights)
  866. _ovewrite_named_param(kwargs, "width_per_group", 64 * 2)
  867. return _resnet(Bottleneck, [3, 4, 23, 3], weights, progress, **kwargs)