keypoint_rcnn.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472
  1. from typing import Any, Optional
  2. import torch
  3. from torch import nn
  4. from torchvision.ops import MultiScaleRoIAlign
  5. from ...ops import misc as misc_nn_ops
  6. from ...transforms._presets import ObjectDetection
  7. from .._api import register_model, Weights, WeightsEnum
  8. from .._meta import _COCO_PERSON_CATEGORIES, _COCO_PERSON_KEYPOINT_NAMES
  9. from .._utils import _ovewrite_value_param, handle_legacy_interface
  10. from ..resnet import resnet50, ResNet50_Weights
  11. from ._utils import overwrite_eps
  12. from .backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers
  13. from .faster_rcnn import FasterRCNN
  14. __all__ = [
  15. "KeypointRCNN",
  16. "KeypointRCNN_ResNet50_FPN_Weights",
  17. "keypointrcnn_resnet50_fpn",
  18. ]
  19. class KeypointRCNN(FasterRCNN):
  20. """
  21. Implements Keypoint R-CNN.
  22. The input to the model is expected to be a list of tensors, each of shape [C, H, W], one for each
  23. image, and should be in 0-1 range. Different images can have different sizes.
  24. The behavior of the model changes depending on if it is in training or evaluation mode.
  25. During training, the model expects both the input tensors and targets (list of dictionary),
  26. containing:
  27. - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with
  28. ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
  29. - labels (Int64Tensor[N]): the class label for each ground-truth box
  30. - keypoints (FloatTensor[N, K, 3]): the K keypoints location for each of the N instances, in the
  31. format [x, y, visibility], where visibility=0 means that the keypoint is not visible.
  32. The model returns a Dict[Tensor] during training, containing the classification and regression
  33. losses for both the RPN and the R-CNN, and the keypoint loss.
  34. During inference, the model requires only the input tensors, and returns the post-processed
  35. predictions as a List[Dict[Tensor]], one for each input image. The fields of the Dict are as
  36. follows:
  37. - boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with
  38. ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
  39. - labels (Int64Tensor[N]): the predicted labels for each image
  40. - scores (Tensor[N]): the scores or each prediction
  41. - keypoints (FloatTensor[N, K, 3]): the locations of the predicted keypoints, in [x, y, v] format.
  42. Args:
  43. backbone (nn.Module): the network used to compute the features for the model.
  44. It should contain an out_channels attribute, which indicates the number of output
  45. channels that each feature map has (and it should be the same for all feature maps).
  46. The backbone should return a single Tensor or and OrderedDict[Tensor].
  47. num_classes (int): number of output classes of the model (including the background).
  48. If box_predictor is specified, num_classes should be None.
  49. min_size (int): minimum size of the image to be rescaled before feeding it to the backbone
  50. max_size (int): maximum size of the image to be rescaled before feeding it to the backbone
  51. image_mean (Tuple[float, float, float]): mean values used for input normalization.
  52. They are generally the mean values of the dataset on which the backbone has been trained
  53. on
  54. image_std (Tuple[float, float, float]): std values used for input normalization.
  55. They are generally the std values of the dataset on which the backbone has been trained on
  56. rpn_anchor_generator (AnchorGenerator): module that generates the anchors for a set of feature
  57. maps.
  58. rpn_head (nn.Module): module that computes the objectness and regression deltas from the RPN
  59. rpn_pre_nms_top_n_train (int): number of proposals to keep before applying NMS during training
  60. rpn_pre_nms_top_n_test (int): number of proposals to keep before applying NMS during testing
  61. rpn_post_nms_top_n_train (int): number of proposals to keep after applying NMS during training
  62. rpn_post_nms_top_n_test (int): number of proposals to keep after applying NMS during testing
  63. rpn_nms_thresh (float): NMS threshold used for postprocessing the RPN proposals
  64. rpn_fg_iou_thresh (float): minimum IoU between the anchor and the GT box so that they can be
  65. considered as positive during training of the RPN.
  66. rpn_bg_iou_thresh (float): maximum IoU between the anchor and the GT box so that they can be
  67. considered as negative during training of the RPN.
  68. rpn_batch_size_per_image (int): number of anchors that are sampled during training of the RPN
  69. for computing the loss
  70. rpn_positive_fraction (float): proportion of positive anchors in a mini-batch during training
  71. of the RPN
  72. rpn_score_thresh (float): during inference, only return proposals with a classification score
  73. greater than rpn_score_thresh
  74. box_roi_pool (MultiScaleRoIAlign): the module which crops and resizes the feature maps in
  75. the locations indicated by the bounding boxes
  76. box_head (nn.Module): module that takes the cropped feature maps as input
  77. box_predictor (nn.Module): module that takes the output of box_head and returns the
  78. classification logits and box regression deltas.
  79. box_score_thresh (float): during inference, only return proposals with a classification score
  80. greater than box_score_thresh
  81. box_nms_thresh (float): NMS threshold for the prediction head. Used during inference
  82. box_detections_per_img (int): maximum number of detections per image, for all classes.
  83. box_fg_iou_thresh (float): minimum IoU between the proposals and the GT box so that they can be
  84. considered as positive during training of the classification head
  85. box_bg_iou_thresh (float): maximum IoU between the proposals and the GT box so that they can be
  86. considered as negative during training of the classification head
  87. box_batch_size_per_image (int): number of proposals that are sampled during training of the
  88. classification head
  89. box_positive_fraction (float): proportion of positive proposals in a mini-batch during training
  90. of the classification head
  91. bbox_reg_weights (Tuple[float, float, float, float]): weights for the encoding/decoding of the
  92. bounding boxes
  93. keypoint_roi_pool (MultiScaleRoIAlign): the module which crops and resizes the feature maps in
  94. the locations indicated by the bounding boxes, which will be used for the keypoint head.
  95. keypoint_head (nn.Module): module that takes the cropped feature maps as input
  96. keypoint_predictor (nn.Module): module that takes the output of the keypoint_head and returns the
  97. heatmap logits
  98. Example::
  99. >>> import torch
  100. >>> import torchvision
  101. >>> from torchvision.models.detection import KeypointRCNN
  102. >>> from torchvision.models.detection.anchor_utils import AnchorGenerator
  103. >>>
  104. >>> # load a pre-trained model for classification and return
  105. >>> # only the features
  106. >>> backbone = torchvision.models.mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT).features
  107. >>> # KeypointRCNN needs to know the number of
  108. >>> # output channels in a backbone. For mobilenet_v2, it's 1280,
  109. >>> # so we need to add it here
  110. >>> backbone.out_channels = 1280
  111. >>>
  112. >>> # let's make the RPN generate 5 x 3 anchors per spatial
  113. >>> # location, with 5 different sizes and 3 different aspect
  114. >>> # ratios. We have a Tuple[Tuple[int]] because each feature
  115. >>> # map could potentially have different sizes and
  116. >>> # aspect ratios
  117. >>> anchor_generator = AnchorGenerator(sizes=((32, 64, 128, 256, 512),),
  118. >>> aspect_ratios=((0.5, 1.0, 2.0),))
  119. >>>
  120. >>> # let's define what are the feature maps that we will
  121. >>> # use to perform the region of interest cropping, as well as
  122. >>> # the size of the crop after rescaling.
  123. >>> # if your backbone returns a Tensor, featmap_names is expected to
  124. >>> # be ['0']. More generally, the backbone should return an
  125. >>> # OrderedDict[Tensor], and in featmap_names you can choose which
  126. >>> # feature maps to use.
  127. >>> roi_pooler = torchvision.ops.MultiScaleRoIAlign(featmap_names=['0'],
  128. >>> output_size=7,
  129. >>> sampling_ratio=2)
  130. >>>
  131. >>> keypoint_roi_pooler = torchvision.ops.MultiScaleRoIAlign(featmap_names=['0'],
  132. >>> output_size=14,
  133. >>> sampling_ratio=2)
  134. >>> # put the pieces together inside a KeypointRCNN model
  135. >>> model = KeypointRCNN(backbone,
  136. >>> num_classes=2,
  137. >>> rpn_anchor_generator=anchor_generator,
  138. >>> box_roi_pool=roi_pooler,
  139. >>> keypoint_roi_pool=keypoint_roi_pooler)
  140. >>> model.eval()
  141. >>> model.eval()
  142. >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
  143. >>> predictions = model(x)
  144. """
  145. def __init__(
  146. self,
  147. backbone,
  148. num_classes=None,
  149. # transform parameters
  150. min_size=None,
  151. max_size=1333,
  152. image_mean=None,
  153. image_std=None,
  154. # RPN parameters
  155. rpn_anchor_generator=None,
  156. rpn_head=None,
  157. rpn_pre_nms_top_n_train=2000,
  158. rpn_pre_nms_top_n_test=1000,
  159. rpn_post_nms_top_n_train=2000,
  160. rpn_post_nms_top_n_test=1000,
  161. rpn_nms_thresh=0.7,
  162. rpn_fg_iou_thresh=0.7,
  163. rpn_bg_iou_thresh=0.3,
  164. rpn_batch_size_per_image=256,
  165. rpn_positive_fraction=0.5,
  166. rpn_score_thresh=0.0,
  167. # Box parameters
  168. box_roi_pool=None,
  169. box_head=None,
  170. box_predictor=None,
  171. box_score_thresh=0.05,
  172. box_nms_thresh=0.5,
  173. box_detections_per_img=100,
  174. box_fg_iou_thresh=0.5,
  175. box_bg_iou_thresh=0.5,
  176. box_batch_size_per_image=512,
  177. box_positive_fraction=0.25,
  178. bbox_reg_weights=None,
  179. # keypoint parameters
  180. keypoint_roi_pool=None,
  181. keypoint_head=None,
  182. keypoint_predictor=None,
  183. num_keypoints=None,
  184. **kwargs,
  185. ):
  186. if not isinstance(keypoint_roi_pool, (MultiScaleRoIAlign, type(None))):
  187. raise TypeError(
  188. "keypoint_roi_pool should be of type MultiScaleRoIAlign or None instead of {type(keypoint_roi_pool)}"
  189. )
  190. if min_size is None:
  191. min_size = (640, 672, 704, 736, 768, 800)
  192. if num_keypoints is not None:
  193. if keypoint_predictor is not None:
  194. raise ValueError("num_keypoints should be None when keypoint_predictor is specified")
  195. else:
  196. num_keypoints = 17
  197. out_channels = backbone.out_channels
  198. if keypoint_roi_pool is None:
  199. keypoint_roi_pool = MultiScaleRoIAlign(featmap_names=["0", "1", "2", "3"], output_size=14, sampling_ratio=2)
  200. if keypoint_head is None:
  201. keypoint_layers = tuple(512 for _ in range(8))
  202. keypoint_head = KeypointRCNNHeads(out_channels, keypoint_layers)
  203. if keypoint_predictor is None:
  204. keypoint_dim_reduced = 512 # == keypoint_layers[-1]
  205. keypoint_predictor = KeypointRCNNPredictor(keypoint_dim_reduced, num_keypoints)
  206. super().__init__(
  207. backbone,
  208. num_classes,
  209. # transform parameters
  210. min_size,
  211. max_size,
  212. image_mean,
  213. image_std,
  214. # RPN-specific parameters
  215. rpn_anchor_generator,
  216. rpn_head,
  217. rpn_pre_nms_top_n_train,
  218. rpn_pre_nms_top_n_test,
  219. rpn_post_nms_top_n_train,
  220. rpn_post_nms_top_n_test,
  221. rpn_nms_thresh,
  222. rpn_fg_iou_thresh,
  223. rpn_bg_iou_thresh,
  224. rpn_batch_size_per_image,
  225. rpn_positive_fraction,
  226. rpn_score_thresh,
  227. # Box parameters
  228. box_roi_pool,
  229. box_head,
  230. box_predictor,
  231. box_score_thresh,
  232. box_nms_thresh,
  233. box_detections_per_img,
  234. box_fg_iou_thresh,
  235. box_bg_iou_thresh,
  236. box_batch_size_per_image,
  237. box_positive_fraction,
  238. bbox_reg_weights,
  239. **kwargs,
  240. )
  241. self.roi_heads.keypoint_roi_pool = keypoint_roi_pool
  242. self.roi_heads.keypoint_head = keypoint_head
  243. self.roi_heads.keypoint_predictor = keypoint_predictor
  244. class KeypointRCNNHeads(nn.Sequential):
  245. def __init__(self, in_channels, layers):
  246. d = []
  247. next_feature = in_channels
  248. for out_channels in layers:
  249. d.append(nn.Conv2d(next_feature, out_channels, 3, stride=1, padding=1))
  250. d.append(nn.ReLU(inplace=True))
  251. next_feature = out_channels
  252. super().__init__(*d)
  253. for m in self.children():
  254. if isinstance(m, nn.Conv2d):
  255. nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
  256. nn.init.constant_(m.bias, 0)
  257. class KeypointRCNNPredictor(nn.Module):
  258. def __init__(self, in_channels, num_keypoints):
  259. super().__init__()
  260. input_features = in_channels
  261. deconv_kernel = 4
  262. self.kps_score_lowres = nn.ConvTranspose2d(
  263. input_features,
  264. num_keypoints,
  265. deconv_kernel,
  266. stride=2,
  267. padding=deconv_kernel // 2 - 1,
  268. )
  269. nn.init.kaiming_normal_(self.kps_score_lowres.weight, mode="fan_out", nonlinearity="relu")
  270. nn.init.constant_(self.kps_score_lowres.bias, 0)
  271. self.up_scale = 2
  272. self.out_channels = num_keypoints
  273. def forward(self, x):
  274. x = self.kps_score_lowres(x)
  275. return torch.nn.functional.interpolate(
  276. x, scale_factor=float(self.up_scale), mode="bilinear", align_corners=False, recompute_scale_factor=False
  277. )
  278. _COMMON_META = {
  279. "categories": _COCO_PERSON_CATEGORIES,
  280. "keypoint_names": _COCO_PERSON_KEYPOINT_NAMES,
  281. "min_size": (1, 1),
  282. }
  283. class KeypointRCNN_ResNet50_FPN_Weights(WeightsEnum):
  284. COCO_LEGACY = Weights(
  285. url="https://download.pytorch.org/models/keypointrcnn_resnet50_fpn_coco-9f466800.pth",
  286. transforms=ObjectDetection,
  287. meta={
  288. **_COMMON_META,
  289. "num_params": 59137258,
  290. "recipe": "https://github.com/pytorch/vision/issues/1606",
  291. "_metrics": {
  292. "COCO-val2017": {
  293. "box_map": 50.6,
  294. "kp_map": 61.1,
  295. }
  296. },
  297. "_ops": 133.924,
  298. "_file_size": 226.054,
  299. "_docs": """
  300. These weights were produced by following a similar training recipe as on the paper but use a checkpoint
  301. from an early epoch.
  302. """,
  303. },
  304. )
  305. COCO_V1 = Weights(
  306. url="https://download.pytorch.org/models/keypointrcnn_resnet50_fpn_coco-fc266e95.pth",
  307. transforms=ObjectDetection,
  308. meta={
  309. **_COMMON_META,
  310. "num_params": 59137258,
  311. "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#keypoint-r-cnn",
  312. "_metrics": {
  313. "COCO-val2017": {
  314. "box_map": 54.6,
  315. "kp_map": 65.0,
  316. }
  317. },
  318. "_ops": 137.42,
  319. "_file_size": 226.054,
  320. "_docs": """These weights were produced by following a similar training recipe as on the paper.""",
  321. },
  322. )
  323. DEFAULT = COCO_V1
  324. @register_model()
  325. @handle_legacy_interface(
  326. weights=(
  327. "pretrained",
  328. lambda kwargs: KeypointRCNN_ResNet50_FPN_Weights.COCO_LEGACY
  329. if kwargs["pretrained"] == "legacy"
  330. else KeypointRCNN_ResNet50_FPN_Weights.COCO_V1,
  331. ),
  332. weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1),
  333. )
  334. def keypointrcnn_resnet50_fpn(
  335. *,
  336. weights: Optional[KeypointRCNN_ResNet50_FPN_Weights] = None,
  337. progress: bool = True,
  338. num_classes: Optional[int] = None,
  339. num_keypoints: Optional[int] = None,
  340. weights_backbone: Optional[ResNet50_Weights] = ResNet50_Weights.IMAGENET1K_V1,
  341. trainable_backbone_layers: Optional[int] = None,
  342. **kwargs: Any,
  343. ) -> KeypointRCNN:
  344. """
  345. Constructs a Keypoint R-CNN model with a ResNet-50-FPN backbone.
  346. .. betastatus:: detection module
  347. Reference: `Mask R-CNN <https://arxiv.org/abs/1703.06870>`__.
  348. The input to the model is expected to be a list of tensors, each of shape ``[C, H, W]``, one for each
  349. image, and should be in ``0-1`` range. Different images can have different sizes.
  350. The behavior of the model changes depending on if it is in training or evaluation mode.
  351. During training, the model expects both the input tensors and targets (list of dictionary),
  352. containing:
  353. - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with
  354. ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
  355. - labels (``Int64Tensor[N]``): the class label for each ground-truth box
  356. - keypoints (``FloatTensor[N, K, 3]``): the ``K`` keypoints location for each of the ``N`` instances, in the
  357. format ``[x, y, visibility]``, where ``visibility=0`` means that the keypoint is not visible.
  358. The model returns a ``Dict[Tensor]`` during training, containing the classification and regression
  359. losses for both the RPN and the R-CNN, and the keypoint loss.
  360. During inference, the model requires only the input tensors, and returns the post-processed
  361. predictions as a ``List[Dict[Tensor]]``, one for each input image. The fields of the ``Dict`` are as
  362. follows, where ``N`` is the number of detected instances:
  363. - boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with
  364. ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
  365. - labels (``Int64Tensor[N]``): the predicted labels for each instance
  366. - scores (``Tensor[N]``): the scores or each instance
  367. - keypoints (``FloatTensor[N, K, 3]``): the locations of the predicted keypoints, in ``[x, y, v]`` format.
  368. For more details on the output, you may refer to :ref:`instance_seg_output`.
  369. Keypoint R-CNN is exportable to ONNX for a fixed batch size with inputs images of fixed size.
  370. Example::
  371. >>> model = torchvision.models.detection.keypointrcnn_resnet50_fpn(weights=KeypointRCNN_ResNet50_FPN_Weights.DEFAULT)
  372. >>> model.eval()
  373. >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
  374. >>> predictions = model(x)
  375. >>>
  376. >>> # optionally, if you want to export the model to ONNX:
  377. >>> torch.onnx.export(model, x, "keypoint_rcnn.onnx", opset_version = 11)
  378. Args:
  379. weights (:class:`~torchvision.models.detection.KeypointRCNN_ResNet50_FPN_Weights`, optional): The
  380. pretrained weights to use. See
  381. :class:`~torchvision.models.detection.KeypointRCNN_ResNet50_FPN_Weights`
  382. below for more details, and possible values. By default, no
  383. pre-trained weights are used.
  384. progress (bool): If True, displays a progress bar of the download to stderr
  385. num_classes (int, optional): number of output classes of the model (including the background)
  386. num_keypoints (int, optional): number of keypoints
  387. weights_backbone (:class:`~torchvision.models.ResNet50_Weights`, optional): The
  388. pretrained weights for the backbone.
  389. trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from final block.
  390. Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. If ``None`` is
  391. passed (the default) this value is set to 3.
  392. .. autoclass:: torchvision.models.detection.KeypointRCNN_ResNet50_FPN_Weights
  393. :members:
  394. """
  395. weights = KeypointRCNN_ResNet50_FPN_Weights.verify(weights)
  396. weights_backbone = ResNet50_Weights.verify(weights_backbone)
  397. if weights is not None:
  398. weights_backbone = None
  399. num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"]))
  400. num_keypoints = _ovewrite_value_param("num_keypoints", num_keypoints, len(weights.meta["keypoint_names"]))
  401. else:
  402. if num_classes is None:
  403. num_classes = 2
  404. if num_keypoints is None:
  405. num_keypoints = 17
  406. is_trained = weights is not None or weights_backbone is not None
  407. trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)
  408. norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
  409. backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=norm_layer)
  410. backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers)
  411. model = KeypointRCNN(backbone, num_classes, num_keypoints=num_keypoints, **kwargs)
  412. if weights is not None:
  413. model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
  414. if weights == KeypointRCNN_ResNet50_FPN_Weights.COCO_V1:
  415. overwrite_eps(model, 0.0)
  416. return model