mobilenetv3.py 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237
  1. from functools import partial
  2. from typing import Any, List, Optional, Union
  3. import torch
  4. from torch import nn, Tensor
  5. from torch.ao.quantization import DeQuantStub, QuantStub
  6. from ...ops.misc import Conv2dNormActivation, SqueezeExcitation
  7. from ...transforms._presets import ImageClassification
  8. from .._api import register_model, Weights, WeightsEnum
  9. from .._meta import _IMAGENET_CATEGORIES
  10. from .._utils import _ovewrite_named_param, handle_legacy_interface
  11. from ..mobilenetv3 import (
  12. _mobilenet_v3_conf,
  13. InvertedResidual,
  14. InvertedResidualConfig,
  15. MobileNet_V3_Large_Weights,
  16. MobileNetV3,
  17. )
  18. from .utils import _fuse_modules, _replace_relu
  19. __all__ = [
  20. "QuantizableMobileNetV3",
  21. "MobileNet_V3_Large_QuantizedWeights",
  22. "mobilenet_v3_large",
  23. ]
  24. class QuantizableSqueezeExcitation(SqueezeExcitation):
  25. _version = 2
  26. def __init__(self, *args: Any, **kwargs: Any) -> None:
  27. kwargs["scale_activation"] = nn.Hardsigmoid
  28. super().__init__(*args, **kwargs)
  29. self.skip_mul = nn.quantized.FloatFunctional()
  30. def forward(self, input: Tensor) -> Tensor:
  31. return self.skip_mul.mul(self._scale(input), input)
  32. def fuse_model(self, is_qat: Optional[bool] = None) -> None:
  33. _fuse_modules(self, ["fc1", "activation"], is_qat, inplace=True)
  34. def _load_from_state_dict(
  35. self,
  36. state_dict,
  37. prefix,
  38. local_metadata,
  39. strict,
  40. missing_keys,
  41. unexpected_keys,
  42. error_msgs,
  43. ):
  44. version = local_metadata.get("version", None)
  45. if hasattr(self, "qconfig") and (version is None or version < 2):
  46. default_state_dict = {
  47. "scale_activation.activation_post_process.scale": torch.tensor([1.0]),
  48. "scale_activation.activation_post_process.activation_post_process.scale": torch.tensor([1.0]),
  49. "scale_activation.activation_post_process.zero_point": torch.tensor([0], dtype=torch.int32),
  50. "scale_activation.activation_post_process.activation_post_process.zero_point": torch.tensor(
  51. [0], dtype=torch.int32
  52. ),
  53. "scale_activation.activation_post_process.fake_quant_enabled": torch.tensor([1]),
  54. "scale_activation.activation_post_process.observer_enabled": torch.tensor([1]),
  55. }
  56. for k, v in default_state_dict.items():
  57. full_key = prefix + k
  58. if full_key not in state_dict:
  59. state_dict[full_key] = v
  60. super()._load_from_state_dict(
  61. state_dict,
  62. prefix,
  63. local_metadata,
  64. strict,
  65. missing_keys,
  66. unexpected_keys,
  67. error_msgs,
  68. )
  69. class QuantizableInvertedResidual(InvertedResidual):
  70. # TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659
  71. def __init__(self, *args: Any, **kwargs: Any) -> None:
  72. super().__init__(*args, se_layer=QuantizableSqueezeExcitation, **kwargs) # type: ignore[misc]
  73. self.skip_add = nn.quantized.FloatFunctional()
  74. def forward(self, x: Tensor) -> Tensor:
  75. if self.use_res_connect:
  76. return self.skip_add.add(x, self.block(x))
  77. else:
  78. return self.block(x)
  79. class QuantizableMobileNetV3(MobileNetV3):
  80. def __init__(self, *args: Any, **kwargs: Any) -> None:
  81. """
  82. MobileNet V3 main class
  83. Args:
  84. Inherits args from floating point MobileNetV3
  85. """
  86. super().__init__(*args, **kwargs)
  87. self.quant = QuantStub()
  88. self.dequant = DeQuantStub()
  89. def forward(self, x: Tensor) -> Tensor:
  90. x = self.quant(x)
  91. x = self._forward_impl(x)
  92. x = self.dequant(x)
  93. return x
  94. def fuse_model(self, is_qat: Optional[bool] = None) -> None:
  95. for m in self.modules():
  96. if type(m) is Conv2dNormActivation:
  97. modules_to_fuse = ["0", "1"]
  98. if len(m) == 3 and type(m[2]) is nn.ReLU:
  99. modules_to_fuse.append("2")
  100. _fuse_modules(m, modules_to_fuse, is_qat, inplace=True)
  101. elif type(m) is QuantizableSqueezeExcitation:
  102. m.fuse_model(is_qat)
  103. def _mobilenet_v3_model(
  104. inverted_residual_setting: List[InvertedResidualConfig],
  105. last_channel: int,
  106. weights: Optional[WeightsEnum],
  107. progress: bool,
  108. quantize: bool,
  109. **kwargs: Any,
  110. ) -> QuantizableMobileNetV3:
  111. if weights is not None:
  112. _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
  113. if "backend" in weights.meta:
  114. _ovewrite_named_param(kwargs, "backend", weights.meta["backend"])
  115. backend = kwargs.pop("backend", "qnnpack")
  116. model = QuantizableMobileNetV3(inverted_residual_setting, last_channel, block=QuantizableInvertedResidual, **kwargs)
  117. _replace_relu(model)
  118. if quantize:
  119. # Instead of quantizing the model and then loading the quantized weights we take a different approach.
  120. # We prepare the QAT model, load the QAT weights from training and then convert it.
  121. # This is done to avoid extremely low accuracies observed on the specific model. This is rather a workaround
  122. # for an unresolved bug on the eager quantization API detailed at: https://github.com/pytorch/vision/issues/5890
  123. model.fuse_model(is_qat=True)
  124. model.qconfig = torch.ao.quantization.get_default_qat_qconfig(backend)
  125. torch.ao.quantization.prepare_qat(model, inplace=True)
  126. if weights is not None:
  127. model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
  128. if quantize:
  129. torch.ao.quantization.convert(model, inplace=True)
  130. model.eval()
  131. return model
  132. class MobileNet_V3_Large_QuantizedWeights(WeightsEnum):
  133. IMAGENET1K_QNNPACK_V1 = Weights(
  134. url="https://download.pytorch.org/models/quantized/mobilenet_v3_large_qnnpack-5bcacf28.pth",
  135. transforms=partial(ImageClassification, crop_size=224),
  136. meta={
  137. "num_params": 5483032,
  138. "min_size": (1, 1),
  139. "categories": _IMAGENET_CATEGORIES,
  140. "backend": "qnnpack",
  141. "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#qat-mobilenetv3",
  142. "unquantized": MobileNet_V3_Large_Weights.IMAGENET1K_V1,
  143. "_metrics": {
  144. "ImageNet-1K": {
  145. "acc@1": 73.004,
  146. "acc@5": 90.858,
  147. }
  148. },
  149. "_ops": 0.217,
  150. "_file_size": 21.554,
  151. "_docs": """
  152. These weights were produced by doing Quantization Aware Training (eager mode) on top of the unquantized
  153. weights listed below.
  154. """,
  155. },
  156. )
  157. DEFAULT = IMAGENET1K_QNNPACK_V1
  158. @register_model(name="quantized_mobilenet_v3_large")
  159. @handle_legacy_interface(
  160. weights=(
  161. "pretrained",
  162. lambda kwargs: MobileNet_V3_Large_QuantizedWeights.IMAGENET1K_QNNPACK_V1
  163. if kwargs.get("quantize", False)
  164. else MobileNet_V3_Large_Weights.IMAGENET1K_V1,
  165. )
  166. )
  167. def mobilenet_v3_large(
  168. *,
  169. weights: Optional[Union[MobileNet_V3_Large_QuantizedWeights, MobileNet_V3_Large_Weights]] = None,
  170. progress: bool = True,
  171. quantize: bool = False,
  172. **kwargs: Any,
  173. ) -> QuantizableMobileNetV3:
  174. """
  175. MobileNetV3 (Large) model from
  176. `Searching for MobileNetV3 <https://arxiv.org/abs/1905.02244>`_.
  177. .. note::
  178. Note that ``quantize = True`` returns a quantized model with 8 bit
  179. weights. Quantized models only support inference and run on CPUs.
  180. GPU inference is not yet supported.
  181. Args:
  182. weights (:class:`~torchvision.models.quantization.MobileNet_V3_Large_QuantizedWeights` or :class:`~torchvision.models.MobileNet_V3_Large_Weights`, optional): The
  183. pretrained weights for the model. See
  184. :class:`~torchvision.models.quantization.MobileNet_V3_Large_QuantizedWeights` below for
  185. more details, and possible values. By default, no pre-trained
  186. weights are used.
  187. progress (bool): If True, displays a progress bar of the
  188. download to stderr. Default is True.
  189. quantize (bool): If True, return a quantized version of the model. Default is False.
  190. **kwargs: parameters passed to the ``torchvision.models.quantization.MobileNet_V3_Large_QuantizedWeights``
  191. base class. Please refer to the `source code
  192. <https://github.com/pytorch/vision/blob/main/torchvision/models/quantization/mobilenetv3.py>`_
  193. for more details about this class.
  194. .. autoclass:: torchvision.models.quantization.MobileNet_V3_Large_QuantizedWeights
  195. :members:
  196. .. autoclass:: torchvision.models.MobileNet_V3_Large_Weights
  197. :members:
  198. :noindex:
  199. """
  200. weights = (MobileNet_V3_Large_QuantizedWeights if quantize else MobileNet_V3_Large_Weights).verify(weights)
  201. inverted_residual_setting, last_channel = _mobilenet_v3_conf("mobilenet_v3_large", **kwargs)
  202. return _mobilenet_v3_model(inverted_residual_setting, last_channel, weights, progress, quantize, **kwargs)