test_models_detection_negative_samples.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167
  1. import pytest
  2. import torch
  3. import torchvision.models
  4. from common_utils import assert_equal
  5. from torchvision.models.detection.faster_rcnn import FastRCNNPredictor, TwoMLPHead
  6. from torchvision.models.detection.roi_heads import RoIHeads
  7. from torchvision.models.detection.rpn import AnchorGenerator, RegionProposalNetwork, RPNHead
  8. from torchvision.ops import MultiScaleRoIAlign
  9. class TestModelsDetectionNegativeSamples:
  10. def _make_empty_sample(self, add_masks=False, add_keypoints=False):
  11. images = [torch.rand((3, 100, 100), dtype=torch.float32)]
  12. boxes = torch.zeros((0, 4), dtype=torch.float32)
  13. negative_target = {
  14. "boxes": boxes,
  15. "labels": torch.zeros(0, dtype=torch.int64),
  16. "image_id": 4,
  17. "area": (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0]),
  18. "iscrowd": torch.zeros((0,), dtype=torch.int64),
  19. }
  20. if add_masks:
  21. negative_target["masks"] = torch.zeros(0, 100, 100, dtype=torch.uint8)
  22. if add_keypoints:
  23. negative_target["keypoints"] = torch.zeros(17, 0, 3, dtype=torch.float32)
  24. targets = [negative_target]
  25. return images, targets
  26. def test_targets_to_anchors(self):
  27. _, targets = self._make_empty_sample()
  28. anchors = [torch.randint(-50, 50, (3, 4), dtype=torch.float32)]
  29. anchor_sizes = ((32,), (64,), (128,), (256,), (512,))
  30. aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes)
  31. rpn_anchor_generator = AnchorGenerator(anchor_sizes, aspect_ratios)
  32. rpn_head = RPNHead(4, rpn_anchor_generator.num_anchors_per_location()[0])
  33. head = RegionProposalNetwork(rpn_anchor_generator, rpn_head, 0.5, 0.3, 256, 0.5, 2000, 2000, 0.7, 0.05)
  34. labels, matched_gt_boxes = head.assign_targets_to_anchors(anchors, targets)
  35. assert labels[0].sum() == 0
  36. assert labels[0].shape == torch.Size([anchors[0].shape[0]])
  37. assert labels[0].dtype == torch.float32
  38. assert matched_gt_boxes[0].sum() == 0
  39. assert matched_gt_boxes[0].shape == anchors[0].shape
  40. assert matched_gt_boxes[0].dtype == torch.float32
  41. def test_assign_targets_to_proposals(self):
  42. proposals = [torch.randint(-50, 50, (20, 4), dtype=torch.float32)]
  43. gt_boxes = [torch.zeros((0, 4), dtype=torch.float32)]
  44. gt_labels = [torch.tensor([[0]], dtype=torch.int64)]
  45. box_roi_pool = MultiScaleRoIAlign(featmap_names=["0", "1", "2", "3"], output_size=7, sampling_ratio=2)
  46. resolution = box_roi_pool.output_size[0]
  47. representation_size = 1024
  48. box_head = TwoMLPHead(4 * resolution**2, representation_size)
  49. representation_size = 1024
  50. box_predictor = FastRCNNPredictor(representation_size, 2)
  51. roi_heads = RoIHeads(
  52. # Box
  53. box_roi_pool,
  54. box_head,
  55. box_predictor,
  56. 0.5,
  57. 0.5,
  58. 512,
  59. 0.25,
  60. None,
  61. 0.05,
  62. 0.5,
  63. 100,
  64. )
  65. matched_idxs, labels = roi_heads.assign_targets_to_proposals(proposals, gt_boxes, gt_labels)
  66. assert matched_idxs[0].sum() == 0
  67. assert matched_idxs[0].shape == torch.Size([proposals[0].shape[0]])
  68. assert matched_idxs[0].dtype == torch.int64
  69. assert labels[0].sum() == 0
  70. assert labels[0].shape == torch.Size([proposals[0].shape[0]])
  71. assert labels[0].dtype == torch.int64
  72. @pytest.mark.parametrize(
  73. "name",
  74. [
  75. "fasterrcnn_resnet50_fpn",
  76. "fasterrcnn_mobilenet_v3_large_fpn",
  77. "fasterrcnn_mobilenet_v3_large_320_fpn",
  78. ],
  79. )
  80. def test_forward_negative_sample_frcnn(self, name):
  81. model = torchvision.models.get_model(
  82. name, weights=None, weights_backbone=None, num_classes=2, min_size=100, max_size=100
  83. )
  84. images, targets = self._make_empty_sample()
  85. loss_dict = model(images, targets)
  86. assert_equal(loss_dict["loss_box_reg"], torch.tensor(0.0))
  87. assert_equal(loss_dict["loss_rpn_box_reg"], torch.tensor(0.0))
  88. def test_forward_negative_sample_mrcnn(self):
  89. model = torchvision.models.detection.maskrcnn_resnet50_fpn(
  90. weights=None, weights_backbone=None, num_classes=2, min_size=100, max_size=100
  91. )
  92. images, targets = self._make_empty_sample(add_masks=True)
  93. loss_dict = model(images, targets)
  94. assert_equal(loss_dict["loss_box_reg"], torch.tensor(0.0))
  95. assert_equal(loss_dict["loss_rpn_box_reg"], torch.tensor(0.0))
  96. assert_equal(loss_dict["loss_mask"], torch.tensor(0.0))
  97. def test_forward_negative_sample_krcnn(self):
  98. model = torchvision.models.detection.keypointrcnn_resnet50_fpn(
  99. weights=None, weights_backbone=None, num_classes=2, min_size=100, max_size=100
  100. )
  101. images, targets = self._make_empty_sample(add_keypoints=True)
  102. loss_dict = model(images, targets)
  103. assert_equal(loss_dict["loss_box_reg"], torch.tensor(0.0))
  104. assert_equal(loss_dict["loss_rpn_box_reg"], torch.tensor(0.0))
  105. assert_equal(loss_dict["loss_keypoint"], torch.tensor(0.0))
  106. def test_forward_negative_sample_retinanet(self):
  107. model = torchvision.models.detection.retinanet_resnet50_fpn(
  108. weights=None, weights_backbone=None, num_classes=2, min_size=100, max_size=100
  109. )
  110. images, targets = self._make_empty_sample()
  111. loss_dict = model(images, targets)
  112. assert_equal(loss_dict["bbox_regression"], torch.tensor(0.0))
  113. def test_forward_negative_sample_fcos(self):
  114. model = torchvision.models.detection.fcos_resnet50_fpn(
  115. weights=None, weights_backbone=None, num_classes=2, min_size=100, max_size=100
  116. )
  117. images, targets = self._make_empty_sample()
  118. loss_dict = model(images, targets)
  119. assert_equal(loss_dict["bbox_regression"], torch.tensor(0.0))
  120. assert_equal(loss_dict["bbox_ctrness"], torch.tensor(0.0))
  121. def test_forward_negative_sample_ssd(self):
  122. model = torchvision.models.detection.ssd300_vgg16(weights=None, weights_backbone=None, num_classes=2)
  123. images, targets = self._make_empty_sample()
  124. loss_dict = model(images, targets)
  125. assert_equal(loss_dict["bbox_regression"], torch.tensor(0.0))
  126. if __name__ == "__main__":
  127. pytest.main([__file__])