retinanet.py 36 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899
  1. import math
  2. import warnings
  3. from collections import OrderedDict
  4. from functools import partial
  5. from typing import Any, Callable, Dict, List, Optional, Tuple
  6. import torch
  7. from torch import nn, Tensor
  8. from ...ops import boxes as box_ops, misc as misc_nn_ops, sigmoid_focal_loss
  9. from ...ops.feature_pyramid_network import LastLevelP6P7
  10. from ...transforms._presets import ObjectDetection
  11. from ...utils import _log_api_usage_once
  12. from .._api import register_model, Weights, WeightsEnum
  13. from .._meta import _COCO_CATEGORIES
  14. from .._utils import _ovewrite_value_param, handle_legacy_interface
  15. from ..resnet import resnet50, ResNet50_Weights
  16. from . import _utils as det_utils
  17. from ._utils import _box_loss, overwrite_eps
  18. from .anchor_utils import AnchorGenerator
  19. from .backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers
  20. from .transform import GeneralizedRCNNTransform
  21. __all__ = [
  22. "RetinaNet",
  23. "RetinaNet_ResNet50_FPN_Weights",
  24. "RetinaNet_ResNet50_FPN_V2_Weights",
  25. "retinanet_resnet50_fpn",
  26. "retinanet_resnet50_fpn_v2",
  27. ]
  28. def _sum(x: List[Tensor]) -> Tensor:
  29. res = x[0]
  30. for i in x[1:]:
  31. res = res + i
  32. return res
  33. def _v1_to_v2_weights(state_dict, prefix):
  34. for i in range(4):
  35. for type in ["weight", "bias"]:
  36. old_key = f"{prefix}conv.{2*i}.{type}"
  37. new_key = f"{prefix}conv.{i}.0.{type}"
  38. if old_key in state_dict:
  39. state_dict[new_key] = state_dict.pop(old_key)
  40. def _default_anchorgen():
  41. anchor_sizes = tuple((x, int(x * 2 ** (1.0 / 3)), int(x * 2 ** (2.0 / 3))) for x in [32, 64, 128, 256, 512])
  42. aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes)
  43. anchor_generator = AnchorGenerator(anchor_sizes, aspect_ratios)
  44. return anchor_generator
  45. class RetinaNetHead(nn.Module):
  46. """
  47. A regression and classification head for use in RetinaNet.
  48. Args:
  49. in_channels (int): number of channels of the input feature
  50. num_anchors (int): number of anchors to be predicted
  51. num_classes (int): number of classes to be predicted
  52. norm_layer (callable, optional): Module specifying the normalization layer to use. Default: None
  53. """
  54. def __init__(self, in_channels, num_anchors, num_classes, norm_layer: Optional[Callable[..., nn.Module]] = None):
  55. super().__init__()
  56. self.classification_head = RetinaNetClassificationHead(
  57. in_channels, num_anchors, num_classes, norm_layer=norm_layer
  58. )
  59. self.regression_head = RetinaNetRegressionHead(in_channels, num_anchors, norm_layer=norm_layer)
  60. def compute_loss(self, targets, head_outputs, anchors, matched_idxs):
  61. # type: (List[Dict[str, Tensor]], Dict[str, Tensor], List[Tensor], List[Tensor]) -> Dict[str, Tensor]
  62. return {
  63. "classification": self.classification_head.compute_loss(targets, head_outputs, matched_idxs),
  64. "bbox_regression": self.regression_head.compute_loss(targets, head_outputs, anchors, matched_idxs),
  65. }
  66. def forward(self, x):
  67. # type: (List[Tensor]) -> Dict[str, Tensor]
  68. return {"cls_logits": self.classification_head(x), "bbox_regression": self.regression_head(x)}
  69. class RetinaNetClassificationHead(nn.Module):
  70. """
  71. A classification head for use in RetinaNet.
  72. Args:
  73. in_channels (int): number of channels of the input feature
  74. num_anchors (int): number of anchors to be predicted
  75. num_classes (int): number of classes to be predicted
  76. norm_layer (callable, optional): Module specifying the normalization layer to use. Default: None
  77. """
  78. _version = 2
  79. def __init__(
  80. self,
  81. in_channels,
  82. num_anchors,
  83. num_classes,
  84. prior_probability=0.01,
  85. norm_layer: Optional[Callable[..., nn.Module]] = None,
  86. ):
  87. super().__init__()
  88. conv = []
  89. for _ in range(4):
  90. conv.append(misc_nn_ops.Conv2dNormActivation(in_channels, in_channels, norm_layer=norm_layer))
  91. self.conv = nn.Sequential(*conv)
  92. for layer in self.conv.modules():
  93. if isinstance(layer, nn.Conv2d):
  94. torch.nn.init.normal_(layer.weight, std=0.01)
  95. if layer.bias is not None:
  96. torch.nn.init.constant_(layer.bias, 0)
  97. self.cls_logits = nn.Conv2d(in_channels, num_anchors * num_classes, kernel_size=3, stride=1, padding=1)
  98. torch.nn.init.normal_(self.cls_logits.weight, std=0.01)
  99. torch.nn.init.constant_(self.cls_logits.bias, -math.log((1 - prior_probability) / prior_probability))
  100. self.num_classes = num_classes
  101. self.num_anchors = num_anchors
  102. # This is to fix using det_utils.Matcher.BETWEEN_THRESHOLDS in TorchScript.
  103. # TorchScript doesn't support class attributes.
  104. # https://github.com/pytorch/vision/pull/1697#issuecomment-630255584
  105. self.BETWEEN_THRESHOLDS = det_utils.Matcher.BETWEEN_THRESHOLDS
  106. def _load_from_state_dict(
  107. self,
  108. state_dict,
  109. prefix,
  110. local_metadata,
  111. strict,
  112. missing_keys,
  113. unexpected_keys,
  114. error_msgs,
  115. ):
  116. version = local_metadata.get("version", None)
  117. if version is None or version < 2:
  118. _v1_to_v2_weights(state_dict, prefix)
  119. super()._load_from_state_dict(
  120. state_dict,
  121. prefix,
  122. local_metadata,
  123. strict,
  124. missing_keys,
  125. unexpected_keys,
  126. error_msgs,
  127. )
  128. def compute_loss(self, targets, head_outputs, matched_idxs):
  129. # type: (List[Dict[str, Tensor]], Dict[str, Tensor], List[Tensor]) -> Tensor
  130. losses = []
  131. cls_logits = head_outputs["cls_logits"]
  132. for targets_per_image, cls_logits_per_image, matched_idxs_per_image in zip(targets, cls_logits, matched_idxs):
  133. # determine only the foreground
  134. foreground_idxs_per_image = matched_idxs_per_image >= 0
  135. num_foreground = foreground_idxs_per_image.sum()
  136. # create the target classification
  137. gt_classes_target = torch.zeros_like(cls_logits_per_image)
  138. gt_classes_target[
  139. foreground_idxs_per_image,
  140. targets_per_image["labels"][matched_idxs_per_image[foreground_idxs_per_image]],
  141. ] = 1.0
  142. # find indices for which anchors should be ignored
  143. valid_idxs_per_image = matched_idxs_per_image != self.BETWEEN_THRESHOLDS
  144. # compute the classification loss
  145. losses.append(
  146. sigmoid_focal_loss(
  147. cls_logits_per_image[valid_idxs_per_image],
  148. gt_classes_target[valid_idxs_per_image],
  149. reduction="sum",
  150. )
  151. / max(1, num_foreground)
  152. )
  153. return _sum(losses) / len(targets)
  154. def forward(self, x):
  155. # type: (List[Tensor]) -> Tensor
  156. all_cls_logits = []
  157. for features in x:
  158. cls_logits = self.conv(features)
  159. cls_logits = self.cls_logits(cls_logits)
  160. # Permute classification output from (N, A * K, H, W) to (N, HWA, K).
  161. N, _, H, W = cls_logits.shape
  162. cls_logits = cls_logits.view(N, -1, self.num_classes, H, W)
  163. cls_logits = cls_logits.permute(0, 3, 4, 1, 2)
  164. cls_logits = cls_logits.reshape(N, -1, self.num_classes) # Size=(N, HWA, 4)
  165. all_cls_logits.append(cls_logits)
  166. return torch.cat(all_cls_logits, dim=1)
  167. class RetinaNetRegressionHead(nn.Module):
  168. """
  169. A regression head for use in RetinaNet.
  170. Args:
  171. in_channels (int): number of channels of the input feature
  172. num_anchors (int): number of anchors to be predicted
  173. norm_layer (callable, optional): Module specifying the normalization layer to use. Default: None
  174. """
  175. _version = 2
  176. __annotations__ = {
  177. "box_coder": det_utils.BoxCoder,
  178. }
  179. def __init__(self, in_channels, num_anchors, norm_layer: Optional[Callable[..., nn.Module]] = None):
  180. super().__init__()
  181. conv = []
  182. for _ in range(4):
  183. conv.append(misc_nn_ops.Conv2dNormActivation(in_channels, in_channels, norm_layer=norm_layer))
  184. self.conv = nn.Sequential(*conv)
  185. self.bbox_reg = nn.Conv2d(in_channels, num_anchors * 4, kernel_size=3, stride=1, padding=1)
  186. torch.nn.init.normal_(self.bbox_reg.weight, std=0.01)
  187. torch.nn.init.zeros_(self.bbox_reg.bias)
  188. for layer in self.conv.modules():
  189. if isinstance(layer, nn.Conv2d):
  190. torch.nn.init.normal_(layer.weight, std=0.01)
  191. if layer.bias is not None:
  192. torch.nn.init.zeros_(layer.bias)
  193. self.box_coder = det_utils.BoxCoder(weights=(1.0, 1.0, 1.0, 1.0))
  194. self._loss_type = "l1"
  195. def _load_from_state_dict(
  196. self,
  197. state_dict,
  198. prefix,
  199. local_metadata,
  200. strict,
  201. missing_keys,
  202. unexpected_keys,
  203. error_msgs,
  204. ):
  205. version = local_metadata.get("version", None)
  206. if version is None or version < 2:
  207. _v1_to_v2_weights(state_dict, prefix)
  208. super()._load_from_state_dict(
  209. state_dict,
  210. prefix,
  211. local_metadata,
  212. strict,
  213. missing_keys,
  214. unexpected_keys,
  215. error_msgs,
  216. )
  217. def compute_loss(self, targets, head_outputs, anchors, matched_idxs):
  218. # type: (List[Dict[str, Tensor]], Dict[str, Tensor], List[Tensor], List[Tensor]) -> Tensor
  219. losses = []
  220. bbox_regression = head_outputs["bbox_regression"]
  221. for targets_per_image, bbox_regression_per_image, anchors_per_image, matched_idxs_per_image in zip(
  222. targets, bbox_regression, anchors, matched_idxs
  223. ):
  224. # determine only the foreground indices, ignore the rest
  225. foreground_idxs_per_image = torch.where(matched_idxs_per_image >= 0)[0]
  226. num_foreground = foreground_idxs_per_image.numel()
  227. # select only the foreground boxes
  228. matched_gt_boxes_per_image = targets_per_image["boxes"][matched_idxs_per_image[foreground_idxs_per_image]]
  229. bbox_regression_per_image = bbox_regression_per_image[foreground_idxs_per_image, :]
  230. anchors_per_image = anchors_per_image[foreground_idxs_per_image, :]
  231. # compute the loss
  232. losses.append(
  233. _box_loss(
  234. self._loss_type,
  235. self.box_coder,
  236. anchors_per_image,
  237. matched_gt_boxes_per_image,
  238. bbox_regression_per_image,
  239. )
  240. / max(1, num_foreground)
  241. )
  242. return _sum(losses) / max(1, len(targets))
  243. def forward(self, x):
  244. # type: (List[Tensor]) -> Tensor
  245. all_bbox_regression = []
  246. for features in x:
  247. bbox_regression = self.conv(features)
  248. bbox_regression = self.bbox_reg(bbox_regression)
  249. # Permute bbox regression output from (N, 4 * A, H, W) to (N, HWA, 4).
  250. N, _, H, W = bbox_regression.shape
  251. bbox_regression = bbox_regression.view(N, -1, 4, H, W)
  252. bbox_regression = bbox_regression.permute(0, 3, 4, 1, 2)
  253. bbox_regression = bbox_regression.reshape(N, -1, 4) # Size=(N, HWA, 4)
  254. all_bbox_regression.append(bbox_regression)
  255. return torch.cat(all_bbox_regression, dim=1)
  256. class RetinaNet(nn.Module):
  257. """
  258. Implements RetinaNet.
  259. The input to the model is expected to be a list of tensors, each of shape [C, H, W], one for each
  260. image, and should be in 0-1 range. Different images can have different sizes.
  261. The behavior of the model changes depending on if it is in training or evaluation mode.
  262. During training, the model expects both the input tensors and targets (list of dictionary),
  263. containing:
  264. - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with
  265. ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
  266. - labels (Int64Tensor[N]): the class label for each ground-truth box
  267. The model returns a Dict[Tensor] during training, containing the classification and regression
  268. losses.
  269. During inference, the model requires only the input tensors, and returns the post-processed
  270. predictions as a List[Dict[Tensor]], one for each input image. The fields of the Dict are as
  271. follows:
  272. - boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with
  273. ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
  274. - labels (Int64Tensor[N]): the predicted labels for each image
  275. - scores (Tensor[N]): the scores for each prediction
  276. Args:
  277. backbone (nn.Module): the network used to compute the features for the model.
  278. It should contain an out_channels attribute, which indicates the number of output
  279. channels that each feature map has (and it should be the same for all feature maps).
  280. The backbone should return a single Tensor or an OrderedDict[Tensor].
  281. num_classes (int): number of output classes of the model (including the background).
  282. min_size (int): minimum size of the image to be rescaled before feeding it to the backbone
  283. max_size (int): maximum size of the image to be rescaled before feeding it to the backbone
  284. image_mean (Tuple[float, float, float]): mean values used for input normalization.
  285. They are generally the mean values of the dataset on which the backbone has been trained
  286. on
  287. image_std (Tuple[float, float, float]): std values used for input normalization.
  288. They are generally the std values of the dataset on which the backbone has been trained on
  289. anchor_generator (AnchorGenerator): module that generates the anchors for a set of feature
  290. maps.
  291. head (nn.Module): Module run on top of the feature pyramid.
  292. Defaults to a module containing a classification and regression module.
  293. score_thresh (float): Score threshold used for postprocessing the detections.
  294. nms_thresh (float): NMS threshold used for postprocessing the detections.
  295. detections_per_img (int): Number of best detections to keep after NMS.
  296. fg_iou_thresh (float): minimum IoU between the anchor and the GT box so that they can be
  297. considered as positive during training.
  298. bg_iou_thresh (float): maximum IoU between the anchor and the GT box so that they can be
  299. considered as negative during training.
  300. topk_candidates (int): Number of best detections to keep before NMS.
  301. Example:
  302. >>> import torch
  303. >>> import torchvision
  304. >>> from torchvision.models.detection import RetinaNet
  305. >>> from torchvision.models.detection.anchor_utils import AnchorGenerator
  306. >>> # load a pre-trained model for classification and return
  307. >>> # only the features
  308. >>> backbone = torchvision.models.mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT).features
  309. >>> # RetinaNet needs to know the number of
  310. >>> # output channels in a backbone. For mobilenet_v2, it's 1280,
  311. >>> # so we need to add it here
  312. >>> backbone.out_channels = 1280
  313. >>>
  314. >>> # let's make the network generate 5 x 3 anchors per spatial
  315. >>> # location, with 5 different sizes and 3 different aspect
  316. >>> # ratios. We have a Tuple[Tuple[int]] because each feature
  317. >>> # map could potentially have different sizes and
  318. >>> # aspect ratios
  319. >>> anchor_generator = AnchorGenerator(
  320. >>> sizes=((32, 64, 128, 256, 512),),
  321. >>> aspect_ratios=((0.5, 1.0, 2.0),)
  322. >>> )
  323. >>>
  324. >>> # put the pieces together inside a RetinaNet model
  325. >>> model = RetinaNet(backbone,
  326. >>> num_classes=2,
  327. >>> anchor_generator=anchor_generator)
  328. >>> model.eval()
  329. >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
  330. >>> predictions = model(x)
  331. """
  332. __annotations__ = {
  333. "box_coder": det_utils.BoxCoder,
  334. "proposal_matcher": det_utils.Matcher,
  335. }
  336. def __init__(
  337. self,
  338. backbone,
  339. num_classes,
  340. # transform parameters
  341. min_size=800,
  342. max_size=1333,
  343. image_mean=None,
  344. image_std=None,
  345. # Anchor parameters
  346. anchor_generator=None,
  347. head=None,
  348. proposal_matcher=None,
  349. score_thresh=0.05,
  350. nms_thresh=0.5,
  351. detections_per_img=300,
  352. fg_iou_thresh=0.5,
  353. bg_iou_thresh=0.4,
  354. topk_candidates=1000,
  355. **kwargs,
  356. ):
  357. super().__init__()
  358. _log_api_usage_once(self)
  359. if not hasattr(backbone, "out_channels"):
  360. raise ValueError(
  361. "backbone should contain an attribute out_channels "
  362. "specifying the number of output channels (assumed to be the "
  363. "same for all the levels)"
  364. )
  365. self.backbone = backbone
  366. if not isinstance(anchor_generator, (AnchorGenerator, type(None))):
  367. raise TypeError(
  368. f"anchor_generator should be of type AnchorGenerator or None instead of {type(anchor_generator)}"
  369. )
  370. if anchor_generator is None:
  371. anchor_generator = _default_anchorgen()
  372. self.anchor_generator = anchor_generator
  373. if head is None:
  374. head = RetinaNetHead(backbone.out_channels, anchor_generator.num_anchors_per_location()[0], num_classes)
  375. self.head = head
  376. if proposal_matcher is None:
  377. proposal_matcher = det_utils.Matcher(
  378. fg_iou_thresh,
  379. bg_iou_thresh,
  380. allow_low_quality_matches=True,
  381. )
  382. self.proposal_matcher = proposal_matcher
  383. self.box_coder = det_utils.BoxCoder(weights=(1.0, 1.0, 1.0, 1.0))
  384. if image_mean is None:
  385. image_mean = [0.485, 0.456, 0.406]
  386. if image_std is None:
  387. image_std = [0.229, 0.224, 0.225]
  388. self.transform = GeneralizedRCNNTransform(min_size, max_size, image_mean, image_std, **kwargs)
  389. self.score_thresh = score_thresh
  390. self.nms_thresh = nms_thresh
  391. self.detections_per_img = detections_per_img
  392. self.topk_candidates = topk_candidates
  393. # used only on torchscript mode
  394. self._has_warned = False
  395. @torch.jit.unused
  396. def eager_outputs(self, losses, detections):
  397. # type: (Dict[str, Tensor], List[Dict[str, Tensor]]) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]
  398. if self.training:
  399. return losses
  400. return detections
  401. def compute_loss(self, targets, head_outputs, anchors):
  402. # type: (List[Dict[str, Tensor]], Dict[str, Tensor], List[Tensor]) -> Dict[str, Tensor]
  403. matched_idxs = []
  404. for anchors_per_image, targets_per_image in zip(anchors, targets):
  405. if targets_per_image["boxes"].numel() == 0:
  406. matched_idxs.append(
  407. torch.full((anchors_per_image.size(0),), -1, dtype=torch.int64, device=anchors_per_image.device)
  408. )
  409. continue
  410. match_quality_matrix = box_ops.box_iou(targets_per_image["boxes"], anchors_per_image)
  411. matched_idxs.append(self.proposal_matcher(match_quality_matrix))
  412. return self.head.compute_loss(targets, head_outputs, anchors, matched_idxs)
  413. def postprocess_detections(self, head_outputs, anchors, image_shapes):
  414. # type: (Dict[str, List[Tensor]], List[List[Tensor]], List[Tuple[int, int]]) -> List[Dict[str, Tensor]]
  415. class_logits = head_outputs["cls_logits"]
  416. box_regression = head_outputs["bbox_regression"]
  417. num_images = len(image_shapes)
  418. detections: List[Dict[str, Tensor]] = []
  419. for index in range(num_images):
  420. box_regression_per_image = [br[index] for br in box_regression]
  421. logits_per_image = [cl[index] for cl in class_logits]
  422. anchors_per_image, image_shape = anchors[index], image_shapes[index]
  423. image_boxes = []
  424. image_scores = []
  425. image_labels = []
  426. for box_regression_per_level, logits_per_level, anchors_per_level in zip(
  427. box_regression_per_image, logits_per_image, anchors_per_image
  428. ):
  429. num_classes = logits_per_level.shape[-1]
  430. # remove low scoring boxes
  431. scores_per_level = torch.sigmoid(logits_per_level).flatten()
  432. keep_idxs = scores_per_level > self.score_thresh
  433. scores_per_level = scores_per_level[keep_idxs]
  434. topk_idxs = torch.where(keep_idxs)[0]
  435. # keep only topk scoring predictions
  436. num_topk = det_utils._topk_min(topk_idxs, self.topk_candidates, 0)
  437. scores_per_level, idxs = scores_per_level.topk(num_topk)
  438. topk_idxs = topk_idxs[idxs]
  439. anchor_idxs = torch.div(topk_idxs, num_classes, rounding_mode="floor")
  440. labels_per_level = topk_idxs % num_classes
  441. boxes_per_level = self.box_coder.decode_single(
  442. box_regression_per_level[anchor_idxs], anchors_per_level[anchor_idxs]
  443. )
  444. boxes_per_level = box_ops.clip_boxes_to_image(boxes_per_level, image_shape)
  445. image_boxes.append(boxes_per_level)
  446. image_scores.append(scores_per_level)
  447. image_labels.append(labels_per_level)
  448. image_boxes = torch.cat(image_boxes, dim=0)
  449. image_scores = torch.cat(image_scores, dim=0)
  450. image_labels = torch.cat(image_labels, dim=0)
  451. # non-maximum suppression
  452. keep = box_ops.batched_nms(image_boxes, image_scores, image_labels, self.nms_thresh)
  453. keep = keep[: self.detections_per_img]
  454. detections.append(
  455. {
  456. "boxes": image_boxes[keep],
  457. "scores": image_scores[keep],
  458. "labels": image_labels[keep],
  459. }
  460. )
  461. return detections
  462. def forward(self, images, targets=None):
  463. # type: (List[Tensor], Optional[List[Dict[str, Tensor]]]) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]
  464. """
  465. Args:
  466. images (list[Tensor]): images to be processed
  467. targets (list[Dict[Tensor]]): ground-truth boxes present in the image (optional)
  468. Returns:
  469. result (list[BoxList] or dict[Tensor]): the output from the model.
  470. During training, it returns a dict[Tensor] which contains the losses.
  471. During testing, it returns list[BoxList] contains additional fields
  472. like `scores`, `labels` and `mask` (for Mask R-CNN models).
  473. """
  474. if self.training:
  475. if targets is None:
  476. torch._assert(False, "targets should not be none when in training mode")
  477. else:
  478. for target in targets:
  479. boxes = target["boxes"]
  480. torch._assert(isinstance(boxes, torch.Tensor), "Expected target boxes to be of type Tensor.")
  481. torch._assert(
  482. len(boxes.shape) == 2 and boxes.shape[-1] == 4,
  483. "Expected target boxes to be a tensor of shape [N, 4].",
  484. )
  485. # get the original image sizes
  486. original_image_sizes: List[Tuple[int, int]] = []
  487. for img in images:
  488. val = img.shape[-2:]
  489. torch._assert(
  490. len(val) == 2,
  491. f"expecting the last two dimensions of the Tensor to be H and W instead got {img.shape[-2:]}",
  492. )
  493. original_image_sizes.append((val[0], val[1]))
  494. # transform the input
  495. images, targets = self.transform(images, targets)
  496. # Check for degenerate boxes
  497. # TODO: Move this to a function
  498. if targets is not None:
  499. for target_idx, target in enumerate(targets):
  500. boxes = target["boxes"]
  501. degenerate_boxes = boxes[:, 2:] <= boxes[:, :2]
  502. if degenerate_boxes.any():
  503. # print the first degenerate box
  504. bb_idx = torch.where(degenerate_boxes.any(dim=1))[0][0]
  505. degen_bb: List[float] = boxes[bb_idx].tolist()
  506. torch._assert(
  507. False,
  508. "All bounding boxes should have positive height and width."
  509. f" Found invalid box {degen_bb} for target at index {target_idx}.",
  510. )
  511. # get the features from the backbone
  512. features = self.backbone(images.tensors)
  513. if isinstance(features, torch.Tensor):
  514. features = OrderedDict([("0", features)])
  515. # TODO: Do we want a list or a dict?
  516. features = list(features.values())
  517. # compute the retinanet heads outputs using the features
  518. head_outputs = self.head(features)
  519. # create the set of anchors
  520. anchors = self.anchor_generator(images, features)
  521. losses = {}
  522. detections: List[Dict[str, Tensor]] = []
  523. if self.training:
  524. if targets is None:
  525. torch._assert(False, "targets should not be none when in training mode")
  526. else:
  527. # compute the losses
  528. losses = self.compute_loss(targets, head_outputs, anchors)
  529. else:
  530. # recover level sizes
  531. num_anchors_per_level = [x.size(2) * x.size(3) for x in features]
  532. HW = 0
  533. for v in num_anchors_per_level:
  534. HW += v
  535. HWA = head_outputs["cls_logits"].size(1)
  536. A = HWA // HW
  537. num_anchors_per_level = [hw * A for hw in num_anchors_per_level]
  538. # split outputs per level
  539. split_head_outputs: Dict[str, List[Tensor]] = {}
  540. for k in head_outputs:
  541. split_head_outputs[k] = list(head_outputs[k].split(num_anchors_per_level, dim=1))
  542. split_anchors = [list(a.split(num_anchors_per_level)) for a in anchors]
  543. # compute the detections
  544. detections = self.postprocess_detections(split_head_outputs, split_anchors, images.image_sizes)
  545. detections = self.transform.postprocess(detections, images.image_sizes, original_image_sizes)
  546. if torch.jit.is_scripting():
  547. if not self._has_warned:
  548. warnings.warn("RetinaNet always returns a (Losses, Detections) tuple in scripting")
  549. self._has_warned = True
  550. return losses, detections
  551. return self.eager_outputs(losses, detections)
  552. _COMMON_META = {
  553. "categories": _COCO_CATEGORIES,
  554. "min_size": (1, 1),
  555. }
  556. class RetinaNet_ResNet50_FPN_Weights(WeightsEnum):
  557. COCO_V1 = Weights(
  558. url="https://download.pytorch.org/models/retinanet_resnet50_fpn_coco-eeacb38b.pth",
  559. transforms=ObjectDetection,
  560. meta={
  561. **_COMMON_META,
  562. "num_params": 34014999,
  563. "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#retinanet",
  564. "_metrics": {
  565. "COCO-val2017": {
  566. "box_map": 36.4,
  567. }
  568. },
  569. "_ops": 151.54,
  570. "_file_size": 130.267,
  571. "_docs": """These weights were produced by following a similar training recipe as on the paper.""",
  572. },
  573. )
  574. DEFAULT = COCO_V1
  575. class RetinaNet_ResNet50_FPN_V2_Weights(WeightsEnum):
  576. COCO_V1 = Weights(
  577. url="https://download.pytorch.org/models/retinanet_resnet50_fpn_v2_coco-5905b1c5.pth",
  578. transforms=ObjectDetection,
  579. meta={
  580. **_COMMON_META,
  581. "num_params": 38198935,
  582. "recipe": "https://github.com/pytorch/vision/pull/5756",
  583. "_metrics": {
  584. "COCO-val2017": {
  585. "box_map": 41.5,
  586. }
  587. },
  588. "_ops": 152.238,
  589. "_file_size": 146.037,
  590. "_docs": """These weights were produced using an enhanced training recipe to boost the model accuracy.""",
  591. },
  592. )
  593. DEFAULT = COCO_V1
  594. @register_model()
  595. @handle_legacy_interface(
  596. weights=("pretrained", RetinaNet_ResNet50_FPN_Weights.COCO_V1),
  597. weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1),
  598. )
  599. def retinanet_resnet50_fpn(
  600. *,
  601. weights: Optional[RetinaNet_ResNet50_FPN_Weights] = None,
  602. progress: bool = True,
  603. num_classes: Optional[int] = None,
  604. weights_backbone: Optional[ResNet50_Weights] = ResNet50_Weights.IMAGENET1K_V1,
  605. trainable_backbone_layers: Optional[int] = None,
  606. **kwargs: Any,
  607. ) -> RetinaNet:
  608. """
  609. Constructs a RetinaNet model with a ResNet-50-FPN backbone.
  610. .. betastatus:: detection module
  611. Reference: `Focal Loss for Dense Object Detection <https://arxiv.org/abs/1708.02002>`_.
  612. The input to the model is expected to be a list of tensors, each of shape ``[C, H, W]``, one for each
  613. image, and should be in ``0-1`` range. Different images can have different sizes.
  614. The behavior of the model changes depending on if it is in training or evaluation mode.
  615. During training, the model expects both the input tensors and targets (list of dictionary),
  616. containing:
  617. - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with
  618. ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
  619. - labels (``Int64Tensor[N]``): the class label for each ground-truth box
  620. The model returns a ``Dict[Tensor]`` during training, containing the classification and regression
  621. losses.
  622. During inference, the model requires only the input tensors, and returns the post-processed
  623. predictions as a ``List[Dict[Tensor]]``, one for each input image. The fields of the ``Dict`` are as
  624. follows, where ``N`` is the number of detections:
  625. - boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with
  626. ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
  627. - labels (``Int64Tensor[N]``): the predicted labels for each detection
  628. - scores (``Tensor[N]``): the scores of each detection
  629. For more details on the output, you may refer to :ref:`instance_seg_output`.
  630. Example::
  631. >>> model = torchvision.models.detection.retinanet_resnet50_fpn(weights=RetinaNet_ResNet50_FPN_Weights.DEFAULT)
  632. >>> model.eval()
  633. >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
  634. >>> predictions = model(x)
  635. Args:
  636. weights (:class:`~torchvision.models.detection.RetinaNet_ResNet50_FPN_Weights`, optional): The
  637. pretrained weights to use. See
  638. :class:`~torchvision.models.detection.RetinaNet_ResNet50_FPN_Weights`
  639. below for more details, and possible values. By default, no
  640. pre-trained weights are used.
  641. progress (bool): If True, displays a progress bar of the download to stderr. Default is True.
  642. num_classes (int, optional): number of output classes of the model (including the background)
  643. weights_backbone (:class:`~torchvision.models.ResNet50_Weights`, optional): The pretrained weights for
  644. the backbone.
  645. trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from final block.
  646. Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. If ``None`` is
  647. passed (the default) this value is set to 3.
  648. **kwargs: parameters passed to the ``torchvision.models.detection.RetinaNet``
  649. base class. Please refer to the `source code
  650. <https://github.com/pytorch/vision/blob/main/torchvision/models/detection/retinanet.py>`_
  651. for more details about this class.
  652. .. autoclass:: torchvision.models.detection.RetinaNet_ResNet50_FPN_Weights
  653. :members:
  654. """
  655. weights = RetinaNet_ResNet50_FPN_Weights.verify(weights)
  656. weights_backbone = ResNet50_Weights.verify(weights_backbone)
  657. if weights is not None:
  658. weights_backbone = None
  659. num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"]))
  660. elif num_classes is None:
  661. num_classes = 91
  662. is_trained = weights is not None or weights_backbone is not None
  663. trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)
  664. norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
  665. backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=norm_layer)
  666. # skip P2 because it generates too many anchors (according to their paper)
  667. backbone = _resnet_fpn_extractor(
  668. backbone, trainable_backbone_layers, returned_layers=[2, 3, 4], extra_blocks=LastLevelP6P7(256, 256)
  669. )
  670. model = RetinaNet(backbone, num_classes, **kwargs)
  671. if weights is not None:
  672. model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
  673. if weights == RetinaNet_ResNet50_FPN_Weights.COCO_V1:
  674. overwrite_eps(model, 0.0)
  675. return model
  676. @register_model()
  677. @handle_legacy_interface(
  678. weights=("pretrained", RetinaNet_ResNet50_FPN_V2_Weights.COCO_V1),
  679. weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1),
  680. )
  681. def retinanet_resnet50_fpn_v2(
  682. *,
  683. weights: Optional[RetinaNet_ResNet50_FPN_V2_Weights] = None,
  684. progress: bool = True,
  685. num_classes: Optional[int] = None,
  686. weights_backbone: Optional[ResNet50_Weights] = None,
  687. trainable_backbone_layers: Optional[int] = None,
  688. **kwargs: Any,
  689. ) -> RetinaNet:
  690. """
  691. Constructs an improved RetinaNet model with a ResNet-50-FPN backbone.
  692. .. betastatus:: detection module
  693. Reference: `Bridging the Gap Between Anchor-based and Anchor-free Detection via Adaptive Training Sample Selection
  694. <https://arxiv.org/abs/1912.02424>`_.
  695. :func:`~torchvision.models.detection.retinanet_resnet50_fpn` for more details.
  696. Args:
  697. weights (:class:`~torchvision.models.detection.RetinaNet_ResNet50_FPN_V2_Weights`, optional): The
  698. pretrained weights to use. See
  699. :class:`~torchvision.models.detection.RetinaNet_ResNet50_FPN_V2_Weights`
  700. below for more details, and possible values. By default, no
  701. pre-trained weights are used.
  702. progress (bool): If True, displays a progress bar of the download to stderr. Default is True.
  703. num_classes (int, optional): number of output classes of the model (including the background)
  704. weights_backbone (:class:`~torchvision.models.ResNet50_Weights`, optional): The pretrained weights for
  705. the backbone.
  706. trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from final block.
  707. Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. If ``None`` is
  708. passed (the default) this value is set to 3.
  709. **kwargs: parameters passed to the ``torchvision.models.detection.RetinaNet``
  710. base class. Please refer to the `source code
  711. <https://github.com/pytorch/vision/blob/main/torchvision/models/detection/retinanet.py>`_
  712. for more details about this class.
  713. .. autoclass:: torchvision.models.detection.RetinaNet_ResNet50_FPN_V2_Weights
  714. :members:
  715. """
  716. weights = RetinaNet_ResNet50_FPN_V2_Weights.verify(weights)
  717. weights_backbone = ResNet50_Weights.verify(weights_backbone)
  718. if weights is not None:
  719. weights_backbone = None
  720. num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"]))
  721. elif num_classes is None:
  722. num_classes = 91
  723. is_trained = weights is not None or weights_backbone is not None
  724. trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)
  725. backbone = resnet50(weights=weights_backbone, progress=progress)
  726. backbone = _resnet_fpn_extractor(
  727. backbone, trainable_backbone_layers, returned_layers=[2, 3, 4], extra_blocks=LastLevelP6P7(2048, 256)
  728. )
  729. anchor_generator = _default_anchorgen()
  730. head = RetinaNetHead(
  731. backbone.out_channels,
  732. anchor_generator.num_anchors_per_location()[0],
  733. num_classes,
  734. norm_layer=partial(nn.GroupNorm, 32),
  735. )
  736. head.regression_head._loss_type = "giou"
  737. model = RetinaNet(backbone, num_classes, anchor_generator=anchor_generator, head=head, **kwargs)
  738. if weights is not None:
  739. model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
  740. return model