ssd.py 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682
  1. import warnings
  2. from collections import OrderedDict
  3. from typing import Any, Dict, List, Optional, Tuple
  4. import torch
  5. import torch.nn.functional as F
  6. from torch import nn, Tensor
  7. from ...ops import boxes as box_ops
  8. from ...transforms._presets import ObjectDetection
  9. from ...utils import _log_api_usage_once
  10. from .._api import register_model, Weights, WeightsEnum
  11. from .._meta import _COCO_CATEGORIES
  12. from .._utils import _ovewrite_value_param, handle_legacy_interface
  13. from ..vgg import VGG, vgg16, VGG16_Weights
  14. from . import _utils as det_utils
  15. from .anchor_utils import DefaultBoxGenerator
  16. from .backbone_utils import _validate_trainable_layers
  17. from .transform import GeneralizedRCNNTransform
  18. __all__ = [
  19. "SSD300_VGG16_Weights",
  20. "ssd300_vgg16",
  21. ]
  22. class SSD300_VGG16_Weights(WeightsEnum):
  23. COCO_V1 = Weights(
  24. url="https://download.pytorch.org/models/ssd300_vgg16_coco-b556d3b4.pth",
  25. transforms=ObjectDetection,
  26. meta={
  27. "num_params": 35641826,
  28. "categories": _COCO_CATEGORIES,
  29. "min_size": (1, 1),
  30. "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#ssd300-vgg16",
  31. "_metrics": {
  32. "COCO-val2017": {
  33. "box_map": 25.1,
  34. }
  35. },
  36. "_ops": 34.858,
  37. "_file_size": 135.988,
  38. "_docs": """These weights were produced by following a similar training recipe as on the paper.""",
  39. },
  40. )
  41. DEFAULT = COCO_V1
  42. def _xavier_init(conv: nn.Module):
  43. for layer in conv.modules():
  44. if isinstance(layer, nn.Conv2d):
  45. torch.nn.init.xavier_uniform_(layer.weight)
  46. if layer.bias is not None:
  47. torch.nn.init.constant_(layer.bias, 0.0)
  48. class SSDHead(nn.Module):
  49. def __init__(self, in_channels: List[int], num_anchors: List[int], num_classes: int):
  50. super().__init__()
  51. self.classification_head = SSDClassificationHead(in_channels, num_anchors, num_classes)
  52. self.regression_head = SSDRegressionHead(in_channels, num_anchors)
  53. def forward(self, x: List[Tensor]) -> Dict[str, Tensor]:
  54. return {
  55. "bbox_regression": self.regression_head(x),
  56. "cls_logits": self.classification_head(x),
  57. }
  58. class SSDScoringHead(nn.Module):
  59. def __init__(self, module_list: nn.ModuleList, num_columns: int):
  60. super().__init__()
  61. self.module_list = module_list
  62. self.num_columns = num_columns
  63. def _get_result_from_module_list(self, x: Tensor, idx: int) -> Tensor:
  64. """
  65. This is equivalent to self.module_list[idx](x),
  66. but torchscript doesn't support this yet
  67. """
  68. num_blocks = len(self.module_list)
  69. if idx < 0:
  70. idx += num_blocks
  71. out = x
  72. for i, module in enumerate(self.module_list):
  73. if i == idx:
  74. out = module(x)
  75. return out
  76. def forward(self, x: List[Tensor]) -> Tensor:
  77. all_results = []
  78. for i, features in enumerate(x):
  79. results = self._get_result_from_module_list(features, i)
  80. # Permute output from (N, A * K, H, W) to (N, HWA, K).
  81. N, _, H, W = results.shape
  82. results = results.view(N, -1, self.num_columns, H, W)
  83. results = results.permute(0, 3, 4, 1, 2)
  84. results = results.reshape(N, -1, self.num_columns) # Size=(N, HWA, K)
  85. all_results.append(results)
  86. return torch.cat(all_results, dim=1)
  87. class SSDClassificationHead(SSDScoringHead):
  88. def __init__(self, in_channels: List[int], num_anchors: List[int], num_classes: int):
  89. cls_logits = nn.ModuleList()
  90. for channels, anchors in zip(in_channels, num_anchors):
  91. cls_logits.append(nn.Conv2d(channels, num_classes * anchors, kernel_size=3, padding=1))
  92. _xavier_init(cls_logits)
  93. super().__init__(cls_logits, num_classes)
  94. class SSDRegressionHead(SSDScoringHead):
  95. def __init__(self, in_channels: List[int], num_anchors: List[int]):
  96. bbox_reg = nn.ModuleList()
  97. for channels, anchors in zip(in_channels, num_anchors):
  98. bbox_reg.append(nn.Conv2d(channels, 4 * anchors, kernel_size=3, padding=1))
  99. _xavier_init(bbox_reg)
  100. super().__init__(bbox_reg, 4)
  101. class SSD(nn.Module):
  102. """
  103. Implements SSD architecture from `"SSD: Single Shot MultiBox Detector" <https://arxiv.org/abs/1512.02325>`_.
  104. The input to the model is expected to be a list of tensors, each of shape [C, H, W], one for each
  105. image, and should be in 0-1 range. Different images can have different sizes, but they will be resized
  106. to a fixed size before passing it to the backbone.
  107. The behavior of the model changes depending on if it is in training or evaluation mode.
  108. During training, the model expects both the input tensors and targets (list of dictionary),
  109. containing:
  110. - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with
  111. ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
  112. - labels (Int64Tensor[N]): the class label for each ground-truth box
  113. The model returns a Dict[Tensor] during training, containing the classification and regression
  114. losses.
  115. During inference, the model requires only the input tensors, and returns the post-processed
  116. predictions as a List[Dict[Tensor]], one for each input image. The fields of the Dict are as
  117. follows, where ``N`` is the number of detections:
  118. - boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with
  119. ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
  120. - labels (Int64Tensor[N]): the predicted labels for each detection
  121. - scores (Tensor[N]): the scores for each detection
  122. Args:
  123. backbone (nn.Module): the network used to compute the features for the model.
  124. It should contain an out_channels attribute with the list of the output channels of
  125. each feature map. The backbone should return a single Tensor or an OrderedDict[Tensor].
  126. anchor_generator (DefaultBoxGenerator): module that generates the default boxes for a
  127. set of feature maps.
  128. size (Tuple[int, int]): the width and height to which images will be rescaled before feeding them
  129. to the backbone.
  130. num_classes (int): number of output classes of the model (including the background).
  131. image_mean (Tuple[float, float, float]): mean values used for input normalization.
  132. They are generally the mean values of the dataset on which the backbone has been trained
  133. on
  134. image_std (Tuple[float, float, float]): std values used for input normalization.
  135. They are generally the std values of the dataset on which the backbone has been trained on
  136. head (nn.Module, optional): Module run on top of the backbone features. Defaults to a module containing
  137. a classification and regression module.
  138. score_thresh (float): Score threshold used for postprocessing the detections.
  139. nms_thresh (float): NMS threshold used for postprocessing the detections.
  140. detections_per_img (int): Number of best detections to keep after NMS.
  141. iou_thresh (float): minimum IoU between the anchor and the GT box so that they can be
  142. considered as positive during training.
  143. topk_candidates (int): Number of best detections to keep before NMS.
  144. positive_fraction (float): a number between 0 and 1 which indicates the proportion of positive
  145. proposals used during the training of the classification head. It is used to estimate the negative to
  146. positive ratio.
  147. """
  148. __annotations__ = {
  149. "box_coder": det_utils.BoxCoder,
  150. "proposal_matcher": det_utils.Matcher,
  151. }
  152. def __init__(
  153. self,
  154. backbone: nn.Module,
  155. anchor_generator: DefaultBoxGenerator,
  156. size: Tuple[int, int],
  157. num_classes: int,
  158. image_mean: Optional[List[float]] = None,
  159. image_std: Optional[List[float]] = None,
  160. head: Optional[nn.Module] = None,
  161. score_thresh: float = 0.01,
  162. nms_thresh: float = 0.45,
  163. detections_per_img: int = 200,
  164. iou_thresh: float = 0.5,
  165. topk_candidates: int = 400,
  166. positive_fraction: float = 0.25,
  167. **kwargs: Any,
  168. ):
  169. super().__init__()
  170. _log_api_usage_once(self)
  171. self.backbone = backbone
  172. self.anchor_generator = anchor_generator
  173. self.box_coder = det_utils.BoxCoder(weights=(10.0, 10.0, 5.0, 5.0))
  174. if head is None:
  175. if hasattr(backbone, "out_channels"):
  176. out_channels = backbone.out_channels
  177. else:
  178. out_channels = det_utils.retrieve_out_channels(backbone, size)
  179. if len(out_channels) != len(anchor_generator.aspect_ratios):
  180. raise ValueError(
  181. f"The length of the output channels from the backbone ({len(out_channels)}) do not match the length of the anchor generator aspect ratios ({len(anchor_generator.aspect_ratios)})"
  182. )
  183. num_anchors = self.anchor_generator.num_anchors_per_location()
  184. head = SSDHead(out_channels, num_anchors, num_classes)
  185. self.head = head
  186. self.proposal_matcher = det_utils.SSDMatcher(iou_thresh)
  187. if image_mean is None:
  188. image_mean = [0.485, 0.456, 0.406]
  189. if image_std is None:
  190. image_std = [0.229, 0.224, 0.225]
  191. self.transform = GeneralizedRCNNTransform(
  192. min(size), max(size), image_mean, image_std, size_divisible=1, fixed_size=size, **kwargs
  193. )
  194. self.score_thresh = score_thresh
  195. self.nms_thresh = nms_thresh
  196. self.detections_per_img = detections_per_img
  197. self.topk_candidates = topk_candidates
  198. self.neg_to_pos_ratio = (1.0 - positive_fraction) / positive_fraction
  199. # used only on torchscript mode
  200. self._has_warned = False
  201. @torch.jit.unused
  202. def eager_outputs(
  203. self, losses: Dict[str, Tensor], detections: List[Dict[str, Tensor]]
  204. ) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]:
  205. if self.training:
  206. return losses
  207. return detections
  208. def compute_loss(
  209. self,
  210. targets: List[Dict[str, Tensor]],
  211. head_outputs: Dict[str, Tensor],
  212. anchors: List[Tensor],
  213. matched_idxs: List[Tensor],
  214. ) -> Dict[str, Tensor]:
  215. bbox_regression = head_outputs["bbox_regression"]
  216. cls_logits = head_outputs["cls_logits"]
  217. # Match original targets with default boxes
  218. num_foreground = 0
  219. bbox_loss = []
  220. cls_targets = []
  221. for (
  222. targets_per_image,
  223. bbox_regression_per_image,
  224. cls_logits_per_image,
  225. anchors_per_image,
  226. matched_idxs_per_image,
  227. ) in zip(targets, bbox_regression, cls_logits, anchors, matched_idxs):
  228. # produce the matching between boxes and targets
  229. foreground_idxs_per_image = torch.where(matched_idxs_per_image >= 0)[0]
  230. foreground_matched_idxs_per_image = matched_idxs_per_image[foreground_idxs_per_image]
  231. num_foreground += foreground_matched_idxs_per_image.numel()
  232. # Calculate regression loss
  233. matched_gt_boxes_per_image = targets_per_image["boxes"][foreground_matched_idxs_per_image]
  234. bbox_regression_per_image = bbox_regression_per_image[foreground_idxs_per_image, :]
  235. anchors_per_image = anchors_per_image[foreground_idxs_per_image, :]
  236. target_regression = self.box_coder.encode_single(matched_gt_boxes_per_image, anchors_per_image)
  237. bbox_loss.append(
  238. torch.nn.functional.smooth_l1_loss(bbox_regression_per_image, target_regression, reduction="sum")
  239. )
  240. # Estimate ground truth for class targets
  241. gt_classes_target = torch.zeros(
  242. (cls_logits_per_image.size(0),),
  243. dtype=targets_per_image["labels"].dtype,
  244. device=targets_per_image["labels"].device,
  245. )
  246. gt_classes_target[foreground_idxs_per_image] = targets_per_image["labels"][
  247. foreground_matched_idxs_per_image
  248. ]
  249. cls_targets.append(gt_classes_target)
  250. bbox_loss = torch.stack(bbox_loss)
  251. cls_targets = torch.stack(cls_targets)
  252. # Calculate classification loss
  253. num_classes = cls_logits.size(-1)
  254. cls_loss = F.cross_entropy(cls_logits.view(-1, num_classes), cls_targets.view(-1), reduction="none").view(
  255. cls_targets.size()
  256. )
  257. # Hard Negative Sampling
  258. foreground_idxs = cls_targets > 0
  259. num_negative = self.neg_to_pos_ratio * foreground_idxs.sum(1, keepdim=True)
  260. # num_negative[num_negative < self.neg_to_pos_ratio] = self.neg_to_pos_ratio
  261. negative_loss = cls_loss.clone()
  262. negative_loss[foreground_idxs] = -float("inf") # use -inf to detect positive values that creeped in the sample
  263. values, idx = negative_loss.sort(1, descending=True)
  264. # background_idxs = torch.logical_and(idx.sort(1)[1] < num_negative, torch.isfinite(values))
  265. background_idxs = idx.sort(1)[1] < num_negative
  266. N = max(1, num_foreground)
  267. return {
  268. "bbox_regression": bbox_loss.sum() / N,
  269. "classification": (cls_loss[foreground_idxs].sum() + cls_loss[background_idxs].sum()) / N,
  270. }
  271. def forward(
  272. self, images: List[Tensor], targets: Optional[List[Dict[str, Tensor]]] = None
  273. ) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]:
  274. if self.training:
  275. if targets is None:
  276. torch._assert(False, "targets should not be none when in training mode")
  277. else:
  278. for target in targets:
  279. boxes = target["boxes"]
  280. if isinstance(boxes, torch.Tensor):
  281. torch._assert(
  282. len(boxes.shape) == 2 and boxes.shape[-1] == 4,
  283. f"Expected target boxes to be a tensor of shape [N, 4], got {boxes.shape}.",
  284. )
  285. else:
  286. torch._assert(False, f"Expected target boxes to be of type Tensor, got {type(boxes)}.")
  287. # get the original image sizes
  288. original_image_sizes: List[Tuple[int, int]] = []
  289. for img in images:
  290. val = img.shape[-2:]
  291. torch._assert(
  292. len(val) == 2,
  293. f"expecting the last two dimensions of the Tensor to be H and W instead got {img.shape[-2:]}",
  294. )
  295. original_image_sizes.append((val[0], val[1]))
  296. # transform the input
  297. images, targets = self.transform(images, targets)
  298. # Check for degenerate boxes
  299. if targets is not None:
  300. for target_idx, target in enumerate(targets):
  301. boxes = target["boxes"]
  302. degenerate_boxes = boxes[:, 2:] <= boxes[:, :2]
  303. if degenerate_boxes.any():
  304. bb_idx = torch.where(degenerate_boxes.any(dim=1))[0][0]
  305. degen_bb: List[float] = boxes[bb_idx].tolist()
  306. torch._assert(
  307. False,
  308. "All bounding boxes should have positive height and width."
  309. f" Found invalid box {degen_bb} for target at index {target_idx}.",
  310. )
  311. # get the features from the backbone
  312. features = self.backbone(images.tensors)
  313. if isinstance(features, torch.Tensor):
  314. features = OrderedDict([("0", features)])
  315. features = list(features.values())
  316. # compute the ssd heads outputs using the features
  317. head_outputs = self.head(features)
  318. # create the set of anchors
  319. anchors = self.anchor_generator(images, features)
  320. losses = {}
  321. detections: List[Dict[str, Tensor]] = []
  322. if self.training:
  323. matched_idxs = []
  324. if targets is None:
  325. torch._assert(False, "targets should not be none when in training mode")
  326. else:
  327. for anchors_per_image, targets_per_image in zip(anchors, targets):
  328. if targets_per_image["boxes"].numel() == 0:
  329. matched_idxs.append(
  330. torch.full(
  331. (anchors_per_image.size(0),), -1, dtype=torch.int64, device=anchors_per_image.device
  332. )
  333. )
  334. continue
  335. match_quality_matrix = box_ops.box_iou(targets_per_image["boxes"], anchors_per_image)
  336. matched_idxs.append(self.proposal_matcher(match_quality_matrix))
  337. losses = self.compute_loss(targets, head_outputs, anchors, matched_idxs)
  338. else:
  339. detections = self.postprocess_detections(head_outputs, anchors, images.image_sizes)
  340. detections = self.transform.postprocess(detections, images.image_sizes, original_image_sizes)
  341. if torch.jit.is_scripting():
  342. if not self._has_warned:
  343. warnings.warn("SSD always returns a (Losses, Detections) tuple in scripting")
  344. self._has_warned = True
  345. return losses, detections
  346. return self.eager_outputs(losses, detections)
  347. def postprocess_detections(
  348. self, head_outputs: Dict[str, Tensor], image_anchors: List[Tensor], image_shapes: List[Tuple[int, int]]
  349. ) -> List[Dict[str, Tensor]]:
  350. bbox_regression = head_outputs["bbox_regression"]
  351. pred_scores = F.softmax(head_outputs["cls_logits"], dim=-1)
  352. num_classes = pred_scores.size(-1)
  353. device = pred_scores.device
  354. detections: List[Dict[str, Tensor]] = []
  355. for boxes, scores, anchors, image_shape in zip(bbox_regression, pred_scores, image_anchors, image_shapes):
  356. boxes = self.box_coder.decode_single(boxes, anchors)
  357. boxes = box_ops.clip_boxes_to_image(boxes, image_shape)
  358. image_boxes = []
  359. image_scores = []
  360. image_labels = []
  361. for label in range(1, num_classes):
  362. score = scores[:, label]
  363. keep_idxs = score > self.score_thresh
  364. score = score[keep_idxs]
  365. box = boxes[keep_idxs]
  366. # keep only topk scoring predictions
  367. num_topk = det_utils._topk_min(score, self.topk_candidates, 0)
  368. score, idxs = score.topk(num_topk)
  369. box = box[idxs]
  370. image_boxes.append(box)
  371. image_scores.append(score)
  372. image_labels.append(torch.full_like(score, fill_value=label, dtype=torch.int64, device=device))
  373. image_boxes = torch.cat(image_boxes, dim=0)
  374. image_scores = torch.cat(image_scores, dim=0)
  375. image_labels = torch.cat(image_labels, dim=0)
  376. # non-maximum suppression
  377. keep = box_ops.batched_nms(image_boxes, image_scores, image_labels, self.nms_thresh)
  378. keep = keep[: self.detections_per_img]
  379. detections.append(
  380. {
  381. "boxes": image_boxes[keep],
  382. "scores": image_scores[keep],
  383. "labels": image_labels[keep],
  384. }
  385. )
  386. return detections
  387. class SSDFeatureExtractorVGG(nn.Module):
  388. def __init__(self, backbone: nn.Module, highres: bool):
  389. super().__init__()
  390. _, _, maxpool3_pos, maxpool4_pos, _ = (i for i, layer in enumerate(backbone) if isinstance(layer, nn.MaxPool2d))
  391. # Patch ceil_mode for maxpool3 to get the same WxH output sizes as the paper
  392. backbone[maxpool3_pos].ceil_mode = True
  393. # parameters used for L2 regularization + rescaling
  394. self.scale_weight = nn.Parameter(torch.ones(512) * 20)
  395. # Multiple Feature maps - page 4, Fig 2 of SSD paper
  396. self.features = nn.Sequential(*backbone[:maxpool4_pos]) # until conv4_3
  397. # SSD300 case - page 4, Fig 2 of SSD paper
  398. extra = nn.ModuleList(
  399. [
  400. nn.Sequential(
  401. nn.Conv2d(1024, 256, kernel_size=1),
  402. nn.ReLU(inplace=True),
  403. nn.Conv2d(256, 512, kernel_size=3, padding=1, stride=2), # conv8_2
  404. nn.ReLU(inplace=True),
  405. ),
  406. nn.Sequential(
  407. nn.Conv2d(512, 128, kernel_size=1),
  408. nn.ReLU(inplace=True),
  409. nn.Conv2d(128, 256, kernel_size=3, padding=1, stride=2), # conv9_2
  410. nn.ReLU(inplace=True),
  411. ),
  412. nn.Sequential(
  413. nn.Conv2d(256, 128, kernel_size=1),
  414. nn.ReLU(inplace=True),
  415. nn.Conv2d(128, 256, kernel_size=3), # conv10_2
  416. nn.ReLU(inplace=True),
  417. ),
  418. nn.Sequential(
  419. nn.Conv2d(256, 128, kernel_size=1),
  420. nn.ReLU(inplace=True),
  421. nn.Conv2d(128, 256, kernel_size=3), # conv11_2
  422. nn.ReLU(inplace=True),
  423. ),
  424. ]
  425. )
  426. if highres:
  427. # Additional layers for the SSD512 case. See page 11, footernote 5.
  428. extra.append(
  429. nn.Sequential(
  430. nn.Conv2d(256, 128, kernel_size=1),
  431. nn.ReLU(inplace=True),
  432. nn.Conv2d(128, 256, kernel_size=4), # conv12_2
  433. nn.ReLU(inplace=True),
  434. )
  435. )
  436. _xavier_init(extra)
  437. fc = nn.Sequential(
  438. nn.MaxPool2d(kernel_size=3, stride=1, padding=1, ceil_mode=False), # add modified maxpool5
  439. nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=3, padding=6, dilation=6), # FC6 with atrous
  440. nn.ReLU(inplace=True),
  441. nn.Conv2d(in_channels=1024, out_channels=1024, kernel_size=1), # FC7
  442. nn.ReLU(inplace=True),
  443. )
  444. _xavier_init(fc)
  445. extra.insert(
  446. 0,
  447. nn.Sequential(
  448. *backbone[maxpool4_pos:-1], # until conv5_3, skip maxpool5
  449. fc,
  450. ),
  451. )
  452. self.extra = extra
  453. def forward(self, x: Tensor) -> Dict[str, Tensor]:
  454. # L2 regularization + Rescaling of 1st block's feature map
  455. x = self.features(x)
  456. rescaled = self.scale_weight.view(1, -1, 1, 1) * F.normalize(x)
  457. output = [rescaled]
  458. # Calculating Feature maps for the rest blocks
  459. for block in self.extra:
  460. x = block(x)
  461. output.append(x)
  462. return OrderedDict([(str(i), v) for i, v in enumerate(output)])
  463. def _vgg_extractor(backbone: VGG, highres: bool, trainable_layers: int):
  464. backbone = backbone.features
  465. # Gather the indices of maxpools. These are the locations of output blocks.
  466. stage_indices = [0] + [i for i, b in enumerate(backbone) if isinstance(b, nn.MaxPool2d)][:-1]
  467. num_stages = len(stage_indices)
  468. # find the index of the layer from which we won't freeze
  469. torch._assert(
  470. 0 <= trainable_layers <= num_stages,
  471. f"trainable_layers should be in the range [0, {num_stages}]. Instead got {trainable_layers}",
  472. )
  473. freeze_before = len(backbone) if trainable_layers == 0 else stage_indices[num_stages - trainable_layers]
  474. for b in backbone[:freeze_before]:
  475. for parameter in b.parameters():
  476. parameter.requires_grad_(False)
  477. return SSDFeatureExtractorVGG(backbone, highres)
  478. @register_model()
  479. @handle_legacy_interface(
  480. weights=("pretrained", SSD300_VGG16_Weights.COCO_V1),
  481. weights_backbone=("pretrained_backbone", VGG16_Weights.IMAGENET1K_FEATURES),
  482. )
  483. def ssd300_vgg16(
  484. *,
  485. weights: Optional[SSD300_VGG16_Weights] = None,
  486. progress: bool = True,
  487. num_classes: Optional[int] = None,
  488. weights_backbone: Optional[VGG16_Weights] = VGG16_Weights.IMAGENET1K_FEATURES,
  489. trainable_backbone_layers: Optional[int] = None,
  490. **kwargs: Any,
  491. ) -> SSD:
  492. """The SSD300 model is based on the `SSD: Single Shot MultiBox Detector
  493. <https://arxiv.org/abs/1512.02325>`_ paper.
  494. .. betastatus:: detection module
  495. The input to the model is expected to be a list of tensors, each of shape [C, H, W], one for each
  496. image, and should be in 0-1 range. Different images can have different sizes, but they will be resized
  497. to a fixed size before passing it to the backbone.
  498. The behavior of the model changes depending on if it is in training or evaluation mode.
  499. During training, the model expects both the input tensors and targets (list of dictionary),
  500. containing:
  501. - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with
  502. ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
  503. - labels (Int64Tensor[N]): the class label for each ground-truth box
  504. The model returns a Dict[Tensor] during training, containing the classification and regression
  505. losses.
  506. During inference, the model requires only the input tensors, and returns the post-processed
  507. predictions as a List[Dict[Tensor]], one for each input image. The fields of the Dict are as
  508. follows, where ``N`` is the number of detections:
  509. - boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with
  510. ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
  511. - labels (Int64Tensor[N]): the predicted labels for each detection
  512. - scores (Tensor[N]): the scores for each detection
  513. Example:
  514. >>> model = torchvision.models.detection.ssd300_vgg16(weights=SSD300_VGG16_Weights.DEFAULT)
  515. >>> model.eval()
  516. >>> x = [torch.rand(3, 300, 300), torch.rand(3, 500, 400)]
  517. >>> predictions = model(x)
  518. Args:
  519. weights (:class:`~torchvision.models.detection.SSD300_VGG16_Weights`, optional): The pretrained
  520. weights to use. See
  521. :class:`~torchvision.models.detection.SSD300_VGG16_Weights`
  522. below for more details, and possible values. By default, no
  523. pre-trained weights are used.
  524. progress (bool, optional): If True, displays a progress bar of the download to stderr
  525. Default is True.
  526. num_classes (int, optional): number of output classes of the model (including the background)
  527. weights_backbone (:class:`~torchvision.models.VGG16_Weights`, optional): The pretrained weights for the
  528. backbone
  529. trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from final block.
  530. Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. If ``None`` is
  531. passed (the default) this value is set to 4.
  532. **kwargs: parameters passed to the ``torchvision.models.detection.SSD``
  533. base class. Please refer to the `source code
  534. <https://github.com/pytorch/vision/blob/main/torchvision/models/detection/ssd.py>`_
  535. for more details about this class.
  536. .. autoclass:: torchvision.models.detection.SSD300_VGG16_Weights
  537. :members:
  538. """
  539. weights = SSD300_VGG16_Weights.verify(weights)
  540. weights_backbone = VGG16_Weights.verify(weights_backbone)
  541. if "size" in kwargs:
  542. warnings.warn("The size of the model is already fixed; ignoring the parameter.")
  543. if weights is not None:
  544. weights_backbone = None
  545. num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"]))
  546. elif num_classes is None:
  547. num_classes = 91
  548. trainable_backbone_layers = _validate_trainable_layers(
  549. weights is not None or weights_backbone is not None, trainable_backbone_layers, 5, 4
  550. )
  551. # Use custom backbones more appropriate for SSD
  552. backbone = vgg16(weights=weights_backbone, progress=progress)
  553. backbone = _vgg_extractor(backbone, False, trainable_backbone_layers)
  554. anchor_generator = DefaultBoxGenerator(
  555. [[2], [2, 3], [2, 3], [2, 3], [2], [2]],
  556. scales=[0.07, 0.15, 0.33, 0.51, 0.69, 0.87, 1.05],
  557. steps=[8, 16, 32, 64, 100, 300],
  558. )
  559. defaults = {
  560. # Rescale the input in a way compatible to the backbone
  561. "image_mean": [0.48235, 0.45882, 0.40784],
  562. "image_std": [1.0 / 255.0, 1.0 / 255.0, 1.0 / 255.0], # undo the 0-1 scaling of toTensor
  563. }
  564. kwargs: Any = {**defaults, **kwargs}
  565. model = SSD(backbone, anchor_generator, (300, 300), num_classes, **kwargs)
  566. if weights is not None:
  567. model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
  568. return model