mnasnet.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434
  1. import warnings
  2. from functools import partial
  3. from typing import Any, Dict, List, Optional
  4. import torch
  5. import torch.nn as nn
  6. from torch import Tensor
  7. from ..transforms._presets import ImageClassification
  8. from ..utils import _log_api_usage_once
  9. from ._api import register_model, Weights, WeightsEnum
  10. from ._meta import _IMAGENET_CATEGORIES
  11. from ._utils import _ovewrite_named_param, handle_legacy_interface
  12. __all__ = [
  13. "MNASNet",
  14. "MNASNet0_5_Weights",
  15. "MNASNet0_75_Weights",
  16. "MNASNet1_0_Weights",
  17. "MNASNet1_3_Weights",
  18. "mnasnet0_5",
  19. "mnasnet0_75",
  20. "mnasnet1_0",
  21. "mnasnet1_3",
  22. ]
  23. # Paper suggests 0.9997 momentum, for TensorFlow. Equivalent PyTorch momentum is
  24. # 1.0 - tensorflow.
  25. _BN_MOMENTUM = 1 - 0.9997
  26. class _InvertedResidual(nn.Module):
  27. def __init__(
  28. self, in_ch: int, out_ch: int, kernel_size: int, stride: int, expansion_factor: int, bn_momentum: float = 0.1
  29. ) -> None:
  30. super().__init__()
  31. if stride not in [1, 2]:
  32. raise ValueError(f"stride should be 1 or 2 instead of {stride}")
  33. if kernel_size not in [3, 5]:
  34. raise ValueError(f"kernel_size should be 3 or 5 instead of {kernel_size}")
  35. mid_ch = in_ch * expansion_factor
  36. self.apply_residual = in_ch == out_ch and stride == 1
  37. self.layers = nn.Sequential(
  38. # Pointwise
  39. nn.Conv2d(in_ch, mid_ch, 1, bias=False),
  40. nn.BatchNorm2d(mid_ch, momentum=bn_momentum),
  41. nn.ReLU(inplace=True),
  42. # Depthwise
  43. nn.Conv2d(mid_ch, mid_ch, kernel_size, padding=kernel_size // 2, stride=stride, groups=mid_ch, bias=False),
  44. nn.BatchNorm2d(mid_ch, momentum=bn_momentum),
  45. nn.ReLU(inplace=True),
  46. # Linear pointwise. Note that there's no activation.
  47. nn.Conv2d(mid_ch, out_ch, 1, bias=False),
  48. nn.BatchNorm2d(out_ch, momentum=bn_momentum),
  49. )
  50. def forward(self, input: Tensor) -> Tensor:
  51. if self.apply_residual:
  52. return self.layers(input) + input
  53. else:
  54. return self.layers(input)
  55. def _stack(
  56. in_ch: int, out_ch: int, kernel_size: int, stride: int, exp_factor: int, repeats: int, bn_momentum: float
  57. ) -> nn.Sequential:
  58. """Creates a stack of inverted residuals."""
  59. if repeats < 1:
  60. raise ValueError(f"repeats should be >= 1, instead got {repeats}")
  61. # First one has no skip, because feature map size changes.
  62. first = _InvertedResidual(in_ch, out_ch, kernel_size, stride, exp_factor, bn_momentum=bn_momentum)
  63. remaining = []
  64. for _ in range(1, repeats):
  65. remaining.append(_InvertedResidual(out_ch, out_ch, kernel_size, 1, exp_factor, bn_momentum=bn_momentum))
  66. return nn.Sequential(first, *remaining)
  67. def _round_to_multiple_of(val: float, divisor: int, round_up_bias: float = 0.9) -> int:
  68. """Asymmetric rounding to make `val` divisible by `divisor`. With default
  69. bias, will round up, unless the number is no more than 10% greater than the
  70. smaller divisible value, i.e. (83, 8) -> 80, but (84, 8) -> 88."""
  71. if not 0.0 < round_up_bias < 1.0:
  72. raise ValueError(f"round_up_bias should be greater than 0.0 and smaller than 1.0 instead of {round_up_bias}")
  73. new_val = max(divisor, int(val + divisor / 2) // divisor * divisor)
  74. return new_val if new_val >= round_up_bias * val else new_val + divisor
  75. def _get_depths(alpha: float) -> List[int]:
  76. """Scales tensor depths as in reference MobileNet code, prefers rounding up
  77. rather than down."""
  78. depths = [32, 16, 24, 40, 80, 96, 192, 320]
  79. return [_round_to_multiple_of(depth * alpha, 8) for depth in depths]
  80. class MNASNet(torch.nn.Module):
  81. """MNASNet, as described in https://arxiv.org/abs/1807.11626. This
  82. implements the B1 variant of the model.
  83. >>> model = MNASNet(1.0, num_classes=1000)
  84. >>> x = torch.rand(1, 3, 224, 224)
  85. >>> y = model(x)
  86. >>> y.dim()
  87. 2
  88. >>> y.nelement()
  89. 1000
  90. """
  91. # Version 2 adds depth scaling in the initial stages of the network.
  92. _version = 2
  93. def __init__(self, alpha: float, num_classes: int = 1000, dropout: float = 0.2) -> None:
  94. super().__init__()
  95. _log_api_usage_once(self)
  96. if alpha <= 0.0:
  97. raise ValueError(f"alpha should be greater than 0.0 instead of {alpha}")
  98. self.alpha = alpha
  99. self.num_classes = num_classes
  100. depths = _get_depths(alpha)
  101. layers = [
  102. # First layer: regular conv.
  103. nn.Conv2d(3, depths[0], 3, padding=1, stride=2, bias=False),
  104. nn.BatchNorm2d(depths[0], momentum=_BN_MOMENTUM),
  105. nn.ReLU(inplace=True),
  106. # Depthwise separable, no skip.
  107. nn.Conv2d(depths[0], depths[0], 3, padding=1, stride=1, groups=depths[0], bias=False),
  108. nn.BatchNorm2d(depths[0], momentum=_BN_MOMENTUM),
  109. nn.ReLU(inplace=True),
  110. nn.Conv2d(depths[0], depths[1], 1, padding=0, stride=1, bias=False),
  111. nn.BatchNorm2d(depths[1], momentum=_BN_MOMENTUM),
  112. # MNASNet blocks: stacks of inverted residuals.
  113. _stack(depths[1], depths[2], 3, 2, 3, 3, _BN_MOMENTUM),
  114. _stack(depths[2], depths[3], 5, 2, 3, 3, _BN_MOMENTUM),
  115. _stack(depths[3], depths[4], 5, 2, 6, 3, _BN_MOMENTUM),
  116. _stack(depths[4], depths[5], 3, 1, 6, 2, _BN_MOMENTUM),
  117. _stack(depths[5], depths[6], 5, 2, 6, 4, _BN_MOMENTUM),
  118. _stack(depths[6], depths[7], 3, 1, 6, 1, _BN_MOMENTUM),
  119. # Final mapping to classifier input.
  120. nn.Conv2d(depths[7], 1280, 1, padding=0, stride=1, bias=False),
  121. nn.BatchNorm2d(1280, momentum=_BN_MOMENTUM),
  122. nn.ReLU(inplace=True),
  123. ]
  124. self.layers = nn.Sequential(*layers)
  125. self.classifier = nn.Sequential(nn.Dropout(p=dropout, inplace=True), nn.Linear(1280, num_classes))
  126. for m in self.modules():
  127. if isinstance(m, nn.Conv2d):
  128. nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
  129. if m.bias is not None:
  130. nn.init.zeros_(m.bias)
  131. elif isinstance(m, nn.BatchNorm2d):
  132. nn.init.ones_(m.weight)
  133. nn.init.zeros_(m.bias)
  134. elif isinstance(m, nn.Linear):
  135. nn.init.kaiming_uniform_(m.weight, mode="fan_out", nonlinearity="sigmoid")
  136. nn.init.zeros_(m.bias)
  137. def forward(self, x: Tensor) -> Tensor:
  138. x = self.layers(x)
  139. # Equivalent to global avgpool and removing H and W dimensions.
  140. x = x.mean([2, 3])
  141. return self.classifier(x)
  142. def _load_from_state_dict(
  143. self,
  144. state_dict: Dict,
  145. prefix: str,
  146. local_metadata: Dict,
  147. strict: bool,
  148. missing_keys: List[str],
  149. unexpected_keys: List[str],
  150. error_msgs: List[str],
  151. ) -> None:
  152. version = local_metadata.get("version", None)
  153. if version not in [1, 2]:
  154. raise ValueError(f"version shluld be set to 1 or 2 instead of {version}")
  155. if version == 1 and not self.alpha == 1.0:
  156. # In the initial version of the model (v1), stem was fixed-size.
  157. # All other layer configurations were the same. This will patch
  158. # the model so that it's identical to v1. Model with alpha 1.0 is
  159. # unaffected.
  160. depths = _get_depths(self.alpha)
  161. v1_stem = [
  162. nn.Conv2d(3, 32, 3, padding=1, stride=2, bias=False),
  163. nn.BatchNorm2d(32, momentum=_BN_MOMENTUM),
  164. nn.ReLU(inplace=True),
  165. nn.Conv2d(32, 32, 3, padding=1, stride=1, groups=32, bias=False),
  166. nn.BatchNorm2d(32, momentum=_BN_MOMENTUM),
  167. nn.ReLU(inplace=True),
  168. nn.Conv2d(32, 16, 1, padding=0, stride=1, bias=False),
  169. nn.BatchNorm2d(16, momentum=_BN_MOMENTUM),
  170. _stack(16, depths[2], 3, 2, 3, 3, _BN_MOMENTUM),
  171. ]
  172. for idx, layer in enumerate(v1_stem):
  173. self.layers[idx] = layer
  174. # The model is now identical to v1, and must be saved as such.
  175. self._version = 1
  176. warnings.warn(
  177. "A new version of MNASNet model has been implemented. "
  178. "Your checkpoint was saved using the previous version. "
  179. "This checkpoint will load and work as before, but "
  180. "you may want to upgrade by training a newer model or "
  181. "transfer learning from an updated ImageNet checkpoint.",
  182. UserWarning,
  183. )
  184. super()._load_from_state_dict(
  185. state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
  186. )
  187. _COMMON_META = {
  188. "min_size": (1, 1),
  189. "categories": _IMAGENET_CATEGORIES,
  190. "recipe": "https://github.com/1e100/mnasnet_trainer",
  191. }
  192. class MNASNet0_5_Weights(WeightsEnum):
  193. IMAGENET1K_V1 = Weights(
  194. url="https://download.pytorch.org/models/mnasnet0.5_top1_67.823-3ffadce67e.pth",
  195. transforms=partial(ImageClassification, crop_size=224),
  196. meta={
  197. **_COMMON_META,
  198. "num_params": 2218512,
  199. "_metrics": {
  200. "ImageNet-1K": {
  201. "acc@1": 67.734,
  202. "acc@5": 87.490,
  203. }
  204. },
  205. "_ops": 0.104,
  206. "_file_size": 8.591,
  207. "_docs": """These weights reproduce closely the results of the paper.""",
  208. },
  209. )
  210. DEFAULT = IMAGENET1K_V1
  211. class MNASNet0_75_Weights(WeightsEnum):
  212. IMAGENET1K_V1 = Weights(
  213. url="https://download.pytorch.org/models/mnasnet0_75-7090bc5f.pth",
  214. transforms=partial(ImageClassification, crop_size=224, resize_size=232),
  215. meta={
  216. **_COMMON_META,
  217. "recipe": "https://github.com/pytorch/vision/pull/6019",
  218. "num_params": 3170208,
  219. "_metrics": {
  220. "ImageNet-1K": {
  221. "acc@1": 71.180,
  222. "acc@5": 90.496,
  223. }
  224. },
  225. "_ops": 0.215,
  226. "_file_size": 12.303,
  227. "_docs": """
  228. These weights were trained from scratch by using TorchVision's `new training recipe
  229. <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
  230. """,
  231. },
  232. )
  233. DEFAULT = IMAGENET1K_V1
  234. class MNASNet1_0_Weights(WeightsEnum):
  235. IMAGENET1K_V1 = Weights(
  236. url="https://download.pytorch.org/models/mnasnet1.0_top1_73.512-f206786ef8.pth",
  237. transforms=partial(ImageClassification, crop_size=224),
  238. meta={
  239. **_COMMON_META,
  240. "num_params": 4383312,
  241. "_metrics": {
  242. "ImageNet-1K": {
  243. "acc@1": 73.456,
  244. "acc@5": 91.510,
  245. }
  246. },
  247. "_ops": 0.314,
  248. "_file_size": 16.915,
  249. "_docs": """These weights reproduce closely the results of the paper.""",
  250. },
  251. )
  252. DEFAULT = IMAGENET1K_V1
  253. class MNASNet1_3_Weights(WeightsEnum):
  254. IMAGENET1K_V1 = Weights(
  255. url="https://download.pytorch.org/models/mnasnet1_3-a4c69d6f.pth",
  256. transforms=partial(ImageClassification, crop_size=224, resize_size=232),
  257. meta={
  258. **_COMMON_META,
  259. "recipe": "https://github.com/pytorch/vision/pull/6019",
  260. "num_params": 6282256,
  261. "_metrics": {
  262. "ImageNet-1K": {
  263. "acc@1": 76.506,
  264. "acc@5": 93.522,
  265. }
  266. },
  267. "_ops": 0.526,
  268. "_file_size": 24.246,
  269. "_docs": """
  270. These weights were trained from scratch by using TorchVision's `new training recipe
  271. <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
  272. """,
  273. },
  274. )
  275. DEFAULT = IMAGENET1K_V1
  276. def _mnasnet(alpha: float, weights: Optional[WeightsEnum], progress: bool, **kwargs: Any) -> MNASNet:
  277. if weights is not None:
  278. _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
  279. model = MNASNet(alpha, **kwargs)
  280. if weights:
  281. model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
  282. return model
  283. @register_model()
  284. @handle_legacy_interface(weights=("pretrained", MNASNet0_5_Weights.IMAGENET1K_V1))
  285. def mnasnet0_5(*, weights: Optional[MNASNet0_5_Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet:
  286. """MNASNet with depth multiplier of 0.5 from
  287. `MnasNet: Platform-Aware Neural Architecture Search for Mobile
  288. <https://arxiv.org/abs/1807.11626>`_ paper.
  289. Args:
  290. weights (:class:`~torchvision.models.MNASNet0_5_Weights`, optional): The
  291. pretrained weights to use. See
  292. :class:`~torchvision.models.MNASNet0_5_Weights` below for
  293. more details, and possible values. By default, no pre-trained
  294. weights are used.
  295. progress (bool, optional): If True, displays a progress bar of the
  296. download to stderr. Default is True.
  297. **kwargs: parameters passed to the ``torchvision.models.mnasnet.MNASNet``
  298. base class. Please refer to the `source code
  299. <https://github.com/pytorch/vision/blob/main/torchvision/models/mnasnet.py>`_
  300. for more details about this class.
  301. .. autoclass:: torchvision.models.MNASNet0_5_Weights
  302. :members:
  303. """
  304. weights = MNASNet0_5_Weights.verify(weights)
  305. return _mnasnet(0.5, weights, progress, **kwargs)
  306. @register_model()
  307. @handle_legacy_interface(weights=("pretrained", MNASNet0_75_Weights.IMAGENET1K_V1))
  308. def mnasnet0_75(*, weights: Optional[MNASNet0_75_Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet:
  309. """MNASNet with depth multiplier of 0.75 from
  310. `MnasNet: Platform-Aware Neural Architecture Search for Mobile
  311. <https://arxiv.org/abs/1807.11626>`_ paper.
  312. Args:
  313. weights (:class:`~torchvision.models.MNASNet0_75_Weights`, optional): The
  314. pretrained weights to use. See
  315. :class:`~torchvision.models.MNASNet0_75_Weights` below for
  316. more details, and possible values. By default, no pre-trained
  317. weights are used.
  318. progress (bool, optional): If True, displays a progress bar of the
  319. download to stderr. Default is True.
  320. **kwargs: parameters passed to the ``torchvision.models.mnasnet.MNASNet``
  321. base class. Please refer to the `source code
  322. <https://github.com/pytorch/vision/blob/main/torchvision/models/mnasnet.py>`_
  323. for more details about this class.
  324. .. autoclass:: torchvision.models.MNASNet0_75_Weights
  325. :members:
  326. """
  327. weights = MNASNet0_75_Weights.verify(weights)
  328. return _mnasnet(0.75, weights, progress, **kwargs)
  329. @register_model()
  330. @handle_legacy_interface(weights=("pretrained", MNASNet1_0_Weights.IMAGENET1K_V1))
  331. def mnasnet1_0(*, weights: Optional[MNASNet1_0_Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet:
  332. """MNASNet with depth multiplier of 1.0 from
  333. `MnasNet: Platform-Aware Neural Architecture Search for Mobile
  334. <https://arxiv.org/abs/1807.11626>`_ paper.
  335. Args:
  336. weights (:class:`~torchvision.models.MNASNet1_0_Weights`, optional): The
  337. pretrained weights to use. See
  338. :class:`~torchvision.models.MNASNet1_0_Weights` below for
  339. more details, and possible values. By default, no pre-trained
  340. weights are used.
  341. progress (bool, optional): If True, displays a progress bar of the
  342. download to stderr. Default is True.
  343. **kwargs: parameters passed to the ``torchvision.models.mnasnet.MNASNet``
  344. base class. Please refer to the `source code
  345. <https://github.com/pytorch/vision/blob/main/torchvision/models/mnasnet.py>`_
  346. for more details about this class.
  347. .. autoclass:: torchvision.models.MNASNet1_0_Weights
  348. :members:
  349. """
  350. weights = MNASNet1_0_Weights.verify(weights)
  351. return _mnasnet(1.0, weights, progress, **kwargs)
  352. @register_model()
  353. @handle_legacy_interface(weights=("pretrained", MNASNet1_3_Weights.IMAGENET1K_V1))
  354. def mnasnet1_3(*, weights: Optional[MNASNet1_3_Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet:
  355. """MNASNet with depth multiplier of 1.3 from
  356. `MnasNet: Platform-Aware Neural Architecture Search for Mobile
  357. <https://arxiv.org/abs/1807.11626>`_ paper.
  358. Args:
  359. weights (:class:`~torchvision.models.MNASNet1_3_Weights`, optional): The
  360. pretrained weights to use. See
  361. :class:`~torchvision.models.MNASNet1_3_Weights` below for
  362. more details, and possible values. By default, no pre-trained
  363. weights are used.
  364. progress (bool, optional): If True, displays a progress bar of the
  365. download to stderr. Default is True.
  366. **kwargs: parameters passed to the ``torchvision.models.mnasnet.MNASNet``
  367. base class. Please refer to the `source code
  368. <https://github.com/pytorch/vision/blob/main/torchvision/models/mnasnet.py>`_
  369. for more details about this class.
  370. .. autoclass:: torchvision.models.MNASNet1_3_Weights
  371. :members:
  372. """
  373. weights = MNASNet1_3_Weights.verify(weights)
  374. return _mnasnet(1.3, weights, progress, **kwargs)