test_tv_tensors.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320
  1. from copy import deepcopy
  2. import pytest
  3. import torch
  4. from common_utils import assert_equal, make_bounding_boxes, make_image, make_segmentation_mask, make_video
  5. from PIL import Image
  6. from torchvision import tv_tensors
  7. @pytest.fixture(autouse=True)
  8. def restore_tensor_return_type():
  9. # This is for security, as we should already be restoring the default manually in each test anyway
  10. # (at least at the time of writing...)
  11. yield
  12. tv_tensors.set_return_type("Tensor")
  13. @pytest.mark.parametrize("data", [torch.rand(3, 32, 32), Image.new("RGB", (32, 32), color=123)])
  14. def test_image_instance(data):
  15. image = tv_tensors.Image(data)
  16. assert isinstance(image, torch.Tensor)
  17. assert image.ndim == 3 and image.shape[0] == 3
  18. @pytest.mark.parametrize("data", [torch.randint(0, 10, size=(1, 32, 32)), Image.new("L", (32, 32), color=2)])
  19. def test_mask_instance(data):
  20. mask = tv_tensors.Mask(data)
  21. assert isinstance(mask, torch.Tensor)
  22. assert mask.ndim == 3 and mask.shape[0] == 1
  23. @pytest.mark.parametrize("data", [torch.randint(0, 32, size=(5, 4)), [[0, 0, 5, 5], [2, 2, 7, 7]], [1, 2, 3, 4]])
  24. @pytest.mark.parametrize(
  25. "format", ["XYXY", "CXCYWH", tv_tensors.BoundingBoxFormat.XYXY, tv_tensors.BoundingBoxFormat.XYWH]
  26. )
  27. def test_bbox_instance(data, format):
  28. bboxes = tv_tensors.BoundingBoxes(data, format=format, canvas_size=(32, 32))
  29. assert isinstance(bboxes, torch.Tensor)
  30. assert bboxes.ndim == 2 and bboxes.shape[1] == 4
  31. if isinstance(format, str):
  32. format = tv_tensors.BoundingBoxFormat[(format.upper())]
  33. assert bboxes.format == format
  34. def test_bbox_dim_error():
  35. data_3d = [[[1, 2, 3, 4]]]
  36. with pytest.raises(ValueError, match="Expected a 1D or 2D tensor, got 3D"):
  37. tv_tensors.BoundingBoxes(data_3d, format="XYXY", canvas_size=(32, 32))
  38. @pytest.mark.parametrize(
  39. ("data", "input_requires_grad", "expected_requires_grad"),
  40. [
  41. ([[[0.0, 1.0], [0.0, 1.0]]], None, False),
  42. ([[[0.0, 1.0], [0.0, 1.0]]], False, False),
  43. ([[[0.0, 1.0], [0.0, 1.0]]], True, True),
  44. (torch.rand(3, 16, 16, requires_grad=False), None, False),
  45. (torch.rand(3, 16, 16, requires_grad=False), False, False),
  46. (torch.rand(3, 16, 16, requires_grad=False), True, True),
  47. (torch.rand(3, 16, 16, requires_grad=True), None, True),
  48. (torch.rand(3, 16, 16, requires_grad=True), False, False),
  49. (torch.rand(3, 16, 16, requires_grad=True), True, True),
  50. ],
  51. )
  52. def test_new_requires_grad(data, input_requires_grad, expected_requires_grad):
  53. tv_tensor = tv_tensors.Image(data, requires_grad=input_requires_grad)
  54. assert tv_tensor.requires_grad is expected_requires_grad
  55. @pytest.mark.parametrize("make_input", [make_image, make_bounding_boxes, make_segmentation_mask, make_video])
  56. def test_isinstance(make_input):
  57. assert isinstance(make_input(), torch.Tensor)
  58. def test_wrapping_no_copy():
  59. tensor = torch.rand(3, 16, 16)
  60. image = tv_tensors.Image(tensor)
  61. assert image.data_ptr() == tensor.data_ptr()
  62. @pytest.mark.parametrize("make_input", [make_image, make_bounding_boxes, make_segmentation_mask, make_video])
  63. def test_to_wrapping(make_input):
  64. dp = make_input()
  65. dp_to = dp.to(torch.float64)
  66. assert type(dp_to) is type(dp)
  67. assert dp_to.dtype is torch.float64
  68. @pytest.mark.parametrize("make_input", [make_image, make_bounding_boxes, make_segmentation_mask, make_video])
  69. @pytest.mark.parametrize("return_type", ["Tensor", "TVTensor"])
  70. def test_to_tv_tensor_reference(make_input, return_type):
  71. tensor = torch.rand((3, 16, 16), dtype=torch.float64)
  72. dp = make_input()
  73. with tv_tensors.set_return_type(return_type):
  74. tensor_to = tensor.to(dp)
  75. assert type(tensor_to) is (type(dp) if return_type == "TVTensor" else torch.Tensor)
  76. assert tensor_to.dtype is dp.dtype
  77. assert type(tensor) is torch.Tensor
  78. @pytest.mark.parametrize("make_input", [make_image, make_bounding_boxes, make_segmentation_mask, make_video])
  79. @pytest.mark.parametrize("return_type", ["Tensor", "TVTensor"])
  80. def test_clone_wrapping(make_input, return_type):
  81. dp = make_input()
  82. with tv_tensors.set_return_type(return_type):
  83. dp_clone = dp.clone()
  84. assert type(dp_clone) is type(dp)
  85. assert dp_clone.data_ptr() != dp.data_ptr()
  86. @pytest.mark.parametrize("make_input", [make_image, make_bounding_boxes, make_segmentation_mask, make_video])
  87. @pytest.mark.parametrize("return_type", ["Tensor", "TVTensor"])
  88. def test_requires_grad__wrapping(make_input, return_type):
  89. dp = make_input(dtype=torch.float)
  90. assert not dp.requires_grad
  91. with tv_tensors.set_return_type(return_type):
  92. dp_requires_grad = dp.requires_grad_(True)
  93. assert type(dp_requires_grad) is type(dp)
  94. assert dp.requires_grad
  95. assert dp_requires_grad.requires_grad
  96. @pytest.mark.parametrize("make_input", [make_image, make_bounding_boxes, make_segmentation_mask, make_video])
  97. @pytest.mark.parametrize("return_type", ["Tensor", "TVTensor"])
  98. def test_detach_wrapping(make_input, return_type):
  99. dp = make_input(dtype=torch.float).requires_grad_(True)
  100. with tv_tensors.set_return_type(return_type):
  101. dp_detached = dp.detach()
  102. assert type(dp_detached) is type(dp)
  103. @pytest.mark.parametrize("return_type", ["Tensor", "TVTensor"])
  104. def test_force_subclass_with_metadata(return_type):
  105. # Sanity checks for the ops in _FORCE_TORCHFUNCTION_SUBCLASS and tv_tensors with metadata
  106. # Largely the same as above, we additionally check that the metadata is preserved
  107. format, canvas_size = "XYXY", (32, 32)
  108. bbox = tv_tensors.BoundingBoxes([[0, 0, 5, 5], [2, 2, 7, 7]], format=format, canvas_size=canvas_size)
  109. tv_tensors.set_return_type(return_type)
  110. bbox = bbox.clone()
  111. if return_type == "TVTensor":
  112. assert bbox.format, bbox.canvas_size == (format, canvas_size)
  113. bbox = bbox.to(torch.float64)
  114. if return_type == "TVTensor":
  115. assert bbox.format, bbox.canvas_size == (format, canvas_size)
  116. bbox = bbox.detach()
  117. if return_type == "TVTensor":
  118. assert bbox.format, bbox.canvas_size == (format, canvas_size)
  119. assert not bbox.requires_grad
  120. bbox.requires_grad_(True)
  121. if return_type == "TVTensor":
  122. assert bbox.format, bbox.canvas_size == (format, canvas_size)
  123. assert bbox.requires_grad
  124. tv_tensors.set_return_type("tensor")
  125. @pytest.mark.parametrize("make_input", [make_image, make_bounding_boxes, make_segmentation_mask, make_video])
  126. @pytest.mark.parametrize("return_type", ["Tensor", "TVTensor"])
  127. def test_other_op_no_wrapping(make_input, return_type):
  128. dp = make_input()
  129. with tv_tensors.set_return_type(return_type):
  130. # any operation besides the ones listed in _FORCE_TORCHFUNCTION_SUBCLASS will do here
  131. output = dp * 2
  132. assert type(output) is (type(dp) if return_type == "TVTensor" else torch.Tensor)
  133. @pytest.mark.parametrize("make_input", [make_image, make_bounding_boxes, make_segmentation_mask, make_video])
  134. @pytest.mark.parametrize(
  135. "op",
  136. [
  137. lambda t: t.numpy(),
  138. lambda t: t.tolist(),
  139. lambda t: t.max(dim=-1),
  140. ],
  141. )
  142. def test_no_tensor_output_op_no_wrapping(make_input, op):
  143. dp = make_input()
  144. output = op(dp)
  145. assert type(output) is not type(dp)
  146. @pytest.mark.parametrize("make_input", [make_image, make_bounding_boxes, make_segmentation_mask, make_video])
  147. @pytest.mark.parametrize("return_type", ["Tensor", "TVTensor"])
  148. def test_inplace_op_no_wrapping(make_input, return_type):
  149. dp = make_input()
  150. original_type = type(dp)
  151. with tv_tensors.set_return_type(return_type):
  152. output = dp.add_(0)
  153. assert type(output) is (type(dp) if return_type == "TVTensor" else torch.Tensor)
  154. assert type(dp) is original_type
  155. @pytest.mark.parametrize("make_input", [make_image, make_bounding_boxes, make_segmentation_mask, make_video])
  156. def test_wrap(make_input):
  157. dp = make_input()
  158. # any operation besides the ones listed in _FORCE_TORCHFUNCTION_SUBCLASS will do here
  159. output = dp * 2
  160. dp_new = tv_tensors.wrap(output, like=dp)
  161. assert type(dp_new) is type(dp)
  162. assert dp_new.data_ptr() == output.data_ptr()
  163. @pytest.mark.parametrize("make_input", [make_image, make_bounding_boxes, make_segmentation_mask, make_video])
  164. @pytest.mark.parametrize("requires_grad", [False, True])
  165. def test_deepcopy(make_input, requires_grad):
  166. dp = make_input(dtype=torch.float)
  167. dp.requires_grad_(requires_grad)
  168. dp_deepcopied = deepcopy(dp)
  169. assert dp_deepcopied is not dp
  170. assert dp_deepcopied.data_ptr() != dp.data_ptr()
  171. assert_equal(dp_deepcopied, dp)
  172. assert type(dp_deepcopied) is type(dp)
  173. assert dp_deepcopied.requires_grad is requires_grad
  174. @pytest.mark.parametrize("make_input", [make_image, make_bounding_boxes, make_segmentation_mask, make_video])
  175. @pytest.mark.parametrize("return_type", ["Tensor", "TVTensor"])
  176. @pytest.mark.parametrize(
  177. "op",
  178. (
  179. lambda dp: dp + torch.rand(*dp.shape),
  180. lambda dp: torch.rand(*dp.shape) + dp,
  181. lambda dp: dp * torch.rand(*dp.shape),
  182. lambda dp: torch.rand(*dp.shape) * dp,
  183. lambda dp: dp + 3,
  184. lambda dp: 3 + dp,
  185. lambda dp: dp + dp,
  186. lambda dp: dp.sum(),
  187. lambda dp: dp.reshape(-1),
  188. lambda dp: dp.int(),
  189. lambda dp: torch.stack([dp, dp]),
  190. lambda dp: torch.chunk(dp, 2)[0],
  191. lambda dp: torch.unbind(dp)[0],
  192. ),
  193. )
  194. def test_usual_operations(make_input, return_type, op):
  195. dp = make_input()
  196. with tv_tensors.set_return_type(return_type):
  197. out = op(dp)
  198. assert type(out) is (type(dp) if return_type == "TVTensor" else torch.Tensor)
  199. if isinstance(dp, tv_tensors.BoundingBoxes) and return_type == "TVTensor":
  200. assert hasattr(out, "format")
  201. assert hasattr(out, "canvas_size")
  202. def test_subclasses():
  203. img = make_image()
  204. masks = make_segmentation_mask()
  205. with pytest.raises(TypeError, match="unsupported operand"):
  206. img + masks
  207. def test_set_return_type():
  208. img = make_image()
  209. assert type(img + 3) is torch.Tensor
  210. with tv_tensors.set_return_type("TVTensor"):
  211. assert type(img + 3) is tv_tensors.Image
  212. assert type(img + 3) is torch.Tensor
  213. tv_tensors.set_return_type("TVTensor")
  214. assert type(img + 3) is tv_tensors.Image
  215. with tv_tensors.set_return_type("tensor"):
  216. assert type(img + 3) is torch.Tensor
  217. with tv_tensors.set_return_type("TVTensor"):
  218. assert type(img + 3) is tv_tensors.Image
  219. tv_tensors.set_return_type("tensor")
  220. assert type(img + 3) is torch.Tensor
  221. assert type(img + 3) is torch.Tensor
  222. # Exiting a context manager will restore the return type as it was prior to entering it,
  223. # regardless of whether the "global" tv_tensors.set_return_type() was called within the context manager.
  224. assert type(img + 3) is tv_tensors.Image
  225. tv_tensors.set_return_type("tensor")
  226. def test_return_type_input():
  227. img = make_image()
  228. # Case-insensitive
  229. with tv_tensors.set_return_type("tvtensor"):
  230. assert type(img + 3) is tv_tensors.Image
  231. with pytest.raises(ValueError, match="return_type must be"):
  232. tv_tensors.set_return_type("typo")
  233. tv_tensors.set_return_type("tensor")