test_models_detection_utils.py 4.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485
  1. import copy
  2. import pytest
  3. import torch
  4. from common_utils import assert_equal
  5. from torchvision.models.detection import _utils, backbone_utils
  6. from torchvision.models.detection.transform import GeneralizedRCNNTransform
  7. class TestModelsDetectionUtils:
  8. def test_balanced_positive_negative_sampler(self):
  9. sampler = _utils.BalancedPositiveNegativeSampler(4, 0.25)
  10. # keep all 6 negatives first, then add 3 positives, last two are ignore
  11. matched_idxs = [torch.tensor([0, 0, 0, 0, 0, 0, 1, 1, 1, -1, -1])]
  12. pos, neg = sampler(matched_idxs)
  13. # we know the number of elements that should be sampled for the positive (1)
  14. # and the negative (3), and their location. Let's make sure that they are
  15. # there
  16. assert pos[0].sum() == 1
  17. assert pos[0][6:9].sum() == 1
  18. assert neg[0].sum() == 3
  19. assert neg[0][0:6].sum() == 3
  20. def test_box_linear_coder(self):
  21. box_coder = _utils.BoxLinearCoder(normalize_by_size=True)
  22. # Generate a random 10x4 boxes tensor, with coordinates < 50.
  23. boxes = torch.rand(10, 4) * 50
  24. boxes.clamp_(min=1.0) # tiny boxes cause numerical instability in box regression
  25. boxes[:, 2:] += boxes[:, :2]
  26. proposals = torch.tensor([0, 0, 101, 101] * 10).reshape(10, 4).float()
  27. rel_codes = box_coder.encode(boxes, proposals)
  28. pred_boxes = box_coder.decode(rel_codes, boxes)
  29. torch.allclose(proposals, pred_boxes)
  30. @pytest.mark.parametrize("train_layers, exp_froz_params", [(0, 53), (1, 43), (2, 24), (3, 11), (4, 1), (5, 0)])
  31. def test_resnet_fpn_backbone_frozen_layers(self, train_layers, exp_froz_params):
  32. # we know how many initial layers and parameters of the network should
  33. # be frozen for each trainable_backbone_layers parameter value
  34. # i.e. all 53 params are frozen if trainable_backbone_layers=0
  35. # ad first 24 params are frozen if trainable_backbone_layers=2
  36. model = backbone_utils.resnet_fpn_backbone("resnet50", weights=None, trainable_layers=train_layers)
  37. # boolean list that is true if the param at that index is frozen
  38. is_frozen = [not parameter.requires_grad for _, parameter in model.named_parameters()]
  39. # check that expected initial number of layers are frozen
  40. assert all(is_frozen[:exp_froz_params])
  41. def test_validate_resnet_inputs_detection(self):
  42. # default number of backbone layers to train
  43. ret = backbone_utils._validate_trainable_layers(
  44. is_trained=True, trainable_backbone_layers=None, max_value=5, default_value=3
  45. )
  46. assert ret == 3
  47. # can't go beyond 5
  48. with pytest.raises(ValueError, match=r"Trainable backbone layers should be in the range"):
  49. ret = backbone_utils._validate_trainable_layers(
  50. is_trained=True, trainable_backbone_layers=6, max_value=5, default_value=3
  51. )
  52. # if not trained, should use all trainable layers and warn
  53. with pytest.warns(UserWarning):
  54. ret = backbone_utils._validate_trainable_layers(
  55. is_trained=False, trainable_backbone_layers=0, max_value=5, default_value=3
  56. )
  57. assert ret == 5
  58. def test_transform_copy_targets(self):
  59. transform = GeneralizedRCNNTransform(300, 500, torch.zeros(3), torch.ones(3))
  60. image = [torch.rand(3, 200, 300), torch.rand(3, 200, 200)]
  61. targets = [{"boxes": torch.rand(3, 4)}, {"boxes": torch.rand(2, 4)}]
  62. targets_copy = copy.deepcopy(targets)
  63. out = transform(image, targets) # noqa: F841
  64. assert_equal(targets[0]["boxes"], targets_copy[0]["boxes"])
  65. assert_equal(targets[1]["boxes"], targets_copy[1]["boxes"])
  66. def test_not_float_normalize(self):
  67. transform = GeneralizedRCNNTransform(300, 500, torch.zeros(3), torch.ones(3))
  68. image = [torch.randint(0, 255, (3, 200, 300), dtype=torch.uint8)]
  69. targets = [{"boxes": torch.rand(3, 4)}]
  70. with pytest.raises(TypeError):
  71. out = transform(image, targets) # noqa: F841
  72. if __name__ == "__main__":
  73. pytest.main([__file__])