googlenet.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210
  1. import warnings
  2. from functools import partial
  3. from typing import Any, Optional, Union
  4. import torch
  5. import torch.nn as nn
  6. from torch import Tensor
  7. from torch.nn import functional as F
  8. from ...transforms._presets import ImageClassification
  9. from .._api import register_model, Weights, WeightsEnum
  10. from .._meta import _IMAGENET_CATEGORIES
  11. from .._utils import _ovewrite_named_param, handle_legacy_interface
  12. from ..googlenet import BasicConv2d, GoogLeNet, GoogLeNet_Weights, GoogLeNetOutputs, Inception, InceptionAux
  13. from .utils import _fuse_modules, _replace_relu, quantize_model
  14. __all__ = [
  15. "QuantizableGoogLeNet",
  16. "GoogLeNet_QuantizedWeights",
  17. "googlenet",
  18. ]
  19. class QuantizableBasicConv2d(BasicConv2d):
  20. def __init__(self, *args: Any, **kwargs: Any) -> None:
  21. super().__init__(*args, **kwargs)
  22. self.relu = nn.ReLU()
  23. def forward(self, x: Tensor) -> Tensor:
  24. x = self.conv(x)
  25. x = self.bn(x)
  26. x = self.relu(x)
  27. return x
  28. def fuse_model(self, is_qat: Optional[bool] = None) -> None:
  29. _fuse_modules(self, ["conv", "bn", "relu"], is_qat, inplace=True)
  30. class QuantizableInception(Inception):
  31. def __init__(self, *args: Any, **kwargs: Any) -> None:
  32. super().__init__(*args, conv_block=QuantizableBasicConv2d, **kwargs) # type: ignore[misc]
  33. self.cat = nn.quantized.FloatFunctional()
  34. def forward(self, x: Tensor) -> Tensor:
  35. outputs = self._forward(x)
  36. return self.cat.cat(outputs, 1)
  37. class QuantizableInceptionAux(InceptionAux):
  38. # TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659
  39. def __init__(self, *args: Any, **kwargs: Any) -> None:
  40. super().__init__(*args, conv_block=QuantizableBasicConv2d, **kwargs) # type: ignore[misc]
  41. self.relu = nn.ReLU()
  42. def forward(self, x: Tensor) -> Tensor:
  43. # aux1: N x 512 x 14 x 14, aux2: N x 528 x 14 x 14
  44. x = F.adaptive_avg_pool2d(x, (4, 4))
  45. # aux1: N x 512 x 4 x 4, aux2: N x 528 x 4 x 4
  46. x = self.conv(x)
  47. # N x 128 x 4 x 4
  48. x = torch.flatten(x, 1)
  49. # N x 2048
  50. x = self.relu(self.fc1(x))
  51. # N x 1024
  52. x = self.dropout(x)
  53. # N x 1024
  54. x = self.fc2(x)
  55. # N x 1000 (num_classes)
  56. return x
  57. class QuantizableGoogLeNet(GoogLeNet):
  58. # TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659
  59. def __init__(self, *args: Any, **kwargs: Any) -> None:
  60. super().__init__( # type: ignore[misc]
  61. *args, blocks=[QuantizableBasicConv2d, QuantizableInception, QuantizableInceptionAux], **kwargs
  62. )
  63. self.quant = torch.ao.quantization.QuantStub()
  64. self.dequant = torch.ao.quantization.DeQuantStub()
  65. def forward(self, x: Tensor) -> GoogLeNetOutputs:
  66. x = self._transform_input(x)
  67. x = self.quant(x)
  68. x, aux1, aux2 = self._forward(x)
  69. x = self.dequant(x)
  70. aux_defined = self.training and self.aux_logits
  71. if torch.jit.is_scripting():
  72. if not aux_defined:
  73. warnings.warn("Scripted QuantizableGoogleNet always returns GoogleNetOutputs Tuple")
  74. return GoogLeNetOutputs(x, aux2, aux1)
  75. else:
  76. return self.eager_outputs(x, aux2, aux1)
  77. def fuse_model(self, is_qat: Optional[bool] = None) -> None:
  78. r"""Fuse conv/bn/relu modules in googlenet model
  79. Fuse conv+bn+relu/ conv+relu/conv+bn modules to prepare for quantization.
  80. Model is modified in place. Note that this operation does not change numerics
  81. and the model after modification is in floating point
  82. """
  83. for m in self.modules():
  84. if type(m) is QuantizableBasicConv2d:
  85. m.fuse_model(is_qat)
  86. class GoogLeNet_QuantizedWeights(WeightsEnum):
  87. IMAGENET1K_FBGEMM_V1 = Weights(
  88. url="https://download.pytorch.org/models/quantized/googlenet_fbgemm-c81f6644.pth",
  89. transforms=partial(ImageClassification, crop_size=224),
  90. meta={
  91. "num_params": 6624904,
  92. "min_size": (15, 15),
  93. "categories": _IMAGENET_CATEGORIES,
  94. "backend": "fbgemm",
  95. "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#post-training-quantized-models",
  96. "unquantized": GoogLeNet_Weights.IMAGENET1K_V1,
  97. "_metrics": {
  98. "ImageNet-1K": {
  99. "acc@1": 69.826,
  100. "acc@5": 89.404,
  101. }
  102. },
  103. "_ops": 1.498,
  104. "_file_size": 12.618,
  105. "_docs": """
  106. These weights were produced by doing Post Training Quantization (eager mode) on top of the unquantized
  107. weights listed below.
  108. """,
  109. },
  110. )
  111. DEFAULT = IMAGENET1K_FBGEMM_V1
  112. @register_model(name="quantized_googlenet")
  113. @handle_legacy_interface(
  114. weights=(
  115. "pretrained",
  116. lambda kwargs: GoogLeNet_QuantizedWeights.IMAGENET1K_FBGEMM_V1
  117. if kwargs.get("quantize", False)
  118. else GoogLeNet_Weights.IMAGENET1K_V1,
  119. )
  120. )
  121. def googlenet(
  122. *,
  123. weights: Optional[Union[GoogLeNet_QuantizedWeights, GoogLeNet_Weights]] = None,
  124. progress: bool = True,
  125. quantize: bool = False,
  126. **kwargs: Any,
  127. ) -> QuantizableGoogLeNet:
  128. """GoogLeNet (Inception v1) model architecture from `Going Deeper with Convolutions <http://arxiv.org/abs/1409.4842>`__.
  129. .. note::
  130. Note that ``quantize = True`` returns a quantized model with 8 bit
  131. weights. Quantized models only support inference and run on CPUs.
  132. GPU inference is not yet supported.
  133. Args:
  134. weights (:class:`~torchvision.models.quantization.GoogLeNet_QuantizedWeights` or :class:`~torchvision.models.GoogLeNet_Weights`, optional): The
  135. pretrained weights for the model. See
  136. :class:`~torchvision.models.quantization.GoogLeNet_QuantizedWeights` below for
  137. more details, and possible values. By default, no pre-trained
  138. weights are used.
  139. progress (bool, optional): If True, displays a progress bar of the
  140. download to stderr. Default is True.
  141. quantize (bool, optional): If True, return a quantized version of the model. Default is False.
  142. **kwargs: parameters passed to the ``torchvision.models.quantization.QuantizableGoogLeNet``
  143. base class. Please refer to the `source code
  144. <https://github.com/pytorch/vision/blob/main/torchvision/models/quantization/googlenet.py>`_
  145. for more details about this class.
  146. .. autoclass:: torchvision.models.quantization.GoogLeNet_QuantizedWeights
  147. :members:
  148. .. autoclass:: torchvision.models.GoogLeNet_Weights
  149. :members:
  150. :noindex:
  151. """
  152. weights = (GoogLeNet_QuantizedWeights if quantize else GoogLeNet_Weights).verify(weights)
  153. original_aux_logits = kwargs.get("aux_logits", False)
  154. if weights is not None:
  155. if "transform_input" not in kwargs:
  156. _ovewrite_named_param(kwargs, "transform_input", True)
  157. _ovewrite_named_param(kwargs, "aux_logits", True)
  158. _ovewrite_named_param(kwargs, "init_weights", False)
  159. _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
  160. if "backend" in weights.meta:
  161. _ovewrite_named_param(kwargs, "backend", weights.meta["backend"])
  162. backend = kwargs.pop("backend", "fbgemm")
  163. model = QuantizableGoogLeNet(**kwargs)
  164. _replace_relu(model)
  165. if quantize:
  166. quantize_model(model, backend)
  167. if weights is not None:
  168. model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
  169. if not original_aux_logits:
  170. model.aux_logits = False
  171. model.aux1 = None # type: ignore[assignment]
  172. model.aux2 = None # type: ignore[assignment]
  173. else:
  174. warnings.warn(
  175. "auxiliary heads in the pretrained googlenet model are NOT pretrained, so make sure to train them"
  176. )
  177. return model