inception.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478
  1. import warnings
  2. from collections import namedtuple
  3. from functools import partial
  4. from typing import Any, Callable, List, Optional, Tuple
  5. import torch
  6. import torch.nn.functional as F
  7. from torch import nn, Tensor
  8. from ..transforms._presets import ImageClassification
  9. from ..utils import _log_api_usage_once
  10. from ._api import register_model, Weights, WeightsEnum
  11. from ._meta import _IMAGENET_CATEGORIES
  12. from ._utils import _ovewrite_named_param, handle_legacy_interface
  13. __all__ = ["Inception3", "InceptionOutputs", "_InceptionOutputs", "Inception_V3_Weights", "inception_v3"]
  14. InceptionOutputs = namedtuple("InceptionOutputs", ["logits", "aux_logits"])
  15. InceptionOutputs.__annotations__ = {"logits": Tensor, "aux_logits": Optional[Tensor]}
  16. # Script annotations failed with _GoogleNetOutputs = namedtuple ...
  17. # _InceptionOutputs set here for backwards compat
  18. _InceptionOutputs = InceptionOutputs
  19. class Inception3(nn.Module):
  20. def __init__(
  21. self,
  22. num_classes: int = 1000,
  23. aux_logits: bool = True,
  24. transform_input: bool = False,
  25. inception_blocks: Optional[List[Callable[..., nn.Module]]] = None,
  26. init_weights: Optional[bool] = None,
  27. dropout: float = 0.5,
  28. ) -> None:
  29. super().__init__()
  30. _log_api_usage_once(self)
  31. if inception_blocks is None:
  32. inception_blocks = [BasicConv2d, InceptionA, InceptionB, InceptionC, InceptionD, InceptionE, InceptionAux]
  33. if init_weights is None:
  34. warnings.warn(
  35. "The default weight initialization of inception_v3 will be changed in future releases of "
  36. "torchvision. If you wish to keep the old behavior (which leads to long initialization times"
  37. " due to scipy/scipy#11299), please set init_weights=True.",
  38. FutureWarning,
  39. )
  40. init_weights = True
  41. if len(inception_blocks) != 7:
  42. raise ValueError(f"length of inception_blocks should be 7 instead of {len(inception_blocks)}")
  43. conv_block = inception_blocks[0]
  44. inception_a = inception_blocks[1]
  45. inception_b = inception_blocks[2]
  46. inception_c = inception_blocks[3]
  47. inception_d = inception_blocks[4]
  48. inception_e = inception_blocks[5]
  49. inception_aux = inception_blocks[6]
  50. self.aux_logits = aux_logits
  51. self.transform_input = transform_input
  52. self.Conv2d_1a_3x3 = conv_block(3, 32, kernel_size=3, stride=2)
  53. self.Conv2d_2a_3x3 = conv_block(32, 32, kernel_size=3)
  54. self.Conv2d_2b_3x3 = conv_block(32, 64, kernel_size=3, padding=1)
  55. self.maxpool1 = nn.MaxPool2d(kernel_size=3, stride=2)
  56. self.Conv2d_3b_1x1 = conv_block(64, 80, kernel_size=1)
  57. self.Conv2d_4a_3x3 = conv_block(80, 192, kernel_size=3)
  58. self.maxpool2 = nn.MaxPool2d(kernel_size=3, stride=2)
  59. self.Mixed_5b = inception_a(192, pool_features=32)
  60. self.Mixed_5c = inception_a(256, pool_features=64)
  61. self.Mixed_5d = inception_a(288, pool_features=64)
  62. self.Mixed_6a = inception_b(288)
  63. self.Mixed_6b = inception_c(768, channels_7x7=128)
  64. self.Mixed_6c = inception_c(768, channels_7x7=160)
  65. self.Mixed_6d = inception_c(768, channels_7x7=160)
  66. self.Mixed_6e = inception_c(768, channels_7x7=192)
  67. self.AuxLogits: Optional[nn.Module] = None
  68. if aux_logits:
  69. self.AuxLogits = inception_aux(768, num_classes)
  70. self.Mixed_7a = inception_d(768)
  71. self.Mixed_7b = inception_e(1280)
  72. self.Mixed_7c = inception_e(2048)
  73. self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
  74. self.dropout = nn.Dropout(p=dropout)
  75. self.fc = nn.Linear(2048, num_classes)
  76. if init_weights:
  77. for m in self.modules():
  78. if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
  79. stddev = float(m.stddev) if hasattr(m, "stddev") else 0.1 # type: ignore
  80. torch.nn.init.trunc_normal_(m.weight, mean=0.0, std=stddev, a=-2, b=2)
  81. elif isinstance(m, nn.BatchNorm2d):
  82. nn.init.constant_(m.weight, 1)
  83. nn.init.constant_(m.bias, 0)
  84. def _transform_input(self, x: Tensor) -> Tensor:
  85. if self.transform_input:
  86. x_ch0 = torch.unsqueeze(x[:, 0], 1) * (0.229 / 0.5) + (0.485 - 0.5) / 0.5
  87. x_ch1 = torch.unsqueeze(x[:, 1], 1) * (0.224 / 0.5) + (0.456 - 0.5) / 0.5
  88. x_ch2 = torch.unsqueeze(x[:, 2], 1) * (0.225 / 0.5) + (0.406 - 0.5) / 0.5
  89. x = torch.cat((x_ch0, x_ch1, x_ch2), 1)
  90. return x
  91. def _forward(self, x: Tensor) -> Tuple[Tensor, Optional[Tensor]]:
  92. # N x 3 x 299 x 299
  93. x = self.Conv2d_1a_3x3(x)
  94. # N x 32 x 149 x 149
  95. x = self.Conv2d_2a_3x3(x)
  96. # N x 32 x 147 x 147
  97. x = self.Conv2d_2b_3x3(x)
  98. # N x 64 x 147 x 147
  99. x = self.maxpool1(x)
  100. # N x 64 x 73 x 73
  101. x = self.Conv2d_3b_1x1(x)
  102. # N x 80 x 73 x 73
  103. x = self.Conv2d_4a_3x3(x)
  104. # N x 192 x 71 x 71
  105. x = self.maxpool2(x)
  106. # N x 192 x 35 x 35
  107. x = self.Mixed_5b(x)
  108. # N x 256 x 35 x 35
  109. x = self.Mixed_5c(x)
  110. # N x 288 x 35 x 35
  111. x = self.Mixed_5d(x)
  112. # N x 288 x 35 x 35
  113. x = self.Mixed_6a(x)
  114. # N x 768 x 17 x 17
  115. x = self.Mixed_6b(x)
  116. # N x 768 x 17 x 17
  117. x = self.Mixed_6c(x)
  118. # N x 768 x 17 x 17
  119. x = self.Mixed_6d(x)
  120. # N x 768 x 17 x 17
  121. x = self.Mixed_6e(x)
  122. # N x 768 x 17 x 17
  123. aux: Optional[Tensor] = None
  124. if self.AuxLogits is not None:
  125. if self.training:
  126. aux = self.AuxLogits(x)
  127. # N x 768 x 17 x 17
  128. x = self.Mixed_7a(x)
  129. # N x 1280 x 8 x 8
  130. x = self.Mixed_7b(x)
  131. # N x 2048 x 8 x 8
  132. x = self.Mixed_7c(x)
  133. # N x 2048 x 8 x 8
  134. # Adaptive average pooling
  135. x = self.avgpool(x)
  136. # N x 2048 x 1 x 1
  137. x = self.dropout(x)
  138. # N x 2048 x 1 x 1
  139. x = torch.flatten(x, 1)
  140. # N x 2048
  141. x = self.fc(x)
  142. # N x 1000 (num_classes)
  143. return x, aux
  144. @torch.jit.unused
  145. def eager_outputs(self, x: Tensor, aux: Optional[Tensor]) -> InceptionOutputs:
  146. if self.training and self.aux_logits:
  147. return InceptionOutputs(x, aux)
  148. else:
  149. return x # type: ignore[return-value]
  150. def forward(self, x: Tensor) -> InceptionOutputs:
  151. x = self._transform_input(x)
  152. x, aux = self._forward(x)
  153. aux_defined = self.training and self.aux_logits
  154. if torch.jit.is_scripting():
  155. if not aux_defined:
  156. warnings.warn("Scripted Inception3 always returns Inception3 Tuple")
  157. return InceptionOutputs(x, aux)
  158. else:
  159. return self.eager_outputs(x, aux)
  160. class InceptionA(nn.Module):
  161. def __init__(
  162. self, in_channels: int, pool_features: int, conv_block: Optional[Callable[..., nn.Module]] = None
  163. ) -> None:
  164. super().__init__()
  165. if conv_block is None:
  166. conv_block = BasicConv2d
  167. self.branch1x1 = conv_block(in_channels, 64, kernel_size=1)
  168. self.branch5x5_1 = conv_block(in_channels, 48, kernel_size=1)
  169. self.branch5x5_2 = conv_block(48, 64, kernel_size=5, padding=2)
  170. self.branch3x3dbl_1 = conv_block(in_channels, 64, kernel_size=1)
  171. self.branch3x3dbl_2 = conv_block(64, 96, kernel_size=3, padding=1)
  172. self.branch3x3dbl_3 = conv_block(96, 96, kernel_size=3, padding=1)
  173. self.branch_pool = conv_block(in_channels, pool_features, kernel_size=1)
  174. def _forward(self, x: Tensor) -> List[Tensor]:
  175. branch1x1 = self.branch1x1(x)
  176. branch5x5 = self.branch5x5_1(x)
  177. branch5x5 = self.branch5x5_2(branch5x5)
  178. branch3x3dbl = self.branch3x3dbl_1(x)
  179. branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
  180. branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
  181. branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)
  182. branch_pool = self.branch_pool(branch_pool)
  183. outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
  184. return outputs
  185. def forward(self, x: Tensor) -> Tensor:
  186. outputs = self._forward(x)
  187. return torch.cat(outputs, 1)
  188. class InceptionB(nn.Module):
  189. def __init__(self, in_channels: int, conv_block: Optional[Callable[..., nn.Module]] = None) -> None:
  190. super().__init__()
  191. if conv_block is None:
  192. conv_block = BasicConv2d
  193. self.branch3x3 = conv_block(in_channels, 384, kernel_size=3, stride=2)
  194. self.branch3x3dbl_1 = conv_block(in_channels, 64, kernel_size=1)
  195. self.branch3x3dbl_2 = conv_block(64, 96, kernel_size=3, padding=1)
  196. self.branch3x3dbl_3 = conv_block(96, 96, kernel_size=3, stride=2)
  197. def _forward(self, x: Tensor) -> List[Tensor]:
  198. branch3x3 = self.branch3x3(x)
  199. branch3x3dbl = self.branch3x3dbl_1(x)
  200. branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
  201. branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
  202. branch_pool = F.max_pool2d(x, kernel_size=3, stride=2)
  203. outputs = [branch3x3, branch3x3dbl, branch_pool]
  204. return outputs
  205. def forward(self, x: Tensor) -> Tensor:
  206. outputs = self._forward(x)
  207. return torch.cat(outputs, 1)
  208. class InceptionC(nn.Module):
  209. def __init__(
  210. self, in_channels: int, channels_7x7: int, conv_block: Optional[Callable[..., nn.Module]] = None
  211. ) -> None:
  212. super().__init__()
  213. if conv_block is None:
  214. conv_block = BasicConv2d
  215. self.branch1x1 = conv_block(in_channels, 192, kernel_size=1)
  216. c7 = channels_7x7
  217. self.branch7x7_1 = conv_block(in_channels, c7, kernel_size=1)
  218. self.branch7x7_2 = conv_block(c7, c7, kernel_size=(1, 7), padding=(0, 3))
  219. self.branch7x7_3 = conv_block(c7, 192, kernel_size=(7, 1), padding=(3, 0))
  220. self.branch7x7dbl_1 = conv_block(in_channels, c7, kernel_size=1)
  221. self.branch7x7dbl_2 = conv_block(c7, c7, kernel_size=(7, 1), padding=(3, 0))
  222. self.branch7x7dbl_3 = conv_block(c7, c7, kernel_size=(1, 7), padding=(0, 3))
  223. self.branch7x7dbl_4 = conv_block(c7, c7, kernel_size=(7, 1), padding=(3, 0))
  224. self.branch7x7dbl_5 = conv_block(c7, 192, kernel_size=(1, 7), padding=(0, 3))
  225. self.branch_pool = conv_block(in_channels, 192, kernel_size=1)
  226. def _forward(self, x: Tensor) -> List[Tensor]:
  227. branch1x1 = self.branch1x1(x)
  228. branch7x7 = self.branch7x7_1(x)
  229. branch7x7 = self.branch7x7_2(branch7x7)
  230. branch7x7 = self.branch7x7_3(branch7x7)
  231. branch7x7dbl = self.branch7x7dbl_1(x)
  232. branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)
  233. branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)
  234. branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)
  235. branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)
  236. branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)
  237. branch_pool = self.branch_pool(branch_pool)
  238. outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]
  239. return outputs
  240. def forward(self, x: Tensor) -> Tensor:
  241. outputs = self._forward(x)
  242. return torch.cat(outputs, 1)
  243. class InceptionD(nn.Module):
  244. def __init__(self, in_channels: int, conv_block: Optional[Callable[..., nn.Module]] = None) -> None:
  245. super().__init__()
  246. if conv_block is None:
  247. conv_block = BasicConv2d
  248. self.branch3x3_1 = conv_block(in_channels, 192, kernel_size=1)
  249. self.branch3x3_2 = conv_block(192, 320, kernel_size=3, stride=2)
  250. self.branch7x7x3_1 = conv_block(in_channels, 192, kernel_size=1)
  251. self.branch7x7x3_2 = conv_block(192, 192, kernel_size=(1, 7), padding=(0, 3))
  252. self.branch7x7x3_3 = conv_block(192, 192, kernel_size=(7, 1), padding=(3, 0))
  253. self.branch7x7x3_4 = conv_block(192, 192, kernel_size=3, stride=2)
  254. def _forward(self, x: Tensor) -> List[Tensor]:
  255. branch3x3 = self.branch3x3_1(x)
  256. branch3x3 = self.branch3x3_2(branch3x3)
  257. branch7x7x3 = self.branch7x7x3_1(x)
  258. branch7x7x3 = self.branch7x7x3_2(branch7x7x3)
  259. branch7x7x3 = self.branch7x7x3_3(branch7x7x3)
  260. branch7x7x3 = self.branch7x7x3_4(branch7x7x3)
  261. branch_pool = F.max_pool2d(x, kernel_size=3, stride=2)
  262. outputs = [branch3x3, branch7x7x3, branch_pool]
  263. return outputs
  264. def forward(self, x: Tensor) -> Tensor:
  265. outputs = self._forward(x)
  266. return torch.cat(outputs, 1)
  267. class InceptionE(nn.Module):
  268. def __init__(self, in_channels: int, conv_block: Optional[Callable[..., nn.Module]] = None) -> None:
  269. super().__init__()
  270. if conv_block is None:
  271. conv_block = BasicConv2d
  272. self.branch1x1 = conv_block(in_channels, 320, kernel_size=1)
  273. self.branch3x3_1 = conv_block(in_channels, 384, kernel_size=1)
  274. self.branch3x3_2a = conv_block(384, 384, kernel_size=(1, 3), padding=(0, 1))
  275. self.branch3x3_2b = conv_block(384, 384, kernel_size=(3, 1), padding=(1, 0))
  276. self.branch3x3dbl_1 = conv_block(in_channels, 448, kernel_size=1)
  277. self.branch3x3dbl_2 = conv_block(448, 384, kernel_size=3, padding=1)
  278. self.branch3x3dbl_3a = conv_block(384, 384, kernel_size=(1, 3), padding=(0, 1))
  279. self.branch3x3dbl_3b = conv_block(384, 384, kernel_size=(3, 1), padding=(1, 0))
  280. self.branch_pool = conv_block(in_channels, 192, kernel_size=1)
  281. def _forward(self, x: Tensor) -> List[Tensor]:
  282. branch1x1 = self.branch1x1(x)
  283. branch3x3 = self.branch3x3_1(x)
  284. branch3x3 = [
  285. self.branch3x3_2a(branch3x3),
  286. self.branch3x3_2b(branch3x3),
  287. ]
  288. branch3x3 = torch.cat(branch3x3, 1)
  289. branch3x3dbl = self.branch3x3dbl_1(x)
  290. branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
  291. branch3x3dbl = [
  292. self.branch3x3dbl_3a(branch3x3dbl),
  293. self.branch3x3dbl_3b(branch3x3dbl),
  294. ]
  295. branch3x3dbl = torch.cat(branch3x3dbl, 1)
  296. branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)
  297. branch_pool = self.branch_pool(branch_pool)
  298. outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
  299. return outputs
  300. def forward(self, x: Tensor) -> Tensor:
  301. outputs = self._forward(x)
  302. return torch.cat(outputs, 1)
  303. class InceptionAux(nn.Module):
  304. def __init__(
  305. self, in_channels: int, num_classes: int, conv_block: Optional[Callable[..., nn.Module]] = None
  306. ) -> None:
  307. super().__init__()
  308. if conv_block is None:
  309. conv_block = BasicConv2d
  310. self.conv0 = conv_block(in_channels, 128, kernel_size=1)
  311. self.conv1 = conv_block(128, 768, kernel_size=5)
  312. self.conv1.stddev = 0.01 # type: ignore[assignment]
  313. self.fc = nn.Linear(768, num_classes)
  314. self.fc.stddev = 0.001 # type: ignore[assignment]
  315. def forward(self, x: Tensor) -> Tensor:
  316. # N x 768 x 17 x 17
  317. x = F.avg_pool2d(x, kernel_size=5, stride=3)
  318. # N x 768 x 5 x 5
  319. x = self.conv0(x)
  320. # N x 128 x 5 x 5
  321. x = self.conv1(x)
  322. # N x 768 x 1 x 1
  323. # Adaptive average pooling
  324. x = F.adaptive_avg_pool2d(x, (1, 1))
  325. # N x 768 x 1 x 1
  326. x = torch.flatten(x, 1)
  327. # N x 768
  328. x = self.fc(x)
  329. # N x 1000
  330. return x
  331. class BasicConv2d(nn.Module):
  332. def __init__(self, in_channels: int, out_channels: int, **kwargs: Any) -> None:
  333. super().__init__()
  334. self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
  335. self.bn = nn.BatchNorm2d(out_channels, eps=0.001)
  336. def forward(self, x: Tensor) -> Tensor:
  337. x = self.conv(x)
  338. x = self.bn(x)
  339. return F.relu(x, inplace=True)
  340. class Inception_V3_Weights(WeightsEnum):
  341. IMAGENET1K_V1 = Weights(
  342. url="https://download.pytorch.org/models/inception_v3_google-0cc3c7bd.pth",
  343. transforms=partial(ImageClassification, crop_size=299, resize_size=342),
  344. meta={
  345. "num_params": 27161264,
  346. "min_size": (75, 75),
  347. "categories": _IMAGENET_CATEGORIES,
  348. "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#inception-v3",
  349. "_metrics": {
  350. "ImageNet-1K": {
  351. "acc@1": 77.294,
  352. "acc@5": 93.450,
  353. }
  354. },
  355. "_ops": 5.713,
  356. "_file_size": 103.903,
  357. "_docs": """These weights are ported from the original paper.""",
  358. },
  359. )
  360. DEFAULT = IMAGENET1K_V1
  361. @register_model()
  362. @handle_legacy_interface(weights=("pretrained", Inception_V3_Weights.IMAGENET1K_V1))
  363. def inception_v3(*, weights: Optional[Inception_V3_Weights] = None, progress: bool = True, **kwargs: Any) -> Inception3:
  364. """
  365. Inception v3 model architecture from
  366. `Rethinking the Inception Architecture for Computer Vision <http://arxiv.org/abs/1512.00567>`_.
  367. .. note::
  368. **Important**: In contrast to the other models the inception_v3 expects tensors with a size of
  369. N x 3 x 299 x 299, so ensure your images are sized accordingly.
  370. Args:
  371. weights (:class:`~torchvision.models.Inception_V3_Weights`, optional): The
  372. pretrained weights for the model. See
  373. :class:`~torchvision.models.Inception_V3_Weights` below for
  374. more details, and possible values. By default, no pre-trained
  375. weights are used.
  376. progress (bool, optional): If True, displays a progress bar of the
  377. download to stderr. Default is True.
  378. **kwargs: parameters passed to the ``torchvision.models.Inception3``
  379. base class. Please refer to the `source code
  380. <https://github.com/pytorch/vision/blob/main/torchvision/models/inception.py>`_
  381. for more details about this class.
  382. .. autoclass:: torchvision.models.Inception_V3_Weights
  383. :members:
  384. """
  385. weights = Inception_V3_Weights.verify(weights)
  386. original_aux_logits = kwargs.get("aux_logits", True)
  387. if weights is not None:
  388. if "transform_input" not in kwargs:
  389. _ovewrite_named_param(kwargs, "transform_input", True)
  390. _ovewrite_named_param(kwargs, "aux_logits", True)
  391. _ovewrite_named_param(kwargs, "init_weights", False)
  392. _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
  393. model = Inception3(**kwargs)
  394. if weights is not None:
  395. model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
  396. if not original_aux_logits:
  397. model.aux_logits = False
  398. model.AuxLogits = None
  399. return model