faster_rcnn.py 36 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843
  1. from typing import Any, Callable, List, Optional, Tuple, Union
  2. import torch
  3. import torch.nn.functional as F
  4. from torch import nn
  5. from torchvision.ops import MultiScaleRoIAlign
  6. from ...ops import misc as misc_nn_ops
  7. from ...transforms._presets import ObjectDetection
  8. from .._api import register_model, Weights, WeightsEnum
  9. from .._meta import _COCO_CATEGORIES
  10. from .._utils import _ovewrite_value_param, handle_legacy_interface
  11. from ..mobilenetv3 import mobilenet_v3_large, MobileNet_V3_Large_Weights
  12. from ..resnet import resnet50, ResNet50_Weights
  13. from ._utils import overwrite_eps
  14. from .anchor_utils import AnchorGenerator
  15. from .backbone_utils import _mobilenet_extractor, _resnet_fpn_extractor, _validate_trainable_layers
  16. from .generalized_rcnn import GeneralizedRCNN
  17. from .roi_heads import RoIHeads
  18. from .rpn import RegionProposalNetwork, RPNHead
  19. from .transform import GeneralizedRCNNTransform
  20. __all__ = [
  21. "FasterRCNN",
  22. "FasterRCNN_ResNet50_FPN_Weights",
  23. "FasterRCNN_ResNet50_FPN_V2_Weights",
  24. "FasterRCNN_MobileNet_V3_Large_FPN_Weights",
  25. "FasterRCNN_MobileNet_V3_Large_320_FPN_Weights",
  26. "fasterrcnn_resnet50_fpn",
  27. "fasterrcnn_resnet50_fpn_v2",
  28. "fasterrcnn_mobilenet_v3_large_fpn",
  29. "fasterrcnn_mobilenet_v3_large_320_fpn",
  30. ]
  31. def _default_anchorgen():
  32. anchor_sizes = ((32,), (64,), (128,), (256,), (512,))
  33. aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes)
  34. return AnchorGenerator(anchor_sizes, aspect_ratios)
  35. class FasterRCNN(GeneralizedRCNN):
  36. """
  37. Implements Faster R-CNN.
  38. The input to the model is expected to be a list of tensors, each of shape [C, H, W], one for each
  39. image, and should be in 0-1 range. Different images can have different sizes.
  40. The behavior of the model changes depending on if it is in training or evaluation mode.
  41. During training, the model expects both the input tensors and targets (list of dictionary),
  42. containing:
  43. - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with
  44. ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
  45. - labels (Int64Tensor[N]): the class label for each ground-truth box
  46. The model returns a Dict[Tensor] during training, containing the classification and regression
  47. losses for both the RPN and the R-CNN.
  48. During inference, the model requires only the input tensors, and returns the post-processed
  49. predictions as a List[Dict[Tensor]], one for each input image. The fields of the Dict are as
  50. follows:
  51. - boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with
  52. ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
  53. - labels (Int64Tensor[N]): the predicted labels for each image
  54. - scores (Tensor[N]): the scores or each prediction
  55. Args:
  56. backbone (nn.Module): the network used to compute the features for the model.
  57. It should contain an out_channels attribute, which indicates the number of output
  58. channels that each feature map has (and it should be the same for all feature maps).
  59. The backbone should return a single Tensor or and OrderedDict[Tensor].
  60. num_classes (int): number of output classes of the model (including the background).
  61. If box_predictor is specified, num_classes should be None.
  62. min_size (int): minimum size of the image to be rescaled before feeding it to the backbone
  63. max_size (int): maximum size of the image to be rescaled before feeding it to the backbone
  64. image_mean (Tuple[float, float, float]): mean values used for input normalization.
  65. They are generally the mean values of the dataset on which the backbone has been trained
  66. on
  67. image_std (Tuple[float, float, float]): std values used for input normalization.
  68. They are generally the std values of the dataset on which the backbone has been trained on
  69. rpn_anchor_generator (AnchorGenerator): module that generates the anchors for a set of feature
  70. maps.
  71. rpn_head (nn.Module): module that computes the objectness and regression deltas from the RPN
  72. rpn_pre_nms_top_n_train (int): number of proposals to keep before applying NMS during training
  73. rpn_pre_nms_top_n_test (int): number of proposals to keep before applying NMS during testing
  74. rpn_post_nms_top_n_train (int): number of proposals to keep after applying NMS during training
  75. rpn_post_nms_top_n_test (int): number of proposals to keep after applying NMS during testing
  76. rpn_nms_thresh (float): NMS threshold used for postprocessing the RPN proposals
  77. rpn_fg_iou_thresh (float): minimum IoU between the anchor and the GT box so that they can be
  78. considered as positive during training of the RPN.
  79. rpn_bg_iou_thresh (float): maximum IoU between the anchor and the GT box so that they can be
  80. considered as negative during training of the RPN.
  81. rpn_batch_size_per_image (int): number of anchors that are sampled during training of the RPN
  82. for computing the loss
  83. rpn_positive_fraction (float): proportion of positive anchors in a mini-batch during training
  84. of the RPN
  85. rpn_score_thresh (float): during inference, only return proposals with a classification score
  86. greater than rpn_score_thresh
  87. box_roi_pool (MultiScaleRoIAlign): the module which crops and resizes the feature maps in
  88. the locations indicated by the bounding boxes
  89. box_head (nn.Module): module that takes the cropped feature maps as input
  90. box_predictor (nn.Module): module that takes the output of box_head and returns the
  91. classification logits and box regression deltas.
  92. box_score_thresh (float): during inference, only return proposals with a classification score
  93. greater than box_score_thresh
  94. box_nms_thresh (float): NMS threshold for the prediction head. Used during inference
  95. box_detections_per_img (int): maximum number of detections per image, for all classes.
  96. box_fg_iou_thresh (float): minimum IoU between the proposals and the GT box so that they can be
  97. considered as positive during training of the classification head
  98. box_bg_iou_thresh (float): maximum IoU between the proposals and the GT box so that they can be
  99. considered as negative during training of the classification head
  100. box_batch_size_per_image (int): number of proposals that are sampled during training of the
  101. classification head
  102. box_positive_fraction (float): proportion of positive proposals in a mini-batch during training
  103. of the classification head
  104. bbox_reg_weights (Tuple[float, float, float, float]): weights for the encoding/decoding of the
  105. bounding boxes
  106. Example::
  107. >>> import torch
  108. >>> import torchvision
  109. >>> from torchvision.models.detection import FasterRCNN
  110. >>> from torchvision.models.detection.rpn import AnchorGenerator
  111. >>> # load a pre-trained model for classification and return
  112. >>> # only the features
  113. >>> backbone = torchvision.models.mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT).features
  114. >>> # FasterRCNN needs to know the number of
  115. >>> # output channels in a backbone. For mobilenet_v2, it's 1280,
  116. >>> # so we need to add it here
  117. >>> backbone.out_channels = 1280
  118. >>>
  119. >>> # let's make the RPN generate 5 x 3 anchors per spatial
  120. >>> # location, with 5 different sizes and 3 different aspect
  121. >>> # ratios. We have a Tuple[Tuple[int]] because each feature
  122. >>> # map could potentially have different sizes and
  123. >>> # aspect ratios
  124. >>> anchor_generator = AnchorGenerator(sizes=((32, 64, 128, 256, 512),),
  125. >>> aspect_ratios=((0.5, 1.0, 2.0),))
  126. >>>
  127. >>> # let's define what are the feature maps that we will
  128. >>> # use to perform the region of interest cropping, as well as
  129. >>> # the size of the crop after rescaling.
  130. >>> # if your backbone returns a Tensor, featmap_names is expected to
  131. >>> # be ['0']. More generally, the backbone should return an
  132. >>> # OrderedDict[Tensor], and in featmap_names you can choose which
  133. >>> # feature maps to use.
  134. >>> roi_pooler = torchvision.ops.MultiScaleRoIAlign(featmap_names=['0'],
  135. >>> output_size=7,
  136. >>> sampling_ratio=2)
  137. >>>
  138. >>> # put the pieces together inside a FasterRCNN model
  139. >>> model = FasterRCNN(backbone,
  140. >>> num_classes=2,
  141. >>> rpn_anchor_generator=anchor_generator,
  142. >>> box_roi_pool=roi_pooler)
  143. >>> model.eval()
  144. >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
  145. >>> predictions = model(x)
  146. """
  147. def __init__(
  148. self,
  149. backbone,
  150. num_classes=None,
  151. # transform parameters
  152. min_size=800,
  153. max_size=1333,
  154. image_mean=None,
  155. image_std=None,
  156. # RPN parameters
  157. rpn_anchor_generator=None,
  158. rpn_head=None,
  159. rpn_pre_nms_top_n_train=2000,
  160. rpn_pre_nms_top_n_test=1000,
  161. rpn_post_nms_top_n_train=2000,
  162. rpn_post_nms_top_n_test=1000,
  163. rpn_nms_thresh=0.7,
  164. rpn_fg_iou_thresh=0.7,
  165. rpn_bg_iou_thresh=0.3,
  166. rpn_batch_size_per_image=256,
  167. rpn_positive_fraction=0.5,
  168. rpn_score_thresh=0.0,
  169. # Box parameters
  170. box_roi_pool=None,
  171. box_head=None,
  172. box_predictor=None,
  173. box_score_thresh=0.05,
  174. box_nms_thresh=0.5,
  175. box_detections_per_img=100,
  176. box_fg_iou_thresh=0.5,
  177. box_bg_iou_thresh=0.5,
  178. box_batch_size_per_image=512,
  179. box_positive_fraction=0.25,
  180. bbox_reg_weights=None,
  181. **kwargs,
  182. ):
  183. if not hasattr(backbone, "out_channels"):
  184. raise ValueError(
  185. "backbone should contain an attribute out_channels "
  186. "specifying the number of output channels (assumed to be the "
  187. "same for all the levels)"
  188. )
  189. if not isinstance(rpn_anchor_generator, (AnchorGenerator, type(None))):
  190. raise TypeError(
  191. f"rpn_anchor_generator should be of type AnchorGenerator or None instead of {type(rpn_anchor_generator)}"
  192. )
  193. if not isinstance(box_roi_pool, (MultiScaleRoIAlign, type(None))):
  194. raise TypeError(
  195. f"box_roi_pool should be of type MultiScaleRoIAlign or None instead of {type(box_roi_pool)}"
  196. )
  197. if num_classes is not None:
  198. if box_predictor is not None:
  199. raise ValueError("num_classes should be None when box_predictor is specified")
  200. else:
  201. if box_predictor is None:
  202. raise ValueError("num_classes should not be None when box_predictor is not specified")
  203. out_channels = backbone.out_channels
  204. if rpn_anchor_generator is None:
  205. rpn_anchor_generator = _default_anchorgen()
  206. if rpn_head is None:
  207. rpn_head = RPNHead(out_channels, rpn_anchor_generator.num_anchors_per_location()[0])
  208. rpn_pre_nms_top_n = dict(training=rpn_pre_nms_top_n_train, testing=rpn_pre_nms_top_n_test)
  209. rpn_post_nms_top_n = dict(training=rpn_post_nms_top_n_train, testing=rpn_post_nms_top_n_test)
  210. rpn = RegionProposalNetwork(
  211. rpn_anchor_generator,
  212. rpn_head,
  213. rpn_fg_iou_thresh,
  214. rpn_bg_iou_thresh,
  215. rpn_batch_size_per_image,
  216. rpn_positive_fraction,
  217. rpn_pre_nms_top_n,
  218. rpn_post_nms_top_n,
  219. rpn_nms_thresh,
  220. score_thresh=rpn_score_thresh,
  221. )
  222. if box_roi_pool is None:
  223. box_roi_pool = MultiScaleRoIAlign(featmap_names=["0", "1", "2", "3"], output_size=7, sampling_ratio=2)
  224. if box_head is None:
  225. resolution = box_roi_pool.output_size[0]
  226. representation_size = 1024
  227. box_head = TwoMLPHead(out_channels * resolution**2, representation_size)
  228. if box_predictor is None:
  229. representation_size = 1024
  230. box_predictor = FastRCNNPredictor(representation_size, num_classes)
  231. roi_heads = RoIHeads(
  232. # Box
  233. box_roi_pool,
  234. box_head,
  235. box_predictor,
  236. box_fg_iou_thresh,
  237. box_bg_iou_thresh,
  238. box_batch_size_per_image,
  239. box_positive_fraction,
  240. bbox_reg_weights,
  241. box_score_thresh,
  242. box_nms_thresh,
  243. box_detections_per_img,
  244. )
  245. if image_mean is None:
  246. image_mean = [0.485, 0.456, 0.406]
  247. if image_std is None:
  248. image_std = [0.229, 0.224, 0.225]
  249. transform = GeneralizedRCNNTransform(min_size, max_size, image_mean, image_std, **kwargs)
  250. super().__init__(backbone, rpn, roi_heads, transform)
  251. class TwoMLPHead(nn.Module):
  252. """
  253. Standard heads for FPN-based models
  254. Args:
  255. in_channels (int): number of input channels
  256. representation_size (int): size of the intermediate representation
  257. """
  258. def __init__(self, in_channels, representation_size):
  259. super().__init__()
  260. self.fc6 = nn.Linear(in_channels, representation_size)
  261. self.fc7 = nn.Linear(representation_size, representation_size)
  262. def forward(self, x):
  263. x = x.flatten(start_dim=1)
  264. x = F.relu(self.fc6(x))
  265. x = F.relu(self.fc7(x))
  266. return x
  267. class FastRCNNConvFCHead(nn.Sequential):
  268. def __init__(
  269. self,
  270. input_size: Tuple[int, int, int],
  271. conv_layers: List[int],
  272. fc_layers: List[int],
  273. norm_layer: Optional[Callable[..., nn.Module]] = None,
  274. ):
  275. """
  276. Args:
  277. input_size (Tuple[int, int, int]): the input size in CHW format.
  278. conv_layers (list): feature dimensions of each Convolution layer
  279. fc_layers (list): feature dimensions of each FCN layer
  280. norm_layer (callable, optional): Module specifying the normalization layer to use. Default: None
  281. """
  282. in_channels, in_height, in_width = input_size
  283. blocks = []
  284. previous_channels = in_channels
  285. for current_channels in conv_layers:
  286. blocks.append(misc_nn_ops.Conv2dNormActivation(previous_channels, current_channels, norm_layer=norm_layer))
  287. previous_channels = current_channels
  288. blocks.append(nn.Flatten())
  289. previous_channels = previous_channels * in_height * in_width
  290. for current_channels in fc_layers:
  291. blocks.append(nn.Linear(previous_channels, current_channels))
  292. blocks.append(nn.ReLU(inplace=True))
  293. previous_channels = current_channels
  294. super().__init__(*blocks)
  295. for layer in self.modules():
  296. if isinstance(layer, nn.Conv2d):
  297. nn.init.kaiming_normal_(layer.weight, mode="fan_out", nonlinearity="relu")
  298. if layer.bias is not None:
  299. nn.init.zeros_(layer.bias)
  300. class FastRCNNPredictor(nn.Module):
  301. """
  302. Standard classification + bounding box regression layers
  303. for Fast R-CNN.
  304. Args:
  305. in_channels (int): number of input channels
  306. num_classes (int): number of output classes (including background)
  307. """
  308. def __init__(self, in_channels, num_classes):
  309. super().__init__()
  310. self.cls_score = nn.Linear(in_channels, num_classes)
  311. self.bbox_pred = nn.Linear(in_channels, num_classes * 4)
  312. def forward(self, x):
  313. if x.dim() == 4:
  314. torch._assert(
  315. list(x.shape[2:]) == [1, 1],
  316. f"x has the wrong shape, expecting the last two dimensions to be [1,1] instead of {list(x.shape[2:])}",
  317. )
  318. x = x.flatten(start_dim=1)
  319. scores = self.cls_score(x)
  320. bbox_deltas = self.bbox_pred(x)
  321. return scores, bbox_deltas
  322. _COMMON_META = {
  323. "categories": _COCO_CATEGORIES,
  324. "min_size": (1, 1),
  325. }
  326. class FasterRCNN_ResNet50_FPN_Weights(WeightsEnum):
  327. COCO_V1 = Weights(
  328. url="https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth",
  329. transforms=ObjectDetection,
  330. meta={
  331. **_COMMON_META,
  332. "num_params": 41755286,
  333. "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-resnet-50-fpn",
  334. "_metrics": {
  335. "COCO-val2017": {
  336. "box_map": 37.0,
  337. }
  338. },
  339. "_ops": 134.38,
  340. "_file_size": 159.743,
  341. "_docs": """These weights were produced by following a similar training recipe as on the paper.""",
  342. },
  343. )
  344. DEFAULT = COCO_V1
  345. class FasterRCNN_ResNet50_FPN_V2_Weights(WeightsEnum):
  346. COCO_V1 = Weights(
  347. url="https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_v2_coco-dd69338a.pth",
  348. transforms=ObjectDetection,
  349. meta={
  350. **_COMMON_META,
  351. "num_params": 43712278,
  352. "recipe": "https://github.com/pytorch/vision/pull/5763",
  353. "_metrics": {
  354. "COCO-val2017": {
  355. "box_map": 46.7,
  356. }
  357. },
  358. "_ops": 280.371,
  359. "_file_size": 167.104,
  360. "_docs": """These weights were produced using an enhanced training recipe to boost the model accuracy.""",
  361. },
  362. )
  363. DEFAULT = COCO_V1
  364. class FasterRCNN_MobileNet_V3_Large_FPN_Weights(WeightsEnum):
  365. COCO_V1 = Weights(
  366. url="https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_fpn-fb6a3cc7.pth",
  367. transforms=ObjectDetection,
  368. meta={
  369. **_COMMON_META,
  370. "num_params": 19386354,
  371. "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-mobilenetv3-large-fpn",
  372. "_metrics": {
  373. "COCO-val2017": {
  374. "box_map": 32.8,
  375. }
  376. },
  377. "_ops": 4.494,
  378. "_file_size": 74.239,
  379. "_docs": """These weights were produced by following a similar training recipe as on the paper.""",
  380. },
  381. )
  382. DEFAULT = COCO_V1
  383. class FasterRCNN_MobileNet_V3_Large_320_FPN_Weights(WeightsEnum):
  384. COCO_V1 = Weights(
  385. url="https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_320_fpn-907ea3f9.pth",
  386. transforms=ObjectDetection,
  387. meta={
  388. **_COMMON_META,
  389. "num_params": 19386354,
  390. "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-mobilenetv3-large-320-fpn",
  391. "_metrics": {
  392. "COCO-val2017": {
  393. "box_map": 22.8,
  394. }
  395. },
  396. "_ops": 0.719,
  397. "_file_size": 74.239,
  398. "_docs": """These weights were produced by following a similar training recipe as on the paper.""",
  399. },
  400. )
  401. DEFAULT = COCO_V1
  402. @register_model()
  403. @handle_legacy_interface(
  404. weights=("pretrained", FasterRCNN_ResNet50_FPN_Weights.COCO_V1),
  405. weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1),
  406. )
  407. def fasterrcnn_resnet50_fpn(
  408. *,
  409. weights: Optional[FasterRCNN_ResNet50_FPN_Weights] = None,
  410. progress: bool = True,
  411. num_classes: Optional[int] = None,
  412. weights_backbone: Optional[ResNet50_Weights] = ResNet50_Weights.IMAGENET1K_V1,
  413. trainable_backbone_layers: Optional[int] = None,
  414. **kwargs: Any,
  415. ) -> FasterRCNN:
  416. """
  417. Faster R-CNN model with a ResNet-50-FPN backbone from the `Faster R-CNN: Towards Real-Time Object
  418. Detection with Region Proposal Networks <https://arxiv.org/abs/1506.01497>`__
  419. paper.
  420. .. betastatus:: detection module
  421. The input to the model is expected to be a list of tensors, each of shape ``[C, H, W]``, one for each
  422. image, and should be in ``0-1`` range. Different images can have different sizes.
  423. The behavior of the model changes depending on if it is in training or evaluation mode.
  424. During training, the model expects both the input tensors and a targets (list of dictionary),
  425. containing:
  426. - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with
  427. ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
  428. - labels (``Int64Tensor[N]``): the class label for each ground-truth box
  429. The model returns a ``Dict[Tensor]`` during training, containing the classification and regression
  430. losses for both the RPN and the R-CNN.
  431. During inference, the model requires only the input tensors, and returns the post-processed
  432. predictions as a ``List[Dict[Tensor]]``, one for each input image. The fields of the ``Dict`` are as
  433. follows, where ``N`` is the number of detections:
  434. - boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with
  435. ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
  436. - labels (``Int64Tensor[N]``): the predicted labels for each detection
  437. - scores (``Tensor[N]``): the scores of each detection
  438. For more details on the output, you may refer to :ref:`instance_seg_output`.
  439. Faster R-CNN is exportable to ONNX for a fixed batch size with inputs images of fixed size.
  440. Example::
  441. >>> model = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights=FasterRCNN_ResNet50_FPN_Weights.DEFAULT)
  442. >>> # For training
  443. >>> images, boxes = torch.rand(4, 3, 600, 1200), torch.rand(4, 11, 4)
  444. >>> boxes[:, :, 2:4] = boxes[:, :, 0:2] + boxes[:, :, 2:4]
  445. >>> labels = torch.randint(1, 91, (4, 11))
  446. >>> images = list(image for image in images)
  447. >>> targets = []
  448. >>> for i in range(len(images)):
  449. >>> d = {}
  450. >>> d['boxes'] = boxes[i]
  451. >>> d['labels'] = labels[i]
  452. >>> targets.append(d)
  453. >>> output = model(images, targets)
  454. >>> # For inference
  455. >>> model.eval()
  456. >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
  457. >>> predictions = model(x)
  458. >>>
  459. >>> # optionally, if you want to export the model to ONNX:
  460. >>> torch.onnx.export(model, x, "faster_rcnn.onnx", opset_version = 11)
  461. Args:
  462. weights (:class:`~torchvision.models.detection.FasterRCNN_ResNet50_FPN_Weights`, optional): The
  463. pretrained weights to use. See
  464. :class:`~torchvision.models.detection.FasterRCNN_ResNet50_FPN_Weights` below for
  465. more details, and possible values. By default, no pre-trained
  466. weights are used.
  467. progress (bool, optional): If True, displays a progress bar of the
  468. download to stderr. Default is True.
  469. num_classes (int, optional): number of output classes of the model (including the background)
  470. weights_backbone (:class:`~torchvision.models.ResNet50_Weights`, optional): The
  471. pretrained weights for the backbone.
  472. trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from
  473. final block. Valid values are between 0 and 5, with 5 meaning all backbone layers are
  474. trainable. If ``None`` is passed (the default) this value is set to 3.
  475. **kwargs: parameters passed to the ``torchvision.models.detection.faster_rcnn.FasterRCNN``
  476. base class. Please refer to the `source code
  477. <https://github.com/pytorch/vision/blob/main/torchvision/models/detection/faster_rcnn.py>`_
  478. for more details about this class.
  479. .. autoclass:: torchvision.models.detection.FasterRCNN_ResNet50_FPN_Weights
  480. :members:
  481. """
  482. weights = FasterRCNN_ResNet50_FPN_Weights.verify(weights)
  483. weights_backbone = ResNet50_Weights.verify(weights_backbone)
  484. if weights is not None:
  485. weights_backbone = None
  486. num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"]))
  487. elif num_classes is None:
  488. num_classes = 91
  489. is_trained = weights is not None or weights_backbone is not None
  490. trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)
  491. norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
  492. backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=norm_layer)
  493. backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers)
  494. model = FasterRCNN(backbone, num_classes=num_classes, **kwargs)
  495. if weights is not None:
  496. model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
  497. if weights == FasterRCNN_ResNet50_FPN_Weights.COCO_V1:
  498. overwrite_eps(model, 0.0)
  499. return model
  500. @register_model()
  501. @handle_legacy_interface(
  502. weights=("pretrained", FasterRCNN_ResNet50_FPN_V2_Weights.COCO_V1),
  503. weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1),
  504. )
  505. def fasterrcnn_resnet50_fpn_v2(
  506. *,
  507. weights: Optional[FasterRCNN_ResNet50_FPN_V2_Weights] = None,
  508. progress: bool = True,
  509. num_classes: Optional[int] = None,
  510. weights_backbone: Optional[ResNet50_Weights] = None,
  511. trainable_backbone_layers: Optional[int] = None,
  512. **kwargs: Any,
  513. ) -> FasterRCNN:
  514. """
  515. Constructs an improved Faster R-CNN model with a ResNet-50-FPN backbone from `Benchmarking Detection
  516. Transfer Learning with Vision Transformers <https://arxiv.org/abs/2111.11429>`__ paper.
  517. .. betastatus:: detection module
  518. It works similarly to Faster R-CNN with ResNet-50 FPN backbone. See
  519. :func:`~torchvision.models.detection.fasterrcnn_resnet50_fpn` for more
  520. details.
  521. Args:
  522. weights (:class:`~torchvision.models.detection.FasterRCNN_ResNet50_FPN_V2_Weights`, optional): The
  523. pretrained weights to use. See
  524. :class:`~torchvision.models.detection.FasterRCNN_ResNet50_FPN_V2_Weights` below for
  525. more details, and possible values. By default, no pre-trained
  526. weights are used.
  527. progress (bool, optional): If True, displays a progress bar of the
  528. download to stderr. Default is True.
  529. num_classes (int, optional): number of output classes of the model (including the background)
  530. weights_backbone (:class:`~torchvision.models.ResNet50_Weights`, optional): The
  531. pretrained weights for the backbone.
  532. trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from
  533. final block. Valid values are between 0 and 5, with 5 meaning all backbone layers are
  534. trainable. If ``None`` is passed (the default) this value is set to 3.
  535. **kwargs: parameters passed to the ``torchvision.models.detection.faster_rcnn.FasterRCNN``
  536. base class. Please refer to the `source code
  537. <https://github.com/pytorch/vision/blob/main/torchvision/models/detection/faster_rcnn.py>`_
  538. for more details about this class.
  539. .. autoclass:: torchvision.models.detection.FasterRCNN_ResNet50_FPN_V2_Weights
  540. :members:
  541. """
  542. weights = FasterRCNN_ResNet50_FPN_V2_Weights.verify(weights)
  543. weights_backbone = ResNet50_Weights.verify(weights_backbone)
  544. if weights is not None:
  545. weights_backbone = None
  546. num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"]))
  547. elif num_classes is None:
  548. num_classes = 91
  549. is_trained = weights is not None or weights_backbone is not None
  550. trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)
  551. backbone = resnet50(weights=weights_backbone, progress=progress)
  552. backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers, norm_layer=nn.BatchNorm2d)
  553. rpn_anchor_generator = _default_anchorgen()
  554. rpn_head = RPNHead(backbone.out_channels, rpn_anchor_generator.num_anchors_per_location()[0], conv_depth=2)
  555. box_head = FastRCNNConvFCHead(
  556. (backbone.out_channels, 7, 7), [256, 256, 256, 256], [1024], norm_layer=nn.BatchNorm2d
  557. )
  558. model = FasterRCNN(
  559. backbone,
  560. num_classes=num_classes,
  561. rpn_anchor_generator=rpn_anchor_generator,
  562. rpn_head=rpn_head,
  563. box_head=box_head,
  564. **kwargs,
  565. )
  566. if weights is not None:
  567. model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
  568. return model
  569. def _fasterrcnn_mobilenet_v3_large_fpn(
  570. *,
  571. weights: Optional[Union[FasterRCNN_MobileNet_V3_Large_FPN_Weights, FasterRCNN_MobileNet_V3_Large_320_FPN_Weights]],
  572. progress: bool,
  573. num_classes: Optional[int],
  574. weights_backbone: Optional[MobileNet_V3_Large_Weights],
  575. trainable_backbone_layers: Optional[int],
  576. **kwargs: Any,
  577. ) -> FasterRCNN:
  578. if weights is not None:
  579. weights_backbone = None
  580. num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"]))
  581. elif num_classes is None:
  582. num_classes = 91
  583. is_trained = weights is not None or weights_backbone is not None
  584. trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 6, 3)
  585. norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
  586. backbone = mobilenet_v3_large(weights=weights_backbone, progress=progress, norm_layer=norm_layer)
  587. backbone = _mobilenet_extractor(backbone, True, trainable_backbone_layers)
  588. anchor_sizes = (
  589. (
  590. 32,
  591. 64,
  592. 128,
  593. 256,
  594. 512,
  595. ),
  596. ) * 3
  597. aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes)
  598. model = FasterRCNN(
  599. backbone, num_classes, rpn_anchor_generator=AnchorGenerator(anchor_sizes, aspect_ratios), **kwargs
  600. )
  601. if weights is not None:
  602. model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
  603. return model
  604. @register_model()
  605. @handle_legacy_interface(
  606. weights=("pretrained", FasterRCNN_MobileNet_V3_Large_320_FPN_Weights.COCO_V1),
  607. weights_backbone=("pretrained_backbone", MobileNet_V3_Large_Weights.IMAGENET1K_V1),
  608. )
  609. def fasterrcnn_mobilenet_v3_large_320_fpn(
  610. *,
  611. weights: Optional[FasterRCNN_MobileNet_V3_Large_320_FPN_Weights] = None,
  612. progress: bool = True,
  613. num_classes: Optional[int] = None,
  614. weights_backbone: Optional[MobileNet_V3_Large_Weights] = MobileNet_V3_Large_Weights.IMAGENET1K_V1,
  615. trainable_backbone_layers: Optional[int] = None,
  616. **kwargs: Any,
  617. ) -> FasterRCNN:
  618. """
  619. Low resolution Faster R-CNN model with a MobileNetV3-Large backbone tuned for mobile use cases.
  620. .. betastatus:: detection module
  621. It works similarly to Faster R-CNN with ResNet-50 FPN backbone. See
  622. :func:`~torchvision.models.detection.fasterrcnn_resnet50_fpn` for more
  623. details.
  624. Example::
  625. >>> model = torchvision.models.detection.fasterrcnn_mobilenet_v3_large_320_fpn(weights=FasterRCNN_MobileNet_V3_Large_320_FPN_Weights.DEFAULT)
  626. >>> model.eval()
  627. >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
  628. >>> predictions = model(x)
  629. Args:
  630. weights (:class:`~torchvision.models.detection.FasterRCNN_MobileNet_V3_Large_320_FPN_Weights`, optional): The
  631. pretrained weights to use. See
  632. :class:`~torchvision.models.detection.FasterRCNN_MobileNet_V3_Large_320_FPN_Weights` below for
  633. more details, and possible values. By default, no pre-trained
  634. weights are used.
  635. progress (bool, optional): If True, displays a progress bar of the
  636. download to stderr. Default is True.
  637. num_classes (int, optional): number of output classes of the model (including the background)
  638. weights_backbone (:class:`~torchvision.models.MobileNet_V3_Large_Weights`, optional): The
  639. pretrained weights for the backbone.
  640. trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from
  641. final block. Valid values are between 0 and 6, with 6 meaning all backbone layers are
  642. trainable. If ``None`` is passed (the default) this value is set to 3.
  643. **kwargs: parameters passed to the ``torchvision.models.detection.faster_rcnn.FasterRCNN``
  644. base class. Please refer to the `source code
  645. <https://github.com/pytorch/vision/blob/main/torchvision/models/detection/faster_rcnn.py>`_
  646. for more details about this class.
  647. .. autoclass:: torchvision.models.detection.FasterRCNN_MobileNet_V3_Large_320_FPN_Weights
  648. :members:
  649. """
  650. weights = FasterRCNN_MobileNet_V3_Large_320_FPN_Weights.verify(weights)
  651. weights_backbone = MobileNet_V3_Large_Weights.verify(weights_backbone)
  652. defaults = {
  653. "min_size": 320,
  654. "max_size": 640,
  655. "rpn_pre_nms_top_n_test": 150,
  656. "rpn_post_nms_top_n_test": 150,
  657. "rpn_score_thresh": 0.05,
  658. }
  659. kwargs = {**defaults, **kwargs}
  660. return _fasterrcnn_mobilenet_v3_large_fpn(
  661. weights=weights,
  662. progress=progress,
  663. num_classes=num_classes,
  664. weights_backbone=weights_backbone,
  665. trainable_backbone_layers=trainable_backbone_layers,
  666. **kwargs,
  667. )
  668. @register_model()
  669. @handle_legacy_interface(
  670. weights=("pretrained", FasterRCNN_MobileNet_V3_Large_FPN_Weights.COCO_V1),
  671. weights_backbone=("pretrained_backbone", MobileNet_V3_Large_Weights.IMAGENET1K_V1),
  672. )
  673. def fasterrcnn_mobilenet_v3_large_fpn(
  674. *,
  675. weights: Optional[FasterRCNN_MobileNet_V3_Large_FPN_Weights] = None,
  676. progress: bool = True,
  677. num_classes: Optional[int] = None,
  678. weights_backbone: Optional[MobileNet_V3_Large_Weights] = MobileNet_V3_Large_Weights.IMAGENET1K_V1,
  679. trainable_backbone_layers: Optional[int] = None,
  680. **kwargs: Any,
  681. ) -> FasterRCNN:
  682. """
  683. Constructs a high resolution Faster R-CNN model with a MobileNetV3-Large FPN backbone.
  684. .. betastatus:: detection module
  685. It works similarly to Faster R-CNN with ResNet-50 FPN backbone. See
  686. :func:`~torchvision.models.detection.fasterrcnn_resnet50_fpn` for more
  687. details.
  688. Example::
  689. >>> model = torchvision.models.detection.fasterrcnn_mobilenet_v3_large_fpn(weights=FasterRCNN_MobileNet_V3_Large_FPN_Weights.DEFAULT)
  690. >>> model.eval()
  691. >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
  692. >>> predictions = model(x)
  693. Args:
  694. weights (:class:`~torchvision.models.detection.FasterRCNN_MobileNet_V3_Large_FPN_Weights`, optional): The
  695. pretrained weights to use. See
  696. :class:`~torchvision.models.detection.FasterRCNN_MobileNet_V3_Large_FPN_Weights` below for
  697. more details, and possible values. By default, no pre-trained
  698. weights are used.
  699. progress (bool, optional): If True, displays a progress bar of the
  700. download to stderr. Default is True.
  701. num_classes (int, optional): number of output classes of the model (including the background)
  702. weights_backbone (:class:`~torchvision.models.MobileNet_V3_Large_Weights`, optional): The
  703. pretrained weights for the backbone.
  704. trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from
  705. final block. Valid values are between 0 and 6, with 6 meaning all backbone layers are
  706. trainable. If ``None`` is passed (the default) this value is set to 3.
  707. **kwargs: parameters passed to the ``torchvision.models.detection.faster_rcnn.FasterRCNN``
  708. base class. Please refer to the `source code
  709. <https://github.com/pytorch/vision/blob/main/torchvision/models/detection/faster_rcnn.py>`_
  710. for more details about this class.
  711. .. autoclass:: torchvision.models.detection.FasterRCNN_MobileNet_V3_Large_FPN_Weights
  712. :members:
  713. """
  714. weights = FasterRCNN_MobileNet_V3_Large_FPN_Weights.verify(weights)
  715. weights_backbone = MobileNet_V3_Large_Weights.verify(weights_backbone)
  716. defaults = {
  717. "rpn_score_thresh": 0.05,
  718. }
  719. kwargs = {**defaults, **kwargs}
  720. return _fasterrcnn_mobilenet_v3_large_fpn(
  721. weights=weights,
  722. progress=progress,
  723. num_classes=num_classes,
  724. weights_backbone=weights_backbone,
  725. trainable_backbone_layers=trainable_backbone_layers,
  726. **kwargs,
  727. )