convnext.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414
  1. from functools import partial
  2. from typing import Any, Callable, List, Optional, Sequence
  3. import torch
  4. from torch import nn, Tensor
  5. from torch.nn import functional as F
  6. from ..ops.misc import Conv2dNormActivation, Permute
  7. from ..ops.stochastic_depth import StochasticDepth
  8. from ..transforms._presets import ImageClassification
  9. from ..utils import _log_api_usage_once
  10. from ._api import register_model, Weights, WeightsEnum
  11. from ._meta import _IMAGENET_CATEGORIES
  12. from ._utils import _ovewrite_named_param, handle_legacy_interface
  13. __all__ = [
  14. "ConvNeXt",
  15. "ConvNeXt_Tiny_Weights",
  16. "ConvNeXt_Small_Weights",
  17. "ConvNeXt_Base_Weights",
  18. "ConvNeXt_Large_Weights",
  19. "convnext_tiny",
  20. "convnext_small",
  21. "convnext_base",
  22. "convnext_large",
  23. ]
  24. class LayerNorm2d(nn.LayerNorm):
  25. def forward(self, x: Tensor) -> Tensor:
  26. x = x.permute(0, 2, 3, 1)
  27. x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
  28. x = x.permute(0, 3, 1, 2)
  29. return x
  30. class CNBlock(nn.Module):
  31. def __init__(
  32. self,
  33. dim,
  34. layer_scale: float,
  35. stochastic_depth_prob: float,
  36. norm_layer: Optional[Callable[..., nn.Module]] = None,
  37. ) -> None:
  38. super().__init__()
  39. if norm_layer is None:
  40. norm_layer = partial(nn.LayerNorm, eps=1e-6)
  41. self.block = nn.Sequential(
  42. nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim, bias=True),
  43. Permute([0, 2, 3, 1]),
  44. norm_layer(dim),
  45. nn.Linear(in_features=dim, out_features=4 * dim, bias=True),
  46. nn.GELU(),
  47. nn.Linear(in_features=4 * dim, out_features=dim, bias=True),
  48. Permute([0, 3, 1, 2]),
  49. )
  50. self.layer_scale = nn.Parameter(torch.ones(dim, 1, 1) * layer_scale)
  51. self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row")
  52. def forward(self, input: Tensor) -> Tensor:
  53. result = self.layer_scale * self.block(input)
  54. result = self.stochastic_depth(result)
  55. result += input
  56. return result
  57. class CNBlockConfig:
  58. # Stores information listed at Section 3 of the ConvNeXt paper
  59. def __init__(
  60. self,
  61. input_channels: int,
  62. out_channels: Optional[int],
  63. num_layers: int,
  64. ) -> None:
  65. self.input_channels = input_channels
  66. self.out_channels = out_channels
  67. self.num_layers = num_layers
  68. def __repr__(self) -> str:
  69. s = self.__class__.__name__ + "("
  70. s += "input_channels={input_channels}"
  71. s += ", out_channels={out_channels}"
  72. s += ", num_layers={num_layers}"
  73. s += ")"
  74. return s.format(**self.__dict__)
  75. class ConvNeXt(nn.Module):
  76. def __init__(
  77. self,
  78. block_setting: List[CNBlockConfig],
  79. stochastic_depth_prob: float = 0.0,
  80. layer_scale: float = 1e-6,
  81. num_classes: int = 1000,
  82. block: Optional[Callable[..., nn.Module]] = None,
  83. norm_layer: Optional[Callable[..., nn.Module]] = None,
  84. **kwargs: Any,
  85. ) -> None:
  86. super().__init__()
  87. _log_api_usage_once(self)
  88. if not block_setting:
  89. raise ValueError("The block_setting should not be empty")
  90. elif not (isinstance(block_setting, Sequence) and all([isinstance(s, CNBlockConfig) for s in block_setting])):
  91. raise TypeError("The block_setting should be List[CNBlockConfig]")
  92. if block is None:
  93. block = CNBlock
  94. if norm_layer is None:
  95. norm_layer = partial(LayerNorm2d, eps=1e-6)
  96. layers: List[nn.Module] = []
  97. # Stem
  98. firstconv_output_channels = block_setting[0].input_channels
  99. layers.append(
  100. Conv2dNormActivation(
  101. 3,
  102. firstconv_output_channels,
  103. kernel_size=4,
  104. stride=4,
  105. padding=0,
  106. norm_layer=norm_layer,
  107. activation_layer=None,
  108. bias=True,
  109. )
  110. )
  111. total_stage_blocks = sum(cnf.num_layers for cnf in block_setting)
  112. stage_block_id = 0
  113. for cnf in block_setting:
  114. # Bottlenecks
  115. stage: List[nn.Module] = []
  116. for _ in range(cnf.num_layers):
  117. # adjust stochastic depth probability based on the depth of the stage block
  118. sd_prob = stochastic_depth_prob * stage_block_id / (total_stage_blocks - 1.0)
  119. stage.append(block(cnf.input_channels, layer_scale, sd_prob))
  120. stage_block_id += 1
  121. layers.append(nn.Sequential(*stage))
  122. if cnf.out_channels is not None:
  123. # Downsampling
  124. layers.append(
  125. nn.Sequential(
  126. norm_layer(cnf.input_channels),
  127. nn.Conv2d(cnf.input_channels, cnf.out_channels, kernel_size=2, stride=2),
  128. )
  129. )
  130. self.features = nn.Sequential(*layers)
  131. self.avgpool = nn.AdaptiveAvgPool2d(1)
  132. lastblock = block_setting[-1]
  133. lastconv_output_channels = (
  134. lastblock.out_channels if lastblock.out_channels is not None else lastblock.input_channels
  135. )
  136. self.classifier = nn.Sequential(
  137. norm_layer(lastconv_output_channels), nn.Flatten(1), nn.Linear(lastconv_output_channels, num_classes)
  138. )
  139. for m in self.modules():
  140. if isinstance(m, (nn.Conv2d, nn.Linear)):
  141. nn.init.trunc_normal_(m.weight, std=0.02)
  142. if m.bias is not None:
  143. nn.init.zeros_(m.bias)
  144. def _forward_impl(self, x: Tensor) -> Tensor:
  145. x = self.features(x)
  146. x = self.avgpool(x)
  147. x = self.classifier(x)
  148. return x
  149. def forward(self, x: Tensor) -> Tensor:
  150. return self._forward_impl(x)
  151. def _convnext(
  152. block_setting: List[CNBlockConfig],
  153. stochastic_depth_prob: float,
  154. weights: Optional[WeightsEnum],
  155. progress: bool,
  156. **kwargs: Any,
  157. ) -> ConvNeXt:
  158. if weights is not None:
  159. _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
  160. model = ConvNeXt(block_setting, stochastic_depth_prob=stochastic_depth_prob, **kwargs)
  161. if weights is not None:
  162. model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
  163. return model
  164. _COMMON_META = {
  165. "min_size": (32, 32),
  166. "categories": _IMAGENET_CATEGORIES,
  167. "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#convnext",
  168. "_docs": """
  169. These weights improve upon the results of the original paper by using a modified version of TorchVision's
  170. `new training recipe
  171. <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
  172. """,
  173. }
  174. class ConvNeXt_Tiny_Weights(WeightsEnum):
  175. IMAGENET1K_V1 = Weights(
  176. url="https://download.pytorch.org/models/convnext_tiny-983f1562.pth",
  177. transforms=partial(ImageClassification, crop_size=224, resize_size=236),
  178. meta={
  179. **_COMMON_META,
  180. "num_params": 28589128,
  181. "_metrics": {
  182. "ImageNet-1K": {
  183. "acc@1": 82.520,
  184. "acc@5": 96.146,
  185. }
  186. },
  187. "_ops": 4.456,
  188. "_file_size": 109.119,
  189. },
  190. )
  191. DEFAULT = IMAGENET1K_V1
  192. class ConvNeXt_Small_Weights(WeightsEnum):
  193. IMAGENET1K_V1 = Weights(
  194. url="https://download.pytorch.org/models/convnext_small-0c510722.pth",
  195. transforms=partial(ImageClassification, crop_size=224, resize_size=230),
  196. meta={
  197. **_COMMON_META,
  198. "num_params": 50223688,
  199. "_metrics": {
  200. "ImageNet-1K": {
  201. "acc@1": 83.616,
  202. "acc@5": 96.650,
  203. }
  204. },
  205. "_ops": 8.684,
  206. "_file_size": 191.703,
  207. },
  208. )
  209. DEFAULT = IMAGENET1K_V1
  210. class ConvNeXt_Base_Weights(WeightsEnum):
  211. IMAGENET1K_V1 = Weights(
  212. url="https://download.pytorch.org/models/convnext_base-6075fbad.pth",
  213. transforms=partial(ImageClassification, crop_size=224, resize_size=232),
  214. meta={
  215. **_COMMON_META,
  216. "num_params": 88591464,
  217. "_metrics": {
  218. "ImageNet-1K": {
  219. "acc@1": 84.062,
  220. "acc@5": 96.870,
  221. }
  222. },
  223. "_ops": 15.355,
  224. "_file_size": 338.064,
  225. },
  226. )
  227. DEFAULT = IMAGENET1K_V1
  228. class ConvNeXt_Large_Weights(WeightsEnum):
  229. IMAGENET1K_V1 = Weights(
  230. url="https://download.pytorch.org/models/convnext_large-ea097f82.pth",
  231. transforms=partial(ImageClassification, crop_size=224, resize_size=232),
  232. meta={
  233. **_COMMON_META,
  234. "num_params": 197767336,
  235. "_metrics": {
  236. "ImageNet-1K": {
  237. "acc@1": 84.414,
  238. "acc@5": 96.976,
  239. }
  240. },
  241. "_ops": 34.361,
  242. "_file_size": 754.537,
  243. },
  244. )
  245. DEFAULT = IMAGENET1K_V1
  246. @register_model()
  247. @handle_legacy_interface(weights=("pretrained", ConvNeXt_Tiny_Weights.IMAGENET1K_V1))
  248. def convnext_tiny(*, weights: Optional[ConvNeXt_Tiny_Weights] = None, progress: bool = True, **kwargs: Any) -> ConvNeXt:
  249. """ConvNeXt Tiny model architecture from the
  250. `A ConvNet for the 2020s <https://arxiv.org/abs/2201.03545>`_ paper.
  251. Args:
  252. weights (:class:`~torchvision.models.convnext.ConvNeXt_Tiny_Weights`, optional): The pretrained
  253. weights to use. See :class:`~torchvision.models.convnext.ConvNeXt_Tiny_Weights`
  254. below for more details and possible values. By default, no pre-trained weights are used.
  255. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
  256. **kwargs: parameters passed to the ``torchvision.models.convnext.ConvNext``
  257. base class. Please refer to the `source code
  258. <https://github.com/pytorch/vision/blob/main/torchvision/models/convnext.py>`_
  259. for more details about this class.
  260. .. autoclass:: torchvision.models.ConvNeXt_Tiny_Weights
  261. :members:
  262. """
  263. weights = ConvNeXt_Tiny_Weights.verify(weights)
  264. block_setting = [
  265. CNBlockConfig(96, 192, 3),
  266. CNBlockConfig(192, 384, 3),
  267. CNBlockConfig(384, 768, 9),
  268. CNBlockConfig(768, None, 3),
  269. ]
  270. stochastic_depth_prob = kwargs.pop("stochastic_depth_prob", 0.1)
  271. return _convnext(block_setting, stochastic_depth_prob, weights, progress, **kwargs)
  272. @register_model()
  273. @handle_legacy_interface(weights=("pretrained", ConvNeXt_Small_Weights.IMAGENET1K_V1))
  274. def convnext_small(
  275. *, weights: Optional[ConvNeXt_Small_Weights] = None, progress: bool = True, **kwargs: Any
  276. ) -> ConvNeXt:
  277. """ConvNeXt Small model architecture from the
  278. `A ConvNet for the 2020s <https://arxiv.org/abs/2201.03545>`_ paper.
  279. Args:
  280. weights (:class:`~torchvision.models.convnext.ConvNeXt_Small_Weights`, optional): The pretrained
  281. weights to use. See :class:`~torchvision.models.convnext.ConvNeXt_Small_Weights`
  282. below for more details and possible values. By default, no pre-trained weights are used.
  283. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
  284. **kwargs: parameters passed to the ``torchvision.models.convnext.ConvNext``
  285. base class. Please refer to the `source code
  286. <https://github.com/pytorch/vision/blob/main/torchvision/models/convnext.py>`_
  287. for more details about this class.
  288. .. autoclass:: torchvision.models.ConvNeXt_Small_Weights
  289. :members:
  290. """
  291. weights = ConvNeXt_Small_Weights.verify(weights)
  292. block_setting = [
  293. CNBlockConfig(96, 192, 3),
  294. CNBlockConfig(192, 384, 3),
  295. CNBlockConfig(384, 768, 27),
  296. CNBlockConfig(768, None, 3),
  297. ]
  298. stochastic_depth_prob = kwargs.pop("stochastic_depth_prob", 0.4)
  299. return _convnext(block_setting, stochastic_depth_prob, weights, progress, **kwargs)
  300. @register_model()
  301. @handle_legacy_interface(weights=("pretrained", ConvNeXt_Base_Weights.IMAGENET1K_V1))
  302. def convnext_base(*, weights: Optional[ConvNeXt_Base_Weights] = None, progress: bool = True, **kwargs: Any) -> ConvNeXt:
  303. """ConvNeXt Base model architecture from the
  304. `A ConvNet for the 2020s <https://arxiv.org/abs/2201.03545>`_ paper.
  305. Args:
  306. weights (:class:`~torchvision.models.convnext.ConvNeXt_Base_Weights`, optional): The pretrained
  307. weights to use. See :class:`~torchvision.models.convnext.ConvNeXt_Base_Weights`
  308. below for more details and possible values. By default, no pre-trained weights are used.
  309. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
  310. **kwargs: parameters passed to the ``torchvision.models.convnext.ConvNext``
  311. base class. Please refer to the `source code
  312. <https://github.com/pytorch/vision/blob/main/torchvision/models/convnext.py>`_
  313. for more details about this class.
  314. .. autoclass:: torchvision.models.ConvNeXt_Base_Weights
  315. :members:
  316. """
  317. weights = ConvNeXt_Base_Weights.verify(weights)
  318. block_setting = [
  319. CNBlockConfig(128, 256, 3),
  320. CNBlockConfig(256, 512, 3),
  321. CNBlockConfig(512, 1024, 27),
  322. CNBlockConfig(1024, None, 3),
  323. ]
  324. stochastic_depth_prob = kwargs.pop("stochastic_depth_prob", 0.5)
  325. return _convnext(block_setting, stochastic_depth_prob, weights, progress, **kwargs)
  326. @register_model()
  327. @handle_legacy_interface(weights=("pretrained", ConvNeXt_Large_Weights.IMAGENET1K_V1))
  328. def convnext_large(
  329. *, weights: Optional[ConvNeXt_Large_Weights] = None, progress: bool = True, **kwargs: Any
  330. ) -> ConvNeXt:
  331. """ConvNeXt Large model architecture from the
  332. `A ConvNet for the 2020s <https://arxiv.org/abs/2201.03545>`_ paper.
  333. Args:
  334. weights (:class:`~torchvision.models.convnext.ConvNeXt_Large_Weights`, optional): The pretrained
  335. weights to use. See :class:`~torchvision.models.convnext.ConvNeXt_Large_Weights`
  336. below for more details and possible values. By default, no pre-trained weights are used.
  337. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
  338. **kwargs: parameters passed to the ``torchvision.models.convnext.ConvNext``
  339. base class. Please refer to the `source code
  340. <https://github.com/pytorch/vision/blob/main/torchvision/models/convnext.py>`_
  341. for more details about this class.
  342. .. autoclass:: torchvision.models.ConvNeXt_Large_Weights
  343. :members:
  344. """
  345. weights = ConvNeXt_Large_Weights.verify(weights)
  346. block_setting = [
  347. CNBlockConfig(192, 384, 3),
  348. CNBlockConfig(384, 768, 3),
  349. CNBlockConfig(768, 1536, 27),
  350. CNBlockConfig(1536, None, 3),
  351. ]
  352. stochastic_depth_prob = kwargs.pop("stochastic_depth_prob", 0.5)
  353. return _convnext(block_setting, stochastic_depth_prob, weights, progress, **kwargs)