fcos.py 33 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771
  1. import math
  2. import warnings
  3. from collections import OrderedDict
  4. from functools import partial
  5. from typing import Any, Callable, Dict, List, Optional, Tuple
  6. import torch
  7. from torch import nn, Tensor
  8. from ...ops import boxes as box_ops, generalized_box_iou_loss, misc as misc_nn_ops, sigmoid_focal_loss
  9. from ...ops.feature_pyramid_network import LastLevelP6P7
  10. from ...transforms._presets import ObjectDetection
  11. from ...utils import _log_api_usage_once
  12. from .._api import register_model, Weights, WeightsEnum
  13. from .._meta import _COCO_CATEGORIES
  14. from .._utils import _ovewrite_value_param, handle_legacy_interface
  15. from ..resnet import resnet50, ResNet50_Weights
  16. from . import _utils as det_utils
  17. from .anchor_utils import AnchorGenerator
  18. from .backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers
  19. from .transform import GeneralizedRCNNTransform
  20. __all__ = [
  21. "FCOS",
  22. "FCOS_ResNet50_FPN_Weights",
  23. "fcos_resnet50_fpn",
  24. ]
  25. class FCOSHead(nn.Module):
  26. """
  27. A regression and classification head for use in FCOS.
  28. Args:
  29. in_channels (int): number of channels of the input feature
  30. num_anchors (int): number of anchors to be predicted
  31. num_classes (int): number of classes to be predicted
  32. num_convs (Optional[int]): number of conv layer of head. Default: 4.
  33. """
  34. __annotations__ = {
  35. "box_coder": det_utils.BoxLinearCoder,
  36. }
  37. def __init__(self, in_channels: int, num_anchors: int, num_classes: int, num_convs: Optional[int] = 4) -> None:
  38. super().__init__()
  39. self.box_coder = det_utils.BoxLinearCoder(normalize_by_size=True)
  40. self.classification_head = FCOSClassificationHead(in_channels, num_anchors, num_classes, num_convs)
  41. self.regression_head = FCOSRegressionHead(in_channels, num_anchors, num_convs)
  42. def compute_loss(
  43. self,
  44. targets: List[Dict[str, Tensor]],
  45. head_outputs: Dict[str, Tensor],
  46. anchors: List[Tensor],
  47. matched_idxs: List[Tensor],
  48. ) -> Dict[str, Tensor]:
  49. cls_logits = head_outputs["cls_logits"] # [N, HWA, C]
  50. bbox_regression = head_outputs["bbox_regression"] # [N, HWA, 4]
  51. bbox_ctrness = head_outputs["bbox_ctrness"] # [N, HWA, 1]
  52. all_gt_classes_targets = []
  53. all_gt_boxes_targets = []
  54. for targets_per_image, matched_idxs_per_image in zip(targets, matched_idxs):
  55. if len(targets_per_image["labels"]) == 0:
  56. gt_classes_targets = targets_per_image["labels"].new_zeros((len(matched_idxs_per_image),))
  57. gt_boxes_targets = targets_per_image["boxes"].new_zeros((len(matched_idxs_per_image), 4))
  58. else:
  59. gt_classes_targets = targets_per_image["labels"][matched_idxs_per_image.clip(min=0)]
  60. gt_boxes_targets = targets_per_image["boxes"][matched_idxs_per_image.clip(min=0)]
  61. gt_classes_targets[matched_idxs_per_image < 0] = -1 # background
  62. all_gt_classes_targets.append(gt_classes_targets)
  63. all_gt_boxes_targets.append(gt_boxes_targets)
  64. # List[Tensor] to Tensor conversion of `all_gt_boxes_target`, `all_gt_classes_targets` and `anchors`
  65. all_gt_boxes_targets, all_gt_classes_targets, anchors = (
  66. torch.stack(all_gt_boxes_targets),
  67. torch.stack(all_gt_classes_targets),
  68. torch.stack(anchors),
  69. )
  70. # compute foregroud
  71. foregroud_mask = all_gt_classes_targets >= 0
  72. num_foreground = foregroud_mask.sum().item()
  73. # classification loss
  74. gt_classes_targets = torch.zeros_like(cls_logits)
  75. gt_classes_targets[foregroud_mask, all_gt_classes_targets[foregroud_mask]] = 1.0
  76. loss_cls = sigmoid_focal_loss(cls_logits, gt_classes_targets, reduction="sum")
  77. # amp issue: pred_boxes need to convert float
  78. pred_boxes = self.box_coder.decode(bbox_regression, anchors)
  79. # regression loss: GIoU loss
  80. loss_bbox_reg = generalized_box_iou_loss(
  81. pred_boxes[foregroud_mask],
  82. all_gt_boxes_targets[foregroud_mask],
  83. reduction="sum",
  84. )
  85. # ctrness loss
  86. bbox_reg_targets = self.box_coder.encode(anchors, all_gt_boxes_targets)
  87. if len(bbox_reg_targets) == 0:
  88. gt_ctrness_targets = bbox_reg_targets.new_zeros(bbox_reg_targets.size()[:-1])
  89. else:
  90. left_right = bbox_reg_targets[:, :, [0, 2]]
  91. top_bottom = bbox_reg_targets[:, :, [1, 3]]
  92. gt_ctrness_targets = torch.sqrt(
  93. (left_right.min(dim=-1)[0] / left_right.max(dim=-1)[0])
  94. * (top_bottom.min(dim=-1)[0] / top_bottom.max(dim=-1)[0])
  95. )
  96. pred_centerness = bbox_ctrness.squeeze(dim=2)
  97. loss_bbox_ctrness = nn.functional.binary_cross_entropy_with_logits(
  98. pred_centerness[foregroud_mask], gt_ctrness_targets[foregroud_mask], reduction="sum"
  99. )
  100. return {
  101. "classification": loss_cls / max(1, num_foreground),
  102. "bbox_regression": loss_bbox_reg / max(1, num_foreground),
  103. "bbox_ctrness": loss_bbox_ctrness / max(1, num_foreground),
  104. }
  105. def forward(self, x: List[Tensor]) -> Dict[str, Tensor]:
  106. cls_logits = self.classification_head(x)
  107. bbox_regression, bbox_ctrness = self.regression_head(x)
  108. return {
  109. "cls_logits": cls_logits,
  110. "bbox_regression": bbox_regression,
  111. "bbox_ctrness": bbox_ctrness,
  112. }
  113. class FCOSClassificationHead(nn.Module):
  114. """
  115. A classification head for use in FCOS.
  116. Args:
  117. in_channels (int): number of channels of the input feature.
  118. num_anchors (int): number of anchors to be predicted.
  119. num_classes (int): number of classes to be predicted.
  120. num_convs (Optional[int]): number of conv layer. Default: 4.
  121. prior_probability (Optional[float]): probability of prior. Default: 0.01.
  122. norm_layer: Module specifying the normalization layer to use.
  123. """
  124. def __init__(
  125. self,
  126. in_channels: int,
  127. num_anchors: int,
  128. num_classes: int,
  129. num_convs: int = 4,
  130. prior_probability: float = 0.01,
  131. norm_layer: Optional[Callable[..., nn.Module]] = None,
  132. ) -> None:
  133. super().__init__()
  134. self.num_classes = num_classes
  135. self.num_anchors = num_anchors
  136. if norm_layer is None:
  137. norm_layer = partial(nn.GroupNorm, 32)
  138. conv = []
  139. for _ in range(num_convs):
  140. conv.append(nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1))
  141. conv.append(norm_layer(in_channels))
  142. conv.append(nn.ReLU())
  143. self.conv = nn.Sequential(*conv)
  144. for layer in self.conv.children():
  145. if isinstance(layer, nn.Conv2d):
  146. torch.nn.init.normal_(layer.weight, std=0.01)
  147. torch.nn.init.constant_(layer.bias, 0)
  148. self.cls_logits = nn.Conv2d(in_channels, num_anchors * num_classes, kernel_size=3, stride=1, padding=1)
  149. torch.nn.init.normal_(self.cls_logits.weight, std=0.01)
  150. torch.nn.init.constant_(self.cls_logits.bias, -math.log((1 - prior_probability) / prior_probability))
  151. def forward(self, x: List[Tensor]) -> Tensor:
  152. all_cls_logits = []
  153. for features in x:
  154. cls_logits = self.conv(features)
  155. cls_logits = self.cls_logits(cls_logits)
  156. # Permute classification output from (N, A * K, H, W) to (N, HWA, K).
  157. N, _, H, W = cls_logits.shape
  158. cls_logits = cls_logits.view(N, -1, self.num_classes, H, W)
  159. cls_logits = cls_logits.permute(0, 3, 4, 1, 2)
  160. cls_logits = cls_logits.reshape(N, -1, self.num_classes) # Size=(N, HWA, 4)
  161. all_cls_logits.append(cls_logits)
  162. return torch.cat(all_cls_logits, dim=1)
  163. class FCOSRegressionHead(nn.Module):
  164. """
  165. A regression head for use in FCOS, which combines regression branch and center-ness branch.
  166. This can obtain better performance.
  167. Reference: `FCOS: A simple and strong anchor-free object detector <https://arxiv.org/abs/2006.09214>`_.
  168. Args:
  169. in_channels (int): number of channels of the input feature
  170. num_anchors (int): number of anchors to be predicted
  171. num_convs (Optional[int]): number of conv layer. Default: 4.
  172. norm_layer: Module specifying the normalization layer to use.
  173. """
  174. def __init__(
  175. self,
  176. in_channels: int,
  177. num_anchors: int,
  178. num_convs: int = 4,
  179. norm_layer: Optional[Callable[..., nn.Module]] = None,
  180. ):
  181. super().__init__()
  182. if norm_layer is None:
  183. norm_layer = partial(nn.GroupNorm, 32)
  184. conv = []
  185. for _ in range(num_convs):
  186. conv.append(nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1))
  187. conv.append(norm_layer(in_channels))
  188. conv.append(nn.ReLU())
  189. self.conv = nn.Sequential(*conv)
  190. self.bbox_reg = nn.Conv2d(in_channels, num_anchors * 4, kernel_size=3, stride=1, padding=1)
  191. self.bbox_ctrness = nn.Conv2d(in_channels, num_anchors * 1, kernel_size=3, stride=1, padding=1)
  192. for layer in [self.bbox_reg, self.bbox_ctrness]:
  193. torch.nn.init.normal_(layer.weight, std=0.01)
  194. torch.nn.init.zeros_(layer.bias)
  195. for layer in self.conv.children():
  196. if isinstance(layer, nn.Conv2d):
  197. torch.nn.init.normal_(layer.weight, std=0.01)
  198. torch.nn.init.zeros_(layer.bias)
  199. def forward(self, x: List[Tensor]) -> Tuple[Tensor, Tensor]:
  200. all_bbox_regression = []
  201. all_bbox_ctrness = []
  202. for features in x:
  203. bbox_feature = self.conv(features)
  204. bbox_regression = nn.functional.relu(self.bbox_reg(bbox_feature))
  205. bbox_ctrness = self.bbox_ctrness(bbox_feature)
  206. # permute bbox regression output from (N, 4 * A, H, W) to (N, HWA, 4).
  207. N, _, H, W = bbox_regression.shape
  208. bbox_regression = bbox_regression.view(N, -1, 4, H, W)
  209. bbox_regression = bbox_regression.permute(0, 3, 4, 1, 2)
  210. bbox_regression = bbox_regression.reshape(N, -1, 4) # Size=(N, HWA, 4)
  211. all_bbox_regression.append(bbox_regression)
  212. # permute bbox ctrness output from (N, 1 * A, H, W) to (N, HWA, 1).
  213. bbox_ctrness = bbox_ctrness.view(N, -1, 1, H, W)
  214. bbox_ctrness = bbox_ctrness.permute(0, 3, 4, 1, 2)
  215. bbox_ctrness = bbox_ctrness.reshape(N, -1, 1)
  216. all_bbox_ctrness.append(bbox_ctrness)
  217. return torch.cat(all_bbox_regression, dim=1), torch.cat(all_bbox_ctrness, dim=1)
  218. class FCOS(nn.Module):
  219. """
  220. Implements FCOS.
  221. The input to the model is expected to be a list of tensors, each of shape [C, H, W], one for each
  222. image, and should be in 0-1 range. Different images can have different sizes.
  223. The behavior of the model changes depending on if it is in training or evaluation mode.
  224. During training, the model expects both the input tensors and targets (list of dictionary),
  225. containing:
  226. - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with
  227. ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
  228. - labels (Int64Tensor[N]): the class label for each ground-truth box
  229. The model returns a Dict[Tensor] during training, containing the classification, regression
  230. and centerness losses.
  231. During inference, the model requires only the input tensors, and returns the post-processed
  232. predictions as a List[Dict[Tensor]], one for each input image. The fields of the Dict are as
  233. follows:
  234. - boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with
  235. ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
  236. - labels (Int64Tensor[N]): the predicted labels for each image
  237. - scores (Tensor[N]): the scores for each prediction
  238. Args:
  239. backbone (nn.Module): the network used to compute the features for the model.
  240. It should contain an out_channels attribute, which indicates the number of output
  241. channels that each feature map has (and it should be the same for all feature maps).
  242. The backbone should return a single Tensor or an OrderedDict[Tensor].
  243. num_classes (int): number of output classes of the model (including the background).
  244. min_size (int): minimum size of the image to be rescaled before feeding it to the backbone
  245. max_size (int): maximum size of the image to be rescaled before feeding it to the backbone
  246. image_mean (Tuple[float, float, float]): mean values used for input normalization.
  247. They are generally the mean values of the dataset on which the backbone has been trained
  248. on
  249. image_std (Tuple[float, float, float]): std values used for input normalization.
  250. They are generally the std values of the dataset on which the backbone has been trained on
  251. anchor_generator (AnchorGenerator): module that generates the anchors for a set of feature
  252. maps. For FCOS, only set one anchor for per position of each level, the width and height equal to
  253. the stride of feature map, and set aspect ratio = 1.0, so the center of anchor is equivalent to the point
  254. in FCOS paper.
  255. head (nn.Module): Module run on top of the feature pyramid.
  256. Defaults to a module containing a classification and regression module.
  257. center_sampling_radius (int): radius of the "center" of a groundtruth box,
  258. within which all anchor points are labeled positive.
  259. score_thresh (float): Score threshold used for postprocessing the detections.
  260. nms_thresh (float): NMS threshold used for postprocessing the detections.
  261. detections_per_img (int): Number of best detections to keep after NMS.
  262. topk_candidates (int): Number of best detections to keep before NMS.
  263. Example:
  264. >>> import torch
  265. >>> import torchvision
  266. >>> from torchvision.models.detection import FCOS
  267. >>> from torchvision.models.detection.anchor_utils import AnchorGenerator
  268. >>> # load a pre-trained model for classification and return
  269. >>> # only the features
  270. >>> backbone = torchvision.models.mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT).features
  271. >>> # FCOS needs to know the number of
  272. >>> # output channels in a backbone. For mobilenet_v2, it's 1280,
  273. >>> # so we need to add it here
  274. >>> backbone.out_channels = 1280
  275. >>>
  276. >>> # let's make the network generate 5 x 3 anchors per spatial
  277. >>> # location, with 5 different sizes and 3 different aspect
  278. >>> # ratios. We have a Tuple[Tuple[int]] because each feature
  279. >>> # map could potentially have different sizes and
  280. >>> # aspect ratios
  281. >>> anchor_generator = AnchorGenerator(
  282. >>> sizes=((8,), (16,), (32,), (64,), (128,)),
  283. >>> aspect_ratios=((1.0,),)
  284. >>> )
  285. >>>
  286. >>> # put the pieces together inside a FCOS model
  287. >>> model = FCOS(
  288. >>> backbone,
  289. >>> num_classes=80,
  290. >>> anchor_generator=anchor_generator,
  291. >>> )
  292. >>> model.eval()
  293. >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
  294. >>> predictions = model(x)
  295. """
  296. __annotations__ = {
  297. "box_coder": det_utils.BoxLinearCoder,
  298. }
  299. def __init__(
  300. self,
  301. backbone: nn.Module,
  302. num_classes: int,
  303. # transform parameters
  304. min_size: int = 800,
  305. max_size: int = 1333,
  306. image_mean: Optional[List[float]] = None,
  307. image_std: Optional[List[float]] = None,
  308. # Anchor parameters
  309. anchor_generator: Optional[AnchorGenerator] = None,
  310. head: Optional[nn.Module] = None,
  311. center_sampling_radius: float = 1.5,
  312. score_thresh: float = 0.2,
  313. nms_thresh: float = 0.6,
  314. detections_per_img: int = 100,
  315. topk_candidates: int = 1000,
  316. **kwargs,
  317. ):
  318. super().__init__()
  319. _log_api_usage_once(self)
  320. if not hasattr(backbone, "out_channels"):
  321. raise ValueError(
  322. "backbone should contain an attribute out_channels "
  323. "specifying the number of output channels (assumed to be the "
  324. "same for all the levels)"
  325. )
  326. self.backbone = backbone
  327. if not isinstance(anchor_generator, (AnchorGenerator, type(None))):
  328. raise TypeError(
  329. f"anchor_generator should be of type AnchorGenerator or None, instead got {type(anchor_generator)}"
  330. )
  331. if anchor_generator is None:
  332. anchor_sizes = ((8,), (16,), (32,), (64,), (128,)) # equal to strides of multi-level feature map
  333. aspect_ratios = ((1.0,),) * len(anchor_sizes) # set only one anchor
  334. anchor_generator = AnchorGenerator(anchor_sizes, aspect_ratios)
  335. self.anchor_generator = anchor_generator
  336. if self.anchor_generator.num_anchors_per_location()[0] != 1:
  337. raise ValueError(
  338. f"anchor_generator.num_anchors_per_location()[0] should be 1 instead of {anchor_generator.num_anchors_per_location()[0]}"
  339. )
  340. if head is None:
  341. head = FCOSHead(backbone.out_channels, anchor_generator.num_anchors_per_location()[0], num_classes)
  342. self.head = head
  343. self.box_coder = det_utils.BoxLinearCoder(normalize_by_size=True)
  344. if image_mean is None:
  345. image_mean = [0.485, 0.456, 0.406]
  346. if image_std is None:
  347. image_std = [0.229, 0.224, 0.225]
  348. self.transform = GeneralizedRCNNTransform(min_size, max_size, image_mean, image_std, **kwargs)
  349. self.center_sampling_radius = center_sampling_radius
  350. self.score_thresh = score_thresh
  351. self.nms_thresh = nms_thresh
  352. self.detections_per_img = detections_per_img
  353. self.topk_candidates = topk_candidates
  354. # used only on torchscript mode
  355. self._has_warned = False
  356. @torch.jit.unused
  357. def eager_outputs(
  358. self, losses: Dict[str, Tensor], detections: List[Dict[str, Tensor]]
  359. ) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]:
  360. if self.training:
  361. return losses
  362. return detections
  363. def compute_loss(
  364. self,
  365. targets: List[Dict[str, Tensor]],
  366. head_outputs: Dict[str, Tensor],
  367. anchors: List[Tensor],
  368. num_anchors_per_level: List[int],
  369. ) -> Dict[str, Tensor]:
  370. matched_idxs = []
  371. for anchors_per_image, targets_per_image in zip(anchors, targets):
  372. if targets_per_image["boxes"].numel() == 0:
  373. matched_idxs.append(
  374. torch.full((anchors_per_image.size(0),), -1, dtype=torch.int64, device=anchors_per_image.device)
  375. )
  376. continue
  377. gt_boxes = targets_per_image["boxes"]
  378. gt_centers = (gt_boxes[:, :2] + gt_boxes[:, 2:]) / 2 # Nx2
  379. anchor_centers = (anchors_per_image[:, :2] + anchors_per_image[:, 2:]) / 2 # N
  380. anchor_sizes = anchors_per_image[:, 2] - anchors_per_image[:, 0]
  381. # center sampling: anchor point must be close enough to gt center.
  382. pairwise_match = (anchor_centers[:, None, :] - gt_centers[None, :, :]).abs_().max(
  383. dim=2
  384. ).values < self.center_sampling_radius * anchor_sizes[:, None]
  385. # compute pairwise distance between N points and M boxes
  386. x, y = anchor_centers.unsqueeze(dim=2).unbind(dim=1) # (N, 1)
  387. x0, y0, x1, y1 = gt_boxes.unsqueeze(dim=0).unbind(dim=2) # (1, M)
  388. pairwise_dist = torch.stack([x - x0, y - y0, x1 - x, y1 - y], dim=2) # (N, M)
  389. # anchor point must be inside gt
  390. pairwise_match &= pairwise_dist.min(dim=2).values > 0
  391. # each anchor is only responsible for certain scale range.
  392. lower_bound = anchor_sizes * 4
  393. lower_bound[: num_anchors_per_level[0]] = 0
  394. upper_bound = anchor_sizes * 8
  395. upper_bound[-num_anchors_per_level[-1] :] = float("inf")
  396. pairwise_dist = pairwise_dist.max(dim=2).values
  397. pairwise_match &= (pairwise_dist > lower_bound[:, None]) & (pairwise_dist < upper_bound[:, None])
  398. # match the GT box with minimum area, if there are multiple GT matches
  399. gt_areas = (gt_boxes[:, 2] - gt_boxes[:, 0]) * (gt_boxes[:, 3] - gt_boxes[:, 1]) # N
  400. pairwise_match = pairwise_match.to(torch.float32) * (1e8 - gt_areas[None, :])
  401. min_values, matched_idx = pairwise_match.max(dim=1) # R, per-anchor match
  402. matched_idx[min_values < 1e-5] = -1 # unmatched anchors are assigned -1
  403. matched_idxs.append(matched_idx)
  404. return self.head.compute_loss(targets, head_outputs, anchors, matched_idxs)
  405. def postprocess_detections(
  406. self, head_outputs: Dict[str, List[Tensor]], anchors: List[List[Tensor]], image_shapes: List[Tuple[int, int]]
  407. ) -> List[Dict[str, Tensor]]:
  408. class_logits = head_outputs["cls_logits"]
  409. box_regression = head_outputs["bbox_regression"]
  410. box_ctrness = head_outputs["bbox_ctrness"]
  411. num_images = len(image_shapes)
  412. detections: List[Dict[str, Tensor]] = []
  413. for index in range(num_images):
  414. box_regression_per_image = [br[index] for br in box_regression]
  415. logits_per_image = [cl[index] for cl in class_logits]
  416. box_ctrness_per_image = [bc[index] for bc in box_ctrness]
  417. anchors_per_image, image_shape = anchors[index], image_shapes[index]
  418. image_boxes = []
  419. image_scores = []
  420. image_labels = []
  421. for box_regression_per_level, logits_per_level, box_ctrness_per_level, anchors_per_level in zip(
  422. box_regression_per_image, logits_per_image, box_ctrness_per_image, anchors_per_image
  423. ):
  424. num_classes = logits_per_level.shape[-1]
  425. # remove low scoring boxes
  426. scores_per_level = torch.sqrt(
  427. torch.sigmoid(logits_per_level) * torch.sigmoid(box_ctrness_per_level)
  428. ).flatten()
  429. keep_idxs = scores_per_level > self.score_thresh
  430. scores_per_level = scores_per_level[keep_idxs]
  431. topk_idxs = torch.where(keep_idxs)[0]
  432. # keep only topk scoring predictions
  433. num_topk = det_utils._topk_min(topk_idxs, self.topk_candidates, 0)
  434. scores_per_level, idxs = scores_per_level.topk(num_topk)
  435. topk_idxs = topk_idxs[idxs]
  436. anchor_idxs = torch.div(topk_idxs, num_classes, rounding_mode="floor")
  437. labels_per_level = topk_idxs % num_classes
  438. boxes_per_level = self.box_coder.decode(
  439. box_regression_per_level[anchor_idxs], anchors_per_level[anchor_idxs]
  440. )
  441. boxes_per_level = box_ops.clip_boxes_to_image(boxes_per_level, image_shape)
  442. image_boxes.append(boxes_per_level)
  443. image_scores.append(scores_per_level)
  444. image_labels.append(labels_per_level)
  445. image_boxes = torch.cat(image_boxes, dim=0)
  446. image_scores = torch.cat(image_scores, dim=0)
  447. image_labels = torch.cat(image_labels, dim=0)
  448. # non-maximum suppression
  449. keep = box_ops.batched_nms(image_boxes, image_scores, image_labels, self.nms_thresh)
  450. keep = keep[: self.detections_per_img]
  451. detections.append(
  452. {
  453. "boxes": image_boxes[keep],
  454. "scores": image_scores[keep],
  455. "labels": image_labels[keep],
  456. }
  457. )
  458. return detections
  459. def forward(
  460. self,
  461. images: List[Tensor],
  462. targets: Optional[List[Dict[str, Tensor]]] = None,
  463. ) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]:
  464. """
  465. Args:
  466. images (list[Tensor]): images to be processed
  467. targets (list[Dict[Tensor]]): ground-truth boxes present in the image (optional)
  468. Returns:
  469. result (list[BoxList] or dict[Tensor]): the output from the model.
  470. During training, it returns a dict[Tensor] which contains the losses.
  471. During testing, it returns list[BoxList] contains additional fields
  472. like `scores`, `labels` and `mask` (for Mask R-CNN models).
  473. """
  474. if self.training:
  475. if targets is None:
  476. torch._assert(False, "targets should not be none when in training mode")
  477. else:
  478. for target in targets:
  479. boxes = target["boxes"]
  480. torch._assert(isinstance(boxes, torch.Tensor), "Expected target boxes to be of type Tensor.")
  481. torch._assert(
  482. len(boxes.shape) == 2 and boxes.shape[-1] == 4,
  483. f"Expected target boxes to be a tensor of shape [N, 4], got {boxes.shape}.",
  484. )
  485. original_image_sizes: List[Tuple[int, int]] = []
  486. for img in images:
  487. val = img.shape[-2:]
  488. torch._assert(
  489. len(val) == 2,
  490. f"expecting the last two dimensions of the Tensor to be H and W instead got {img.shape[-2:]}",
  491. )
  492. original_image_sizes.append((val[0], val[1]))
  493. # transform the input
  494. images, targets = self.transform(images, targets)
  495. # Check for degenerate boxes
  496. if targets is not None:
  497. for target_idx, target in enumerate(targets):
  498. boxes = target["boxes"]
  499. degenerate_boxes = boxes[:, 2:] <= boxes[:, :2]
  500. if degenerate_boxes.any():
  501. # print the first degenerate box
  502. bb_idx = torch.where(degenerate_boxes.any(dim=1))[0][0]
  503. degen_bb: List[float] = boxes[bb_idx].tolist()
  504. torch._assert(
  505. False,
  506. f"All bounding boxes should have positive height and width. Found invalid box {degen_bb} for target at index {target_idx}.",
  507. )
  508. # get the features from the backbone
  509. features = self.backbone(images.tensors)
  510. if isinstance(features, torch.Tensor):
  511. features = OrderedDict([("0", features)])
  512. features = list(features.values())
  513. # compute the fcos heads outputs using the features
  514. head_outputs = self.head(features)
  515. # create the set of anchors
  516. anchors = self.anchor_generator(images, features)
  517. # recover level sizes
  518. num_anchors_per_level = [x.size(2) * x.size(3) for x in features]
  519. losses = {}
  520. detections: List[Dict[str, Tensor]] = []
  521. if self.training:
  522. if targets is None:
  523. torch._assert(False, "targets should not be none when in training mode")
  524. else:
  525. # compute the losses
  526. losses = self.compute_loss(targets, head_outputs, anchors, num_anchors_per_level)
  527. else:
  528. # split outputs per level
  529. split_head_outputs: Dict[str, List[Tensor]] = {}
  530. for k in head_outputs:
  531. split_head_outputs[k] = list(head_outputs[k].split(num_anchors_per_level, dim=1))
  532. split_anchors = [list(a.split(num_anchors_per_level)) for a in anchors]
  533. # compute the detections
  534. detections = self.postprocess_detections(split_head_outputs, split_anchors, images.image_sizes)
  535. detections = self.transform.postprocess(detections, images.image_sizes, original_image_sizes)
  536. if torch.jit.is_scripting():
  537. if not self._has_warned:
  538. warnings.warn("FCOS always returns a (Losses, Detections) tuple in scripting")
  539. self._has_warned = True
  540. return losses, detections
  541. return self.eager_outputs(losses, detections)
  542. class FCOS_ResNet50_FPN_Weights(WeightsEnum):
  543. COCO_V1 = Weights(
  544. url="https://download.pytorch.org/models/fcos_resnet50_fpn_coco-99b0c9b7.pth",
  545. transforms=ObjectDetection,
  546. meta={
  547. "num_params": 32269600,
  548. "categories": _COCO_CATEGORIES,
  549. "min_size": (1, 1),
  550. "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#fcos-resnet-50-fpn",
  551. "_metrics": {
  552. "COCO-val2017": {
  553. "box_map": 39.2,
  554. }
  555. },
  556. "_ops": 128.207,
  557. "_file_size": 123.608,
  558. "_docs": """These weights were produced by following a similar training recipe as on the paper.""",
  559. },
  560. )
  561. DEFAULT = COCO_V1
  562. @register_model()
  563. @handle_legacy_interface(
  564. weights=("pretrained", FCOS_ResNet50_FPN_Weights.COCO_V1),
  565. weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1),
  566. )
  567. def fcos_resnet50_fpn(
  568. *,
  569. weights: Optional[FCOS_ResNet50_FPN_Weights] = None,
  570. progress: bool = True,
  571. num_classes: Optional[int] = None,
  572. weights_backbone: Optional[ResNet50_Weights] = ResNet50_Weights.IMAGENET1K_V1,
  573. trainable_backbone_layers: Optional[int] = None,
  574. **kwargs: Any,
  575. ) -> FCOS:
  576. """
  577. Constructs a FCOS model with a ResNet-50-FPN backbone.
  578. .. betastatus:: detection module
  579. Reference: `FCOS: Fully Convolutional One-Stage Object Detection <https://arxiv.org/abs/1904.01355>`_.
  580. `FCOS: A simple and strong anchor-free object detector <https://arxiv.org/abs/2006.09214>`_.
  581. The input to the model is expected to be a list of tensors, each of shape ``[C, H, W]``, one for each
  582. image, and should be in ``0-1`` range. Different images can have different sizes.
  583. The behavior of the model changes depending on if it is in training or evaluation mode.
  584. During training, the model expects both the input tensors and targets (list of dictionary),
  585. containing:
  586. - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with
  587. ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
  588. - labels (``Int64Tensor[N]``): the class label for each ground-truth box
  589. The model returns a ``Dict[Tensor]`` during training, containing the classification and regression
  590. losses.
  591. During inference, the model requires only the input tensors, and returns the post-processed
  592. predictions as a ``List[Dict[Tensor]]``, one for each input image. The fields of the ``Dict`` are as
  593. follows, where ``N`` is the number of detections:
  594. - boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with
  595. ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
  596. - labels (``Int64Tensor[N]``): the predicted labels for each detection
  597. - scores (``Tensor[N]``): the scores of each detection
  598. For more details on the output, you may refer to :ref:`instance_seg_output`.
  599. Example:
  600. >>> model = torchvision.models.detection.fcos_resnet50_fpn(weights=FCOS_ResNet50_FPN_Weights.DEFAULT)
  601. >>> model.eval()
  602. >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
  603. >>> predictions = model(x)
  604. Args:
  605. weights (:class:`~torchvision.models.detection.FCOS_ResNet50_FPN_Weights`, optional): The
  606. pretrained weights to use. See
  607. :class:`~torchvision.models.detection.FCOS_ResNet50_FPN_Weights`
  608. below for more details, and possible values. By default, no
  609. pre-trained weights are used.
  610. progress (bool): If True, displays a progress bar of the download to stderr
  611. num_classes (int, optional): number of output classes of the model (including the background)
  612. weights_backbone (:class:`~torchvision.models.ResNet50_Weights`, optional): The pretrained weights for
  613. the backbone.
  614. trainable_backbone_layers (int, optional): number of trainable (not frozen) resnet layers starting
  615. from final block. Valid values are between 0 and 5, with 5 meaning all backbone layers are
  616. trainable. If ``None`` is passed (the default) this value is set to 3. Default: None
  617. **kwargs: parameters passed to the ``torchvision.models.detection.FCOS``
  618. base class. Please refer to the `source code
  619. <https://github.com/pytorch/vision/blob/main/torchvision/models/detection/fcos.py>`_
  620. for more details about this class.
  621. .. autoclass:: torchvision.models.detection.FCOS_ResNet50_FPN_Weights
  622. :members:
  623. """
  624. weights = FCOS_ResNet50_FPN_Weights.verify(weights)
  625. weights_backbone = ResNet50_Weights.verify(weights_backbone)
  626. if weights is not None:
  627. weights_backbone = None
  628. num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"]))
  629. elif num_classes is None:
  630. num_classes = 91
  631. is_trained = weights is not None or weights_backbone is not None
  632. trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)
  633. norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
  634. backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=norm_layer)
  635. backbone = _resnet_fpn_extractor(
  636. backbone, trainable_backbone_layers, returned_layers=[2, 3, 4], extra_blocks=LastLevelP6P7(256, 256)
  637. )
  638. model = FCOS(backbone, num_classes, **kwargs)
  639. if weights is not None:
  640. model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
  641. return model