generalized_rcnn.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118
  1. """
  2. Implements the Generalized R-CNN framework
  3. """
  4. import warnings
  5. from collections import OrderedDict
  6. from typing import Dict, List, Optional, Tuple, Union
  7. import torch
  8. from torch import nn, Tensor
  9. from ...utils import _log_api_usage_once
  10. class GeneralizedRCNN(nn.Module):
  11. """
  12. Main class for Generalized R-CNN.
  13. Args:
  14. backbone (nn.Module):
  15. rpn (nn.Module):
  16. roi_heads (nn.Module): takes the features + the proposals from the RPN and computes
  17. detections / masks from it.
  18. transform (nn.Module): performs the data transformation from the inputs to feed into
  19. the model
  20. """
  21. def __init__(self, backbone: nn.Module, rpn: nn.Module, roi_heads: nn.Module, transform: nn.Module) -> None:
  22. super().__init__()
  23. _log_api_usage_once(self)
  24. self.transform = transform
  25. self.backbone = backbone
  26. self.rpn = rpn
  27. self.roi_heads = roi_heads
  28. # used only on torchscript mode
  29. self._has_warned = False
  30. @torch.jit.unused
  31. def eager_outputs(self, losses, detections):
  32. # type: (Dict[str, Tensor], List[Dict[str, Tensor]]) -> Union[Dict[str, Tensor], List[Dict[str, Tensor]]]
  33. if self.training:
  34. return losses
  35. return detections
  36. def forward(self, images, targets=None):
  37. # type: (List[Tensor], Optional[List[Dict[str, Tensor]]]) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]
  38. """
  39. Args:
  40. images (list[Tensor]): images to be processed
  41. targets (list[Dict[str, Tensor]]): ground-truth boxes present in the image (optional)
  42. Returns:
  43. result (list[BoxList] or dict[Tensor]): the output from the model.
  44. During training, it returns a dict[Tensor] which contains the losses.
  45. During testing, it returns list[BoxList] contains additional fields
  46. like `scores`, `labels` and `mask` (for Mask R-CNN models).
  47. """
  48. if self.training:
  49. if targets is None:
  50. torch._assert(False, "targets should not be none when in training mode")
  51. else:
  52. for target in targets:
  53. boxes = target["boxes"]
  54. if isinstance(boxes, torch.Tensor):
  55. torch._assert(
  56. len(boxes.shape) == 2 and boxes.shape[-1] == 4,
  57. f"Expected target boxes to be a tensor of shape [N, 4], got {boxes.shape}.",
  58. )
  59. else:
  60. torch._assert(False, f"Expected target boxes to be of type Tensor, got {type(boxes)}.")
  61. original_image_sizes: List[Tuple[int, int]] = []
  62. for img in images:
  63. val = img.shape[-2:]
  64. torch._assert(
  65. len(val) == 2,
  66. f"expecting the last two dimensions of the Tensor to be H and W instead got {img.shape[-2:]}",
  67. )
  68. original_image_sizes.append((val[0], val[1]))
  69. images, targets = self.transform(images, targets)
  70. # Check for degenerate boxes
  71. # TODO: Move this to a function
  72. if targets is not None:
  73. for target_idx, target in enumerate(targets):
  74. boxes = target["boxes"]
  75. degenerate_boxes = boxes[:, 2:] <= boxes[:, :2]
  76. if degenerate_boxes.any():
  77. # print the first degenerate box
  78. bb_idx = torch.where(degenerate_boxes.any(dim=1))[0][0]
  79. degen_bb: List[float] = boxes[bb_idx].tolist()
  80. torch._assert(
  81. False,
  82. "All bounding boxes should have positive height and width."
  83. f" Found invalid box {degen_bb} for target at index {target_idx}.",
  84. )
  85. features = self.backbone(images.tensors)
  86. if isinstance(features, torch.Tensor):
  87. features = OrderedDict([("0", features)])
  88. proposals, proposal_losses = self.rpn(images, features, targets)
  89. detections, detector_losses = self.roi_heads(features, proposals, images.image_sizes, targets)
  90. detections = self.transform.postprocess(detections, images.image_sizes, original_image_sizes) # type: ignore[operator]
  91. losses = {}
  92. losses.update(detector_losses)
  93. losses.update(proposal_losses)
  94. if torch.jit.is_scripting():
  95. if not self._has_warned:
  96. warnings.warn("RCNN always returns a (Losses, Detections) tuple in scripting")
  97. self._has_warned = True
  98. return losses, detections
  99. else:
  100. return self.eager_outputs(losses, detections)