123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511 |
- from functools import partial
- from typing import Any, cast, Dict, List, Optional, Union
- import torch
- import torch.nn as nn
- from ..transforms._presets import ImageClassification
- from ..utils import _log_api_usage_once
- from ._api import register_model, Weights, WeightsEnum
- from ._meta import _IMAGENET_CATEGORIES
- from ._utils import _ovewrite_named_param, handle_legacy_interface
- __all__ = [
- "VGG",
- "VGG11_Weights",
- "VGG11_BN_Weights",
- "VGG13_Weights",
- "VGG13_BN_Weights",
- "VGG16_Weights",
- "VGG16_BN_Weights",
- "VGG19_Weights",
- "VGG19_BN_Weights",
- "vgg11",
- "vgg11_bn",
- "vgg13",
- "vgg13_bn",
- "vgg16",
- "vgg16_bn",
- "vgg19",
- "vgg19_bn",
- ]
- class VGG(nn.Module):
- def __init__(
- self, features: nn.Module, num_classes: int = 1000, init_weights: bool = True, dropout: float = 0.5
- ) -> None:
- super().__init__()
- _log_api_usage_once(self)
- self.features = features
- self.avgpool = nn.AdaptiveAvgPool2d((7, 7))
- self.classifier = nn.Sequential(
- nn.Linear(512 * 7 * 7, 4096),
- nn.ReLU(True),
- nn.Dropout(p=dropout),
- nn.Linear(4096, 4096),
- nn.ReLU(True),
- nn.Dropout(p=dropout),
- nn.Linear(4096, num_classes),
- )
- if init_weights:
- for m in self.modules():
- if isinstance(m, nn.Conv2d):
- nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
- if m.bias is not None:
- nn.init.constant_(m.bias, 0)
- elif isinstance(m, nn.BatchNorm2d):
- nn.init.constant_(m.weight, 1)
- nn.init.constant_(m.bias, 0)
- elif isinstance(m, nn.Linear):
- nn.init.normal_(m.weight, 0, 0.01)
- nn.init.constant_(m.bias, 0)
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- x = self.features(x)
- x = self.avgpool(x)
- x = torch.flatten(x, 1)
- x = self.classifier(x)
- return x
- def make_layers(cfg: List[Union[str, int]], batch_norm: bool = False) -> nn.Sequential:
- layers: List[nn.Module] = []
- in_channels = 3
- for v in cfg:
- if v == "M":
- layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
- else:
- v = cast(int, v)
- conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
- if batch_norm:
- layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
- else:
- layers += [conv2d, nn.ReLU(inplace=True)]
- in_channels = v
- return nn.Sequential(*layers)
- cfgs: Dict[str, List[Union[str, int]]] = {
- "A": [64, "M", 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"],
- "B": [64, 64, "M", 128, 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"],
- "D": [64, 64, "M", 128, 128, "M", 256, 256, 256, "M", 512, 512, 512, "M", 512, 512, 512, "M"],
- "E": [64, 64, "M", 128, 128, "M", 256, 256, 256, 256, "M", 512, 512, 512, 512, "M", 512, 512, 512, 512, "M"],
- }
- def _vgg(cfg: str, batch_norm: bool, weights: Optional[WeightsEnum], progress: bool, **kwargs: Any) -> VGG:
- if weights is not None:
- kwargs["init_weights"] = False
- if weights.meta["categories"] is not None:
- _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
- model = VGG(make_layers(cfgs[cfg], batch_norm=batch_norm), **kwargs)
- if weights is not None:
- model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
- return model
- _COMMON_META = {
- "min_size": (32, 32),
- "categories": _IMAGENET_CATEGORIES,
- "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#alexnet-and-vgg",
- "_docs": """These weights were trained from scratch by using a simplified training recipe.""",
- }
- class VGG11_Weights(WeightsEnum):
- IMAGENET1K_V1 = Weights(
- url="https://download.pytorch.org/models/vgg11-8a719046.pth",
- transforms=partial(ImageClassification, crop_size=224),
- meta={
- **_COMMON_META,
- "num_params": 132863336,
- "_metrics": {
- "ImageNet-1K": {
- "acc@1": 69.020,
- "acc@5": 88.628,
- }
- },
- "_ops": 7.609,
- "_file_size": 506.84,
- },
- )
- DEFAULT = IMAGENET1K_V1
- class VGG11_BN_Weights(WeightsEnum):
- IMAGENET1K_V1 = Weights(
- url="https://download.pytorch.org/models/vgg11_bn-6002323d.pth",
- transforms=partial(ImageClassification, crop_size=224),
- meta={
- **_COMMON_META,
- "num_params": 132868840,
- "_metrics": {
- "ImageNet-1K": {
- "acc@1": 70.370,
- "acc@5": 89.810,
- }
- },
- "_ops": 7.609,
- "_file_size": 506.881,
- },
- )
- DEFAULT = IMAGENET1K_V1
- class VGG13_Weights(WeightsEnum):
- IMAGENET1K_V1 = Weights(
- url="https://download.pytorch.org/models/vgg13-19584684.pth",
- transforms=partial(ImageClassification, crop_size=224),
- meta={
- **_COMMON_META,
- "num_params": 133047848,
- "_metrics": {
- "ImageNet-1K": {
- "acc@1": 69.928,
- "acc@5": 89.246,
- }
- },
- "_ops": 11.308,
- "_file_size": 507.545,
- },
- )
- DEFAULT = IMAGENET1K_V1
- class VGG13_BN_Weights(WeightsEnum):
- IMAGENET1K_V1 = Weights(
- url="https://download.pytorch.org/models/vgg13_bn-abd245e5.pth",
- transforms=partial(ImageClassification, crop_size=224),
- meta={
- **_COMMON_META,
- "num_params": 133053736,
- "_metrics": {
- "ImageNet-1K": {
- "acc@1": 71.586,
- "acc@5": 90.374,
- }
- },
- "_ops": 11.308,
- "_file_size": 507.59,
- },
- )
- DEFAULT = IMAGENET1K_V1
- class VGG16_Weights(WeightsEnum):
- IMAGENET1K_V1 = Weights(
- url="https://download.pytorch.org/models/vgg16-397923af.pth",
- transforms=partial(ImageClassification, crop_size=224),
- meta={
- **_COMMON_META,
- "num_params": 138357544,
- "_metrics": {
- "ImageNet-1K": {
- "acc@1": 71.592,
- "acc@5": 90.382,
- }
- },
- "_ops": 15.47,
- "_file_size": 527.796,
- },
- )
- IMAGENET1K_FEATURES = Weights(
- # Weights ported from https://github.com/amdegroot/ssd.pytorch/
- url="https://download.pytorch.org/models/vgg16_features-amdegroot-88682ab5.pth",
- transforms=partial(
- ImageClassification,
- crop_size=224,
- mean=(0.48235, 0.45882, 0.40784),
- std=(1.0 / 255.0, 1.0 / 255.0, 1.0 / 255.0),
- ),
- meta={
- **_COMMON_META,
- "num_params": 138357544,
- "categories": None,
- "recipe": "https://github.com/amdegroot/ssd.pytorch#training-ssd",
- "_metrics": {
- "ImageNet-1K": {
- "acc@1": float("nan"),
- "acc@5": float("nan"),
- }
- },
- "_ops": 15.47,
- "_file_size": 527.802,
- "_docs": """
- These weights can't be used for classification because they are missing values in the `classifier`
- module. Only the `features` module has valid values and can be used for feature extraction. The weights
- were trained using the original input standardization method as described in the paper.
- """,
- },
- )
- DEFAULT = IMAGENET1K_V1
- class VGG16_BN_Weights(WeightsEnum):
- IMAGENET1K_V1 = Weights(
- url="https://download.pytorch.org/models/vgg16_bn-6c64b313.pth",
- transforms=partial(ImageClassification, crop_size=224),
- meta={
- **_COMMON_META,
- "num_params": 138365992,
- "_metrics": {
- "ImageNet-1K": {
- "acc@1": 73.360,
- "acc@5": 91.516,
- }
- },
- "_ops": 15.47,
- "_file_size": 527.866,
- },
- )
- DEFAULT = IMAGENET1K_V1
- class VGG19_Weights(WeightsEnum):
- IMAGENET1K_V1 = Weights(
- url="https://download.pytorch.org/models/vgg19-dcbb9e9d.pth",
- transforms=partial(ImageClassification, crop_size=224),
- meta={
- **_COMMON_META,
- "num_params": 143667240,
- "_metrics": {
- "ImageNet-1K": {
- "acc@1": 72.376,
- "acc@5": 90.876,
- }
- },
- "_ops": 19.632,
- "_file_size": 548.051,
- },
- )
- DEFAULT = IMAGENET1K_V1
- class VGG19_BN_Weights(WeightsEnum):
- IMAGENET1K_V1 = Weights(
- url="https://download.pytorch.org/models/vgg19_bn-c79401a0.pth",
- transforms=partial(ImageClassification, crop_size=224),
- meta={
- **_COMMON_META,
- "num_params": 143678248,
- "_metrics": {
- "ImageNet-1K": {
- "acc@1": 74.218,
- "acc@5": 91.842,
- }
- },
- "_ops": 19.632,
- "_file_size": 548.143,
- },
- )
- DEFAULT = IMAGENET1K_V1
- @register_model()
- @handle_legacy_interface(weights=("pretrained", VGG11_Weights.IMAGENET1K_V1))
- def vgg11(*, weights: Optional[VGG11_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
- """VGG-11 from `Very Deep Convolutional Networks for Large-Scale Image Recognition <https://arxiv.org/abs/1409.1556>`__.
- Args:
- weights (:class:`~torchvision.models.VGG11_Weights`, optional): The
- pretrained weights to use. See
- :class:`~torchvision.models.VGG11_Weights` below for
- more details, and possible values. By default, no pre-trained
- weights are used.
- progress (bool, optional): If True, displays a progress bar of the
- download to stderr. Default is True.
- **kwargs: parameters passed to the ``torchvision.models.vgg.VGG``
- base class. Please refer to the `source code
- <https://github.com/pytorch/vision/blob/main/torchvision/models/vgg.py>`_
- for more details about this class.
- .. autoclass:: torchvision.models.VGG11_Weights
- :members:
- """
- weights = VGG11_Weights.verify(weights)
- return _vgg("A", False, weights, progress, **kwargs)
- @register_model()
- @handle_legacy_interface(weights=("pretrained", VGG11_BN_Weights.IMAGENET1K_V1))
- def vgg11_bn(*, weights: Optional[VGG11_BN_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
- """VGG-11-BN from `Very Deep Convolutional Networks for Large-Scale Image Recognition <https://arxiv.org/abs/1409.1556>`__.
- Args:
- weights (:class:`~torchvision.models.VGG11_BN_Weights`, optional): The
- pretrained weights to use. See
- :class:`~torchvision.models.VGG11_BN_Weights` below for
- more details, and possible values. By default, no pre-trained
- weights are used.
- progress (bool, optional): If True, displays a progress bar of the
- download to stderr. Default is True.
- **kwargs: parameters passed to the ``torchvision.models.vgg.VGG``
- base class. Please refer to the `source code
- <https://github.com/pytorch/vision/blob/main/torchvision/models/vgg.py>`_
- for more details about this class.
- .. autoclass:: torchvision.models.VGG11_BN_Weights
- :members:
- """
- weights = VGG11_BN_Weights.verify(weights)
- return _vgg("A", True, weights, progress, **kwargs)
- @register_model()
- @handle_legacy_interface(weights=("pretrained", VGG13_Weights.IMAGENET1K_V1))
- def vgg13(*, weights: Optional[VGG13_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
- """VGG-13 from `Very Deep Convolutional Networks for Large-Scale Image Recognition <https://arxiv.org/abs/1409.1556>`__.
- Args:
- weights (:class:`~torchvision.models.VGG13_Weights`, optional): The
- pretrained weights to use. See
- :class:`~torchvision.models.VGG13_Weights` below for
- more details, and possible values. By default, no pre-trained
- weights are used.
- progress (bool, optional): If True, displays a progress bar of the
- download to stderr. Default is True.
- **kwargs: parameters passed to the ``torchvision.models.vgg.VGG``
- base class. Please refer to the `source code
- <https://github.com/pytorch/vision/blob/main/torchvision/models/vgg.py>`_
- for more details about this class.
- .. autoclass:: torchvision.models.VGG13_Weights
- :members:
- """
- weights = VGG13_Weights.verify(weights)
- return _vgg("B", False, weights, progress, **kwargs)
- @register_model()
- @handle_legacy_interface(weights=("pretrained", VGG13_BN_Weights.IMAGENET1K_V1))
- def vgg13_bn(*, weights: Optional[VGG13_BN_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
- """VGG-13-BN from `Very Deep Convolutional Networks for Large-Scale Image Recognition <https://arxiv.org/abs/1409.1556>`__.
- Args:
- weights (:class:`~torchvision.models.VGG13_BN_Weights`, optional): The
- pretrained weights to use. See
- :class:`~torchvision.models.VGG13_BN_Weights` below for
- more details, and possible values. By default, no pre-trained
- weights are used.
- progress (bool, optional): If True, displays a progress bar of the
- download to stderr. Default is True.
- **kwargs: parameters passed to the ``torchvision.models.vgg.VGG``
- base class. Please refer to the `source code
- <https://github.com/pytorch/vision/blob/main/torchvision/models/vgg.py>`_
- for more details about this class.
- .. autoclass:: torchvision.models.VGG13_BN_Weights
- :members:
- """
- weights = VGG13_BN_Weights.verify(weights)
- return _vgg("B", True, weights, progress, **kwargs)
- @register_model()
- @handle_legacy_interface(weights=("pretrained", VGG16_Weights.IMAGENET1K_V1))
- def vgg16(*, weights: Optional[VGG16_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
- """VGG-16 from `Very Deep Convolutional Networks for Large-Scale Image Recognition <https://arxiv.org/abs/1409.1556>`__.
- Args:
- weights (:class:`~torchvision.models.VGG16_Weights`, optional): The
- pretrained weights to use. See
- :class:`~torchvision.models.VGG16_Weights` below for
- more details, and possible values. By default, no pre-trained
- weights are used.
- progress (bool, optional): If True, displays a progress bar of the
- download to stderr. Default is True.
- **kwargs: parameters passed to the ``torchvision.models.vgg.VGG``
- base class. Please refer to the `source code
- <https://github.com/pytorch/vision/blob/main/torchvision/models/vgg.py>`_
- for more details about this class.
- .. autoclass:: torchvision.models.VGG16_Weights
- :members:
- """
- weights = VGG16_Weights.verify(weights)
- return _vgg("D", False, weights, progress, **kwargs)
- @register_model()
- @handle_legacy_interface(weights=("pretrained", VGG16_BN_Weights.IMAGENET1K_V1))
- def vgg16_bn(*, weights: Optional[VGG16_BN_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
- """VGG-16-BN from `Very Deep Convolutional Networks for Large-Scale Image Recognition <https://arxiv.org/abs/1409.1556>`__.
- Args:
- weights (:class:`~torchvision.models.VGG16_BN_Weights`, optional): The
- pretrained weights to use. See
- :class:`~torchvision.models.VGG16_BN_Weights` below for
- more details, and possible values. By default, no pre-trained
- weights are used.
- progress (bool, optional): If True, displays a progress bar of the
- download to stderr. Default is True.
- **kwargs: parameters passed to the ``torchvision.models.vgg.VGG``
- base class. Please refer to the `source code
- <https://github.com/pytorch/vision/blob/main/torchvision/models/vgg.py>`_
- for more details about this class.
- .. autoclass:: torchvision.models.VGG16_BN_Weights
- :members:
- """
- weights = VGG16_BN_Weights.verify(weights)
- return _vgg("D", True, weights, progress, **kwargs)
- @register_model()
- @handle_legacy_interface(weights=("pretrained", VGG19_Weights.IMAGENET1K_V1))
- def vgg19(*, weights: Optional[VGG19_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
- """VGG-19 from `Very Deep Convolutional Networks for Large-Scale Image Recognition <https://arxiv.org/abs/1409.1556>`__.
- Args:
- weights (:class:`~torchvision.models.VGG19_Weights`, optional): The
- pretrained weights to use. See
- :class:`~torchvision.models.VGG19_Weights` below for
- more details, and possible values. By default, no pre-trained
- weights are used.
- progress (bool, optional): If True, displays a progress bar of the
- download to stderr. Default is True.
- **kwargs: parameters passed to the ``torchvision.models.vgg.VGG``
- base class. Please refer to the `source code
- <https://github.com/pytorch/vision/blob/main/torchvision/models/vgg.py>`_
- for more details about this class.
- .. autoclass:: torchvision.models.VGG19_Weights
- :members:
- """
- weights = VGG19_Weights.verify(weights)
- return _vgg("E", False, weights, progress, **kwargs)
- @register_model()
- @handle_legacy_interface(weights=("pretrained", VGG19_BN_Weights.IMAGENET1K_V1))
- def vgg19_bn(*, weights: Optional[VGG19_BN_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
- """VGG-19_BN from `Very Deep Convolutional Networks for Large-Scale Image Recognition <https://arxiv.org/abs/1409.1556>`__.
- Args:
- weights (:class:`~torchvision.models.VGG19_BN_Weights`, optional): The
- pretrained weights to use. See
- :class:`~torchvision.models.VGG19_BN_Weights` below for
- more details, and possible values. By default, no pre-trained
- weights are used.
- progress (bool, optional): If True, displays a progress bar of the
- download to stderr. Default is True.
- **kwargs: parameters passed to the ``torchvision.models.vgg.VGG``
- base class. Please refer to the `source code
- <https://github.com/pytorch/vision/blob/main/torchvision/models/vgg.py>`_
- for more details about this class.
- .. autoclass:: torchvision.models.VGG19_BN_Weights
- :members:
- """
- weights = VGG19_BN_Weights.verify(weights)
- return _vgg("E", True, weights, progress, **kwargs)
|