resnet.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503
  1. from functools import partial
  2. from typing import Any, Callable, List, Optional, Sequence, Tuple, Type, Union
  3. import torch.nn as nn
  4. from torch import Tensor
  5. from ...transforms._presets import VideoClassification
  6. from ...utils import _log_api_usage_once
  7. from .._api import register_model, Weights, WeightsEnum
  8. from .._meta import _KINETICS400_CATEGORIES
  9. from .._utils import _ovewrite_named_param, handle_legacy_interface
  10. __all__ = [
  11. "VideoResNet",
  12. "R3D_18_Weights",
  13. "MC3_18_Weights",
  14. "R2Plus1D_18_Weights",
  15. "r3d_18",
  16. "mc3_18",
  17. "r2plus1d_18",
  18. ]
  19. class Conv3DSimple(nn.Conv3d):
  20. def __init__(
  21. self, in_planes: int, out_planes: int, midplanes: Optional[int] = None, stride: int = 1, padding: int = 1
  22. ) -> None:
  23. super().__init__(
  24. in_channels=in_planes,
  25. out_channels=out_planes,
  26. kernel_size=(3, 3, 3),
  27. stride=stride,
  28. padding=padding,
  29. bias=False,
  30. )
  31. @staticmethod
  32. def get_downsample_stride(stride: int) -> Tuple[int, int, int]:
  33. return stride, stride, stride
  34. class Conv2Plus1D(nn.Sequential):
  35. def __init__(self, in_planes: int, out_planes: int, midplanes: int, stride: int = 1, padding: int = 1) -> None:
  36. super().__init__(
  37. nn.Conv3d(
  38. in_planes,
  39. midplanes,
  40. kernel_size=(1, 3, 3),
  41. stride=(1, stride, stride),
  42. padding=(0, padding, padding),
  43. bias=False,
  44. ),
  45. nn.BatchNorm3d(midplanes),
  46. nn.ReLU(inplace=True),
  47. nn.Conv3d(
  48. midplanes, out_planes, kernel_size=(3, 1, 1), stride=(stride, 1, 1), padding=(padding, 0, 0), bias=False
  49. ),
  50. )
  51. @staticmethod
  52. def get_downsample_stride(stride: int) -> Tuple[int, int, int]:
  53. return stride, stride, stride
  54. class Conv3DNoTemporal(nn.Conv3d):
  55. def __init__(
  56. self, in_planes: int, out_planes: int, midplanes: Optional[int] = None, stride: int = 1, padding: int = 1
  57. ) -> None:
  58. super().__init__(
  59. in_channels=in_planes,
  60. out_channels=out_planes,
  61. kernel_size=(1, 3, 3),
  62. stride=(1, stride, stride),
  63. padding=(0, padding, padding),
  64. bias=False,
  65. )
  66. @staticmethod
  67. def get_downsample_stride(stride: int) -> Tuple[int, int, int]:
  68. return 1, stride, stride
  69. class BasicBlock(nn.Module):
  70. expansion = 1
  71. def __init__(
  72. self,
  73. inplanes: int,
  74. planes: int,
  75. conv_builder: Callable[..., nn.Module],
  76. stride: int = 1,
  77. downsample: Optional[nn.Module] = None,
  78. ) -> None:
  79. midplanes = (inplanes * planes * 3 * 3 * 3) // (inplanes * 3 * 3 + 3 * planes)
  80. super().__init__()
  81. self.conv1 = nn.Sequential(
  82. conv_builder(inplanes, planes, midplanes, stride), nn.BatchNorm3d(planes), nn.ReLU(inplace=True)
  83. )
  84. self.conv2 = nn.Sequential(conv_builder(planes, planes, midplanes), nn.BatchNorm3d(planes))
  85. self.relu = nn.ReLU(inplace=True)
  86. self.downsample = downsample
  87. self.stride = stride
  88. def forward(self, x: Tensor) -> Tensor:
  89. residual = x
  90. out = self.conv1(x)
  91. out = self.conv2(out)
  92. if self.downsample is not None:
  93. residual = self.downsample(x)
  94. out += residual
  95. out = self.relu(out)
  96. return out
  97. class Bottleneck(nn.Module):
  98. expansion = 4
  99. def __init__(
  100. self,
  101. inplanes: int,
  102. planes: int,
  103. conv_builder: Callable[..., nn.Module],
  104. stride: int = 1,
  105. downsample: Optional[nn.Module] = None,
  106. ) -> None:
  107. super().__init__()
  108. midplanes = (inplanes * planes * 3 * 3 * 3) // (inplanes * 3 * 3 + 3 * planes)
  109. # 1x1x1
  110. self.conv1 = nn.Sequential(
  111. nn.Conv3d(inplanes, planes, kernel_size=1, bias=False), nn.BatchNorm3d(planes), nn.ReLU(inplace=True)
  112. )
  113. # Second kernel
  114. self.conv2 = nn.Sequential(
  115. conv_builder(planes, planes, midplanes, stride), nn.BatchNorm3d(planes), nn.ReLU(inplace=True)
  116. )
  117. # 1x1x1
  118. self.conv3 = nn.Sequential(
  119. nn.Conv3d(planes, planes * self.expansion, kernel_size=1, bias=False),
  120. nn.BatchNorm3d(planes * self.expansion),
  121. )
  122. self.relu = nn.ReLU(inplace=True)
  123. self.downsample = downsample
  124. self.stride = stride
  125. def forward(self, x: Tensor) -> Tensor:
  126. residual = x
  127. out = self.conv1(x)
  128. out = self.conv2(out)
  129. out = self.conv3(out)
  130. if self.downsample is not None:
  131. residual = self.downsample(x)
  132. out += residual
  133. out = self.relu(out)
  134. return out
  135. class BasicStem(nn.Sequential):
  136. """The default conv-batchnorm-relu stem"""
  137. def __init__(self) -> None:
  138. super().__init__(
  139. nn.Conv3d(3, 64, kernel_size=(3, 7, 7), stride=(1, 2, 2), padding=(1, 3, 3), bias=False),
  140. nn.BatchNorm3d(64),
  141. nn.ReLU(inplace=True),
  142. )
  143. class R2Plus1dStem(nn.Sequential):
  144. """R(2+1)D stem is different than the default one as it uses separated 3D convolution"""
  145. def __init__(self) -> None:
  146. super().__init__(
  147. nn.Conv3d(3, 45, kernel_size=(1, 7, 7), stride=(1, 2, 2), padding=(0, 3, 3), bias=False),
  148. nn.BatchNorm3d(45),
  149. nn.ReLU(inplace=True),
  150. nn.Conv3d(45, 64, kernel_size=(3, 1, 1), stride=(1, 1, 1), padding=(1, 0, 0), bias=False),
  151. nn.BatchNorm3d(64),
  152. nn.ReLU(inplace=True),
  153. )
  154. class VideoResNet(nn.Module):
  155. def __init__(
  156. self,
  157. block: Type[Union[BasicBlock, Bottleneck]],
  158. conv_makers: Sequence[Type[Union[Conv3DSimple, Conv3DNoTemporal, Conv2Plus1D]]],
  159. layers: List[int],
  160. stem: Callable[..., nn.Module],
  161. num_classes: int = 400,
  162. zero_init_residual: bool = False,
  163. ) -> None:
  164. """Generic resnet video generator.
  165. Args:
  166. block (Type[Union[BasicBlock, Bottleneck]]): resnet building block
  167. conv_makers (List[Type[Union[Conv3DSimple, Conv3DNoTemporal, Conv2Plus1D]]]): generator
  168. function for each layer
  169. layers (List[int]): number of blocks per layer
  170. stem (Callable[..., nn.Module]): module specifying the ResNet stem.
  171. num_classes (int, optional): Dimension of the final FC layer. Defaults to 400.
  172. zero_init_residual (bool, optional): Zero init bottleneck residual BN. Defaults to False.
  173. """
  174. super().__init__()
  175. _log_api_usage_once(self)
  176. self.inplanes = 64
  177. self.stem = stem()
  178. self.layer1 = self._make_layer(block, conv_makers[0], 64, layers[0], stride=1)
  179. self.layer2 = self._make_layer(block, conv_makers[1], 128, layers[1], stride=2)
  180. self.layer3 = self._make_layer(block, conv_makers[2], 256, layers[2], stride=2)
  181. self.layer4 = self._make_layer(block, conv_makers[3], 512, layers[3], stride=2)
  182. self.avgpool = nn.AdaptiveAvgPool3d((1, 1, 1))
  183. self.fc = nn.Linear(512 * block.expansion, num_classes)
  184. # init weights
  185. for m in self.modules():
  186. if isinstance(m, nn.Conv3d):
  187. nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
  188. if m.bias is not None:
  189. nn.init.constant_(m.bias, 0)
  190. elif isinstance(m, nn.BatchNorm3d):
  191. nn.init.constant_(m.weight, 1)
  192. nn.init.constant_(m.bias, 0)
  193. elif isinstance(m, nn.Linear):
  194. nn.init.normal_(m.weight, 0, 0.01)
  195. nn.init.constant_(m.bias, 0)
  196. if zero_init_residual:
  197. for m in self.modules():
  198. if isinstance(m, Bottleneck):
  199. nn.init.constant_(m.bn3.weight, 0) # type: ignore[union-attr, arg-type]
  200. def forward(self, x: Tensor) -> Tensor:
  201. x = self.stem(x)
  202. x = self.layer1(x)
  203. x = self.layer2(x)
  204. x = self.layer3(x)
  205. x = self.layer4(x)
  206. x = self.avgpool(x)
  207. # Flatten the layer to fc
  208. x = x.flatten(1)
  209. x = self.fc(x)
  210. return x
  211. def _make_layer(
  212. self,
  213. block: Type[Union[BasicBlock, Bottleneck]],
  214. conv_builder: Type[Union[Conv3DSimple, Conv3DNoTemporal, Conv2Plus1D]],
  215. planes: int,
  216. blocks: int,
  217. stride: int = 1,
  218. ) -> nn.Sequential:
  219. downsample = None
  220. if stride != 1 or self.inplanes != planes * block.expansion:
  221. ds_stride = conv_builder.get_downsample_stride(stride)
  222. downsample = nn.Sequential(
  223. nn.Conv3d(self.inplanes, planes * block.expansion, kernel_size=1, stride=ds_stride, bias=False),
  224. nn.BatchNorm3d(planes * block.expansion),
  225. )
  226. layers = []
  227. layers.append(block(self.inplanes, planes, conv_builder, stride, downsample))
  228. self.inplanes = planes * block.expansion
  229. for i in range(1, blocks):
  230. layers.append(block(self.inplanes, planes, conv_builder))
  231. return nn.Sequential(*layers)
  232. def _video_resnet(
  233. block: Type[Union[BasicBlock, Bottleneck]],
  234. conv_makers: Sequence[Type[Union[Conv3DSimple, Conv3DNoTemporal, Conv2Plus1D]]],
  235. layers: List[int],
  236. stem: Callable[..., nn.Module],
  237. weights: Optional[WeightsEnum],
  238. progress: bool,
  239. **kwargs: Any,
  240. ) -> VideoResNet:
  241. if weights is not None:
  242. _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
  243. model = VideoResNet(block, conv_makers, layers, stem, **kwargs)
  244. if weights is not None:
  245. model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
  246. return model
  247. _COMMON_META = {
  248. "min_size": (1, 1),
  249. "categories": _KINETICS400_CATEGORIES,
  250. "recipe": "https://github.com/pytorch/vision/tree/main/references/video_classification",
  251. "_docs": (
  252. "The weights reproduce closely the accuracy of the paper. The accuracies are estimated on video-level "
  253. "with parameters `frame_rate=15`, `clips_per_video=5`, and `clip_len=16`."
  254. ),
  255. }
  256. class R3D_18_Weights(WeightsEnum):
  257. KINETICS400_V1 = Weights(
  258. url="https://download.pytorch.org/models/r3d_18-b3b3357e.pth",
  259. transforms=partial(VideoClassification, crop_size=(112, 112), resize_size=(128, 171)),
  260. meta={
  261. **_COMMON_META,
  262. "num_params": 33371472,
  263. "_metrics": {
  264. "Kinetics-400": {
  265. "acc@1": 63.200,
  266. "acc@5": 83.479,
  267. }
  268. },
  269. "_ops": 40.697,
  270. "_file_size": 127.359,
  271. },
  272. )
  273. DEFAULT = KINETICS400_V1
  274. class MC3_18_Weights(WeightsEnum):
  275. KINETICS400_V1 = Weights(
  276. url="https://download.pytorch.org/models/mc3_18-a90a0ba3.pth",
  277. transforms=partial(VideoClassification, crop_size=(112, 112), resize_size=(128, 171)),
  278. meta={
  279. **_COMMON_META,
  280. "num_params": 11695440,
  281. "_metrics": {
  282. "Kinetics-400": {
  283. "acc@1": 63.960,
  284. "acc@5": 84.130,
  285. }
  286. },
  287. "_ops": 43.343,
  288. "_file_size": 44.672,
  289. },
  290. )
  291. DEFAULT = KINETICS400_V1
  292. class R2Plus1D_18_Weights(WeightsEnum):
  293. KINETICS400_V1 = Weights(
  294. url="https://download.pytorch.org/models/r2plus1d_18-91a641e6.pth",
  295. transforms=partial(VideoClassification, crop_size=(112, 112), resize_size=(128, 171)),
  296. meta={
  297. **_COMMON_META,
  298. "num_params": 31505325,
  299. "_metrics": {
  300. "Kinetics-400": {
  301. "acc@1": 67.463,
  302. "acc@5": 86.175,
  303. }
  304. },
  305. "_ops": 40.519,
  306. "_file_size": 120.318,
  307. },
  308. )
  309. DEFAULT = KINETICS400_V1
  310. @register_model()
  311. @handle_legacy_interface(weights=("pretrained", R3D_18_Weights.KINETICS400_V1))
  312. def r3d_18(*, weights: Optional[R3D_18_Weights] = None, progress: bool = True, **kwargs: Any) -> VideoResNet:
  313. """Construct 18 layer Resnet3D model.
  314. .. betastatus:: video module
  315. Reference: `A Closer Look at Spatiotemporal Convolutions for Action Recognition <https://arxiv.org/abs/1711.11248>`__.
  316. Args:
  317. weights (:class:`~torchvision.models.video.R3D_18_Weights`, optional): The
  318. pretrained weights to use. See
  319. :class:`~torchvision.models.video.R3D_18_Weights`
  320. below for more details, and possible values. By default, no
  321. pre-trained weights are used.
  322. progress (bool): If True, displays a progress bar of the download to stderr. Default is True.
  323. **kwargs: parameters passed to the ``torchvision.models.video.resnet.VideoResNet`` base class.
  324. Please refer to the `source code
  325. <https://github.com/pytorch/vision/blob/main/torchvision/models/video/resnet.py>`_
  326. for more details about this class.
  327. .. autoclass:: torchvision.models.video.R3D_18_Weights
  328. :members:
  329. """
  330. weights = R3D_18_Weights.verify(weights)
  331. return _video_resnet(
  332. BasicBlock,
  333. [Conv3DSimple] * 4,
  334. [2, 2, 2, 2],
  335. BasicStem,
  336. weights,
  337. progress,
  338. **kwargs,
  339. )
  340. @register_model()
  341. @handle_legacy_interface(weights=("pretrained", MC3_18_Weights.KINETICS400_V1))
  342. def mc3_18(*, weights: Optional[MC3_18_Weights] = None, progress: bool = True, **kwargs: Any) -> VideoResNet:
  343. """Construct 18 layer Mixed Convolution network as in
  344. .. betastatus:: video module
  345. Reference: `A Closer Look at Spatiotemporal Convolutions for Action Recognition <https://arxiv.org/abs/1711.11248>`__.
  346. Args:
  347. weights (:class:`~torchvision.models.video.MC3_18_Weights`, optional): The
  348. pretrained weights to use. See
  349. :class:`~torchvision.models.video.MC3_18_Weights`
  350. below for more details, and possible values. By default, no
  351. pre-trained weights are used.
  352. progress (bool): If True, displays a progress bar of the download to stderr. Default is True.
  353. **kwargs: parameters passed to the ``torchvision.models.video.resnet.VideoResNet`` base class.
  354. Please refer to the `source code
  355. <https://github.com/pytorch/vision/blob/main/torchvision/models/video/resnet.py>`_
  356. for more details about this class.
  357. .. autoclass:: torchvision.models.video.MC3_18_Weights
  358. :members:
  359. """
  360. weights = MC3_18_Weights.verify(weights)
  361. return _video_resnet(
  362. BasicBlock,
  363. [Conv3DSimple] + [Conv3DNoTemporal] * 3, # type: ignore[list-item]
  364. [2, 2, 2, 2],
  365. BasicStem,
  366. weights,
  367. progress,
  368. **kwargs,
  369. )
  370. @register_model()
  371. @handle_legacy_interface(weights=("pretrained", R2Plus1D_18_Weights.KINETICS400_V1))
  372. def r2plus1d_18(*, weights: Optional[R2Plus1D_18_Weights] = None, progress: bool = True, **kwargs: Any) -> VideoResNet:
  373. """Construct 18 layer deep R(2+1)D network as in
  374. .. betastatus:: video module
  375. Reference: `A Closer Look at Spatiotemporal Convolutions for Action Recognition <https://arxiv.org/abs/1711.11248>`__.
  376. Args:
  377. weights (:class:`~torchvision.models.video.R2Plus1D_18_Weights`, optional): The
  378. pretrained weights to use. See
  379. :class:`~torchvision.models.video.R2Plus1D_18_Weights`
  380. below for more details, and possible values. By default, no
  381. pre-trained weights are used.
  382. progress (bool): If True, displays a progress bar of the download to stderr. Default is True.
  383. **kwargs: parameters passed to the ``torchvision.models.video.resnet.VideoResNet`` base class.
  384. Please refer to the `source code
  385. <https://github.com/pytorch/vision/blob/main/torchvision/models/video/resnet.py>`_
  386. for more details about this class.
  387. .. autoclass:: torchvision.models.video.R2Plus1D_18_Weights
  388. :members:
  389. """
  390. weights = R2Plus1D_18_Weights.verify(weights)
  391. return _video_resnet(
  392. BasicBlock,
  393. [Conv2Plus1D] * 4,
  394. [2, 2, 2, 2],
  395. R2Plus1dStem,
  396. weights,
  397. progress,
  398. **kwargs,
  399. )
  400. # The dictionary below is internal implementation detail and will be removed in v0.15
  401. from .._utils import _ModelURLs
  402. model_urls = _ModelURLs(
  403. {
  404. "r3d_18": R3D_18_Weights.KINETICS400_V1.url,
  405. "mc3_18": MC3_18_Weights.KINETICS400_V1.url,
  406. "r2plus1d_18": R2Plus1D_18_Weights.KINETICS400_V1.url,
  407. }
  408. )