vgg.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511
  1. from functools import partial
  2. from typing import Any, cast, Dict, List, Optional, Union
  3. import torch
  4. import torch.nn as nn
  5. from ..transforms._presets import ImageClassification
  6. from ..utils import _log_api_usage_once
  7. from ._api import register_model, Weights, WeightsEnum
  8. from ._meta import _IMAGENET_CATEGORIES
  9. from ._utils import _ovewrite_named_param, handle_legacy_interface
  10. __all__ = [
  11. "VGG",
  12. "VGG11_Weights",
  13. "VGG11_BN_Weights",
  14. "VGG13_Weights",
  15. "VGG13_BN_Weights",
  16. "VGG16_Weights",
  17. "VGG16_BN_Weights",
  18. "VGG19_Weights",
  19. "VGG19_BN_Weights",
  20. "vgg11",
  21. "vgg11_bn",
  22. "vgg13",
  23. "vgg13_bn",
  24. "vgg16",
  25. "vgg16_bn",
  26. "vgg19",
  27. "vgg19_bn",
  28. ]
  29. class VGG(nn.Module):
  30. def __init__(
  31. self, features: nn.Module, num_classes: int = 1000, init_weights: bool = True, dropout: float = 0.5
  32. ) -> None:
  33. super().__init__()
  34. _log_api_usage_once(self)
  35. self.features = features
  36. self.avgpool = nn.AdaptiveAvgPool2d((7, 7))
  37. self.classifier = nn.Sequential(
  38. nn.Linear(512 * 7 * 7, 4096),
  39. nn.ReLU(True),
  40. nn.Dropout(p=dropout),
  41. nn.Linear(4096, 4096),
  42. nn.ReLU(True),
  43. nn.Dropout(p=dropout),
  44. nn.Linear(4096, num_classes),
  45. )
  46. if init_weights:
  47. for m in self.modules():
  48. if isinstance(m, nn.Conv2d):
  49. nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
  50. if m.bias is not None:
  51. nn.init.constant_(m.bias, 0)
  52. elif isinstance(m, nn.BatchNorm2d):
  53. nn.init.constant_(m.weight, 1)
  54. nn.init.constant_(m.bias, 0)
  55. elif isinstance(m, nn.Linear):
  56. nn.init.normal_(m.weight, 0, 0.01)
  57. nn.init.constant_(m.bias, 0)
  58. def forward(self, x: torch.Tensor) -> torch.Tensor:
  59. x = self.features(x)
  60. x = self.avgpool(x)
  61. x = torch.flatten(x, 1)
  62. x = self.classifier(x)
  63. return x
  64. def make_layers(cfg: List[Union[str, int]], batch_norm: bool = False) -> nn.Sequential:
  65. layers: List[nn.Module] = []
  66. in_channels = 3
  67. for v in cfg:
  68. if v == "M":
  69. layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
  70. else:
  71. v = cast(int, v)
  72. conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
  73. if batch_norm:
  74. layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
  75. else:
  76. layers += [conv2d, nn.ReLU(inplace=True)]
  77. in_channels = v
  78. return nn.Sequential(*layers)
  79. cfgs: Dict[str, List[Union[str, int]]] = {
  80. "A": [64, "M", 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"],
  81. "B": [64, 64, "M", 128, 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"],
  82. "D": [64, 64, "M", 128, 128, "M", 256, 256, 256, "M", 512, 512, 512, "M", 512, 512, 512, "M"],
  83. "E": [64, 64, "M", 128, 128, "M", 256, 256, 256, 256, "M", 512, 512, 512, 512, "M", 512, 512, 512, 512, "M"],
  84. }
  85. def _vgg(cfg: str, batch_norm: bool, weights: Optional[WeightsEnum], progress: bool, **kwargs: Any) -> VGG:
  86. if weights is not None:
  87. kwargs["init_weights"] = False
  88. if weights.meta["categories"] is not None:
  89. _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
  90. model = VGG(make_layers(cfgs[cfg], batch_norm=batch_norm), **kwargs)
  91. if weights is not None:
  92. model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
  93. return model
  94. _COMMON_META = {
  95. "min_size": (32, 32),
  96. "categories": _IMAGENET_CATEGORIES,
  97. "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#alexnet-and-vgg",
  98. "_docs": """These weights were trained from scratch by using a simplified training recipe.""",
  99. }
  100. class VGG11_Weights(WeightsEnum):
  101. IMAGENET1K_V1 = Weights(
  102. url="https://download.pytorch.org/models/vgg11-8a719046.pth",
  103. transforms=partial(ImageClassification, crop_size=224),
  104. meta={
  105. **_COMMON_META,
  106. "num_params": 132863336,
  107. "_metrics": {
  108. "ImageNet-1K": {
  109. "acc@1": 69.020,
  110. "acc@5": 88.628,
  111. }
  112. },
  113. "_ops": 7.609,
  114. "_file_size": 506.84,
  115. },
  116. )
  117. DEFAULT = IMAGENET1K_V1
  118. class VGG11_BN_Weights(WeightsEnum):
  119. IMAGENET1K_V1 = Weights(
  120. url="https://download.pytorch.org/models/vgg11_bn-6002323d.pth",
  121. transforms=partial(ImageClassification, crop_size=224),
  122. meta={
  123. **_COMMON_META,
  124. "num_params": 132868840,
  125. "_metrics": {
  126. "ImageNet-1K": {
  127. "acc@1": 70.370,
  128. "acc@5": 89.810,
  129. }
  130. },
  131. "_ops": 7.609,
  132. "_file_size": 506.881,
  133. },
  134. )
  135. DEFAULT = IMAGENET1K_V1
  136. class VGG13_Weights(WeightsEnum):
  137. IMAGENET1K_V1 = Weights(
  138. url="https://download.pytorch.org/models/vgg13-19584684.pth",
  139. transforms=partial(ImageClassification, crop_size=224),
  140. meta={
  141. **_COMMON_META,
  142. "num_params": 133047848,
  143. "_metrics": {
  144. "ImageNet-1K": {
  145. "acc@1": 69.928,
  146. "acc@5": 89.246,
  147. }
  148. },
  149. "_ops": 11.308,
  150. "_file_size": 507.545,
  151. },
  152. )
  153. DEFAULT = IMAGENET1K_V1
  154. class VGG13_BN_Weights(WeightsEnum):
  155. IMAGENET1K_V1 = Weights(
  156. url="https://download.pytorch.org/models/vgg13_bn-abd245e5.pth",
  157. transforms=partial(ImageClassification, crop_size=224),
  158. meta={
  159. **_COMMON_META,
  160. "num_params": 133053736,
  161. "_metrics": {
  162. "ImageNet-1K": {
  163. "acc@1": 71.586,
  164. "acc@5": 90.374,
  165. }
  166. },
  167. "_ops": 11.308,
  168. "_file_size": 507.59,
  169. },
  170. )
  171. DEFAULT = IMAGENET1K_V1
  172. class VGG16_Weights(WeightsEnum):
  173. IMAGENET1K_V1 = Weights(
  174. url="https://download.pytorch.org/models/vgg16-397923af.pth",
  175. transforms=partial(ImageClassification, crop_size=224),
  176. meta={
  177. **_COMMON_META,
  178. "num_params": 138357544,
  179. "_metrics": {
  180. "ImageNet-1K": {
  181. "acc@1": 71.592,
  182. "acc@5": 90.382,
  183. }
  184. },
  185. "_ops": 15.47,
  186. "_file_size": 527.796,
  187. },
  188. )
  189. IMAGENET1K_FEATURES = Weights(
  190. # Weights ported from https://github.com/amdegroot/ssd.pytorch/
  191. url="https://download.pytorch.org/models/vgg16_features-amdegroot-88682ab5.pth",
  192. transforms=partial(
  193. ImageClassification,
  194. crop_size=224,
  195. mean=(0.48235, 0.45882, 0.40784),
  196. std=(1.0 / 255.0, 1.0 / 255.0, 1.0 / 255.0),
  197. ),
  198. meta={
  199. **_COMMON_META,
  200. "num_params": 138357544,
  201. "categories": None,
  202. "recipe": "https://github.com/amdegroot/ssd.pytorch#training-ssd",
  203. "_metrics": {
  204. "ImageNet-1K": {
  205. "acc@1": float("nan"),
  206. "acc@5": float("nan"),
  207. }
  208. },
  209. "_ops": 15.47,
  210. "_file_size": 527.802,
  211. "_docs": """
  212. These weights can't be used for classification because they are missing values in the `classifier`
  213. module. Only the `features` module has valid values and can be used for feature extraction. The weights
  214. were trained using the original input standardization method as described in the paper.
  215. """,
  216. },
  217. )
  218. DEFAULT = IMAGENET1K_V1
  219. class VGG16_BN_Weights(WeightsEnum):
  220. IMAGENET1K_V1 = Weights(
  221. url="https://download.pytorch.org/models/vgg16_bn-6c64b313.pth",
  222. transforms=partial(ImageClassification, crop_size=224),
  223. meta={
  224. **_COMMON_META,
  225. "num_params": 138365992,
  226. "_metrics": {
  227. "ImageNet-1K": {
  228. "acc@1": 73.360,
  229. "acc@5": 91.516,
  230. }
  231. },
  232. "_ops": 15.47,
  233. "_file_size": 527.866,
  234. },
  235. )
  236. DEFAULT = IMAGENET1K_V1
  237. class VGG19_Weights(WeightsEnum):
  238. IMAGENET1K_V1 = Weights(
  239. url="https://download.pytorch.org/models/vgg19-dcbb9e9d.pth",
  240. transforms=partial(ImageClassification, crop_size=224),
  241. meta={
  242. **_COMMON_META,
  243. "num_params": 143667240,
  244. "_metrics": {
  245. "ImageNet-1K": {
  246. "acc@1": 72.376,
  247. "acc@5": 90.876,
  248. }
  249. },
  250. "_ops": 19.632,
  251. "_file_size": 548.051,
  252. },
  253. )
  254. DEFAULT = IMAGENET1K_V1
  255. class VGG19_BN_Weights(WeightsEnum):
  256. IMAGENET1K_V1 = Weights(
  257. url="https://download.pytorch.org/models/vgg19_bn-c79401a0.pth",
  258. transforms=partial(ImageClassification, crop_size=224),
  259. meta={
  260. **_COMMON_META,
  261. "num_params": 143678248,
  262. "_metrics": {
  263. "ImageNet-1K": {
  264. "acc@1": 74.218,
  265. "acc@5": 91.842,
  266. }
  267. },
  268. "_ops": 19.632,
  269. "_file_size": 548.143,
  270. },
  271. )
  272. DEFAULT = IMAGENET1K_V1
  273. @register_model()
  274. @handle_legacy_interface(weights=("pretrained", VGG11_Weights.IMAGENET1K_V1))
  275. def vgg11(*, weights: Optional[VGG11_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
  276. """VGG-11 from `Very Deep Convolutional Networks for Large-Scale Image Recognition <https://arxiv.org/abs/1409.1556>`__.
  277. Args:
  278. weights (:class:`~torchvision.models.VGG11_Weights`, optional): The
  279. pretrained weights to use. See
  280. :class:`~torchvision.models.VGG11_Weights` below for
  281. more details, and possible values. By default, no pre-trained
  282. weights are used.
  283. progress (bool, optional): If True, displays a progress bar of the
  284. download to stderr. Default is True.
  285. **kwargs: parameters passed to the ``torchvision.models.vgg.VGG``
  286. base class. Please refer to the `source code
  287. <https://github.com/pytorch/vision/blob/main/torchvision/models/vgg.py>`_
  288. for more details about this class.
  289. .. autoclass:: torchvision.models.VGG11_Weights
  290. :members:
  291. """
  292. weights = VGG11_Weights.verify(weights)
  293. return _vgg("A", False, weights, progress, **kwargs)
  294. @register_model()
  295. @handle_legacy_interface(weights=("pretrained", VGG11_BN_Weights.IMAGENET1K_V1))
  296. def vgg11_bn(*, weights: Optional[VGG11_BN_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
  297. """VGG-11-BN from `Very Deep Convolutional Networks for Large-Scale Image Recognition <https://arxiv.org/abs/1409.1556>`__.
  298. Args:
  299. weights (:class:`~torchvision.models.VGG11_BN_Weights`, optional): The
  300. pretrained weights to use. See
  301. :class:`~torchvision.models.VGG11_BN_Weights` below for
  302. more details, and possible values. By default, no pre-trained
  303. weights are used.
  304. progress (bool, optional): If True, displays a progress bar of the
  305. download to stderr. Default is True.
  306. **kwargs: parameters passed to the ``torchvision.models.vgg.VGG``
  307. base class. Please refer to the `source code
  308. <https://github.com/pytorch/vision/blob/main/torchvision/models/vgg.py>`_
  309. for more details about this class.
  310. .. autoclass:: torchvision.models.VGG11_BN_Weights
  311. :members:
  312. """
  313. weights = VGG11_BN_Weights.verify(weights)
  314. return _vgg("A", True, weights, progress, **kwargs)
  315. @register_model()
  316. @handle_legacy_interface(weights=("pretrained", VGG13_Weights.IMAGENET1K_V1))
  317. def vgg13(*, weights: Optional[VGG13_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
  318. """VGG-13 from `Very Deep Convolutional Networks for Large-Scale Image Recognition <https://arxiv.org/abs/1409.1556>`__.
  319. Args:
  320. weights (:class:`~torchvision.models.VGG13_Weights`, optional): The
  321. pretrained weights to use. See
  322. :class:`~torchvision.models.VGG13_Weights` below for
  323. more details, and possible values. By default, no pre-trained
  324. weights are used.
  325. progress (bool, optional): If True, displays a progress bar of the
  326. download to stderr. Default is True.
  327. **kwargs: parameters passed to the ``torchvision.models.vgg.VGG``
  328. base class. Please refer to the `source code
  329. <https://github.com/pytorch/vision/blob/main/torchvision/models/vgg.py>`_
  330. for more details about this class.
  331. .. autoclass:: torchvision.models.VGG13_Weights
  332. :members:
  333. """
  334. weights = VGG13_Weights.verify(weights)
  335. return _vgg("B", False, weights, progress, **kwargs)
  336. @register_model()
  337. @handle_legacy_interface(weights=("pretrained", VGG13_BN_Weights.IMAGENET1K_V1))
  338. def vgg13_bn(*, weights: Optional[VGG13_BN_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
  339. """VGG-13-BN from `Very Deep Convolutional Networks for Large-Scale Image Recognition <https://arxiv.org/abs/1409.1556>`__.
  340. Args:
  341. weights (:class:`~torchvision.models.VGG13_BN_Weights`, optional): The
  342. pretrained weights to use. See
  343. :class:`~torchvision.models.VGG13_BN_Weights` below for
  344. more details, and possible values. By default, no pre-trained
  345. weights are used.
  346. progress (bool, optional): If True, displays a progress bar of the
  347. download to stderr. Default is True.
  348. **kwargs: parameters passed to the ``torchvision.models.vgg.VGG``
  349. base class. Please refer to the `source code
  350. <https://github.com/pytorch/vision/blob/main/torchvision/models/vgg.py>`_
  351. for more details about this class.
  352. .. autoclass:: torchvision.models.VGG13_BN_Weights
  353. :members:
  354. """
  355. weights = VGG13_BN_Weights.verify(weights)
  356. return _vgg("B", True, weights, progress, **kwargs)
  357. @register_model()
  358. @handle_legacy_interface(weights=("pretrained", VGG16_Weights.IMAGENET1K_V1))
  359. def vgg16(*, weights: Optional[VGG16_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
  360. """VGG-16 from `Very Deep Convolutional Networks for Large-Scale Image Recognition <https://arxiv.org/abs/1409.1556>`__.
  361. Args:
  362. weights (:class:`~torchvision.models.VGG16_Weights`, optional): The
  363. pretrained weights to use. See
  364. :class:`~torchvision.models.VGG16_Weights` below for
  365. more details, and possible values. By default, no pre-trained
  366. weights are used.
  367. progress (bool, optional): If True, displays a progress bar of the
  368. download to stderr. Default is True.
  369. **kwargs: parameters passed to the ``torchvision.models.vgg.VGG``
  370. base class. Please refer to the `source code
  371. <https://github.com/pytorch/vision/blob/main/torchvision/models/vgg.py>`_
  372. for more details about this class.
  373. .. autoclass:: torchvision.models.VGG16_Weights
  374. :members:
  375. """
  376. weights = VGG16_Weights.verify(weights)
  377. return _vgg("D", False, weights, progress, **kwargs)
  378. @register_model()
  379. @handle_legacy_interface(weights=("pretrained", VGG16_BN_Weights.IMAGENET1K_V1))
  380. def vgg16_bn(*, weights: Optional[VGG16_BN_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
  381. """VGG-16-BN from `Very Deep Convolutional Networks for Large-Scale Image Recognition <https://arxiv.org/abs/1409.1556>`__.
  382. Args:
  383. weights (:class:`~torchvision.models.VGG16_BN_Weights`, optional): The
  384. pretrained weights to use. See
  385. :class:`~torchvision.models.VGG16_BN_Weights` below for
  386. more details, and possible values. By default, no pre-trained
  387. weights are used.
  388. progress (bool, optional): If True, displays a progress bar of the
  389. download to stderr. Default is True.
  390. **kwargs: parameters passed to the ``torchvision.models.vgg.VGG``
  391. base class. Please refer to the `source code
  392. <https://github.com/pytorch/vision/blob/main/torchvision/models/vgg.py>`_
  393. for more details about this class.
  394. .. autoclass:: torchvision.models.VGG16_BN_Weights
  395. :members:
  396. """
  397. weights = VGG16_BN_Weights.verify(weights)
  398. return _vgg("D", True, weights, progress, **kwargs)
  399. @register_model()
  400. @handle_legacy_interface(weights=("pretrained", VGG19_Weights.IMAGENET1K_V1))
  401. def vgg19(*, weights: Optional[VGG19_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
  402. """VGG-19 from `Very Deep Convolutional Networks for Large-Scale Image Recognition <https://arxiv.org/abs/1409.1556>`__.
  403. Args:
  404. weights (:class:`~torchvision.models.VGG19_Weights`, optional): The
  405. pretrained weights to use. See
  406. :class:`~torchvision.models.VGG19_Weights` below for
  407. more details, and possible values. By default, no pre-trained
  408. weights are used.
  409. progress (bool, optional): If True, displays a progress bar of the
  410. download to stderr. Default is True.
  411. **kwargs: parameters passed to the ``torchvision.models.vgg.VGG``
  412. base class. Please refer to the `source code
  413. <https://github.com/pytorch/vision/blob/main/torchvision/models/vgg.py>`_
  414. for more details about this class.
  415. .. autoclass:: torchvision.models.VGG19_Weights
  416. :members:
  417. """
  418. weights = VGG19_Weights.verify(weights)
  419. return _vgg("E", False, weights, progress, **kwargs)
  420. @register_model()
  421. @handle_legacy_interface(weights=("pretrained", VGG19_BN_Weights.IMAGENET1K_V1))
  422. def vgg19_bn(*, weights: Optional[VGG19_BN_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
  423. """VGG-19_BN from `Very Deep Convolutional Networks for Large-Scale Image Recognition <https://arxiv.org/abs/1409.1556>`__.
  424. Args:
  425. weights (:class:`~torchvision.models.VGG19_BN_Weights`, optional): The
  426. pretrained weights to use. See
  427. :class:`~torchvision.models.VGG19_BN_Weights` below for
  428. more details, and possible values. By default, no pre-trained
  429. weights are used.
  430. progress (bool, optional): If True, displays a progress bar of the
  431. download to stderr. Default is True.
  432. **kwargs: parameters passed to the ``torchvision.models.vgg.VGG``
  433. base class. Please refer to the `source code
  434. <https://github.com/pytorch/vision/blob/main/torchvision/models/vgg.py>`_
  435. for more details about this class.
  436. .. autoclass:: torchvision.models.VGG19_BN_Weights
  437. :members:
  438. """
  439. weights = VGG19_BN_Weights.verify(weights)
  440. return _vgg("E", True, weights, progress, **kwargs)