resnet.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484
  1. from functools import partial
  2. from typing import Any, List, Optional, Type, Union
  3. import torch
  4. import torch.nn as nn
  5. from torch import Tensor
  6. from torchvision.models.resnet import (
  7. BasicBlock,
  8. Bottleneck,
  9. ResNet,
  10. ResNet18_Weights,
  11. ResNet50_Weights,
  12. ResNeXt101_32X8D_Weights,
  13. ResNeXt101_64X4D_Weights,
  14. )
  15. from ...transforms._presets import ImageClassification
  16. from .._api import register_model, Weights, WeightsEnum
  17. from .._meta import _IMAGENET_CATEGORIES
  18. from .._utils import _ovewrite_named_param, handle_legacy_interface
  19. from .utils import _fuse_modules, _replace_relu, quantize_model
  20. __all__ = [
  21. "QuantizableResNet",
  22. "ResNet18_QuantizedWeights",
  23. "ResNet50_QuantizedWeights",
  24. "ResNeXt101_32X8D_QuantizedWeights",
  25. "ResNeXt101_64X4D_QuantizedWeights",
  26. "resnet18",
  27. "resnet50",
  28. "resnext101_32x8d",
  29. "resnext101_64x4d",
  30. ]
  31. class QuantizableBasicBlock(BasicBlock):
  32. def __init__(self, *args: Any, **kwargs: Any) -> None:
  33. super().__init__(*args, **kwargs)
  34. self.add_relu = torch.nn.quantized.FloatFunctional()
  35. def forward(self, x: Tensor) -> Tensor:
  36. identity = x
  37. out = self.conv1(x)
  38. out = self.bn1(out)
  39. out = self.relu(out)
  40. out = self.conv2(out)
  41. out = self.bn2(out)
  42. if self.downsample is not None:
  43. identity = self.downsample(x)
  44. out = self.add_relu.add_relu(out, identity)
  45. return out
  46. def fuse_model(self, is_qat: Optional[bool] = None) -> None:
  47. _fuse_modules(self, [["conv1", "bn1", "relu"], ["conv2", "bn2"]], is_qat, inplace=True)
  48. if self.downsample:
  49. _fuse_modules(self.downsample, ["0", "1"], is_qat, inplace=True)
  50. class QuantizableBottleneck(Bottleneck):
  51. def __init__(self, *args: Any, **kwargs: Any) -> None:
  52. super().__init__(*args, **kwargs)
  53. self.skip_add_relu = nn.quantized.FloatFunctional()
  54. self.relu1 = nn.ReLU(inplace=False)
  55. self.relu2 = nn.ReLU(inplace=False)
  56. def forward(self, x: Tensor) -> Tensor:
  57. identity = x
  58. out = self.conv1(x)
  59. out = self.bn1(out)
  60. out = self.relu1(out)
  61. out = self.conv2(out)
  62. out = self.bn2(out)
  63. out = self.relu2(out)
  64. out = self.conv3(out)
  65. out = self.bn3(out)
  66. if self.downsample is not None:
  67. identity = self.downsample(x)
  68. out = self.skip_add_relu.add_relu(out, identity)
  69. return out
  70. def fuse_model(self, is_qat: Optional[bool] = None) -> None:
  71. _fuse_modules(
  72. self, [["conv1", "bn1", "relu1"], ["conv2", "bn2", "relu2"], ["conv3", "bn3"]], is_qat, inplace=True
  73. )
  74. if self.downsample:
  75. _fuse_modules(self.downsample, ["0", "1"], is_qat, inplace=True)
  76. class QuantizableResNet(ResNet):
  77. def __init__(self, *args: Any, **kwargs: Any) -> None:
  78. super().__init__(*args, **kwargs)
  79. self.quant = torch.ao.quantization.QuantStub()
  80. self.dequant = torch.ao.quantization.DeQuantStub()
  81. def forward(self, x: Tensor) -> Tensor:
  82. x = self.quant(x)
  83. # Ensure scriptability
  84. # super(QuantizableResNet,self).forward(x)
  85. # is not scriptable
  86. x = self._forward_impl(x)
  87. x = self.dequant(x)
  88. return x
  89. def fuse_model(self, is_qat: Optional[bool] = None) -> None:
  90. r"""Fuse conv/bn/relu modules in resnet models
  91. Fuse conv+bn+relu/ Conv+relu/conv+Bn modules to prepare for quantization.
  92. Model is modified in place. Note that this operation does not change numerics
  93. and the model after modification is in floating point
  94. """
  95. _fuse_modules(self, ["conv1", "bn1", "relu"], is_qat, inplace=True)
  96. for m in self.modules():
  97. if type(m) is QuantizableBottleneck or type(m) is QuantizableBasicBlock:
  98. m.fuse_model(is_qat)
  99. def _resnet(
  100. block: Type[Union[QuantizableBasicBlock, QuantizableBottleneck]],
  101. layers: List[int],
  102. weights: Optional[WeightsEnum],
  103. progress: bool,
  104. quantize: bool,
  105. **kwargs: Any,
  106. ) -> QuantizableResNet:
  107. if weights is not None:
  108. _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
  109. if "backend" in weights.meta:
  110. _ovewrite_named_param(kwargs, "backend", weights.meta["backend"])
  111. backend = kwargs.pop("backend", "fbgemm")
  112. model = QuantizableResNet(block, layers, **kwargs)
  113. _replace_relu(model)
  114. if quantize:
  115. quantize_model(model, backend)
  116. if weights is not None:
  117. model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
  118. return model
  119. _COMMON_META = {
  120. "min_size": (1, 1),
  121. "categories": _IMAGENET_CATEGORIES,
  122. "backend": "fbgemm",
  123. "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#post-training-quantized-models",
  124. "_docs": """
  125. These weights were produced by doing Post Training Quantization (eager mode) on top of the unquantized
  126. weights listed below.
  127. """,
  128. }
  129. class ResNet18_QuantizedWeights(WeightsEnum):
  130. IMAGENET1K_FBGEMM_V1 = Weights(
  131. url="https://download.pytorch.org/models/quantized/resnet18_fbgemm_16fa66dd.pth",
  132. transforms=partial(ImageClassification, crop_size=224),
  133. meta={
  134. **_COMMON_META,
  135. "num_params": 11689512,
  136. "unquantized": ResNet18_Weights.IMAGENET1K_V1,
  137. "_metrics": {
  138. "ImageNet-1K": {
  139. "acc@1": 69.494,
  140. "acc@5": 88.882,
  141. }
  142. },
  143. "_ops": 1.814,
  144. "_file_size": 11.238,
  145. },
  146. )
  147. DEFAULT = IMAGENET1K_FBGEMM_V1
  148. class ResNet50_QuantizedWeights(WeightsEnum):
  149. IMAGENET1K_FBGEMM_V1 = Weights(
  150. url="https://download.pytorch.org/models/quantized/resnet50_fbgemm_bf931d71.pth",
  151. transforms=partial(ImageClassification, crop_size=224),
  152. meta={
  153. **_COMMON_META,
  154. "num_params": 25557032,
  155. "unquantized": ResNet50_Weights.IMAGENET1K_V1,
  156. "_metrics": {
  157. "ImageNet-1K": {
  158. "acc@1": 75.920,
  159. "acc@5": 92.814,
  160. }
  161. },
  162. "_ops": 4.089,
  163. "_file_size": 24.759,
  164. },
  165. )
  166. IMAGENET1K_FBGEMM_V2 = Weights(
  167. url="https://download.pytorch.org/models/quantized/resnet50_fbgemm-23753f79.pth",
  168. transforms=partial(ImageClassification, crop_size=224, resize_size=232),
  169. meta={
  170. **_COMMON_META,
  171. "num_params": 25557032,
  172. "unquantized": ResNet50_Weights.IMAGENET1K_V2,
  173. "_metrics": {
  174. "ImageNet-1K": {
  175. "acc@1": 80.282,
  176. "acc@5": 94.976,
  177. }
  178. },
  179. "_ops": 4.089,
  180. "_file_size": 24.953,
  181. },
  182. )
  183. DEFAULT = IMAGENET1K_FBGEMM_V2
  184. class ResNeXt101_32X8D_QuantizedWeights(WeightsEnum):
  185. IMAGENET1K_FBGEMM_V1 = Weights(
  186. url="https://download.pytorch.org/models/quantized/resnext101_32x8_fbgemm_09835ccf.pth",
  187. transforms=partial(ImageClassification, crop_size=224),
  188. meta={
  189. **_COMMON_META,
  190. "num_params": 88791336,
  191. "unquantized": ResNeXt101_32X8D_Weights.IMAGENET1K_V1,
  192. "_metrics": {
  193. "ImageNet-1K": {
  194. "acc@1": 78.986,
  195. "acc@5": 94.480,
  196. }
  197. },
  198. "_ops": 16.414,
  199. "_file_size": 86.034,
  200. },
  201. )
  202. IMAGENET1K_FBGEMM_V2 = Weights(
  203. url="https://download.pytorch.org/models/quantized/resnext101_32x8_fbgemm-ee16d00c.pth",
  204. transforms=partial(ImageClassification, crop_size=224, resize_size=232),
  205. meta={
  206. **_COMMON_META,
  207. "num_params": 88791336,
  208. "unquantized": ResNeXt101_32X8D_Weights.IMAGENET1K_V2,
  209. "_metrics": {
  210. "ImageNet-1K": {
  211. "acc@1": 82.574,
  212. "acc@5": 96.132,
  213. }
  214. },
  215. "_ops": 16.414,
  216. "_file_size": 86.645,
  217. },
  218. )
  219. DEFAULT = IMAGENET1K_FBGEMM_V2
  220. class ResNeXt101_64X4D_QuantizedWeights(WeightsEnum):
  221. IMAGENET1K_FBGEMM_V1 = Weights(
  222. url="https://download.pytorch.org/models/quantized/resnext101_64x4d_fbgemm-605a1cb3.pth",
  223. transforms=partial(ImageClassification, crop_size=224, resize_size=232),
  224. meta={
  225. **_COMMON_META,
  226. "num_params": 83455272,
  227. "recipe": "https://github.com/pytorch/vision/pull/5935",
  228. "unquantized": ResNeXt101_64X4D_Weights.IMAGENET1K_V1,
  229. "_metrics": {
  230. "ImageNet-1K": {
  231. "acc@1": 82.898,
  232. "acc@5": 96.326,
  233. }
  234. },
  235. "_ops": 15.46,
  236. "_file_size": 81.556,
  237. },
  238. )
  239. DEFAULT = IMAGENET1K_FBGEMM_V1
  240. @register_model(name="quantized_resnet18")
  241. @handle_legacy_interface(
  242. weights=(
  243. "pretrained",
  244. lambda kwargs: ResNet18_QuantizedWeights.IMAGENET1K_FBGEMM_V1
  245. if kwargs.get("quantize", False)
  246. else ResNet18_Weights.IMAGENET1K_V1,
  247. )
  248. )
  249. def resnet18(
  250. *,
  251. weights: Optional[Union[ResNet18_QuantizedWeights, ResNet18_Weights]] = None,
  252. progress: bool = True,
  253. quantize: bool = False,
  254. **kwargs: Any,
  255. ) -> QuantizableResNet:
  256. """ResNet-18 model from
  257. `Deep Residual Learning for Image Recognition <https://arxiv.org/abs/1512.03385>`_
  258. .. note::
  259. Note that ``quantize = True`` returns a quantized model with 8 bit
  260. weights. Quantized models only support inference and run on CPUs.
  261. GPU inference is not yet supported.
  262. Args:
  263. weights (:class:`~torchvision.models.quantization.ResNet18_QuantizedWeights` or :class:`~torchvision.models.ResNet18_Weights`, optional): The
  264. pretrained weights for the model. See
  265. :class:`~torchvision.models.quantization.ResNet18_QuantizedWeights` below for
  266. more details, and possible values. By default, no pre-trained
  267. weights are used.
  268. progress (bool, optional): If True, displays a progress bar of the
  269. download to stderr. Default is True.
  270. quantize (bool, optional): If True, return a quantized version of the model. Default is False.
  271. **kwargs: parameters passed to the ``torchvision.models.quantization.QuantizableResNet``
  272. base class. Please refer to the `source code
  273. <https://github.com/pytorch/vision/blob/main/torchvision/models/quantization/resnet.py>`_
  274. for more details about this class.
  275. .. autoclass:: torchvision.models.quantization.ResNet18_QuantizedWeights
  276. :members:
  277. .. autoclass:: torchvision.models.ResNet18_Weights
  278. :members:
  279. :noindex:
  280. """
  281. weights = (ResNet18_QuantizedWeights if quantize else ResNet18_Weights).verify(weights)
  282. return _resnet(QuantizableBasicBlock, [2, 2, 2, 2], weights, progress, quantize, **kwargs)
  283. @register_model(name="quantized_resnet50")
  284. @handle_legacy_interface(
  285. weights=(
  286. "pretrained",
  287. lambda kwargs: ResNet50_QuantizedWeights.IMAGENET1K_FBGEMM_V1
  288. if kwargs.get("quantize", False)
  289. else ResNet50_Weights.IMAGENET1K_V1,
  290. )
  291. )
  292. def resnet50(
  293. *,
  294. weights: Optional[Union[ResNet50_QuantizedWeights, ResNet50_Weights]] = None,
  295. progress: bool = True,
  296. quantize: bool = False,
  297. **kwargs: Any,
  298. ) -> QuantizableResNet:
  299. """ResNet-50 model from
  300. `Deep Residual Learning for Image Recognition <https://arxiv.org/abs/1512.03385>`_
  301. .. note::
  302. Note that ``quantize = True`` returns a quantized model with 8 bit
  303. weights. Quantized models only support inference and run on CPUs.
  304. GPU inference is not yet supported.
  305. Args:
  306. weights (:class:`~torchvision.models.quantization.ResNet50_QuantizedWeights` or :class:`~torchvision.models.ResNet50_Weights`, optional): The
  307. pretrained weights for the model. See
  308. :class:`~torchvision.models.quantization.ResNet50_QuantizedWeights` below for
  309. more details, and possible values. By default, no pre-trained
  310. weights are used.
  311. progress (bool, optional): If True, displays a progress bar of the
  312. download to stderr. Default is True.
  313. quantize (bool, optional): If True, return a quantized version of the model. Default is False.
  314. **kwargs: parameters passed to the ``torchvision.models.quantization.QuantizableResNet``
  315. base class. Please refer to the `source code
  316. <https://github.com/pytorch/vision/blob/main/torchvision/models/quantization/resnet.py>`_
  317. for more details about this class.
  318. .. autoclass:: torchvision.models.quantization.ResNet50_QuantizedWeights
  319. :members:
  320. .. autoclass:: torchvision.models.ResNet50_Weights
  321. :members:
  322. :noindex:
  323. """
  324. weights = (ResNet50_QuantizedWeights if quantize else ResNet50_Weights).verify(weights)
  325. return _resnet(QuantizableBottleneck, [3, 4, 6, 3], weights, progress, quantize, **kwargs)
  326. @register_model(name="quantized_resnext101_32x8d")
  327. @handle_legacy_interface(
  328. weights=(
  329. "pretrained",
  330. lambda kwargs: ResNeXt101_32X8D_QuantizedWeights.IMAGENET1K_FBGEMM_V1
  331. if kwargs.get("quantize", False)
  332. else ResNeXt101_32X8D_Weights.IMAGENET1K_V1,
  333. )
  334. )
  335. def resnext101_32x8d(
  336. *,
  337. weights: Optional[Union[ResNeXt101_32X8D_QuantizedWeights, ResNeXt101_32X8D_Weights]] = None,
  338. progress: bool = True,
  339. quantize: bool = False,
  340. **kwargs: Any,
  341. ) -> QuantizableResNet:
  342. """ResNeXt-101 32x8d model from
  343. `Aggregated Residual Transformation for Deep Neural Networks <https://arxiv.org/abs/1611.05431>`_
  344. .. note::
  345. Note that ``quantize = True`` returns a quantized model with 8 bit
  346. weights. Quantized models only support inference and run on CPUs.
  347. GPU inference is not yet supported.
  348. Args:
  349. weights (:class:`~torchvision.models.quantization.ResNeXt101_32X8D_QuantizedWeights` or :class:`~torchvision.models.ResNeXt101_32X8D_Weights`, optional): The
  350. pretrained weights for the model. See
  351. :class:`~torchvision.models.quantization.ResNet101_32X8D_QuantizedWeights` below for
  352. more details, and possible values. By default, no pre-trained
  353. weights are used.
  354. progress (bool, optional): If True, displays a progress bar of the
  355. download to stderr. Default is True.
  356. quantize (bool, optional): If True, return a quantized version of the model. Default is False.
  357. **kwargs: parameters passed to the ``torchvision.models.quantization.QuantizableResNet``
  358. base class. Please refer to the `source code
  359. <https://github.com/pytorch/vision/blob/main/torchvision/models/quantization/resnet.py>`_
  360. for more details about this class.
  361. .. autoclass:: torchvision.models.quantization.ResNeXt101_32X8D_QuantizedWeights
  362. :members:
  363. .. autoclass:: torchvision.models.ResNeXt101_32X8D_Weights
  364. :members:
  365. :noindex:
  366. """
  367. weights = (ResNeXt101_32X8D_QuantizedWeights if quantize else ResNeXt101_32X8D_Weights).verify(weights)
  368. _ovewrite_named_param(kwargs, "groups", 32)
  369. _ovewrite_named_param(kwargs, "width_per_group", 8)
  370. return _resnet(QuantizableBottleneck, [3, 4, 23, 3], weights, progress, quantize, **kwargs)
  371. @register_model(name="quantized_resnext101_64x4d")
  372. @handle_legacy_interface(
  373. weights=(
  374. "pretrained",
  375. lambda kwargs: ResNeXt101_64X4D_QuantizedWeights.IMAGENET1K_FBGEMM_V1
  376. if kwargs.get("quantize", False)
  377. else ResNeXt101_64X4D_Weights.IMAGENET1K_V1,
  378. )
  379. )
  380. def resnext101_64x4d(
  381. *,
  382. weights: Optional[Union[ResNeXt101_64X4D_QuantizedWeights, ResNeXt101_64X4D_Weights]] = None,
  383. progress: bool = True,
  384. quantize: bool = False,
  385. **kwargs: Any,
  386. ) -> QuantizableResNet:
  387. """ResNeXt-101 64x4d model from
  388. `Aggregated Residual Transformation for Deep Neural Networks <https://arxiv.org/abs/1611.05431>`_
  389. .. note::
  390. Note that ``quantize = True`` returns a quantized model with 8 bit
  391. weights. Quantized models only support inference and run on CPUs.
  392. GPU inference is not yet supported.
  393. Args:
  394. weights (:class:`~torchvision.models.quantization.ResNeXt101_64X4D_QuantizedWeights` or :class:`~torchvision.models.ResNeXt101_64X4D_Weights`, optional): The
  395. pretrained weights for the model. See
  396. :class:`~torchvision.models.quantization.ResNet101_64X4D_QuantizedWeights` below for
  397. more details, and possible values. By default, no pre-trained
  398. weights are used.
  399. progress (bool, optional): If True, displays a progress bar of the
  400. download to stderr. Default is True.
  401. quantize (bool, optional): If True, return a quantized version of the model. Default is False.
  402. **kwargs: parameters passed to the ``torchvision.models.quantization.QuantizableResNet``
  403. base class. Please refer to the `source code
  404. <https://github.com/pytorch/vision/blob/main/torchvision/models/quantization/resnet.py>`_
  405. for more details about this class.
  406. .. autoclass:: torchvision.models.quantization.ResNeXt101_64X4D_QuantizedWeights
  407. :members:
  408. .. autoclass:: torchvision.models.ResNeXt101_64X4D_Weights
  409. :members:
  410. :noindex:
  411. """
  412. weights = (ResNeXt101_64X4D_QuantizedWeights if quantize else ResNeXt101_64X4D_Weights).verify(weights)
  413. _ovewrite_named_param(kwargs, "groups", 64)
  414. _ovewrite_named_param(kwargs, "width_per_group", 4)
  415. return _resnet(QuantizableBottleneck, [3, 4, 23, 3], weights, progress, quantize, **kwargs)