rpn.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387
  1. from typing import Dict, List, Optional, Tuple
  2. import torch
  3. from torch import nn, Tensor
  4. from torch.nn import functional as F
  5. from torchvision.ops import boxes as box_ops, Conv2dNormActivation
  6. from . import _utils as det_utils
  7. # Import AnchorGenerator to keep compatibility.
  8. from .anchor_utils import AnchorGenerator # noqa: 401
  9. from .image_list import ImageList
  10. class RPNHead(nn.Module):
  11. """
  12. Adds a simple RPN Head with classification and regression heads
  13. Args:
  14. in_channels (int): number of channels of the input feature
  15. num_anchors (int): number of anchors to be predicted
  16. conv_depth (int, optional): number of convolutions
  17. """
  18. _version = 2
  19. def __init__(self, in_channels: int, num_anchors: int, conv_depth=1) -> None:
  20. super().__init__()
  21. convs = []
  22. for _ in range(conv_depth):
  23. convs.append(Conv2dNormActivation(in_channels, in_channels, kernel_size=3, norm_layer=None))
  24. self.conv = nn.Sequential(*convs)
  25. self.cls_logits = nn.Conv2d(in_channels, num_anchors, kernel_size=1, stride=1)
  26. self.bbox_pred = nn.Conv2d(in_channels, num_anchors * 4, kernel_size=1, stride=1)
  27. for layer in self.modules():
  28. if isinstance(layer, nn.Conv2d):
  29. torch.nn.init.normal_(layer.weight, std=0.01) # type: ignore[arg-type]
  30. if layer.bias is not None:
  31. torch.nn.init.constant_(layer.bias, 0) # type: ignore[arg-type]
  32. def _load_from_state_dict(
  33. self,
  34. state_dict,
  35. prefix,
  36. local_metadata,
  37. strict,
  38. missing_keys,
  39. unexpected_keys,
  40. error_msgs,
  41. ):
  42. version = local_metadata.get("version", None)
  43. if version is None or version < 2:
  44. for type in ["weight", "bias"]:
  45. old_key = f"{prefix}conv.{type}"
  46. new_key = f"{prefix}conv.0.0.{type}"
  47. if old_key in state_dict:
  48. state_dict[new_key] = state_dict.pop(old_key)
  49. super()._load_from_state_dict(
  50. state_dict,
  51. prefix,
  52. local_metadata,
  53. strict,
  54. missing_keys,
  55. unexpected_keys,
  56. error_msgs,
  57. )
  58. def forward(self, x: List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]:
  59. logits = []
  60. bbox_reg = []
  61. for feature in x:
  62. t = self.conv(feature)
  63. logits.append(self.cls_logits(t))
  64. bbox_reg.append(self.bbox_pred(t))
  65. return logits, bbox_reg
  66. def permute_and_flatten(layer: Tensor, N: int, A: int, C: int, H: int, W: int) -> Tensor:
  67. layer = layer.view(N, -1, C, H, W)
  68. layer = layer.permute(0, 3, 4, 1, 2)
  69. layer = layer.reshape(N, -1, C)
  70. return layer
  71. def concat_box_prediction_layers(box_cls: List[Tensor], box_regression: List[Tensor]) -> Tuple[Tensor, Tensor]:
  72. box_cls_flattened = []
  73. box_regression_flattened = []
  74. # for each feature level, permute the outputs to make them be in the
  75. # same format as the labels. Note that the labels are computed for
  76. # all feature levels concatenated, so we keep the same representation
  77. # for the objectness and the box_regression
  78. for box_cls_per_level, box_regression_per_level in zip(box_cls, box_regression):
  79. N, AxC, H, W = box_cls_per_level.shape
  80. Ax4 = box_regression_per_level.shape[1]
  81. A = Ax4 // 4
  82. C = AxC // A
  83. box_cls_per_level = permute_and_flatten(box_cls_per_level, N, A, C, H, W)
  84. box_cls_flattened.append(box_cls_per_level)
  85. box_regression_per_level = permute_and_flatten(box_regression_per_level, N, A, 4, H, W)
  86. box_regression_flattened.append(box_regression_per_level)
  87. # concatenate on the first dimension (representing the feature levels), to
  88. # take into account the way the labels were generated (with all feature maps
  89. # being concatenated as well)
  90. box_cls = torch.cat(box_cls_flattened, dim=1).flatten(0, -2)
  91. box_regression = torch.cat(box_regression_flattened, dim=1).reshape(-1, 4)
  92. return box_cls, box_regression
  93. class RegionProposalNetwork(torch.nn.Module):
  94. """
  95. Implements Region Proposal Network (RPN).
  96. Args:
  97. anchor_generator (AnchorGenerator): module that generates the anchors for a set of feature
  98. maps.
  99. head (nn.Module): module that computes the objectness and regression deltas
  100. fg_iou_thresh (float): minimum IoU between the anchor and the GT box so that they can be
  101. considered as positive during training of the RPN.
  102. bg_iou_thresh (float): maximum IoU between the anchor and the GT box so that they can be
  103. considered as negative during training of the RPN.
  104. batch_size_per_image (int): number of anchors that are sampled during training of the RPN
  105. for computing the loss
  106. positive_fraction (float): proportion of positive anchors in a mini-batch during training
  107. of the RPN
  108. pre_nms_top_n (Dict[str, int]): number of proposals to keep before applying NMS. It should
  109. contain two fields: training and testing, to allow for different values depending
  110. on training or evaluation
  111. post_nms_top_n (Dict[str, int]): number of proposals to keep after applying NMS. It should
  112. contain two fields: training and testing, to allow for different values depending
  113. on training or evaluation
  114. nms_thresh (float): NMS threshold used for postprocessing the RPN proposals
  115. """
  116. __annotations__ = {
  117. "box_coder": det_utils.BoxCoder,
  118. "proposal_matcher": det_utils.Matcher,
  119. "fg_bg_sampler": det_utils.BalancedPositiveNegativeSampler,
  120. }
  121. def __init__(
  122. self,
  123. anchor_generator: AnchorGenerator,
  124. head: nn.Module,
  125. # Faster-RCNN Training
  126. fg_iou_thresh: float,
  127. bg_iou_thresh: float,
  128. batch_size_per_image: int,
  129. positive_fraction: float,
  130. # Faster-RCNN Inference
  131. pre_nms_top_n: Dict[str, int],
  132. post_nms_top_n: Dict[str, int],
  133. nms_thresh: float,
  134. score_thresh: float = 0.0,
  135. ) -> None:
  136. super().__init__()
  137. self.anchor_generator = anchor_generator
  138. self.head = head
  139. self.box_coder = det_utils.BoxCoder(weights=(1.0, 1.0, 1.0, 1.0))
  140. # used during training
  141. self.box_similarity = box_ops.box_iou
  142. self.proposal_matcher = det_utils.Matcher(
  143. fg_iou_thresh,
  144. bg_iou_thresh,
  145. allow_low_quality_matches=True,
  146. )
  147. self.fg_bg_sampler = det_utils.BalancedPositiveNegativeSampler(batch_size_per_image, positive_fraction)
  148. # used during testing
  149. self._pre_nms_top_n = pre_nms_top_n
  150. self._post_nms_top_n = post_nms_top_n
  151. self.nms_thresh = nms_thresh
  152. self.score_thresh = score_thresh
  153. self.min_size = 1e-3
  154. def pre_nms_top_n(self) -> int:
  155. if self.training:
  156. return self._pre_nms_top_n["training"]
  157. return self._pre_nms_top_n["testing"]
  158. def post_nms_top_n(self) -> int:
  159. if self.training:
  160. return self._post_nms_top_n["training"]
  161. return self._post_nms_top_n["testing"]
  162. def assign_targets_to_anchors(
  163. self, anchors: List[Tensor], targets: List[Dict[str, Tensor]]
  164. ) -> Tuple[List[Tensor], List[Tensor]]:
  165. labels = []
  166. matched_gt_boxes = []
  167. for anchors_per_image, targets_per_image in zip(anchors, targets):
  168. gt_boxes = targets_per_image["boxes"]
  169. if gt_boxes.numel() == 0:
  170. # Background image (negative example)
  171. device = anchors_per_image.device
  172. matched_gt_boxes_per_image = torch.zeros(anchors_per_image.shape, dtype=torch.float32, device=device)
  173. labels_per_image = torch.zeros((anchors_per_image.shape[0],), dtype=torch.float32, device=device)
  174. else:
  175. match_quality_matrix = self.box_similarity(gt_boxes, anchors_per_image)
  176. matched_idxs = self.proposal_matcher(match_quality_matrix)
  177. # get the targets corresponding GT for each proposal
  178. # NB: need to clamp the indices because we can have a single
  179. # GT in the image, and matched_idxs can be -2, which goes
  180. # out of bounds
  181. matched_gt_boxes_per_image = gt_boxes[matched_idxs.clamp(min=0)]
  182. labels_per_image = matched_idxs >= 0
  183. labels_per_image = labels_per_image.to(dtype=torch.float32)
  184. # Background (negative examples)
  185. bg_indices = matched_idxs == self.proposal_matcher.BELOW_LOW_THRESHOLD
  186. labels_per_image[bg_indices] = 0.0
  187. # discard indices that are between thresholds
  188. inds_to_discard = matched_idxs == self.proposal_matcher.BETWEEN_THRESHOLDS
  189. labels_per_image[inds_to_discard] = -1.0
  190. labels.append(labels_per_image)
  191. matched_gt_boxes.append(matched_gt_boxes_per_image)
  192. return labels, matched_gt_boxes
  193. def _get_top_n_idx(self, objectness: Tensor, num_anchors_per_level: List[int]) -> Tensor:
  194. r = []
  195. offset = 0
  196. for ob in objectness.split(num_anchors_per_level, 1):
  197. num_anchors = ob.shape[1]
  198. pre_nms_top_n = det_utils._topk_min(ob, self.pre_nms_top_n(), 1)
  199. _, top_n_idx = ob.topk(pre_nms_top_n, dim=1)
  200. r.append(top_n_idx + offset)
  201. offset += num_anchors
  202. return torch.cat(r, dim=1)
  203. def filter_proposals(
  204. self,
  205. proposals: Tensor,
  206. objectness: Tensor,
  207. image_shapes: List[Tuple[int, int]],
  208. num_anchors_per_level: List[int],
  209. ) -> Tuple[List[Tensor], List[Tensor]]:
  210. num_images = proposals.shape[0]
  211. device = proposals.device
  212. # do not backprop through objectness
  213. objectness = objectness.detach()
  214. objectness = objectness.reshape(num_images, -1)
  215. levels = [
  216. torch.full((n,), idx, dtype=torch.int64, device=device) for idx, n in enumerate(num_anchors_per_level)
  217. ]
  218. levels = torch.cat(levels, 0)
  219. levels = levels.reshape(1, -1).expand_as(objectness)
  220. # select top_n boxes independently per level before applying nms
  221. top_n_idx = self._get_top_n_idx(objectness, num_anchors_per_level)
  222. image_range = torch.arange(num_images, device=device)
  223. batch_idx = image_range[:, None]
  224. objectness = objectness[batch_idx, top_n_idx]
  225. levels = levels[batch_idx, top_n_idx]
  226. proposals = proposals[batch_idx, top_n_idx]
  227. objectness_prob = torch.sigmoid(objectness)
  228. final_boxes = []
  229. final_scores = []
  230. for boxes, scores, lvl, img_shape in zip(proposals, objectness_prob, levels, image_shapes):
  231. boxes = box_ops.clip_boxes_to_image(boxes, img_shape)
  232. # remove small boxes
  233. keep = box_ops.remove_small_boxes(boxes, self.min_size)
  234. boxes, scores, lvl = boxes[keep], scores[keep], lvl[keep]
  235. # remove low scoring boxes
  236. # use >= for Backwards compatibility
  237. keep = torch.where(scores >= self.score_thresh)[0]
  238. boxes, scores, lvl = boxes[keep], scores[keep], lvl[keep]
  239. # non-maximum suppression, independently done per level
  240. keep = box_ops.batched_nms(boxes, scores, lvl, self.nms_thresh)
  241. # keep only topk scoring predictions
  242. keep = keep[: self.post_nms_top_n()]
  243. boxes, scores = boxes[keep], scores[keep]
  244. final_boxes.append(boxes)
  245. final_scores.append(scores)
  246. return final_boxes, final_scores
  247. def compute_loss(
  248. self, objectness: Tensor, pred_bbox_deltas: Tensor, labels: List[Tensor], regression_targets: List[Tensor]
  249. ) -> Tuple[Tensor, Tensor]:
  250. """
  251. Args:
  252. objectness (Tensor)
  253. pred_bbox_deltas (Tensor)
  254. labels (List[Tensor])
  255. regression_targets (List[Tensor])
  256. Returns:
  257. objectness_loss (Tensor)
  258. box_loss (Tensor)
  259. """
  260. sampled_pos_inds, sampled_neg_inds = self.fg_bg_sampler(labels)
  261. sampled_pos_inds = torch.where(torch.cat(sampled_pos_inds, dim=0))[0]
  262. sampled_neg_inds = torch.where(torch.cat(sampled_neg_inds, dim=0))[0]
  263. sampled_inds = torch.cat([sampled_pos_inds, sampled_neg_inds], dim=0)
  264. objectness = objectness.flatten()
  265. labels = torch.cat(labels, dim=0)
  266. regression_targets = torch.cat(regression_targets, dim=0)
  267. box_loss = F.smooth_l1_loss(
  268. pred_bbox_deltas[sampled_pos_inds],
  269. regression_targets[sampled_pos_inds],
  270. beta=1 / 9,
  271. reduction="sum",
  272. ) / (sampled_inds.numel())
  273. objectness_loss = F.binary_cross_entropy_with_logits(objectness[sampled_inds], labels[sampled_inds])
  274. return objectness_loss, box_loss
  275. def forward(
  276. self,
  277. images: ImageList,
  278. features: Dict[str, Tensor],
  279. targets: Optional[List[Dict[str, Tensor]]] = None,
  280. ) -> Tuple[List[Tensor], Dict[str, Tensor]]:
  281. """
  282. Args:
  283. images (ImageList): images for which we want to compute the predictions
  284. features (Dict[str, Tensor]): features computed from the images that are
  285. used for computing the predictions. Each tensor in the list
  286. correspond to different feature levels
  287. targets (List[Dict[str, Tensor]]): ground-truth boxes present in the image (optional).
  288. If provided, each element in the dict should contain a field `boxes`,
  289. with the locations of the ground-truth boxes.
  290. Returns:
  291. boxes (List[Tensor]): the predicted boxes from the RPN, one Tensor per
  292. image.
  293. losses (Dict[str, Tensor]): the losses for the model during training. During
  294. testing, it is an empty dict.
  295. """
  296. # RPN uses all feature maps that are available
  297. features = list(features.values())
  298. objectness, pred_bbox_deltas = self.head(features)
  299. anchors = self.anchor_generator(images, features)
  300. num_images = len(anchors)
  301. num_anchors_per_level_shape_tensors = [o[0].shape for o in objectness]
  302. num_anchors_per_level = [s[0] * s[1] * s[2] for s in num_anchors_per_level_shape_tensors]
  303. objectness, pred_bbox_deltas = concat_box_prediction_layers(objectness, pred_bbox_deltas)
  304. # apply pred_bbox_deltas to anchors to obtain the decoded proposals
  305. # note that we detach the deltas because Faster R-CNN do not backprop through
  306. # the proposals
  307. proposals = self.box_coder.decode(pred_bbox_deltas.detach(), anchors)
  308. proposals = proposals.view(num_images, -1, 4)
  309. boxes, scores = self.filter_proposals(proposals, objectness, images.image_sizes, num_anchors_per_level)
  310. losses = {}
  311. if self.training:
  312. if targets is None:
  313. raise ValueError("targets should not be None")
  314. labels, matched_gt_boxes = self.assign_targets_to_anchors(anchors, targets)
  315. regression_targets = self.box_coder.encode(matched_gt_boxes, anchors)
  316. loss_objectness, loss_rpn_box_reg = self.compute_loss(
  317. objectness, pred_bbox_deltas, labels, regression_targets
  318. )
  319. losses = {
  320. "loss_objectness": loss_objectness,
  321. "loss_rpn_box_reg": loss_rpn_box_reg,
  322. }
  323. return boxes, losses