123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434 |
- import warnings
- from functools import partial
- from typing import Any, Dict, List, Optional
- import torch
- import torch.nn as nn
- from torch import Tensor
- from ..transforms._presets import ImageClassification
- from ..utils import _log_api_usage_once
- from ._api import register_model, Weights, WeightsEnum
- from ._meta import _IMAGENET_CATEGORIES
- from ._utils import _ovewrite_named_param, handle_legacy_interface
- __all__ = [
- "MNASNet",
- "MNASNet0_5_Weights",
- "MNASNet0_75_Weights",
- "MNASNet1_0_Weights",
- "MNASNet1_3_Weights",
- "mnasnet0_5",
- "mnasnet0_75",
- "mnasnet1_0",
- "mnasnet1_3",
- ]
- # Paper suggests 0.9997 momentum, for TensorFlow. Equivalent PyTorch momentum is
- # 1.0 - tensorflow.
- _BN_MOMENTUM = 1 - 0.9997
- class _InvertedResidual(nn.Module):
- def __init__(
- self, in_ch: int, out_ch: int, kernel_size: int, stride: int, expansion_factor: int, bn_momentum: float = 0.1
- ) -> None:
- super().__init__()
- if stride not in [1, 2]:
- raise ValueError(f"stride should be 1 or 2 instead of {stride}")
- if kernel_size not in [3, 5]:
- raise ValueError(f"kernel_size should be 3 or 5 instead of {kernel_size}")
- mid_ch = in_ch * expansion_factor
- self.apply_residual = in_ch == out_ch and stride == 1
- self.layers = nn.Sequential(
- # Pointwise
- nn.Conv2d(in_ch, mid_ch, 1, bias=False),
- nn.BatchNorm2d(mid_ch, momentum=bn_momentum),
- nn.ReLU(inplace=True),
- # Depthwise
- nn.Conv2d(mid_ch, mid_ch, kernel_size, padding=kernel_size // 2, stride=stride, groups=mid_ch, bias=False),
- nn.BatchNorm2d(mid_ch, momentum=bn_momentum),
- nn.ReLU(inplace=True),
- # Linear pointwise. Note that there's no activation.
- nn.Conv2d(mid_ch, out_ch, 1, bias=False),
- nn.BatchNorm2d(out_ch, momentum=bn_momentum),
- )
- def forward(self, input: Tensor) -> Tensor:
- if self.apply_residual:
- return self.layers(input) + input
- else:
- return self.layers(input)
- def _stack(
- in_ch: int, out_ch: int, kernel_size: int, stride: int, exp_factor: int, repeats: int, bn_momentum: float
- ) -> nn.Sequential:
- """Creates a stack of inverted residuals."""
- if repeats < 1:
- raise ValueError(f"repeats should be >= 1, instead got {repeats}")
- # First one has no skip, because feature map size changes.
- first = _InvertedResidual(in_ch, out_ch, kernel_size, stride, exp_factor, bn_momentum=bn_momentum)
- remaining = []
- for _ in range(1, repeats):
- remaining.append(_InvertedResidual(out_ch, out_ch, kernel_size, 1, exp_factor, bn_momentum=bn_momentum))
- return nn.Sequential(first, *remaining)
- def _round_to_multiple_of(val: float, divisor: int, round_up_bias: float = 0.9) -> int:
- """Asymmetric rounding to make `val` divisible by `divisor`. With default
- bias, will round up, unless the number is no more than 10% greater than the
- smaller divisible value, i.e. (83, 8) -> 80, but (84, 8) -> 88."""
- if not 0.0 < round_up_bias < 1.0:
- raise ValueError(f"round_up_bias should be greater than 0.0 and smaller than 1.0 instead of {round_up_bias}")
- new_val = max(divisor, int(val + divisor / 2) // divisor * divisor)
- return new_val if new_val >= round_up_bias * val else new_val + divisor
- def _get_depths(alpha: float) -> List[int]:
- """Scales tensor depths as in reference MobileNet code, prefers rounding up
- rather than down."""
- depths = [32, 16, 24, 40, 80, 96, 192, 320]
- return [_round_to_multiple_of(depth * alpha, 8) for depth in depths]
- class MNASNet(torch.nn.Module):
- """MNASNet, as described in https://arxiv.org/abs/1807.11626. This
- implements the B1 variant of the model.
- >>> model = MNASNet(1.0, num_classes=1000)
- >>> x = torch.rand(1, 3, 224, 224)
- >>> y = model(x)
- >>> y.dim()
- 2
- >>> y.nelement()
- 1000
- """
- # Version 2 adds depth scaling in the initial stages of the network.
- _version = 2
- def __init__(self, alpha: float, num_classes: int = 1000, dropout: float = 0.2) -> None:
- super().__init__()
- _log_api_usage_once(self)
- if alpha <= 0.0:
- raise ValueError(f"alpha should be greater than 0.0 instead of {alpha}")
- self.alpha = alpha
- self.num_classes = num_classes
- depths = _get_depths(alpha)
- layers = [
- # First layer: regular conv.
- nn.Conv2d(3, depths[0], 3, padding=1, stride=2, bias=False),
- nn.BatchNorm2d(depths[0], momentum=_BN_MOMENTUM),
- nn.ReLU(inplace=True),
- # Depthwise separable, no skip.
- nn.Conv2d(depths[0], depths[0], 3, padding=1, stride=1, groups=depths[0], bias=False),
- nn.BatchNorm2d(depths[0], momentum=_BN_MOMENTUM),
- nn.ReLU(inplace=True),
- nn.Conv2d(depths[0], depths[1], 1, padding=0, stride=1, bias=False),
- nn.BatchNorm2d(depths[1], momentum=_BN_MOMENTUM),
- # MNASNet blocks: stacks of inverted residuals.
- _stack(depths[1], depths[2], 3, 2, 3, 3, _BN_MOMENTUM),
- _stack(depths[2], depths[3], 5, 2, 3, 3, _BN_MOMENTUM),
- _stack(depths[3], depths[4], 5, 2, 6, 3, _BN_MOMENTUM),
- _stack(depths[4], depths[5], 3, 1, 6, 2, _BN_MOMENTUM),
- _stack(depths[5], depths[6], 5, 2, 6, 4, _BN_MOMENTUM),
- _stack(depths[6], depths[7], 3, 1, 6, 1, _BN_MOMENTUM),
- # Final mapping to classifier input.
- nn.Conv2d(depths[7], 1280, 1, padding=0, stride=1, bias=False),
- nn.BatchNorm2d(1280, momentum=_BN_MOMENTUM),
- nn.ReLU(inplace=True),
- ]
- self.layers = nn.Sequential(*layers)
- self.classifier = nn.Sequential(nn.Dropout(p=dropout, inplace=True), nn.Linear(1280, num_classes))
- for m in self.modules():
- if isinstance(m, nn.Conv2d):
- nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
- if m.bias is not None:
- nn.init.zeros_(m.bias)
- elif isinstance(m, nn.BatchNorm2d):
- nn.init.ones_(m.weight)
- nn.init.zeros_(m.bias)
- elif isinstance(m, nn.Linear):
- nn.init.kaiming_uniform_(m.weight, mode="fan_out", nonlinearity="sigmoid")
- nn.init.zeros_(m.bias)
- def forward(self, x: Tensor) -> Tensor:
- x = self.layers(x)
- # Equivalent to global avgpool and removing H and W dimensions.
- x = x.mean([2, 3])
- return self.classifier(x)
- def _load_from_state_dict(
- self,
- state_dict: Dict,
- prefix: str,
- local_metadata: Dict,
- strict: bool,
- missing_keys: List[str],
- unexpected_keys: List[str],
- error_msgs: List[str],
- ) -> None:
- version = local_metadata.get("version", None)
- if version not in [1, 2]:
- raise ValueError(f"version shluld be set to 1 or 2 instead of {version}")
- if version == 1 and not self.alpha == 1.0:
- # In the initial version of the model (v1), stem was fixed-size.
- # All other layer configurations were the same. This will patch
- # the model so that it's identical to v1. Model with alpha 1.0 is
- # unaffected.
- depths = _get_depths(self.alpha)
- v1_stem = [
- nn.Conv2d(3, 32, 3, padding=1, stride=2, bias=False),
- nn.BatchNorm2d(32, momentum=_BN_MOMENTUM),
- nn.ReLU(inplace=True),
- nn.Conv2d(32, 32, 3, padding=1, stride=1, groups=32, bias=False),
- nn.BatchNorm2d(32, momentum=_BN_MOMENTUM),
- nn.ReLU(inplace=True),
- nn.Conv2d(32, 16, 1, padding=0, stride=1, bias=False),
- nn.BatchNorm2d(16, momentum=_BN_MOMENTUM),
- _stack(16, depths[2], 3, 2, 3, 3, _BN_MOMENTUM),
- ]
- for idx, layer in enumerate(v1_stem):
- self.layers[idx] = layer
- # The model is now identical to v1, and must be saved as such.
- self._version = 1
- warnings.warn(
- "A new version of MNASNet model has been implemented. "
- "Your checkpoint was saved using the previous version. "
- "This checkpoint will load and work as before, but "
- "you may want to upgrade by training a newer model or "
- "transfer learning from an updated ImageNet checkpoint.",
- UserWarning,
- )
- super()._load_from_state_dict(
- state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
- )
- _COMMON_META = {
- "min_size": (1, 1),
- "categories": _IMAGENET_CATEGORIES,
- "recipe": "https://github.com/1e100/mnasnet_trainer",
- }
- class MNASNet0_5_Weights(WeightsEnum):
- IMAGENET1K_V1 = Weights(
- url="https://download.pytorch.org/models/mnasnet0.5_top1_67.823-3ffadce67e.pth",
- transforms=partial(ImageClassification, crop_size=224),
- meta={
- **_COMMON_META,
- "num_params": 2218512,
- "_metrics": {
- "ImageNet-1K": {
- "acc@1": 67.734,
- "acc@5": 87.490,
- }
- },
- "_ops": 0.104,
- "_file_size": 8.591,
- "_docs": """These weights reproduce closely the results of the paper.""",
- },
- )
- DEFAULT = IMAGENET1K_V1
- class MNASNet0_75_Weights(WeightsEnum):
- IMAGENET1K_V1 = Weights(
- url="https://download.pytorch.org/models/mnasnet0_75-7090bc5f.pth",
- transforms=partial(ImageClassification, crop_size=224, resize_size=232),
- meta={
- **_COMMON_META,
- "recipe": "https://github.com/pytorch/vision/pull/6019",
- "num_params": 3170208,
- "_metrics": {
- "ImageNet-1K": {
- "acc@1": 71.180,
- "acc@5": 90.496,
- }
- },
- "_ops": 0.215,
- "_file_size": 12.303,
- "_docs": """
- These weights were trained from scratch by using TorchVision's `new training recipe
- <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
- """,
- },
- )
- DEFAULT = IMAGENET1K_V1
- class MNASNet1_0_Weights(WeightsEnum):
- IMAGENET1K_V1 = Weights(
- url="https://download.pytorch.org/models/mnasnet1.0_top1_73.512-f206786ef8.pth",
- transforms=partial(ImageClassification, crop_size=224),
- meta={
- **_COMMON_META,
- "num_params": 4383312,
- "_metrics": {
- "ImageNet-1K": {
- "acc@1": 73.456,
- "acc@5": 91.510,
- }
- },
- "_ops": 0.314,
- "_file_size": 16.915,
- "_docs": """These weights reproduce closely the results of the paper.""",
- },
- )
- DEFAULT = IMAGENET1K_V1
- class MNASNet1_3_Weights(WeightsEnum):
- IMAGENET1K_V1 = Weights(
- url="https://download.pytorch.org/models/mnasnet1_3-a4c69d6f.pth",
- transforms=partial(ImageClassification, crop_size=224, resize_size=232),
- meta={
- **_COMMON_META,
- "recipe": "https://github.com/pytorch/vision/pull/6019",
- "num_params": 6282256,
- "_metrics": {
- "ImageNet-1K": {
- "acc@1": 76.506,
- "acc@5": 93.522,
- }
- },
- "_ops": 0.526,
- "_file_size": 24.246,
- "_docs": """
- These weights were trained from scratch by using TorchVision's `new training recipe
- <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
- """,
- },
- )
- DEFAULT = IMAGENET1K_V1
- def _mnasnet(alpha: float, weights: Optional[WeightsEnum], progress: bool, **kwargs: Any) -> MNASNet:
- if weights is not None:
- _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
- model = MNASNet(alpha, **kwargs)
- if weights:
- model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
- return model
- @register_model()
- @handle_legacy_interface(weights=("pretrained", MNASNet0_5_Weights.IMAGENET1K_V1))
- def mnasnet0_5(*, weights: Optional[MNASNet0_5_Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet:
- """MNASNet with depth multiplier of 0.5 from
- `MnasNet: Platform-Aware Neural Architecture Search for Mobile
- <https://arxiv.org/abs/1807.11626>`_ paper.
- Args:
- weights (:class:`~torchvision.models.MNASNet0_5_Weights`, optional): The
- pretrained weights to use. See
- :class:`~torchvision.models.MNASNet0_5_Weights` below for
- more details, and possible values. By default, no pre-trained
- weights are used.
- progress (bool, optional): If True, displays a progress bar of the
- download to stderr. Default is True.
- **kwargs: parameters passed to the ``torchvision.models.mnasnet.MNASNet``
- base class. Please refer to the `source code
- <https://github.com/pytorch/vision/blob/main/torchvision/models/mnasnet.py>`_
- for more details about this class.
- .. autoclass:: torchvision.models.MNASNet0_5_Weights
- :members:
- """
- weights = MNASNet0_5_Weights.verify(weights)
- return _mnasnet(0.5, weights, progress, **kwargs)
- @register_model()
- @handle_legacy_interface(weights=("pretrained", MNASNet0_75_Weights.IMAGENET1K_V1))
- def mnasnet0_75(*, weights: Optional[MNASNet0_75_Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet:
- """MNASNet with depth multiplier of 0.75 from
- `MnasNet: Platform-Aware Neural Architecture Search for Mobile
- <https://arxiv.org/abs/1807.11626>`_ paper.
- Args:
- weights (:class:`~torchvision.models.MNASNet0_75_Weights`, optional): The
- pretrained weights to use. See
- :class:`~torchvision.models.MNASNet0_75_Weights` below for
- more details, and possible values. By default, no pre-trained
- weights are used.
- progress (bool, optional): If True, displays a progress bar of the
- download to stderr. Default is True.
- **kwargs: parameters passed to the ``torchvision.models.mnasnet.MNASNet``
- base class. Please refer to the `source code
- <https://github.com/pytorch/vision/blob/main/torchvision/models/mnasnet.py>`_
- for more details about this class.
- .. autoclass:: torchvision.models.MNASNet0_75_Weights
- :members:
- """
- weights = MNASNet0_75_Weights.verify(weights)
- return _mnasnet(0.75, weights, progress, **kwargs)
- @register_model()
- @handle_legacy_interface(weights=("pretrained", MNASNet1_0_Weights.IMAGENET1K_V1))
- def mnasnet1_0(*, weights: Optional[MNASNet1_0_Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet:
- """MNASNet with depth multiplier of 1.0 from
- `MnasNet: Platform-Aware Neural Architecture Search for Mobile
- <https://arxiv.org/abs/1807.11626>`_ paper.
- Args:
- weights (:class:`~torchvision.models.MNASNet1_0_Weights`, optional): The
- pretrained weights to use. See
- :class:`~torchvision.models.MNASNet1_0_Weights` below for
- more details, and possible values. By default, no pre-trained
- weights are used.
- progress (bool, optional): If True, displays a progress bar of the
- download to stderr. Default is True.
- **kwargs: parameters passed to the ``torchvision.models.mnasnet.MNASNet``
- base class. Please refer to the `source code
- <https://github.com/pytorch/vision/blob/main/torchvision/models/mnasnet.py>`_
- for more details about this class.
- .. autoclass:: torchvision.models.MNASNet1_0_Weights
- :members:
- """
- weights = MNASNet1_0_Weights.verify(weights)
- return _mnasnet(1.0, weights, progress, **kwargs)
- @register_model()
- @handle_legacy_interface(weights=("pretrained", MNASNet1_3_Weights.IMAGENET1K_V1))
- def mnasnet1_3(*, weights: Optional[MNASNet1_3_Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet:
- """MNASNet with depth multiplier of 1.3 from
- `MnasNet: Platform-Aware Neural Architecture Search for Mobile
- <https://arxiv.org/abs/1807.11626>`_ paper.
- Args:
- weights (:class:`~torchvision.models.MNASNet1_3_Weights`, optional): The
- pretrained weights to use. See
- :class:`~torchvision.models.MNASNet1_3_Weights` below for
- more details, and possible values. By default, no pre-trained
- weights are used.
- progress (bool, optional): If True, displays a progress bar of the
- download to stderr. Default is True.
- **kwargs: parameters passed to the ``torchvision.models.mnasnet.MNASNet``
- base class. Please refer to the `source code
- <https://github.com/pytorch/vision/blob/main/torchvision/models/mnasnet.py>`_
- for more details about this class.
- .. autoclass:: torchvision.models.MNASNet1_3_Weights
- :members:
- """
- weights = MNASNet1_3_Weights.verify(weights)
- return _mnasnet(1.3, weights, progress, **kwargs)
|