123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118 |
- """
- Implements the Generalized R-CNN framework
- """
- import warnings
- from collections import OrderedDict
- from typing import Dict, List, Optional, Tuple, Union
- import torch
- from torch import nn, Tensor
- from ...utils import _log_api_usage_once
- class GeneralizedRCNN(nn.Module):
- """
- Main class for Generalized R-CNN.
- Args:
- backbone (nn.Module):
- rpn (nn.Module):
- roi_heads (nn.Module): takes the features + the proposals from the RPN and computes
- detections / masks from it.
- transform (nn.Module): performs the data transformation from the inputs to feed into
- the model
- """
- def __init__(self, backbone: nn.Module, rpn: nn.Module, roi_heads: nn.Module, transform: nn.Module) -> None:
- super().__init__()
- _log_api_usage_once(self)
- self.transform = transform
- self.backbone = backbone
- self.rpn = rpn
- self.roi_heads = roi_heads
- # used only on torchscript mode
- self._has_warned = False
- @torch.jit.unused
- def eager_outputs(self, losses, detections):
- # type: (Dict[str, Tensor], List[Dict[str, Tensor]]) -> Union[Dict[str, Tensor], List[Dict[str, Tensor]]]
- if self.training:
- return losses
- return detections
- def forward(self, images, targets=None):
- # type: (List[Tensor], Optional[List[Dict[str, Tensor]]]) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]
- """
- Args:
- images (list[Tensor]): images to be processed
- targets (list[Dict[str, Tensor]]): ground-truth boxes present in the image (optional)
- Returns:
- result (list[BoxList] or dict[Tensor]): the output from the model.
- During training, it returns a dict[Tensor] which contains the losses.
- During testing, it returns list[BoxList] contains additional fields
- like `scores`, `labels` and `mask` (for Mask R-CNN models).
- """
- if self.training:
- if targets is None:
- torch._assert(False, "targets should not be none when in training mode")
- else:
- for target in targets:
- boxes = target["boxes"]
- if isinstance(boxes, torch.Tensor):
- torch._assert(
- len(boxes.shape) == 2 and boxes.shape[-1] == 4,
- f"Expected target boxes to be a tensor of shape [N, 4], got {boxes.shape}.",
- )
- else:
- torch._assert(False, f"Expected target boxes to be of type Tensor, got {type(boxes)}.")
- original_image_sizes: List[Tuple[int, int]] = []
- for img in images:
- val = img.shape[-2:]
- torch._assert(
- len(val) == 2,
- f"expecting the last two dimensions of the Tensor to be H and W instead got {img.shape[-2:]}",
- )
- original_image_sizes.append((val[0], val[1]))
- images, targets = self.transform(images, targets)
- # Check for degenerate boxes
- # TODO: Move this to a function
- if targets is not None:
- for target_idx, target in enumerate(targets):
- boxes = target["boxes"]
- degenerate_boxes = boxes[:, 2:] <= boxes[:, :2]
- if degenerate_boxes.any():
- # print the first degenerate box
- bb_idx = torch.where(degenerate_boxes.any(dim=1))[0][0]
- degen_bb: List[float] = boxes[bb_idx].tolist()
- torch._assert(
- False,
- "All bounding boxes should have positive height and width."
- f" Found invalid box {degen_bb} for target at index {target_idx}.",
- )
- features = self.backbone(images.tensors)
- if isinstance(features, torch.Tensor):
- features = OrderedDict([("0", features)])
- proposals, proposal_losses = self.rpn(images, features, targets)
- detections, detector_losses = self.roi_heads(features, proposals, images.image_sizes, targets)
- detections = self.transform.postprocess(detections, images.image_sizes, original_image_sizes) # type: ignore[operator]
- losses = {}
- losses.update(detector_losses)
- losses.update(proposal_losses)
- if torch.jit.is_scripting():
- if not self._has_warned:
- warnings.warn("RCNN always returns a (Losses, Detections) tuple in scripting")
- self._has_warned = True
- return losses, detections
- else:
- return self.eager_outputs(losses, detections)
|