shufflenetv2.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427
  1. from functools import partial
  2. from typing import Any, List, Optional, Union
  3. import torch
  4. import torch.nn as nn
  5. from torch import Tensor
  6. from torchvision.models import shufflenetv2
  7. from ...transforms._presets import ImageClassification
  8. from .._api import register_model, Weights, WeightsEnum
  9. from .._meta import _IMAGENET_CATEGORIES
  10. from .._utils import _ovewrite_named_param, handle_legacy_interface
  11. from ..shufflenetv2 import (
  12. ShuffleNet_V2_X0_5_Weights,
  13. ShuffleNet_V2_X1_0_Weights,
  14. ShuffleNet_V2_X1_5_Weights,
  15. ShuffleNet_V2_X2_0_Weights,
  16. )
  17. from .utils import _fuse_modules, _replace_relu, quantize_model
  18. __all__ = [
  19. "QuantizableShuffleNetV2",
  20. "ShuffleNet_V2_X0_5_QuantizedWeights",
  21. "ShuffleNet_V2_X1_0_QuantizedWeights",
  22. "ShuffleNet_V2_X1_5_QuantizedWeights",
  23. "ShuffleNet_V2_X2_0_QuantizedWeights",
  24. "shufflenet_v2_x0_5",
  25. "shufflenet_v2_x1_0",
  26. "shufflenet_v2_x1_5",
  27. "shufflenet_v2_x2_0",
  28. ]
  29. class QuantizableInvertedResidual(shufflenetv2.InvertedResidual):
  30. def __init__(self, *args: Any, **kwargs: Any) -> None:
  31. super().__init__(*args, **kwargs)
  32. self.cat = nn.quantized.FloatFunctional()
  33. def forward(self, x: Tensor) -> Tensor:
  34. if self.stride == 1:
  35. x1, x2 = x.chunk(2, dim=1)
  36. out = self.cat.cat([x1, self.branch2(x2)], dim=1)
  37. else:
  38. out = self.cat.cat([self.branch1(x), self.branch2(x)], dim=1)
  39. out = shufflenetv2.channel_shuffle(out, 2)
  40. return out
  41. class QuantizableShuffleNetV2(shufflenetv2.ShuffleNetV2):
  42. # TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659
  43. def __init__(self, *args: Any, **kwargs: Any) -> None:
  44. super().__init__(*args, inverted_residual=QuantizableInvertedResidual, **kwargs) # type: ignore[misc]
  45. self.quant = torch.ao.quantization.QuantStub()
  46. self.dequant = torch.ao.quantization.DeQuantStub()
  47. def forward(self, x: Tensor) -> Tensor:
  48. x = self.quant(x)
  49. x = self._forward_impl(x)
  50. x = self.dequant(x)
  51. return x
  52. def fuse_model(self, is_qat: Optional[bool] = None) -> None:
  53. r"""Fuse conv/bn/relu modules in shufflenetv2 model
  54. Fuse conv+bn+relu/ conv+relu/conv+bn modules to prepare for quantization.
  55. Model is modified in place.
  56. .. note::
  57. Note that this operation does not change numerics
  58. and the model after modification is in floating point
  59. """
  60. for name, m in self._modules.items():
  61. if name in ["conv1", "conv5"] and m is not None:
  62. _fuse_modules(m, [["0", "1", "2"]], is_qat, inplace=True)
  63. for m in self.modules():
  64. if type(m) is QuantizableInvertedResidual:
  65. if len(m.branch1._modules.items()) > 0:
  66. _fuse_modules(m.branch1, [["0", "1"], ["2", "3", "4"]], is_qat, inplace=True)
  67. _fuse_modules(
  68. m.branch2,
  69. [["0", "1", "2"], ["3", "4"], ["5", "6", "7"]],
  70. is_qat,
  71. inplace=True,
  72. )
  73. def _shufflenetv2(
  74. stages_repeats: List[int],
  75. stages_out_channels: List[int],
  76. *,
  77. weights: Optional[WeightsEnum],
  78. progress: bool,
  79. quantize: bool,
  80. **kwargs: Any,
  81. ) -> QuantizableShuffleNetV2:
  82. if weights is not None:
  83. _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
  84. if "backend" in weights.meta:
  85. _ovewrite_named_param(kwargs, "backend", weights.meta["backend"])
  86. backend = kwargs.pop("backend", "fbgemm")
  87. model = QuantizableShuffleNetV2(stages_repeats, stages_out_channels, **kwargs)
  88. _replace_relu(model)
  89. if quantize:
  90. quantize_model(model, backend)
  91. if weights is not None:
  92. model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
  93. return model
  94. _COMMON_META = {
  95. "min_size": (1, 1),
  96. "categories": _IMAGENET_CATEGORIES,
  97. "backend": "fbgemm",
  98. "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#post-training-quantized-models",
  99. "_docs": """
  100. These weights were produced by doing Post Training Quantization (eager mode) on top of the unquantized
  101. weights listed below.
  102. """,
  103. }
  104. class ShuffleNet_V2_X0_5_QuantizedWeights(WeightsEnum):
  105. IMAGENET1K_FBGEMM_V1 = Weights(
  106. url="https://download.pytorch.org/models/quantized/shufflenetv2_x0.5_fbgemm-00845098.pth",
  107. transforms=partial(ImageClassification, crop_size=224),
  108. meta={
  109. **_COMMON_META,
  110. "num_params": 1366792,
  111. "unquantized": ShuffleNet_V2_X0_5_Weights.IMAGENET1K_V1,
  112. "_metrics": {
  113. "ImageNet-1K": {
  114. "acc@1": 57.972,
  115. "acc@5": 79.780,
  116. }
  117. },
  118. "_ops": 0.04,
  119. "_file_size": 1.501,
  120. },
  121. )
  122. DEFAULT = IMAGENET1K_FBGEMM_V1
  123. class ShuffleNet_V2_X1_0_QuantizedWeights(WeightsEnum):
  124. IMAGENET1K_FBGEMM_V1 = Weights(
  125. url="https://download.pytorch.org/models/quantized/shufflenetv2_x1_fbgemm-1e62bb32.pth",
  126. transforms=partial(ImageClassification, crop_size=224),
  127. meta={
  128. **_COMMON_META,
  129. "num_params": 2278604,
  130. "unquantized": ShuffleNet_V2_X1_0_Weights.IMAGENET1K_V1,
  131. "_metrics": {
  132. "ImageNet-1K": {
  133. "acc@1": 68.360,
  134. "acc@5": 87.582,
  135. }
  136. },
  137. "_ops": 0.145,
  138. "_file_size": 2.334,
  139. },
  140. )
  141. DEFAULT = IMAGENET1K_FBGEMM_V1
  142. class ShuffleNet_V2_X1_5_QuantizedWeights(WeightsEnum):
  143. IMAGENET1K_FBGEMM_V1 = Weights(
  144. url="https://download.pytorch.org/models/quantized/shufflenetv2_x1_5_fbgemm-d7401f05.pth",
  145. transforms=partial(ImageClassification, crop_size=224, resize_size=232),
  146. meta={
  147. **_COMMON_META,
  148. "recipe": "https://github.com/pytorch/vision/pull/5906",
  149. "num_params": 3503624,
  150. "unquantized": ShuffleNet_V2_X1_5_Weights.IMAGENET1K_V1,
  151. "_metrics": {
  152. "ImageNet-1K": {
  153. "acc@1": 72.052,
  154. "acc@5": 90.700,
  155. }
  156. },
  157. "_ops": 0.296,
  158. "_file_size": 3.672,
  159. },
  160. )
  161. DEFAULT = IMAGENET1K_FBGEMM_V1
  162. class ShuffleNet_V2_X2_0_QuantizedWeights(WeightsEnum):
  163. IMAGENET1K_FBGEMM_V1 = Weights(
  164. url="https://download.pytorch.org/models/quantized/shufflenetv2_x2_0_fbgemm-5cac526c.pth",
  165. transforms=partial(ImageClassification, crop_size=224, resize_size=232),
  166. meta={
  167. **_COMMON_META,
  168. "recipe": "https://github.com/pytorch/vision/pull/5906",
  169. "num_params": 7393996,
  170. "unquantized": ShuffleNet_V2_X2_0_Weights.IMAGENET1K_V1,
  171. "_metrics": {
  172. "ImageNet-1K": {
  173. "acc@1": 75.354,
  174. "acc@5": 92.488,
  175. }
  176. },
  177. "_ops": 0.583,
  178. "_file_size": 7.467,
  179. },
  180. )
  181. DEFAULT = IMAGENET1K_FBGEMM_V1
  182. @register_model(name="quantized_shufflenet_v2_x0_5")
  183. @handle_legacy_interface(
  184. weights=(
  185. "pretrained",
  186. lambda kwargs: ShuffleNet_V2_X0_5_QuantizedWeights.IMAGENET1K_FBGEMM_V1
  187. if kwargs.get("quantize", False)
  188. else ShuffleNet_V2_X0_5_Weights.IMAGENET1K_V1,
  189. )
  190. )
  191. def shufflenet_v2_x0_5(
  192. *,
  193. weights: Optional[Union[ShuffleNet_V2_X0_5_QuantizedWeights, ShuffleNet_V2_X0_5_Weights]] = None,
  194. progress: bool = True,
  195. quantize: bool = False,
  196. **kwargs: Any,
  197. ) -> QuantizableShuffleNetV2:
  198. """
  199. Constructs a ShuffleNetV2 with 0.5x output channels, as described in
  200. `ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design
  201. <https://arxiv.org/abs/1807.11164>`__.
  202. .. note::
  203. Note that ``quantize = True`` returns a quantized model with 8 bit
  204. weights. Quantized models only support inference and run on CPUs.
  205. GPU inference is not yet supported.
  206. Args:
  207. weights (:class:`~torchvision.models.quantization.ShuffleNet_V2_X0_5_QuantizedWeights` or :class:`~torchvision.models.ShuffleNet_V2_X0_5_Weights`, optional): The
  208. pretrained weights for the model. See
  209. :class:`~torchvision.models.quantization.ShuffleNet_V2_X0_5_QuantizedWeights` below for
  210. more details, and possible values. By default, no pre-trained
  211. weights are used.
  212. progress (bool, optional): If True, displays a progress bar of the download to stderr.
  213. Default is True.
  214. quantize (bool, optional): If True, return a quantized version of the model.
  215. Default is False.
  216. **kwargs: parameters passed to the ``torchvision.models.quantization.ShuffleNet_V2_X0_5_QuantizedWeights``
  217. base class. Please refer to the `source code
  218. <https://github.com/pytorch/vision/blob/main/torchvision/models/quantization/shufflenetv2.py>`_
  219. for more details about this class.
  220. .. autoclass:: torchvision.models.quantization.ShuffleNet_V2_X0_5_QuantizedWeights
  221. :members:
  222. .. autoclass:: torchvision.models.ShuffleNet_V2_X0_5_Weights
  223. :members:
  224. :noindex:
  225. """
  226. weights = (ShuffleNet_V2_X0_5_QuantizedWeights if quantize else ShuffleNet_V2_X0_5_Weights).verify(weights)
  227. return _shufflenetv2(
  228. [4, 8, 4], [24, 48, 96, 192, 1024], weights=weights, progress=progress, quantize=quantize, **kwargs
  229. )
  230. @register_model(name="quantized_shufflenet_v2_x1_0")
  231. @handle_legacy_interface(
  232. weights=(
  233. "pretrained",
  234. lambda kwargs: ShuffleNet_V2_X1_0_QuantizedWeights.IMAGENET1K_FBGEMM_V1
  235. if kwargs.get("quantize", False)
  236. else ShuffleNet_V2_X1_0_Weights.IMAGENET1K_V1,
  237. )
  238. )
  239. def shufflenet_v2_x1_0(
  240. *,
  241. weights: Optional[Union[ShuffleNet_V2_X1_0_QuantizedWeights, ShuffleNet_V2_X1_0_Weights]] = None,
  242. progress: bool = True,
  243. quantize: bool = False,
  244. **kwargs: Any,
  245. ) -> QuantizableShuffleNetV2:
  246. """
  247. Constructs a ShuffleNetV2 with 1.0x output channels, as described in
  248. `ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design
  249. <https://arxiv.org/abs/1807.11164>`__.
  250. .. note::
  251. Note that ``quantize = True`` returns a quantized model with 8 bit
  252. weights. Quantized models only support inference and run on CPUs.
  253. GPU inference is not yet supported.
  254. Args:
  255. weights (:class:`~torchvision.models.quantization.ShuffleNet_V2_X1_0_QuantizedWeights` or :class:`~torchvision.models.ShuffleNet_V2_X1_0_Weights`, optional): The
  256. pretrained weights for the model. See
  257. :class:`~torchvision.models.quantization.ShuffleNet_V2_X1_0_QuantizedWeights` below for
  258. more details, and possible values. By default, no pre-trained
  259. weights are used.
  260. progress (bool, optional): If True, displays a progress bar of the download to stderr.
  261. Default is True.
  262. quantize (bool, optional): If True, return a quantized version of the model.
  263. Default is False.
  264. **kwargs: parameters passed to the ``torchvision.models.quantization.ShuffleNet_V2_X1_0_QuantizedWeights``
  265. base class. Please refer to the `source code
  266. <https://github.com/pytorch/vision/blob/main/torchvision/models/quantization/shufflenetv2.py>`_
  267. for more details about this class.
  268. .. autoclass:: torchvision.models.quantization.ShuffleNet_V2_X1_0_QuantizedWeights
  269. :members:
  270. .. autoclass:: torchvision.models.ShuffleNet_V2_X1_0_Weights
  271. :members:
  272. :noindex:
  273. """
  274. weights = (ShuffleNet_V2_X1_0_QuantizedWeights if quantize else ShuffleNet_V2_X1_0_Weights).verify(weights)
  275. return _shufflenetv2(
  276. [4, 8, 4], [24, 116, 232, 464, 1024], weights=weights, progress=progress, quantize=quantize, **kwargs
  277. )
  278. @register_model(name="quantized_shufflenet_v2_x1_5")
  279. @handle_legacy_interface(
  280. weights=(
  281. "pretrained",
  282. lambda kwargs: ShuffleNet_V2_X1_5_QuantizedWeights.IMAGENET1K_FBGEMM_V1
  283. if kwargs.get("quantize", False)
  284. else ShuffleNet_V2_X1_5_Weights.IMAGENET1K_V1,
  285. )
  286. )
  287. def shufflenet_v2_x1_5(
  288. *,
  289. weights: Optional[Union[ShuffleNet_V2_X1_5_QuantizedWeights, ShuffleNet_V2_X1_5_Weights]] = None,
  290. progress: bool = True,
  291. quantize: bool = False,
  292. **kwargs: Any,
  293. ) -> QuantizableShuffleNetV2:
  294. """
  295. Constructs a ShuffleNetV2 with 1.5x output channels, as described in
  296. `ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design
  297. <https://arxiv.org/abs/1807.11164>`__.
  298. .. note::
  299. Note that ``quantize = True`` returns a quantized model with 8 bit
  300. weights. Quantized models only support inference and run on CPUs.
  301. GPU inference is not yet supported.
  302. Args:
  303. weights (:class:`~torchvision.models.quantization.ShuffleNet_V2_X1_5_QuantizedWeights` or :class:`~torchvision.models.ShuffleNet_V2_X1_5_Weights`, optional): The
  304. pretrained weights for the model. See
  305. :class:`~torchvision.models.quantization.ShuffleNet_V2_X1_5_QuantizedWeights` below for
  306. more details, and possible values. By default, no pre-trained
  307. weights are used.
  308. progress (bool, optional): If True, displays a progress bar of the download to stderr.
  309. Default is True.
  310. quantize (bool, optional): If True, return a quantized version of the model.
  311. Default is False.
  312. **kwargs: parameters passed to the ``torchvision.models.quantization.ShuffleNet_V2_X1_5_QuantizedWeights``
  313. base class. Please refer to the `source code
  314. <https://github.com/pytorch/vision/blob/main/torchvision/models/quantization/shufflenetv2.py>`_
  315. for more details about this class.
  316. .. autoclass:: torchvision.models.quantization.ShuffleNet_V2_X1_5_QuantizedWeights
  317. :members:
  318. .. autoclass:: torchvision.models.ShuffleNet_V2_X1_5_Weights
  319. :members:
  320. :noindex:
  321. """
  322. weights = (ShuffleNet_V2_X1_5_QuantizedWeights if quantize else ShuffleNet_V2_X1_5_Weights).verify(weights)
  323. return _shufflenetv2(
  324. [4, 8, 4], [24, 176, 352, 704, 1024], weights=weights, progress=progress, quantize=quantize, **kwargs
  325. )
  326. @register_model(name="quantized_shufflenet_v2_x2_0")
  327. @handle_legacy_interface(
  328. weights=(
  329. "pretrained",
  330. lambda kwargs: ShuffleNet_V2_X2_0_QuantizedWeights.IMAGENET1K_FBGEMM_V1
  331. if kwargs.get("quantize", False)
  332. else ShuffleNet_V2_X2_0_Weights.IMAGENET1K_V1,
  333. )
  334. )
  335. def shufflenet_v2_x2_0(
  336. *,
  337. weights: Optional[Union[ShuffleNet_V2_X2_0_QuantizedWeights, ShuffleNet_V2_X2_0_Weights]] = None,
  338. progress: bool = True,
  339. quantize: bool = False,
  340. **kwargs: Any,
  341. ) -> QuantizableShuffleNetV2:
  342. """
  343. Constructs a ShuffleNetV2 with 2.0x output channels, as described in
  344. `ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design
  345. <https://arxiv.org/abs/1807.11164>`__.
  346. .. note::
  347. Note that ``quantize = True`` returns a quantized model with 8 bit
  348. weights. Quantized models only support inference and run on CPUs.
  349. GPU inference is not yet supported.
  350. Args:
  351. weights (:class:`~torchvision.models.quantization.ShuffleNet_V2_X2_0_QuantizedWeights` or :class:`~torchvision.models.ShuffleNet_V2_X2_0_Weights`, optional): The
  352. pretrained weights for the model. See
  353. :class:`~torchvision.models.quantization.ShuffleNet_V2_X2_0_QuantizedWeights` below for
  354. more details, and possible values. By default, no pre-trained
  355. weights are used.
  356. progress (bool, optional): If True, displays a progress bar of the download to stderr.
  357. Default is True.
  358. quantize (bool, optional): If True, return a quantized version of the model.
  359. Default is False.
  360. **kwargs: parameters passed to the ``torchvision.models.quantization.ShuffleNet_V2_X2_0_QuantizedWeights``
  361. base class. Please refer to the `source code
  362. <https://github.com/pytorch/vision/blob/main/torchvision/models/quantization/shufflenetv2.py>`_
  363. for more details about this class.
  364. .. autoclass:: torchvision.models.quantization.ShuffleNet_V2_X2_0_QuantizedWeights
  365. :members:
  366. .. autoclass:: torchvision.models.ShuffleNet_V2_X2_0_Weights
  367. :members:
  368. :noindex:
  369. """
  370. weights = (ShuffleNet_V2_X2_0_QuantizedWeights if quantize else ShuffleNet_V2_X2_0_Weights).verify(weights)
  371. return _shufflenetv2(
  372. [4, 8, 4], [24, 244, 488, 976, 2048], weights=weights, progress=progress, quantize=quantize, **kwargs
  373. )