googlenet.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345
  1. import warnings
  2. from collections import namedtuple
  3. from functools import partial
  4. from typing import Any, Callable, List, Optional, Tuple
  5. import torch
  6. import torch.nn as nn
  7. import torch.nn.functional as F
  8. from torch import Tensor
  9. from ..transforms._presets import ImageClassification
  10. from ..utils import _log_api_usage_once
  11. from ._api import register_model, Weights, WeightsEnum
  12. from ._meta import _IMAGENET_CATEGORIES
  13. from ._utils import _ovewrite_named_param, handle_legacy_interface
  14. __all__ = ["GoogLeNet", "GoogLeNetOutputs", "_GoogLeNetOutputs", "GoogLeNet_Weights", "googlenet"]
  15. GoogLeNetOutputs = namedtuple("GoogLeNetOutputs", ["logits", "aux_logits2", "aux_logits1"])
  16. GoogLeNetOutputs.__annotations__ = {"logits": Tensor, "aux_logits2": Optional[Tensor], "aux_logits1": Optional[Tensor]}
  17. # Script annotations failed with _GoogleNetOutputs = namedtuple ...
  18. # _GoogLeNetOutputs set here for backwards compat
  19. _GoogLeNetOutputs = GoogLeNetOutputs
  20. class GoogLeNet(nn.Module):
  21. __constants__ = ["aux_logits", "transform_input"]
  22. def __init__(
  23. self,
  24. num_classes: int = 1000,
  25. aux_logits: bool = True,
  26. transform_input: bool = False,
  27. init_weights: Optional[bool] = None,
  28. blocks: Optional[List[Callable[..., nn.Module]]] = None,
  29. dropout: float = 0.2,
  30. dropout_aux: float = 0.7,
  31. ) -> None:
  32. super().__init__()
  33. _log_api_usage_once(self)
  34. if blocks is None:
  35. blocks = [BasicConv2d, Inception, InceptionAux]
  36. if init_weights is None:
  37. warnings.warn(
  38. "The default weight initialization of GoogleNet will be changed in future releases of "
  39. "torchvision. If you wish to keep the old behavior (which leads to long initialization times"
  40. " due to scipy/scipy#11299), please set init_weights=True.",
  41. FutureWarning,
  42. )
  43. init_weights = True
  44. if len(blocks) != 3:
  45. raise ValueError(f"blocks length should be 3 instead of {len(blocks)}")
  46. conv_block = blocks[0]
  47. inception_block = blocks[1]
  48. inception_aux_block = blocks[2]
  49. self.aux_logits = aux_logits
  50. self.transform_input = transform_input
  51. self.conv1 = conv_block(3, 64, kernel_size=7, stride=2, padding=3)
  52. self.maxpool1 = nn.MaxPool2d(3, stride=2, ceil_mode=True)
  53. self.conv2 = conv_block(64, 64, kernel_size=1)
  54. self.conv3 = conv_block(64, 192, kernel_size=3, padding=1)
  55. self.maxpool2 = nn.MaxPool2d(3, stride=2, ceil_mode=True)
  56. self.inception3a = inception_block(192, 64, 96, 128, 16, 32, 32)
  57. self.inception3b = inception_block(256, 128, 128, 192, 32, 96, 64)
  58. self.maxpool3 = nn.MaxPool2d(3, stride=2, ceil_mode=True)
  59. self.inception4a = inception_block(480, 192, 96, 208, 16, 48, 64)
  60. self.inception4b = inception_block(512, 160, 112, 224, 24, 64, 64)
  61. self.inception4c = inception_block(512, 128, 128, 256, 24, 64, 64)
  62. self.inception4d = inception_block(512, 112, 144, 288, 32, 64, 64)
  63. self.inception4e = inception_block(528, 256, 160, 320, 32, 128, 128)
  64. self.maxpool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
  65. self.inception5a = inception_block(832, 256, 160, 320, 32, 128, 128)
  66. self.inception5b = inception_block(832, 384, 192, 384, 48, 128, 128)
  67. if aux_logits:
  68. self.aux1 = inception_aux_block(512, num_classes, dropout=dropout_aux)
  69. self.aux2 = inception_aux_block(528, num_classes, dropout=dropout_aux)
  70. else:
  71. self.aux1 = None # type: ignore[assignment]
  72. self.aux2 = None # type: ignore[assignment]
  73. self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
  74. self.dropout = nn.Dropout(p=dropout)
  75. self.fc = nn.Linear(1024, num_classes)
  76. if init_weights:
  77. for m in self.modules():
  78. if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
  79. torch.nn.init.trunc_normal_(m.weight, mean=0.0, std=0.01, a=-2, b=2)
  80. elif isinstance(m, nn.BatchNorm2d):
  81. nn.init.constant_(m.weight, 1)
  82. nn.init.constant_(m.bias, 0)
  83. def _transform_input(self, x: Tensor) -> Tensor:
  84. if self.transform_input:
  85. x_ch0 = torch.unsqueeze(x[:, 0], 1) * (0.229 / 0.5) + (0.485 - 0.5) / 0.5
  86. x_ch1 = torch.unsqueeze(x[:, 1], 1) * (0.224 / 0.5) + (0.456 - 0.5) / 0.5
  87. x_ch2 = torch.unsqueeze(x[:, 2], 1) * (0.225 / 0.5) + (0.406 - 0.5) / 0.5
  88. x = torch.cat((x_ch0, x_ch1, x_ch2), 1)
  89. return x
  90. def _forward(self, x: Tensor) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
  91. # N x 3 x 224 x 224
  92. x = self.conv1(x)
  93. # N x 64 x 112 x 112
  94. x = self.maxpool1(x)
  95. # N x 64 x 56 x 56
  96. x = self.conv2(x)
  97. # N x 64 x 56 x 56
  98. x = self.conv3(x)
  99. # N x 192 x 56 x 56
  100. x = self.maxpool2(x)
  101. # N x 192 x 28 x 28
  102. x = self.inception3a(x)
  103. # N x 256 x 28 x 28
  104. x = self.inception3b(x)
  105. # N x 480 x 28 x 28
  106. x = self.maxpool3(x)
  107. # N x 480 x 14 x 14
  108. x = self.inception4a(x)
  109. # N x 512 x 14 x 14
  110. aux1: Optional[Tensor] = None
  111. if self.aux1 is not None:
  112. if self.training:
  113. aux1 = self.aux1(x)
  114. x = self.inception4b(x)
  115. # N x 512 x 14 x 14
  116. x = self.inception4c(x)
  117. # N x 512 x 14 x 14
  118. x = self.inception4d(x)
  119. # N x 528 x 14 x 14
  120. aux2: Optional[Tensor] = None
  121. if self.aux2 is not None:
  122. if self.training:
  123. aux2 = self.aux2(x)
  124. x = self.inception4e(x)
  125. # N x 832 x 14 x 14
  126. x = self.maxpool4(x)
  127. # N x 832 x 7 x 7
  128. x = self.inception5a(x)
  129. # N x 832 x 7 x 7
  130. x = self.inception5b(x)
  131. # N x 1024 x 7 x 7
  132. x = self.avgpool(x)
  133. # N x 1024 x 1 x 1
  134. x = torch.flatten(x, 1)
  135. # N x 1024
  136. x = self.dropout(x)
  137. x = self.fc(x)
  138. # N x 1000 (num_classes)
  139. return x, aux2, aux1
  140. @torch.jit.unused
  141. def eager_outputs(self, x: Tensor, aux2: Tensor, aux1: Optional[Tensor]) -> GoogLeNetOutputs:
  142. if self.training and self.aux_logits:
  143. return _GoogLeNetOutputs(x, aux2, aux1)
  144. else:
  145. return x # type: ignore[return-value]
  146. def forward(self, x: Tensor) -> GoogLeNetOutputs:
  147. x = self._transform_input(x)
  148. x, aux1, aux2 = self._forward(x)
  149. aux_defined = self.training and self.aux_logits
  150. if torch.jit.is_scripting():
  151. if not aux_defined:
  152. warnings.warn("Scripted GoogleNet always returns GoogleNetOutputs Tuple")
  153. return GoogLeNetOutputs(x, aux2, aux1)
  154. else:
  155. return self.eager_outputs(x, aux2, aux1)
  156. class Inception(nn.Module):
  157. def __init__(
  158. self,
  159. in_channels: int,
  160. ch1x1: int,
  161. ch3x3red: int,
  162. ch3x3: int,
  163. ch5x5red: int,
  164. ch5x5: int,
  165. pool_proj: int,
  166. conv_block: Optional[Callable[..., nn.Module]] = None,
  167. ) -> None:
  168. super().__init__()
  169. if conv_block is None:
  170. conv_block = BasicConv2d
  171. self.branch1 = conv_block(in_channels, ch1x1, kernel_size=1)
  172. self.branch2 = nn.Sequential(
  173. conv_block(in_channels, ch3x3red, kernel_size=1), conv_block(ch3x3red, ch3x3, kernel_size=3, padding=1)
  174. )
  175. self.branch3 = nn.Sequential(
  176. conv_block(in_channels, ch5x5red, kernel_size=1),
  177. # Here, kernel_size=3 instead of kernel_size=5 is a known bug.
  178. # Please see https://github.com/pytorch/vision/issues/906 for details.
  179. conv_block(ch5x5red, ch5x5, kernel_size=3, padding=1),
  180. )
  181. self.branch4 = nn.Sequential(
  182. nn.MaxPool2d(kernel_size=3, stride=1, padding=1, ceil_mode=True),
  183. conv_block(in_channels, pool_proj, kernel_size=1),
  184. )
  185. def _forward(self, x: Tensor) -> List[Tensor]:
  186. branch1 = self.branch1(x)
  187. branch2 = self.branch2(x)
  188. branch3 = self.branch3(x)
  189. branch4 = self.branch4(x)
  190. outputs = [branch1, branch2, branch3, branch4]
  191. return outputs
  192. def forward(self, x: Tensor) -> Tensor:
  193. outputs = self._forward(x)
  194. return torch.cat(outputs, 1)
  195. class InceptionAux(nn.Module):
  196. def __init__(
  197. self,
  198. in_channels: int,
  199. num_classes: int,
  200. conv_block: Optional[Callable[..., nn.Module]] = None,
  201. dropout: float = 0.7,
  202. ) -> None:
  203. super().__init__()
  204. if conv_block is None:
  205. conv_block = BasicConv2d
  206. self.conv = conv_block(in_channels, 128, kernel_size=1)
  207. self.fc1 = nn.Linear(2048, 1024)
  208. self.fc2 = nn.Linear(1024, num_classes)
  209. self.dropout = nn.Dropout(p=dropout)
  210. def forward(self, x: Tensor) -> Tensor:
  211. # aux1: N x 512 x 14 x 14, aux2: N x 528 x 14 x 14
  212. x = F.adaptive_avg_pool2d(x, (4, 4))
  213. # aux1: N x 512 x 4 x 4, aux2: N x 528 x 4 x 4
  214. x = self.conv(x)
  215. # N x 128 x 4 x 4
  216. x = torch.flatten(x, 1)
  217. # N x 2048
  218. x = F.relu(self.fc1(x), inplace=True)
  219. # N x 1024
  220. x = self.dropout(x)
  221. # N x 1024
  222. x = self.fc2(x)
  223. # N x 1000 (num_classes)
  224. return x
  225. class BasicConv2d(nn.Module):
  226. def __init__(self, in_channels: int, out_channels: int, **kwargs: Any) -> None:
  227. super().__init__()
  228. self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
  229. self.bn = nn.BatchNorm2d(out_channels, eps=0.001)
  230. def forward(self, x: Tensor) -> Tensor:
  231. x = self.conv(x)
  232. x = self.bn(x)
  233. return F.relu(x, inplace=True)
  234. class GoogLeNet_Weights(WeightsEnum):
  235. IMAGENET1K_V1 = Weights(
  236. url="https://download.pytorch.org/models/googlenet-1378be20.pth",
  237. transforms=partial(ImageClassification, crop_size=224),
  238. meta={
  239. "num_params": 6624904,
  240. "min_size": (15, 15),
  241. "categories": _IMAGENET_CATEGORIES,
  242. "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#googlenet",
  243. "_metrics": {
  244. "ImageNet-1K": {
  245. "acc@1": 69.778,
  246. "acc@5": 89.530,
  247. }
  248. },
  249. "_ops": 1.498,
  250. "_file_size": 49.731,
  251. "_docs": """These weights are ported from the original paper.""",
  252. },
  253. )
  254. DEFAULT = IMAGENET1K_V1
  255. @register_model()
  256. @handle_legacy_interface(weights=("pretrained", GoogLeNet_Weights.IMAGENET1K_V1))
  257. def googlenet(*, weights: Optional[GoogLeNet_Weights] = None, progress: bool = True, **kwargs: Any) -> GoogLeNet:
  258. """GoogLeNet (Inception v1) model architecture from
  259. `Going Deeper with Convolutions <http://arxiv.org/abs/1409.4842>`_.
  260. Args:
  261. weights (:class:`~torchvision.models.GoogLeNet_Weights`, optional): The
  262. pretrained weights for the model. See
  263. :class:`~torchvision.models.GoogLeNet_Weights` below for
  264. more details, and possible values. By default, no pre-trained
  265. weights are used.
  266. progress (bool, optional): If True, displays a progress bar of the
  267. download to stderr. Default is True.
  268. **kwargs: parameters passed to the ``torchvision.models.GoogLeNet``
  269. base class. Please refer to the `source code
  270. <https://github.com/pytorch/vision/blob/main/torchvision/models/googlenet.py>`_
  271. for more details about this class.
  272. .. autoclass:: torchvision.models.GoogLeNet_Weights
  273. :members:
  274. """
  275. weights = GoogLeNet_Weights.verify(weights)
  276. original_aux_logits = kwargs.get("aux_logits", False)
  277. if weights is not None:
  278. if "transform_input" not in kwargs:
  279. _ovewrite_named_param(kwargs, "transform_input", True)
  280. _ovewrite_named_param(kwargs, "aux_logits", True)
  281. _ovewrite_named_param(kwargs, "init_weights", False)
  282. _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
  283. model = GoogLeNet(**kwargs)
  284. if weights is not None:
  285. model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
  286. if not original_aux_logits:
  287. model.aux_logits = False
  288. model.aux1 = None # type: ignore[assignment]
  289. model.aux2 = None # type: ignore[assignment]
  290. else:
  291. warnings.warn(
  292. "auxiliary heads in the pretrained googlenet model are NOT pretrained, so make sure to train them"
  293. )
  294. return model