test_transforms_v2_utils.py 4.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192
  1. import PIL.Image
  2. import pytest
  3. import torch
  4. import torchvision.transforms.v2._utils
  5. from common_utils import DEFAULT_SIZE, make_bounding_boxes, make_detection_mask, make_image
  6. from torchvision import tv_tensors
  7. from torchvision.transforms.v2._utils import has_all, has_any
  8. from torchvision.transforms.v2.functional import to_pil_image
  9. IMAGE = make_image(DEFAULT_SIZE, color_space="RGB")
  10. BOUNDING_BOX = make_bounding_boxes(DEFAULT_SIZE, format=tv_tensors.BoundingBoxFormat.XYXY)
  11. MASK = make_detection_mask(DEFAULT_SIZE)
  12. @pytest.mark.parametrize(
  13. ("sample", "types", "expected"),
  14. [
  15. ((IMAGE, BOUNDING_BOX, MASK), (tv_tensors.Image,), True),
  16. ((IMAGE, BOUNDING_BOX, MASK), (tv_tensors.BoundingBoxes,), True),
  17. ((IMAGE, BOUNDING_BOX, MASK), (tv_tensors.Mask,), True),
  18. ((IMAGE, BOUNDING_BOX, MASK), (tv_tensors.Image, tv_tensors.BoundingBoxes), True),
  19. ((IMAGE, BOUNDING_BOX, MASK), (tv_tensors.Image, tv_tensors.Mask), True),
  20. ((IMAGE, BOUNDING_BOX, MASK), (tv_tensors.BoundingBoxes, tv_tensors.Mask), True),
  21. ((MASK,), (tv_tensors.Image, tv_tensors.BoundingBoxes), False),
  22. ((BOUNDING_BOX,), (tv_tensors.Image, tv_tensors.Mask), False),
  23. ((IMAGE,), (tv_tensors.BoundingBoxes, tv_tensors.Mask), False),
  24. (
  25. (IMAGE, BOUNDING_BOX, MASK),
  26. (tv_tensors.Image, tv_tensors.BoundingBoxes, tv_tensors.Mask),
  27. True,
  28. ),
  29. ((), (tv_tensors.Image, tv_tensors.BoundingBoxes, tv_tensors.Mask), False),
  30. ((IMAGE, BOUNDING_BOX, MASK), (lambda obj: isinstance(obj, tv_tensors.Image),), True),
  31. ((IMAGE, BOUNDING_BOX, MASK), (lambda _: False,), False),
  32. ((IMAGE, BOUNDING_BOX, MASK), (lambda _: True,), True),
  33. ((IMAGE,), (tv_tensors.Image, PIL.Image.Image, torchvision.transforms.v2._utils.is_pure_tensor), True),
  34. (
  35. (torch.Tensor(IMAGE),),
  36. (tv_tensors.Image, PIL.Image.Image, torchvision.transforms.v2._utils.is_pure_tensor),
  37. True,
  38. ),
  39. (
  40. (to_pil_image(IMAGE),),
  41. (tv_tensors.Image, PIL.Image.Image, torchvision.transforms.v2._utils.is_pure_tensor),
  42. True,
  43. ),
  44. ],
  45. )
  46. def test_has_any(sample, types, expected):
  47. assert has_any(sample, *types) is expected
  48. @pytest.mark.parametrize(
  49. ("sample", "types", "expected"),
  50. [
  51. ((IMAGE, BOUNDING_BOX, MASK), (tv_tensors.Image,), True),
  52. ((IMAGE, BOUNDING_BOX, MASK), (tv_tensors.BoundingBoxes,), True),
  53. ((IMAGE, BOUNDING_BOX, MASK), (tv_tensors.Mask,), True),
  54. ((IMAGE, BOUNDING_BOX, MASK), (tv_tensors.Image, tv_tensors.BoundingBoxes), True),
  55. ((IMAGE, BOUNDING_BOX, MASK), (tv_tensors.Image, tv_tensors.Mask), True),
  56. ((IMAGE, BOUNDING_BOX, MASK), (tv_tensors.BoundingBoxes, tv_tensors.Mask), True),
  57. (
  58. (IMAGE, BOUNDING_BOX, MASK),
  59. (tv_tensors.Image, tv_tensors.BoundingBoxes, tv_tensors.Mask),
  60. True,
  61. ),
  62. ((BOUNDING_BOX, MASK), (tv_tensors.Image, tv_tensors.BoundingBoxes), False),
  63. ((BOUNDING_BOX, MASK), (tv_tensors.Image, tv_tensors.Mask), False),
  64. ((IMAGE, MASK), (tv_tensors.BoundingBoxes, tv_tensors.Mask), False),
  65. (
  66. (IMAGE, BOUNDING_BOX, MASK),
  67. (tv_tensors.Image, tv_tensors.BoundingBoxes, tv_tensors.Mask),
  68. True,
  69. ),
  70. ((BOUNDING_BOX, MASK), (tv_tensors.Image, tv_tensors.BoundingBoxes, tv_tensors.Mask), False),
  71. ((IMAGE, MASK), (tv_tensors.Image, tv_tensors.BoundingBoxes, tv_tensors.Mask), False),
  72. ((IMAGE, BOUNDING_BOX), (tv_tensors.Image, tv_tensors.BoundingBoxes, tv_tensors.Mask), False),
  73. (
  74. (IMAGE, BOUNDING_BOX, MASK),
  75. (lambda obj: isinstance(obj, (tv_tensors.Image, tv_tensors.BoundingBoxes, tv_tensors.Mask)),),
  76. True,
  77. ),
  78. ((IMAGE, BOUNDING_BOX, MASK), (lambda _: False,), False),
  79. ((IMAGE, BOUNDING_BOX, MASK), (lambda _: True,), True),
  80. ],
  81. )
  82. def test_has_all(sample, types, expected):
  83. assert has_all(sample, *types) is expected