test_utils.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402
  1. import os
  2. import re
  3. import sys
  4. import tempfile
  5. from io import BytesIO
  6. import numpy as np
  7. import pytest
  8. import torch
  9. import torchvision.transforms.functional as F
  10. import torchvision.utils as utils
  11. from common_utils import assert_equal, cpu_and_cuda
  12. from PIL import __version__ as PILLOW_VERSION, Image, ImageColor
  13. PILLOW_VERSION = tuple(int(x) for x in PILLOW_VERSION.split("."))
  14. boxes = torch.tensor([[0, 0, 20, 20], [0, 0, 0, 0], [10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float)
  15. keypoints = torch.tensor([[[10, 10], [5, 5], [2, 2]], [[20, 20], [30, 30], [3, 3]]], dtype=torch.float)
  16. def test_make_grid_not_inplace():
  17. t = torch.rand(5, 3, 10, 10)
  18. t_clone = t.clone()
  19. utils.make_grid(t, normalize=False)
  20. assert_equal(t, t_clone, msg="make_grid modified tensor in-place")
  21. utils.make_grid(t, normalize=True, scale_each=False)
  22. assert_equal(t, t_clone, msg="make_grid modified tensor in-place")
  23. utils.make_grid(t, normalize=True, scale_each=True)
  24. assert_equal(t, t_clone, msg="make_grid modified tensor in-place")
  25. def test_normalize_in_make_grid():
  26. t = torch.rand(5, 3, 10, 10) * 255
  27. norm_max = torch.tensor(1.0)
  28. norm_min = torch.tensor(0.0)
  29. grid = utils.make_grid(t, normalize=True)
  30. grid_max = torch.max(grid)
  31. grid_min = torch.min(grid)
  32. # Rounding the result to one decimal for comparison
  33. n_digits = 1
  34. rounded_grid_max = torch.round(grid_max * 10**n_digits) / (10**n_digits)
  35. rounded_grid_min = torch.round(grid_min * 10**n_digits) / (10**n_digits)
  36. assert_equal(norm_max, rounded_grid_max, msg="Normalized max is not equal to 1")
  37. assert_equal(norm_min, rounded_grid_min, msg="Normalized min is not equal to 0")
  38. @pytest.mark.skipif(sys.platform in ("win32", "cygwin"), reason="temporarily disabled on Windows")
  39. def test_save_image():
  40. with tempfile.NamedTemporaryFile(suffix=".png") as f:
  41. t = torch.rand(2, 3, 64, 64)
  42. utils.save_image(t, f.name)
  43. assert os.path.exists(f.name), "The image is not present after save"
  44. @pytest.mark.skipif(sys.platform in ("win32", "cygwin"), reason="temporarily disabled on Windows")
  45. def test_save_image_single_pixel():
  46. with tempfile.NamedTemporaryFile(suffix=".png") as f:
  47. t = torch.rand(1, 3, 1, 1)
  48. utils.save_image(t, f.name)
  49. assert os.path.exists(f.name), "The pixel image is not present after save"
  50. @pytest.mark.skipif(sys.platform in ("win32", "cygwin"), reason="temporarily disabled on Windows")
  51. def test_save_image_file_object():
  52. with tempfile.NamedTemporaryFile(suffix=".png") as f:
  53. t = torch.rand(2, 3, 64, 64)
  54. utils.save_image(t, f.name)
  55. img_orig = Image.open(f.name)
  56. fp = BytesIO()
  57. utils.save_image(t, fp, format="png")
  58. img_bytes = Image.open(fp)
  59. assert_equal(F.pil_to_tensor(img_orig), F.pil_to_tensor(img_bytes), msg="Image not stored in file object")
  60. @pytest.mark.skipif(sys.platform in ("win32", "cygwin"), reason="temporarily disabled on Windows")
  61. def test_save_image_single_pixel_file_object():
  62. with tempfile.NamedTemporaryFile(suffix=".png") as f:
  63. t = torch.rand(1, 3, 1, 1)
  64. utils.save_image(t, f.name)
  65. img_orig = Image.open(f.name)
  66. fp = BytesIO()
  67. utils.save_image(t, fp, format="png")
  68. img_bytes = Image.open(fp)
  69. assert_equal(F.pil_to_tensor(img_orig), F.pil_to_tensor(img_bytes), msg="Image not stored in file object")
  70. def test_draw_boxes():
  71. img = torch.full((3, 100, 100), 255, dtype=torch.uint8)
  72. img_cp = img.clone()
  73. boxes_cp = boxes.clone()
  74. labels = ["a", "b", "c", "d"]
  75. colors = ["green", "#FF00FF", (0, 255, 0), "red"]
  76. result = utils.draw_bounding_boxes(img, boxes, labels=labels, colors=colors, fill=True)
  77. path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "fakedata", "draw_boxes_util.png")
  78. if not os.path.exists(path):
  79. res = Image.fromarray(result.permute(1, 2, 0).contiguous().numpy())
  80. res.save(path)
  81. if PILLOW_VERSION >= (8, 2):
  82. # The reference image is only valid for new PIL versions
  83. expected = torch.as_tensor(np.array(Image.open(path))).permute(2, 0, 1)
  84. assert_equal(result, expected)
  85. # Check if modification is not in place
  86. assert_equal(boxes, boxes_cp)
  87. assert_equal(img, img_cp)
  88. @pytest.mark.parametrize("colors", [None, ["red", "blue", "#FF00FF", (1, 34, 122)], "red", "#FF00FF", (1, 34, 122)])
  89. def test_draw_boxes_colors(colors):
  90. img = torch.full((3, 100, 100), 0, dtype=torch.uint8)
  91. utils.draw_bounding_boxes(img, boxes, fill=False, width=7, colors=colors)
  92. with pytest.raises(ValueError, match="Number of colors must be equal or larger than the number of objects"):
  93. utils.draw_bounding_boxes(image=img, boxes=boxes, colors=[])
  94. def test_draw_boxes_vanilla():
  95. img = torch.full((3, 100, 100), 0, dtype=torch.uint8)
  96. img_cp = img.clone()
  97. boxes_cp = boxes.clone()
  98. result = utils.draw_bounding_boxes(img, boxes, fill=False, width=7, colors="white")
  99. path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "fakedata", "draw_boxes_vanilla.png")
  100. if not os.path.exists(path):
  101. res = Image.fromarray(result.permute(1, 2, 0).contiguous().numpy())
  102. res.save(path)
  103. expected = torch.as_tensor(np.array(Image.open(path))).permute(2, 0, 1)
  104. assert_equal(result, expected)
  105. # Check if modification is not in place
  106. assert_equal(boxes, boxes_cp)
  107. assert_equal(img, img_cp)
  108. def test_draw_boxes_grayscale():
  109. img = torch.full((1, 4, 4), fill_value=255, dtype=torch.uint8)
  110. boxes = torch.tensor([[0, 0, 3, 3]], dtype=torch.int64)
  111. bboxed_img = utils.draw_bounding_boxes(image=img, boxes=boxes, colors=["#1BBC9B"])
  112. assert bboxed_img.size(0) == 3
  113. def test_draw_invalid_boxes():
  114. img_tp = ((1, 1, 1), (1, 2, 3))
  115. img_wrong1 = torch.full((3, 5, 5), 255, dtype=torch.float)
  116. img_wrong2 = torch.full((1, 3, 5, 5), 255, dtype=torch.uint8)
  117. img_correct = torch.zeros((3, 10, 10), dtype=torch.uint8)
  118. boxes = torch.tensor([[0, 0, 20, 20], [0, 0, 0, 0], [10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float)
  119. boxes_wrong = torch.tensor([[10, 10, 4, 5], [30, 20, 10, 5]], dtype=torch.float)
  120. labels_wrong = ["one", "two"]
  121. colors_wrong = ["pink", "blue"]
  122. with pytest.raises(TypeError, match="Tensor expected"):
  123. utils.draw_bounding_boxes(img_tp, boxes)
  124. with pytest.raises(ValueError, match="Tensor uint8 expected"):
  125. utils.draw_bounding_boxes(img_wrong1, boxes)
  126. with pytest.raises(ValueError, match="Pass individual images, not batches"):
  127. utils.draw_bounding_boxes(img_wrong2, boxes)
  128. with pytest.raises(ValueError, match="Only grayscale and RGB images are supported"):
  129. utils.draw_bounding_boxes(img_wrong2[0][:2], boxes)
  130. with pytest.raises(ValueError, match="Number of boxes"):
  131. utils.draw_bounding_boxes(img_correct, boxes, labels_wrong)
  132. with pytest.raises(ValueError, match="Number of colors"):
  133. utils.draw_bounding_boxes(img_correct, boxes, colors=colors_wrong)
  134. with pytest.raises(ValueError, match="Boxes need to be in"):
  135. utils.draw_bounding_boxes(img_correct, boxes_wrong)
  136. def test_draw_boxes_warning():
  137. img = torch.full((3, 100, 100), 255, dtype=torch.uint8)
  138. with pytest.warns(UserWarning, match=re.escape("Argument 'font_size' will be ignored since 'font' is not set.")):
  139. utils.draw_bounding_boxes(img, boxes, font_size=11)
  140. def test_draw_no_boxes():
  141. img = torch.full((3, 100, 100), 0, dtype=torch.uint8)
  142. boxes = torch.full((0, 4), 0, dtype=torch.float)
  143. with pytest.warns(UserWarning, match=re.escape("boxes doesn't contain any box. No box was drawn")):
  144. res = utils.draw_bounding_boxes(img, boxes)
  145. # Check that the function didn't change the image
  146. assert res.eq(img).all()
  147. @pytest.mark.parametrize(
  148. "colors",
  149. [
  150. None,
  151. "blue",
  152. "#FF00FF",
  153. (1, 34, 122),
  154. ["red", "blue"],
  155. ["#FF00FF", (1, 34, 122)],
  156. ],
  157. )
  158. @pytest.mark.parametrize("alpha", (0, 0.5, 0.7, 1))
  159. @pytest.mark.parametrize("device", cpu_and_cuda())
  160. def test_draw_segmentation_masks(colors, alpha, device):
  161. """This test makes sure that masks draw their corresponding color where they should"""
  162. num_masks, h, w = 2, 100, 100
  163. dtype = torch.uint8
  164. img = torch.randint(0, 256, size=(3, h, w), dtype=dtype, device=device)
  165. masks = torch.randint(0, 2, (num_masks, h, w), dtype=torch.bool, device=device)
  166. # For testing we enforce that there's no overlap between the masks. The
  167. # current behaviour is that the last mask's color will take priority when
  168. # masks overlap, but this makes testing slightly harder, so we don't really
  169. # care
  170. overlap = masks[0] & masks[1]
  171. masks[:, overlap] = False
  172. out = utils.draw_segmentation_masks(img, masks, colors=colors, alpha=alpha)
  173. assert out.dtype == dtype
  174. assert out is not img
  175. # Make sure the image didn't change where there's no mask
  176. masked_pixels = masks[0] | masks[1]
  177. assert_equal(img[:, ~masked_pixels], out[:, ~masked_pixels])
  178. if colors is None:
  179. colors = utils._generate_color_palette(num_masks)
  180. elif isinstance(colors, str) or isinstance(colors, tuple):
  181. colors = [colors]
  182. # Make sure each mask draws with its own color
  183. for mask, color in zip(masks, colors):
  184. if isinstance(color, str):
  185. color = ImageColor.getrgb(color)
  186. color = torch.tensor(color, dtype=dtype, device=device)
  187. if alpha == 1:
  188. assert (out[:, mask] == color[:, None]).all()
  189. elif alpha == 0:
  190. assert (out[:, mask] == img[:, mask]).all()
  191. interpolated_color = (img[:, mask] * (1 - alpha) + color[:, None] * alpha).to(dtype)
  192. torch.testing.assert_close(out[:, mask], interpolated_color, rtol=0.0, atol=1.0)
  193. @pytest.mark.parametrize("device", cpu_and_cuda())
  194. def test_draw_segmentation_masks_errors(device):
  195. h, w = 10, 10
  196. masks = torch.randint(0, 2, size=(h, w), dtype=torch.bool, device=device)
  197. img = torch.randint(0, 256, size=(3, h, w), dtype=torch.uint8, device=device)
  198. with pytest.raises(TypeError, match="The image must be a tensor"):
  199. utils.draw_segmentation_masks(image="Not A Tensor Image", masks=masks)
  200. with pytest.raises(ValueError, match="The image dtype must be"):
  201. img_bad_dtype = torch.randint(0, 256, size=(3, h, w), dtype=torch.int64)
  202. utils.draw_segmentation_masks(image=img_bad_dtype, masks=masks)
  203. with pytest.raises(ValueError, match="Pass individual images, not batches"):
  204. batch = torch.randint(0, 256, size=(10, 3, h, w), dtype=torch.uint8)
  205. utils.draw_segmentation_masks(image=batch, masks=masks)
  206. with pytest.raises(ValueError, match="Pass an RGB image"):
  207. one_channel = torch.randint(0, 256, size=(1, h, w), dtype=torch.uint8)
  208. utils.draw_segmentation_masks(image=one_channel, masks=masks)
  209. with pytest.raises(ValueError, match="The masks must be of dtype bool"):
  210. masks_bad_dtype = torch.randint(0, 2, size=(h, w), dtype=torch.float)
  211. utils.draw_segmentation_masks(image=img, masks=masks_bad_dtype)
  212. with pytest.raises(ValueError, match="masks must be of shape"):
  213. masks_bad_shape = torch.randint(0, 2, size=(3, 2, h, w), dtype=torch.bool)
  214. utils.draw_segmentation_masks(image=img, masks=masks_bad_shape)
  215. with pytest.raises(ValueError, match="must have the same height and width"):
  216. masks_bad_shape = torch.randint(0, 2, size=(h + 4, w), dtype=torch.bool)
  217. utils.draw_segmentation_masks(image=img, masks=masks_bad_shape)
  218. with pytest.raises(ValueError, match="Number of colors must be equal or larger than the number of objects"):
  219. utils.draw_segmentation_masks(image=img, masks=masks, colors=[])
  220. with pytest.raises(ValueError, match="`colors` must be a tuple or a string, or a list thereof"):
  221. bad_colors = np.array(["red", "blue"]) # should be a list
  222. utils.draw_segmentation_masks(image=img, masks=masks, colors=bad_colors)
  223. with pytest.raises(ValueError, match="If passed as tuple, colors should be an RGB triplet"):
  224. bad_colors = ("red", "blue") # should be a list
  225. utils.draw_segmentation_masks(image=img, masks=masks, colors=bad_colors)
  226. @pytest.mark.parametrize("device", cpu_and_cuda())
  227. def test_draw_no_segmention_mask(device):
  228. img = torch.full((3, 100, 100), 0, dtype=torch.uint8, device=device)
  229. masks = torch.full((0, 100, 100), 0, dtype=torch.bool, device=device)
  230. with pytest.warns(UserWarning, match=re.escape("masks doesn't contain any mask. No mask was drawn")):
  231. res = utils.draw_segmentation_masks(img, masks)
  232. # Check that the function didn't change the image
  233. assert res.eq(img).all()
  234. def test_draw_keypoints_vanilla():
  235. # Keypoints is declared on top as global variable
  236. keypoints_cp = keypoints.clone()
  237. img = torch.full((3, 100, 100), 0, dtype=torch.uint8)
  238. img_cp = img.clone()
  239. result = utils.draw_keypoints(
  240. img,
  241. keypoints,
  242. colors="red",
  243. connectivity=[
  244. (0, 1),
  245. ],
  246. )
  247. path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "fakedata", "draw_keypoint_vanilla.png")
  248. if not os.path.exists(path):
  249. res = Image.fromarray(result.permute(1, 2, 0).contiguous().numpy())
  250. res.save(path)
  251. expected = torch.as_tensor(np.array(Image.open(path))).permute(2, 0, 1)
  252. assert_equal(result, expected)
  253. # Check that keypoints are not modified inplace
  254. assert_equal(keypoints, keypoints_cp)
  255. # Check that image is not modified in place
  256. assert_equal(img, img_cp)
  257. @pytest.mark.parametrize("colors", ["red", "#FF00FF", (1, 34, 122)])
  258. def test_draw_keypoints_colored(colors):
  259. # Keypoints is declared on top as global variable
  260. keypoints_cp = keypoints.clone()
  261. img = torch.full((3, 100, 100), 0, dtype=torch.uint8)
  262. img_cp = img.clone()
  263. result = utils.draw_keypoints(
  264. img,
  265. keypoints,
  266. colors=colors,
  267. connectivity=[
  268. (0, 1),
  269. ],
  270. )
  271. assert result.size(0) == 3
  272. assert_equal(keypoints, keypoints_cp)
  273. assert_equal(img, img_cp)
  274. def test_draw_keypoints_errors():
  275. h, w = 10, 10
  276. img = torch.full((3, 100, 100), 0, dtype=torch.uint8)
  277. with pytest.raises(TypeError, match="The image must be a tensor"):
  278. utils.draw_keypoints(image="Not A Tensor Image", keypoints=keypoints)
  279. with pytest.raises(ValueError, match="The image dtype must be"):
  280. img_bad_dtype = torch.full((3, h, w), 0, dtype=torch.int64)
  281. utils.draw_keypoints(image=img_bad_dtype, keypoints=keypoints)
  282. with pytest.raises(ValueError, match="Pass individual images, not batches"):
  283. batch = torch.randint(0, 256, size=(10, 3, h, w), dtype=torch.uint8)
  284. utils.draw_keypoints(image=batch, keypoints=keypoints)
  285. with pytest.raises(ValueError, match="Pass an RGB image"):
  286. one_channel = torch.randint(0, 256, size=(1, h, w), dtype=torch.uint8)
  287. utils.draw_keypoints(image=one_channel, keypoints=keypoints)
  288. with pytest.raises(ValueError, match="keypoints must be of shape"):
  289. invalid_keypoints = torch.tensor([[10, 10, 10, 10], [5, 6, 7, 8]], dtype=torch.float)
  290. utils.draw_keypoints(image=img, keypoints=invalid_keypoints)
  291. @pytest.mark.parametrize("batch", (True, False))
  292. def test_flow_to_image(batch):
  293. h, w = 100, 100
  294. flow = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij")
  295. flow = torch.stack(flow[::-1], dim=0).float()
  296. flow[0] -= h / 2
  297. flow[1] -= w / 2
  298. if batch:
  299. flow = torch.stack([flow, flow])
  300. img = utils.flow_to_image(flow)
  301. assert img.shape == (2, 3, h, w) if batch else (3, h, w)
  302. path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "expected_flow.pt")
  303. expected_img = torch.load(path, map_location="cpu")
  304. if batch:
  305. expected_img = torch.stack([expected_img, expected_img])
  306. assert_equal(expected_img, img)
  307. @pytest.mark.parametrize(
  308. "input_flow, match",
  309. (
  310. (torch.full((3, 10, 10), 0, dtype=torch.float), "Input flow should have shape"),
  311. (torch.full((5, 3, 10, 10), 0, dtype=torch.float), "Input flow should have shape"),
  312. (torch.full((2, 10), 0, dtype=torch.float), "Input flow should have shape"),
  313. (torch.full((5, 2, 10), 0, dtype=torch.float), "Input flow should have shape"),
  314. (torch.full((2, 10, 30), 0, dtype=torch.int), "Flow should be of dtype torch.float"),
  315. ),
  316. )
  317. def test_flow_to_image_errors(input_flow, match):
  318. with pytest.raises(ValueError, match=match):
  319. utils.flow_to_image(flow=input_flow)
  320. if __name__ == "__main__":
  321. pytest.main([__file__])