test_models_detection_anchor_utils.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. import pytest
  2. import torch
  3. from common_utils import assert_equal
  4. from torchvision.models.detection.anchor_utils import AnchorGenerator, DefaultBoxGenerator
  5. from torchvision.models.detection.image_list import ImageList
  6. class Tester:
  7. def test_incorrect_anchors(self):
  8. incorrect_sizes = (
  9. (2, 4, 8),
  10. (32, 8),
  11. )
  12. incorrect_aspects = (0.5, 1.0)
  13. anc = AnchorGenerator(incorrect_sizes, incorrect_aspects)
  14. image1 = torch.randn(3, 800, 800)
  15. image_list = ImageList(image1, [(800, 800)])
  16. feature_maps = [torch.randn(1, 50)]
  17. pytest.raises(AssertionError, anc, image_list, feature_maps)
  18. def _init_test_anchor_generator(self):
  19. anchor_sizes = ((10,),)
  20. aspect_ratios = ((1,),)
  21. anchor_generator = AnchorGenerator(anchor_sizes, aspect_ratios)
  22. return anchor_generator
  23. def _init_test_defaultbox_generator(self):
  24. aspect_ratios = [[2]]
  25. dbox_generator = DefaultBoxGenerator(aspect_ratios)
  26. return dbox_generator
  27. def get_features(self, images):
  28. s0, s1 = images.shape[-2:]
  29. features = [torch.rand(2, 8, s0 // 5, s1 // 5)]
  30. return features
  31. def test_anchor_generator(self):
  32. images = torch.randn(2, 3, 15, 15)
  33. features = self.get_features(images)
  34. image_shapes = [i.shape[-2:] for i in images]
  35. images = ImageList(images, image_shapes)
  36. model = self._init_test_anchor_generator()
  37. model.eval()
  38. anchors = model(images, features)
  39. # Estimate the number of target anchors
  40. grid_sizes = [f.shape[-2:] for f in features]
  41. num_anchors_estimated = 0
  42. for sizes, num_anchors_per_loc in zip(grid_sizes, model.num_anchors_per_location()):
  43. num_anchors_estimated += sizes[0] * sizes[1] * num_anchors_per_loc
  44. anchors_output = torch.tensor(
  45. [
  46. [-5.0, -5.0, 5.0, 5.0],
  47. [0.0, -5.0, 10.0, 5.0],
  48. [5.0, -5.0, 15.0, 5.0],
  49. [-5.0, 0.0, 5.0, 10.0],
  50. [0.0, 0.0, 10.0, 10.0],
  51. [5.0, 0.0, 15.0, 10.0],
  52. [-5.0, 5.0, 5.0, 15.0],
  53. [0.0, 5.0, 10.0, 15.0],
  54. [5.0, 5.0, 15.0, 15.0],
  55. ]
  56. )
  57. assert num_anchors_estimated == 9
  58. assert len(anchors) == 2
  59. assert tuple(anchors[0].shape) == (9, 4)
  60. assert tuple(anchors[1].shape) == (9, 4)
  61. assert_equal(anchors[0], anchors_output)
  62. assert_equal(anchors[1], anchors_output)
  63. def test_defaultbox_generator(self):
  64. images = torch.zeros(2, 3, 15, 15)
  65. features = [torch.zeros(2, 8, 1, 1)]
  66. image_shapes = [i.shape[-2:] for i in images]
  67. images = ImageList(images, image_shapes)
  68. model = self._init_test_defaultbox_generator()
  69. model.eval()
  70. dboxes = model(images, features)
  71. dboxes_output = torch.tensor(
  72. [
  73. [6.3750, 6.3750, 8.6250, 8.6250],
  74. [4.7443, 4.7443, 10.2557, 10.2557],
  75. [5.9090, 6.7045, 9.0910, 8.2955],
  76. [6.7045, 5.9090, 8.2955, 9.0910],
  77. ]
  78. )
  79. assert len(dboxes) == 2
  80. assert tuple(dboxes[0].shape) == (4, 4)
  81. assert tuple(dboxes[1].shape) == (4, 4)
  82. torch.testing.assert_close(dboxes[0], dboxes_output, rtol=1e-5, atol=1e-8)
  83. torch.testing.assert_close(dboxes[1], dboxes_output, rtol=1e-5, atol=1e-8)