123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899 |
- import math
- import warnings
- from collections import OrderedDict
- from functools import partial
- from typing import Any, Callable, Dict, List, Optional, Tuple
- import torch
- from torch import nn, Tensor
- from ...ops import boxes as box_ops, misc as misc_nn_ops, sigmoid_focal_loss
- from ...ops.feature_pyramid_network import LastLevelP6P7
- from ...transforms._presets import ObjectDetection
- from ...utils import _log_api_usage_once
- from .._api import register_model, Weights, WeightsEnum
- from .._meta import _COCO_CATEGORIES
- from .._utils import _ovewrite_value_param, handle_legacy_interface
- from ..resnet import resnet50, ResNet50_Weights
- from . import _utils as det_utils
- from ._utils import _box_loss, overwrite_eps
- from .anchor_utils import AnchorGenerator
- from .backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers
- from .transform import GeneralizedRCNNTransform
- __all__ = [
- "RetinaNet",
- "RetinaNet_ResNet50_FPN_Weights",
- "RetinaNet_ResNet50_FPN_V2_Weights",
- "retinanet_resnet50_fpn",
- "retinanet_resnet50_fpn_v2",
- ]
- def _sum(x: List[Tensor]) -> Tensor:
- res = x[0]
- for i in x[1:]:
- res = res + i
- return res
- def _v1_to_v2_weights(state_dict, prefix):
- for i in range(4):
- for type in ["weight", "bias"]:
- old_key = f"{prefix}conv.{2*i}.{type}"
- new_key = f"{prefix}conv.{i}.0.{type}"
- if old_key in state_dict:
- state_dict[new_key] = state_dict.pop(old_key)
- def _default_anchorgen():
- anchor_sizes = tuple((x, int(x * 2 ** (1.0 / 3)), int(x * 2 ** (2.0 / 3))) for x in [32, 64, 128, 256, 512])
- aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes)
- anchor_generator = AnchorGenerator(anchor_sizes, aspect_ratios)
- return anchor_generator
- class RetinaNetHead(nn.Module):
- """
- A regression and classification head for use in RetinaNet.
- Args:
- in_channels (int): number of channels of the input feature
- num_anchors (int): number of anchors to be predicted
- num_classes (int): number of classes to be predicted
- norm_layer (callable, optional): Module specifying the normalization layer to use. Default: None
- """
- def __init__(self, in_channels, num_anchors, num_classes, norm_layer: Optional[Callable[..., nn.Module]] = None):
- super().__init__()
- self.classification_head = RetinaNetClassificationHead(
- in_channels, num_anchors, num_classes, norm_layer=norm_layer
- )
- self.regression_head = RetinaNetRegressionHead(in_channels, num_anchors, norm_layer=norm_layer)
- def compute_loss(self, targets, head_outputs, anchors, matched_idxs):
- # type: (List[Dict[str, Tensor]], Dict[str, Tensor], List[Tensor], List[Tensor]) -> Dict[str, Tensor]
- return {
- "classification": self.classification_head.compute_loss(targets, head_outputs, matched_idxs),
- "bbox_regression": self.regression_head.compute_loss(targets, head_outputs, anchors, matched_idxs),
- }
- def forward(self, x):
- # type: (List[Tensor]) -> Dict[str, Tensor]
- return {"cls_logits": self.classification_head(x), "bbox_regression": self.regression_head(x)}
- class RetinaNetClassificationHead(nn.Module):
- """
- A classification head for use in RetinaNet.
- Args:
- in_channels (int): number of channels of the input feature
- num_anchors (int): number of anchors to be predicted
- num_classes (int): number of classes to be predicted
- norm_layer (callable, optional): Module specifying the normalization layer to use. Default: None
- """
- _version = 2
- def __init__(
- self,
- in_channels,
- num_anchors,
- num_classes,
- prior_probability=0.01,
- norm_layer: Optional[Callable[..., nn.Module]] = None,
- ):
- super().__init__()
- conv = []
- for _ in range(4):
- conv.append(misc_nn_ops.Conv2dNormActivation(in_channels, in_channels, norm_layer=norm_layer))
- self.conv = nn.Sequential(*conv)
- for layer in self.conv.modules():
- if isinstance(layer, nn.Conv2d):
- torch.nn.init.normal_(layer.weight, std=0.01)
- if layer.bias is not None:
- torch.nn.init.constant_(layer.bias, 0)
- self.cls_logits = nn.Conv2d(in_channels, num_anchors * num_classes, kernel_size=3, stride=1, padding=1)
- torch.nn.init.normal_(self.cls_logits.weight, std=0.01)
- torch.nn.init.constant_(self.cls_logits.bias, -math.log((1 - prior_probability) / prior_probability))
- self.num_classes = num_classes
- self.num_anchors = num_anchors
- # This is to fix using det_utils.Matcher.BETWEEN_THRESHOLDS in TorchScript.
- # TorchScript doesn't support class attributes.
- # https://github.com/pytorch/vision/pull/1697#issuecomment-630255584
- self.BETWEEN_THRESHOLDS = det_utils.Matcher.BETWEEN_THRESHOLDS
- def _load_from_state_dict(
- self,
- state_dict,
- prefix,
- local_metadata,
- strict,
- missing_keys,
- unexpected_keys,
- error_msgs,
- ):
- version = local_metadata.get("version", None)
- if version is None or version < 2:
- _v1_to_v2_weights(state_dict, prefix)
- super()._load_from_state_dict(
- state_dict,
- prefix,
- local_metadata,
- strict,
- missing_keys,
- unexpected_keys,
- error_msgs,
- )
- def compute_loss(self, targets, head_outputs, matched_idxs):
- # type: (List[Dict[str, Tensor]], Dict[str, Tensor], List[Tensor]) -> Tensor
- losses = []
- cls_logits = head_outputs["cls_logits"]
- for targets_per_image, cls_logits_per_image, matched_idxs_per_image in zip(targets, cls_logits, matched_idxs):
- # determine only the foreground
- foreground_idxs_per_image = matched_idxs_per_image >= 0
- num_foreground = foreground_idxs_per_image.sum()
- # create the target classification
- gt_classes_target = torch.zeros_like(cls_logits_per_image)
- gt_classes_target[
- foreground_idxs_per_image,
- targets_per_image["labels"][matched_idxs_per_image[foreground_idxs_per_image]],
- ] = 1.0
- # find indices for which anchors should be ignored
- valid_idxs_per_image = matched_idxs_per_image != self.BETWEEN_THRESHOLDS
- # compute the classification loss
- losses.append(
- sigmoid_focal_loss(
- cls_logits_per_image[valid_idxs_per_image],
- gt_classes_target[valid_idxs_per_image],
- reduction="sum",
- )
- / max(1, num_foreground)
- )
- return _sum(losses) / len(targets)
- def forward(self, x):
- # type: (List[Tensor]) -> Tensor
- all_cls_logits = []
- for features in x:
- cls_logits = self.conv(features)
- cls_logits = self.cls_logits(cls_logits)
- # Permute classification output from (N, A * K, H, W) to (N, HWA, K).
- N, _, H, W = cls_logits.shape
- cls_logits = cls_logits.view(N, -1, self.num_classes, H, W)
- cls_logits = cls_logits.permute(0, 3, 4, 1, 2)
- cls_logits = cls_logits.reshape(N, -1, self.num_classes) # Size=(N, HWA, 4)
- all_cls_logits.append(cls_logits)
- return torch.cat(all_cls_logits, dim=1)
- class RetinaNetRegressionHead(nn.Module):
- """
- A regression head for use in RetinaNet.
- Args:
- in_channels (int): number of channels of the input feature
- num_anchors (int): number of anchors to be predicted
- norm_layer (callable, optional): Module specifying the normalization layer to use. Default: None
- """
- _version = 2
- __annotations__ = {
- "box_coder": det_utils.BoxCoder,
- }
- def __init__(self, in_channels, num_anchors, norm_layer: Optional[Callable[..., nn.Module]] = None):
- super().__init__()
- conv = []
- for _ in range(4):
- conv.append(misc_nn_ops.Conv2dNormActivation(in_channels, in_channels, norm_layer=norm_layer))
- self.conv = nn.Sequential(*conv)
- self.bbox_reg = nn.Conv2d(in_channels, num_anchors * 4, kernel_size=3, stride=1, padding=1)
- torch.nn.init.normal_(self.bbox_reg.weight, std=0.01)
- torch.nn.init.zeros_(self.bbox_reg.bias)
- for layer in self.conv.modules():
- if isinstance(layer, nn.Conv2d):
- torch.nn.init.normal_(layer.weight, std=0.01)
- if layer.bias is not None:
- torch.nn.init.zeros_(layer.bias)
- self.box_coder = det_utils.BoxCoder(weights=(1.0, 1.0, 1.0, 1.0))
- self._loss_type = "l1"
- def _load_from_state_dict(
- self,
- state_dict,
- prefix,
- local_metadata,
- strict,
- missing_keys,
- unexpected_keys,
- error_msgs,
- ):
- version = local_metadata.get("version", None)
- if version is None or version < 2:
- _v1_to_v2_weights(state_dict, prefix)
- super()._load_from_state_dict(
- state_dict,
- prefix,
- local_metadata,
- strict,
- missing_keys,
- unexpected_keys,
- error_msgs,
- )
- def compute_loss(self, targets, head_outputs, anchors, matched_idxs):
- # type: (List[Dict[str, Tensor]], Dict[str, Tensor], List[Tensor], List[Tensor]) -> Tensor
- losses = []
- bbox_regression = head_outputs["bbox_regression"]
- for targets_per_image, bbox_regression_per_image, anchors_per_image, matched_idxs_per_image in zip(
- targets, bbox_regression, anchors, matched_idxs
- ):
- # determine only the foreground indices, ignore the rest
- foreground_idxs_per_image = torch.where(matched_idxs_per_image >= 0)[0]
- num_foreground = foreground_idxs_per_image.numel()
- # select only the foreground boxes
- matched_gt_boxes_per_image = targets_per_image["boxes"][matched_idxs_per_image[foreground_idxs_per_image]]
- bbox_regression_per_image = bbox_regression_per_image[foreground_idxs_per_image, :]
- anchors_per_image = anchors_per_image[foreground_idxs_per_image, :]
- # compute the loss
- losses.append(
- _box_loss(
- self._loss_type,
- self.box_coder,
- anchors_per_image,
- matched_gt_boxes_per_image,
- bbox_regression_per_image,
- )
- / max(1, num_foreground)
- )
- return _sum(losses) / max(1, len(targets))
- def forward(self, x):
- # type: (List[Tensor]) -> Tensor
- all_bbox_regression = []
- for features in x:
- bbox_regression = self.conv(features)
- bbox_regression = self.bbox_reg(bbox_regression)
- # Permute bbox regression output from (N, 4 * A, H, W) to (N, HWA, 4).
- N, _, H, W = bbox_regression.shape
- bbox_regression = bbox_regression.view(N, -1, 4, H, W)
- bbox_regression = bbox_regression.permute(0, 3, 4, 1, 2)
- bbox_regression = bbox_regression.reshape(N, -1, 4) # Size=(N, HWA, 4)
- all_bbox_regression.append(bbox_regression)
- return torch.cat(all_bbox_regression, dim=1)
- class RetinaNet(nn.Module):
- """
- Implements RetinaNet.
- The input to the model is expected to be a list of tensors, each of shape [C, H, W], one for each
- image, and should be in 0-1 range. Different images can have different sizes.
- The behavior of the model changes depending on if it is in training or evaluation mode.
- During training, the model expects both the input tensors and targets (list of dictionary),
- containing:
- - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with
- ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
- - labels (Int64Tensor[N]): the class label for each ground-truth box
- The model returns a Dict[Tensor] during training, containing the classification and regression
- losses.
- During inference, the model requires only the input tensors, and returns the post-processed
- predictions as a List[Dict[Tensor]], one for each input image. The fields of the Dict are as
- follows:
- - boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with
- ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
- - labels (Int64Tensor[N]): the predicted labels for each image
- - scores (Tensor[N]): the scores for each prediction
- Args:
- backbone (nn.Module): the network used to compute the features for the model.
- It should contain an out_channels attribute, which indicates the number of output
- channels that each feature map has (and it should be the same for all feature maps).
- The backbone should return a single Tensor or an OrderedDict[Tensor].
- num_classes (int): number of output classes of the model (including the background).
- min_size (int): minimum size of the image to be rescaled before feeding it to the backbone
- max_size (int): maximum size of the image to be rescaled before feeding it to the backbone
- image_mean (Tuple[float, float, float]): mean values used for input normalization.
- They are generally the mean values of the dataset on which the backbone has been trained
- on
- image_std (Tuple[float, float, float]): std values used for input normalization.
- They are generally the std values of the dataset on which the backbone has been trained on
- anchor_generator (AnchorGenerator): module that generates the anchors for a set of feature
- maps.
- head (nn.Module): Module run on top of the feature pyramid.
- Defaults to a module containing a classification and regression module.
- score_thresh (float): Score threshold used for postprocessing the detections.
- nms_thresh (float): NMS threshold used for postprocessing the detections.
- detections_per_img (int): Number of best detections to keep after NMS.
- fg_iou_thresh (float): minimum IoU between the anchor and the GT box so that they can be
- considered as positive during training.
- bg_iou_thresh (float): maximum IoU between the anchor and the GT box so that they can be
- considered as negative during training.
- topk_candidates (int): Number of best detections to keep before NMS.
- Example:
- >>> import torch
- >>> import torchvision
- >>> from torchvision.models.detection import RetinaNet
- >>> from torchvision.models.detection.anchor_utils import AnchorGenerator
- >>> # load a pre-trained model for classification and return
- >>> # only the features
- >>> backbone = torchvision.models.mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT).features
- >>> # RetinaNet needs to know the number of
- >>> # output channels in a backbone. For mobilenet_v2, it's 1280,
- >>> # so we need to add it here
- >>> backbone.out_channels = 1280
- >>>
- >>> # let's make the network generate 5 x 3 anchors per spatial
- >>> # location, with 5 different sizes and 3 different aspect
- >>> # ratios. We have a Tuple[Tuple[int]] because each feature
- >>> # map could potentially have different sizes and
- >>> # aspect ratios
- >>> anchor_generator = AnchorGenerator(
- >>> sizes=((32, 64, 128, 256, 512),),
- >>> aspect_ratios=((0.5, 1.0, 2.0),)
- >>> )
- >>>
- >>> # put the pieces together inside a RetinaNet model
- >>> model = RetinaNet(backbone,
- >>> num_classes=2,
- >>> anchor_generator=anchor_generator)
- >>> model.eval()
- >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
- >>> predictions = model(x)
- """
- __annotations__ = {
- "box_coder": det_utils.BoxCoder,
- "proposal_matcher": det_utils.Matcher,
- }
- def __init__(
- self,
- backbone,
- num_classes,
- # transform parameters
- min_size=800,
- max_size=1333,
- image_mean=None,
- image_std=None,
- # Anchor parameters
- anchor_generator=None,
- head=None,
- proposal_matcher=None,
- score_thresh=0.05,
- nms_thresh=0.5,
- detections_per_img=300,
- fg_iou_thresh=0.5,
- bg_iou_thresh=0.4,
- topk_candidates=1000,
- **kwargs,
- ):
- super().__init__()
- _log_api_usage_once(self)
- if not hasattr(backbone, "out_channels"):
- raise ValueError(
- "backbone should contain an attribute out_channels "
- "specifying the number of output channels (assumed to be the "
- "same for all the levels)"
- )
- self.backbone = backbone
- if not isinstance(anchor_generator, (AnchorGenerator, type(None))):
- raise TypeError(
- f"anchor_generator should be of type AnchorGenerator or None instead of {type(anchor_generator)}"
- )
- if anchor_generator is None:
- anchor_generator = _default_anchorgen()
- self.anchor_generator = anchor_generator
- if head is None:
- head = RetinaNetHead(backbone.out_channels, anchor_generator.num_anchors_per_location()[0], num_classes)
- self.head = head
- if proposal_matcher is None:
- proposal_matcher = det_utils.Matcher(
- fg_iou_thresh,
- bg_iou_thresh,
- allow_low_quality_matches=True,
- )
- self.proposal_matcher = proposal_matcher
- self.box_coder = det_utils.BoxCoder(weights=(1.0, 1.0, 1.0, 1.0))
- if image_mean is None:
- image_mean = [0.485, 0.456, 0.406]
- if image_std is None:
- image_std = [0.229, 0.224, 0.225]
- self.transform = GeneralizedRCNNTransform(min_size, max_size, image_mean, image_std, **kwargs)
- self.score_thresh = score_thresh
- self.nms_thresh = nms_thresh
- self.detections_per_img = detections_per_img
- self.topk_candidates = topk_candidates
- # used only on torchscript mode
- self._has_warned = False
- @torch.jit.unused
- def eager_outputs(self, losses, detections):
- # type: (Dict[str, Tensor], List[Dict[str, Tensor]]) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]
- if self.training:
- return losses
- return detections
- def compute_loss(self, targets, head_outputs, anchors):
- # type: (List[Dict[str, Tensor]], Dict[str, Tensor], List[Tensor]) -> Dict[str, Tensor]
- matched_idxs = []
- for anchors_per_image, targets_per_image in zip(anchors, targets):
- if targets_per_image["boxes"].numel() == 0:
- matched_idxs.append(
- torch.full((anchors_per_image.size(0),), -1, dtype=torch.int64, device=anchors_per_image.device)
- )
- continue
- match_quality_matrix = box_ops.box_iou(targets_per_image["boxes"], anchors_per_image)
- matched_idxs.append(self.proposal_matcher(match_quality_matrix))
- return self.head.compute_loss(targets, head_outputs, anchors, matched_idxs)
- def postprocess_detections(self, head_outputs, anchors, image_shapes):
- # type: (Dict[str, List[Tensor]], List[List[Tensor]], List[Tuple[int, int]]) -> List[Dict[str, Tensor]]
- class_logits = head_outputs["cls_logits"]
- box_regression = head_outputs["bbox_regression"]
- num_images = len(image_shapes)
- detections: List[Dict[str, Tensor]] = []
- for index in range(num_images):
- box_regression_per_image = [br[index] for br in box_regression]
- logits_per_image = [cl[index] for cl in class_logits]
- anchors_per_image, image_shape = anchors[index], image_shapes[index]
- image_boxes = []
- image_scores = []
- image_labels = []
- for box_regression_per_level, logits_per_level, anchors_per_level in zip(
- box_regression_per_image, logits_per_image, anchors_per_image
- ):
- num_classes = logits_per_level.shape[-1]
- # remove low scoring boxes
- scores_per_level = torch.sigmoid(logits_per_level).flatten()
- keep_idxs = scores_per_level > self.score_thresh
- scores_per_level = scores_per_level[keep_idxs]
- topk_idxs = torch.where(keep_idxs)[0]
- # keep only topk scoring predictions
- num_topk = det_utils._topk_min(topk_idxs, self.topk_candidates, 0)
- scores_per_level, idxs = scores_per_level.topk(num_topk)
- topk_idxs = topk_idxs[idxs]
- anchor_idxs = torch.div(topk_idxs, num_classes, rounding_mode="floor")
- labels_per_level = topk_idxs % num_classes
- boxes_per_level = self.box_coder.decode_single(
- box_regression_per_level[anchor_idxs], anchors_per_level[anchor_idxs]
- )
- boxes_per_level = box_ops.clip_boxes_to_image(boxes_per_level, image_shape)
- image_boxes.append(boxes_per_level)
- image_scores.append(scores_per_level)
- image_labels.append(labels_per_level)
- image_boxes = torch.cat(image_boxes, dim=0)
- image_scores = torch.cat(image_scores, dim=0)
- image_labels = torch.cat(image_labels, dim=0)
- # non-maximum suppression
- keep = box_ops.batched_nms(image_boxes, image_scores, image_labels, self.nms_thresh)
- keep = keep[: self.detections_per_img]
- detections.append(
- {
- "boxes": image_boxes[keep],
- "scores": image_scores[keep],
- "labels": image_labels[keep],
- }
- )
- return detections
- def forward(self, images, targets=None):
- # type: (List[Tensor], Optional[List[Dict[str, Tensor]]]) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]
- """
- Args:
- images (list[Tensor]): images to be processed
- targets (list[Dict[Tensor]]): ground-truth boxes present in the image (optional)
- Returns:
- result (list[BoxList] or dict[Tensor]): the output from the model.
- During training, it returns a dict[Tensor] which contains the losses.
- During testing, it returns list[BoxList] contains additional fields
- like `scores`, `labels` and `mask` (for Mask R-CNN models).
- """
- if self.training:
- if targets is None:
- torch._assert(False, "targets should not be none when in training mode")
- else:
- for target in targets:
- boxes = target["boxes"]
- torch._assert(isinstance(boxes, torch.Tensor), "Expected target boxes to be of type Tensor.")
- torch._assert(
- len(boxes.shape) == 2 and boxes.shape[-1] == 4,
- "Expected target boxes to be a tensor of shape [N, 4].",
- )
- # get the original image sizes
- original_image_sizes: List[Tuple[int, int]] = []
- for img in images:
- val = img.shape[-2:]
- torch._assert(
- len(val) == 2,
- f"expecting the last two dimensions of the Tensor to be H and W instead got {img.shape[-2:]}",
- )
- original_image_sizes.append((val[0], val[1]))
- # transform the input
- images, targets = self.transform(images, targets)
- # Check for degenerate boxes
- # TODO: Move this to a function
- if targets is not None:
- for target_idx, target in enumerate(targets):
- boxes = target["boxes"]
- degenerate_boxes = boxes[:, 2:] <= boxes[:, :2]
- if degenerate_boxes.any():
- # print the first degenerate box
- bb_idx = torch.where(degenerate_boxes.any(dim=1))[0][0]
- degen_bb: List[float] = boxes[bb_idx].tolist()
- torch._assert(
- False,
- "All bounding boxes should have positive height and width."
- f" Found invalid box {degen_bb} for target at index {target_idx}.",
- )
- # get the features from the backbone
- features = self.backbone(images.tensors)
- if isinstance(features, torch.Tensor):
- features = OrderedDict([("0", features)])
- # TODO: Do we want a list or a dict?
- features = list(features.values())
- # compute the retinanet heads outputs using the features
- head_outputs = self.head(features)
- # create the set of anchors
- anchors = self.anchor_generator(images, features)
- losses = {}
- detections: List[Dict[str, Tensor]] = []
- if self.training:
- if targets is None:
- torch._assert(False, "targets should not be none when in training mode")
- else:
- # compute the losses
- losses = self.compute_loss(targets, head_outputs, anchors)
- else:
- # recover level sizes
- num_anchors_per_level = [x.size(2) * x.size(3) for x in features]
- HW = 0
- for v in num_anchors_per_level:
- HW += v
- HWA = head_outputs["cls_logits"].size(1)
- A = HWA // HW
- num_anchors_per_level = [hw * A for hw in num_anchors_per_level]
- # split outputs per level
- split_head_outputs: Dict[str, List[Tensor]] = {}
- for k in head_outputs:
- split_head_outputs[k] = list(head_outputs[k].split(num_anchors_per_level, dim=1))
- split_anchors = [list(a.split(num_anchors_per_level)) for a in anchors]
- # compute the detections
- detections = self.postprocess_detections(split_head_outputs, split_anchors, images.image_sizes)
- detections = self.transform.postprocess(detections, images.image_sizes, original_image_sizes)
- if torch.jit.is_scripting():
- if not self._has_warned:
- warnings.warn("RetinaNet always returns a (Losses, Detections) tuple in scripting")
- self._has_warned = True
- return losses, detections
- return self.eager_outputs(losses, detections)
- _COMMON_META = {
- "categories": _COCO_CATEGORIES,
- "min_size": (1, 1),
- }
- class RetinaNet_ResNet50_FPN_Weights(WeightsEnum):
- COCO_V1 = Weights(
- url="https://download.pytorch.org/models/retinanet_resnet50_fpn_coco-eeacb38b.pth",
- transforms=ObjectDetection,
- meta={
- **_COMMON_META,
- "num_params": 34014999,
- "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#retinanet",
- "_metrics": {
- "COCO-val2017": {
- "box_map": 36.4,
- }
- },
- "_ops": 151.54,
- "_file_size": 130.267,
- "_docs": """These weights were produced by following a similar training recipe as on the paper.""",
- },
- )
- DEFAULT = COCO_V1
- class RetinaNet_ResNet50_FPN_V2_Weights(WeightsEnum):
- COCO_V1 = Weights(
- url="https://download.pytorch.org/models/retinanet_resnet50_fpn_v2_coco-5905b1c5.pth",
- transforms=ObjectDetection,
- meta={
- **_COMMON_META,
- "num_params": 38198935,
- "recipe": "https://github.com/pytorch/vision/pull/5756",
- "_metrics": {
- "COCO-val2017": {
- "box_map": 41.5,
- }
- },
- "_ops": 152.238,
- "_file_size": 146.037,
- "_docs": """These weights were produced using an enhanced training recipe to boost the model accuracy.""",
- },
- )
- DEFAULT = COCO_V1
- @register_model()
- @handle_legacy_interface(
- weights=("pretrained", RetinaNet_ResNet50_FPN_Weights.COCO_V1),
- weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1),
- )
- def retinanet_resnet50_fpn(
- *,
- weights: Optional[RetinaNet_ResNet50_FPN_Weights] = None,
- progress: bool = True,
- num_classes: Optional[int] = None,
- weights_backbone: Optional[ResNet50_Weights] = ResNet50_Weights.IMAGENET1K_V1,
- trainable_backbone_layers: Optional[int] = None,
- **kwargs: Any,
- ) -> RetinaNet:
- """
- Constructs a RetinaNet model with a ResNet-50-FPN backbone.
- .. betastatus:: detection module
- Reference: `Focal Loss for Dense Object Detection <https://arxiv.org/abs/1708.02002>`_.
- The input to the model is expected to be a list of tensors, each of shape ``[C, H, W]``, one for each
- image, and should be in ``0-1`` range. Different images can have different sizes.
- The behavior of the model changes depending on if it is in training or evaluation mode.
- During training, the model expects both the input tensors and targets (list of dictionary),
- containing:
- - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with
- ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
- - labels (``Int64Tensor[N]``): the class label for each ground-truth box
- The model returns a ``Dict[Tensor]`` during training, containing the classification and regression
- losses.
- During inference, the model requires only the input tensors, and returns the post-processed
- predictions as a ``List[Dict[Tensor]]``, one for each input image. The fields of the ``Dict`` are as
- follows, where ``N`` is the number of detections:
- - boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with
- ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
- - labels (``Int64Tensor[N]``): the predicted labels for each detection
- - scores (``Tensor[N]``): the scores of each detection
- For more details on the output, you may refer to :ref:`instance_seg_output`.
- Example::
- >>> model = torchvision.models.detection.retinanet_resnet50_fpn(weights=RetinaNet_ResNet50_FPN_Weights.DEFAULT)
- >>> model.eval()
- >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
- >>> predictions = model(x)
- Args:
- weights (:class:`~torchvision.models.detection.RetinaNet_ResNet50_FPN_Weights`, optional): The
- pretrained weights to use. See
- :class:`~torchvision.models.detection.RetinaNet_ResNet50_FPN_Weights`
- below for more details, and possible values. By default, no
- pre-trained weights are used.
- progress (bool): If True, displays a progress bar of the download to stderr. Default is True.
- num_classes (int, optional): number of output classes of the model (including the background)
- weights_backbone (:class:`~torchvision.models.ResNet50_Weights`, optional): The pretrained weights for
- the backbone.
- trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from final block.
- Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. If ``None`` is
- passed (the default) this value is set to 3.
- **kwargs: parameters passed to the ``torchvision.models.detection.RetinaNet``
- base class. Please refer to the `source code
- <https://github.com/pytorch/vision/blob/main/torchvision/models/detection/retinanet.py>`_
- for more details about this class.
- .. autoclass:: torchvision.models.detection.RetinaNet_ResNet50_FPN_Weights
- :members:
- """
- weights = RetinaNet_ResNet50_FPN_Weights.verify(weights)
- weights_backbone = ResNet50_Weights.verify(weights_backbone)
- if weights is not None:
- weights_backbone = None
- num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"]))
- elif num_classes is None:
- num_classes = 91
- is_trained = weights is not None or weights_backbone is not None
- trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)
- norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
- backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=norm_layer)
- # skip P2 because it generates too many anchors (according to their paper)
- backbone = _resnet_fpn_extractor(
- backbone, trainable_backbone_layers, returned_layers=[2, 3, 4], extra_blocks=LastLevelP6P7(256, 256)
- )
- model = RetinaNet(backbone, num_classes, **kwargs)
- if weights is not None:
- model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
- if weights == RetinaNet_ResNet50_FPN_Weights.COCO_V1:
- overwrite_eps(model, 0.0)
- return model
- @register_model()
- @handle_legacy_interface(
- weights=("pretrained", RetinaNet_ResNet50_FPN_V2_Weights.COCO_V1),
- weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1),
- )
- def retinanet_resnet50_fpn_v2(
- *,
- weights: Optional[RetinaNet_ResNet50_FPN_V2_Weights] = None,
- progress: bool = True,
- num_classes: Optional[int] = None,
- weights_backbone: Optional[ResNet50_Weights] = None,
- trainable_backbone_layers: Optional[int] = None,
- **kwargs: Any,
- ) -> RetinaNet:
- """
- Constructs an improved RetinaNet model with a ResNet-50-FPN backbone.
- .. betastatus:: detection module
- Reference: `Bridging the Gap Between Anchor-based and Anchor-free Detection via Adaptive Training Sample Selection
- <https://arxiv.org/abs/1912.02424>`_.
- :func:`~torchvision.models.detection.retinanet_resnet50_fpn` for more details.
- Args:
- weights (:class:`~torchvision.models.detection.RetinaNet_ResNet50_FPN_V2_Weights`, optional): The
- pretrained weights to use. See
- :class:`~torchvision.models.detection.RetinaNet_ResNet50_FPN_V2_Weights`
- below for more details, and possible values. By default, no
- pre-trained weights are used.
- progress (bool): If True, displays a progress bar of the download to stderr. Default is True.
- num_classes (int, optional): number of output classes of the model (including the background)
- weights_backbone (:class:`~torchvision.models.ResNet50_Weights`, optional): The pretrained weights for
- the backbone.
- trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from final block.
- Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. If ``None`` is
- passed (the default) this value is set to 3.
- **kwargs: parameters passed to the ``torchvision.models.detection.RetinaNet``
- base class. Please refer to the `source code
- <https://github.com/pytorch/vision/blob/main/torchvision/models/detection/retinanet.py>`_
- for more details about this class.
- .. autoclass:: torchvision.models.detection.RetinaNet_ResNet50_FPN_V2_Weights
- :members:
- """
- weights = RetinaNet_ResNet50_FPN_V2_Weights.verify(weights)
- weights_backbone = ResNet50_Weights.verify(weights_backbone)
- if weights is not None:
- weights_backbone = None
- num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"]))
- elif num_classes is None:
- num_classes = 91
- is_trained = weights is not None or weights_backbone is not None
- trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)
- backbone = resnet50(weights=weights_backbone, progress=progress)
- backbone = _resnet_fpn_extractor(
- backbone, trainable_backbone_layers, returned_layers=[2, 3, 4], extra_blocks=LastLevelP6P7(2048, 256)
- )
- anchor_generator = _default_anchorgen()
- head = RetinaNetHead(
- backbone.out_channels,
- anchor_generator.num_anchors_per_location()[0],
- num_classes,
- norm_layer=partial(nn.GroupNorm, 32),
- )
- head.regression_head._loss_type = "giou"
- model = RetinaNet(backbone, num_classes, anchor_generator=anchor_generator, head=head, **kwargs)
- if weights is not None:
- model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
- return model
|