test_image.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465
  1. import glob
  2. import io
  3. import os
  4. import sys
  5. from pathlib import Path
  6. import numpy as np
  7. import pytest
  8. import torch
  9. import torchvision.transforms.functional as F
  10. from common_utils import assert_equal, needs_cuda
  11. from PIL import __version__ as PILLOW_VERSION, Image
  12. from torchvision.io.image import (
  13. _read_png_16,
  14. decode_image,
  15. decode_jpeg,
  16. decode_png,
  17. encode_jpeg,
  18. encode_png,
  19. ImageReadMode,
  20. read_file,
  21. read_image,
  22. write_file,
  23. write_jpeg,
  24. write_png,
  25. )
  26. IMAGE_ROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets")
  27. FAKEDATA_DIR = os.path.join(IMAGE_ROOT, "fakedata")
  28. IMAGE_DIR = os.path.join(FAKEDATA_DIR, "imagefolder")
  29. DAMAGED_JPEG = os.path.join(IMAGE_ROOT, "damaged_jpeg")
  30. DAMAGED_PNG = os.path.join(IMAGE_ROOT, "damaged_png")
  31. ENCODE_JPEG = os.path.join(IMAGE_ROOT, "encode_jpeg")
  32. INTERLACED_PNG = os.path.join(IMAGE_ROOT, "interlaced_png")
  33. TOOSMALL_PNG = os.path.join(IMAGE_ROOT, "toosmall_png")
  34. IS_WINDOWS = sys.platform in ("win32", "cygwin")
  35. PILLOW_VERSION = tuple(int(x) for x in PILLOW_VERSION.split("."))
  36. def _get_safe_image_name(name):
  37. # Used when we need to change the pytest "id" for an "image path" parameter.
  38. # If we don't, the test id (i.e. its name) will contain the whole path to the image, which is machine-specific,
  39. # and this creates issues when the test is running in a different machine than where it was collected
  40. # (typically, in fb internal infra)
  41. return name.split(os.path.sep)[-1]
  42. def get_images(directory, img_ext):
  43. assert os.path.isdir(directory)
  44. image_paths = glob.glob(directory + f"/**/*{img_ext}", recursive=True)
  45. for path in image_paths:
  46. if path.split(os.sep)[-2] not in ["damaged_jpeg", "jpeg_write"]:
  47. yield path
  48. def pil_read_image(img_path):
  49. with Image.open(img_path) as img:
  50. return torch.from_numpy(np.array(img))
  51. def normalize_dimensions(img_pil):
  52. if len(img_pil.shape) == 3:
  53. img_pil = img_pil.permute(2, 0, 1)
  54. else:
  55. img_pil = img_pil.unsqueeze(0)
  56. return img_pil
  57. @pytest.mark.parametrize(
  58. "img_path",
  59. [pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path)) for jpeg_path in get_images(IMAGE_ROOT, ".jpg")],
  60. )
  61. @pytest.mark.parametrize(
  62. "pil_mode, mode",
  63. [
  64. (None, ImageReadMode.UNCHANGED),
  65. ("L", ImageReadMode.GRAY),
  66. ("RGB", ImageReadMode.RGB),
  67. ],
  68. )
  69. def test_decode_jpeg(img_path, pil_mode, mode):
  70. with Image.open(img_path) as img:
  71. is_cmyk = img.mode == "CMYK"
  72. if pil_mode is not None:
  73. img = img.convert(pil_mode)
  74. img_pil = torch.from_numpy(np.array(img))
  75. if is_cmyk and mode == ImageReadMode.UNCHANGED:
  76. # flip the colors to match libjpeg
  77. img_pil = 255 - img_pil
  78. img_pil = normalize_dimensions(img_pil)
  79. data = read_file(img_path)
  80. img_ljpeg = decode_image(data, mode=mode)
  81. # Permit a small variation on pixel values to account for implementation
  82. # differences between Pillow and LibJPEG.
  83. abs_mean_diff = (img_ljpeg.type(torch.float32) - img_pil).abs().mean().item()
  84. assert abs_mean_diff < 2
  85. def test_decode_jpeg_errors():
  86. with pytest.raises(RuntimeError, match="Expected a non empty 1-dimensional tensor"):
  87. decode_jpeg(torch.empty((100, 1), dtype=torch.uint8))
  88. with pytest.raises(RuntimeError, match="Expected a torch.uint8 tensor"):
  89. decode_jpeg(torch.empty((100,), dtype=torch.float16))
  90. with pytest.raises(RuntimeError, match="Not a JPEG file"):
  91. decode_jpeg(torch.empty((100), dtype=torch.uint8))
  92. def test_decode_bad_huffman_images():
  93. # sanity check: make sure we can decode the bad Huffman encoding
  94. bad_huff = read_file(os.path.join(DAMAGED_JPEG, "bad_huffman.jpg"))
  95. decode_jpeg(bad_huff)
  96. @pytest.mark.parametrize(
  97. "img_path",
  98. [
  99. pytest.param(truncated_image, id=_get_safe_image_name(truncated_image))
  100. for truncated_image in glob.glob(os.path.join(DAMAGED_JPEG, "corrupt*.jpg"))
  101. ],
  102. )
  103. def test_damaged_corrupt_images(img_path):
  104. # Truncated images should raise an exception
  105. data = read_file(img_path)
  106. if "corrupt34" in img_path:
  107. match_message = "Image is incomplete or truncated"
  108. else:
  109. match_message = "Unsupported marker type"
  110. with pytest.raises(RuntimeError, match=match_message):
  111. decode_jpeg(data)
  112. @pytest.mark.parametrize(
  113. "img_path",
  114. [pytest.param(png_path, id=_get_safe_image_name(png_path)) for png_path in get_images(FAKEDATA_DIR, ".png")],
  115. )
  116. @pytest.mark.parametrize(
  117. "pil_mode, mode",
  118. [
  119. (None, ImageReadMode.UNCHANGED),
  120. ("L", ImageReadMode.GRAY),
  121. ("LA", ImageReadMode.GRAY_ALPHA),
  122. ("RGB", ImageReadMode.RGB),
  123. ("RGBA", ImageReadMode.RGB_ALPHA),
  124. ],
  125. )
  126. def test_decode_png(img_path, pil_mode, mode):
  127. with Image.open(img_path) as img:
  128. if pil_mode is not None:
  129. img = img.convert(pil_mode)
  130. img_pil = torch.from_numpy(np.array(img))
  131. img_pil = normalize_dimensions(img_pil)
  132. if img_path.endswith("16.png"):
  133. # 16 bits image decoding is supported, but only as a private API
  134. # FIXME: see https://github.com/pytorch/vision/issues/4731 for potential solutions to making it public
  135. with pytest.raises(RuntimeError, match="At most 8-bit PNG images are supported"):
  136. data = read_file(img_path)
  137. img_lpng = decode_image(data, mode=mode)
  138. img_lpng = _read_png_16(img_path, mode=mode)
  139. assert img_lpng.dtype == torch.int32
  140. # PIL converts 16 bits pngs in uint8
  141. img_lpng = torch.round(img_lpng / (2**16 - 1) * 255).to(torch.uint8)
  142. else:
  143. data = read_file(img_path)
  144. img_lpng = decode_image(data, mode=mode)
  145. tol = 0 if pil_mode is None else 1
  146. if PILLOW_VERSION >= (8, 3) and pil_mode == "LA":
  147. # Avoid checking the transparency channel until
  148. # https://github.com/python-pillow/Pillow/issues/5593#issuecomment-878244910
  149. # is fixed.
  150. # TODO: remove once fix is released in PIL. Should be > 8.3.1.
  151. img_lpng, img_pil = img_lpng[0], img_pil[0]
  152. torch.testing.assert_close(img_lpng, img_pil, atol=tol, rtol=0)
  153. def test_decode_png_errors():
  154. with pytest.raises(RuntimeError, match="Expected a non empty 1-dimensional tensor"):
  155. decode_png(torch.empty((), dtype=torch.uint8))
  156. with pytest.raises(RuntimeError, match="Content is not png"):
  157. decode_png(torch.randint(3, 5, (300,), dtype=torch.uint8))
  158. with pytest.raises(RuntimeError, match="Out of bound read in decode_png"):
  159. decode_png(read_file(os.path.join(DAMAGED_PNG, "sigsegv.png")))
  160. with pytest.raises(RuntimeError, match="Content is too small for png"):
  161. decode_png(read_file(os.path.join(TOOSMALL_PNG, "heapbof.png")))
  162. @pytest.mark.parametrize(
  163. "img_path",
  164. [pytest.param(png_path, id=_get_safe_image_name(png_path)) for png_path in get_images(IMAGE_DIR, ".png")],
  165. )
  166. def test_encode_png(img_path):
  167. pil_image = Image.open(img_path)
  168. img_pil = torch.from_numpy(np.array(pil_image))
  169. img_pil = img_pil.permute(2, 0, 1)
  170. png_buf = encode_png(img_pil, compression_level=6)
  171. rec_img = Image.open(io.BytesIO(bytes(png_buf.tolist())))
  172. rec_img = torch.from_numpy(np.array(rec_img))
  173. rec_img = rec_img.permute(2, 0, 1)
  174. assert_equal(img_pil, rec_img)
  175. def test_encode_png_errors():
  176. with pytest.raises(RuntimeError, match="Input tensor dtype should be uint8"):
  177. encode_png(torch.empty((3, 100, 100), dtype=torch.float32))
  178. with pytest.raises(RuntimeError, match="Compression level should be between 0 and 9"):
  179. encode_png(torch.empty((3, 100, 100), dtype=torch.uint8), compression_level=-1)
  180. with pytest.raises(RuntimeError, match="Compression level should be between 0 and 9"):
  181. encode_png(torch.empty((3, 100, 100), dtype=torch.uint8), compression_level=10)
  182. with pytest.raises(RuntimeError, match="The number of channels should be 1 or 3, got: 5"):
  183. encode_png(torch.empty((5, 100, 100), dtype=torch.uint8))
  184. @pytest.mark.parametrize(
  185. "img_path",
  186. [pytest.param(png_path, id=_get_safe_image_name(png_path)) for png_path in get_images(IMAGE_DIR, ".png")],
  187. )
  188. def test_write_png(img_path, tmpdir):
  189. pil_image = Image.open(img_path)
  190. img_pil = torch.from_numpy(np.array(pil_image))
  191. img_pil = img_pil.permute(2, 0, 1)
  192. filename, _ = os.path.splitext(os.path.basename(img_path))
  193. torch_png = os.path.join(tmpdir, f"{filename}_torch.png")
  194. write_png(img_pil, torch_png, compression_level=6)
  195. saved_image = torch.from_numpy(np.array(Image.open(torch_png)))
  196. saved_image = saved_image.permute(2, 0, 1)
  197. assert_equal(img_pil, saved_image)
  198. def test_read_file(tmpdir):
  199. fname, content = "test1.bin", b"TorchVision\211\n"
  200. fpath = os.path.join(tmpdir, fname)
  201. with open(fpath, "wb") as f:
  202. f.write(content)
  203. data = read_file(fpath)
  204. expected = torch.tensor(list(content), dtype=torch.uint8)
  205. os.unlink(fpath)
  206. assert_equal(data, expected)
  207. with pytest.raises(RuntimeError, match="No such file or directory: 'tst'"):
  208. read_file("tst")
  209. def test_read_file_non_ascii(tmpdir):
  210. fname, content = "日本語(Japanese).bin", b"TorchVision\211\n"
  211. fpath = os.path.join(tmpdir, fname)
  212. with open(fpath, "wb") as f:
  213. f.write(content)
  214. data = read_file(fpath)
  215. expected = torch.tensor(list(content), dtype=torch.uint8)
  216. os.unlink(fpath)
  217. assert_equal(data, expected)
  218. def test_write_file(tmpdir):
  219. fname, content = "test1.bin", b"TorchVision\211\n"
  220. fpath = os.path.join(tmpdir, fname)
  221. content_tensor = torch.tensor(list(content), dtype=torch.uint8)
  222. write_file(fpath, content_tensor)
  223. with open(fpath, "rb") as f:
  224. saved_content = f.read()
  225. os.unlink(fpath)
  226. assert content == saved_content
  227. def test_write_file_non_ascii(tmpdir):
  228. fname, content = "日本語(Japanese).bin", b"TorchVision\211\n"
  229. fpath = os.path.join(tmpdir, fname)
  230. content_tensor = torch.tensor(list(content), dtype=torch.uint8)
  231. write_file(fpath, content_tensor)
  232. with open(fpath, "rb") as f:
  233. saved_content = f.read()
  234. os.unlink(fpath)
  235. assert content == saved_content
  236. @pytest.mark.parametrize(
  237. "shape",
  238. [
  239. (27, 27),
  240. (60, 60),
  241. (105, 105),
  242. ],
  243. )
  244. def test_read_1_bit_png(shape, tmpdir):
  245. np_rng = np.random.RandomState(0)
  246. image_path = os.path.join(tmpdir, f"test_{shape}.png")
  247. pixels = np_rng.rand(*shape) > 0.5
  248. img = Image.fromarray(pixels)
  249. img.save(image_path)
  250. img1 = read_image(image_path)
  251. img2 = normalize_dimensions(torch.as_tensor(pixels * 255, dtype=torch.uint8))
  252. assert_equal(img1, img2)
  253. @pytest.mark.parametrize(
  254. "shape",
  255. [
  256. (27, 27),
  257. (60, 60),
  258. (105, 105),
  259. ],
  260. )
  261. @pytest.mark.parametrize(
  262. "mode",
  263. [
  264. ImageReadMode.UNCHANGED,
  265. ImageReadMode.GRAY,
  266. ],
  267. )
  268. def test_read_1_bit_png_consistency(shape, mode, tmpdir):
  269. np_rng = np.random.RandomState(0)
  270. image_path = os.path.join(tmpdir, f"test_{shape}.png")
  271. pixels = np_rng.rand(*shape) > 0.5
  272. img = Image.fromarray(pixels)
  273. img.save(image_path)
  274. img1 = read_image(image_path, mode)
  275. img2 = read_image(image_path, mode)
  276. assert_equal(img1, img2)
  277. def test_read_interlaced_png():
  278. imgs = list(get_images(INTERLACED_PNG, ".png"))
  279. with Image.open(imgs[0]) as im1, Image.open(imgs[1]) as im2:
  280. assert not (im1.info.get("interlace") is im2.info.get("interlace"))
  281. img1 = read_image(imgs[0])
  282. img2 = read_image(imgs[1])
  283. assert_equal(img1, img2)
  284. @needs_cuda
  285. @pytest.mark.parametrize(
  286. "img_path",
  287. [pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path)) for jpeg_path in get_images(IMAGE_ROOT, ".jpg")],
  288. )
  289. @pytest.mark.parametrize("mode", [ImageReadMode.UNCHANGED, ImageReadMode.GRAY, ImageReadMode.RGB])
  290. @pytest.mark.parametrize("scripted", (False, True))
  291. def test_decode_jpeg_cuda(mode, img_path, scripted):
  292. if "cmyk" in img_path:
  293. pytest.xfail("Decoding a CMYK jpeg isn't supported")
  294. data = read_file(img_path)
  295. img = decode_image(data, mode=mode)
  296. f = torch.jit.script(decode_jpeg) if scripted else decode_jpeg
  297. img_nvjpeg = f(data, mode=mode, device="cuda")
  298. # Some difference expected between jpeg implementations
  299. assert (img.float() - img_nvjpeg.cpu().float()).abs().mean() < 2
  300. @needs_cuda
  301. def test_decode_image_cuda_raises():
  302. data = torch.randint(0, 127, size=(255,), device="cuda", dtype=torch.uint8)
  303. with pytest.raises(RuntimeError):
  304. decode_image(data)
  305. @needs_cuda
  306. @pytest.mark.parametrize("cuda_device", ("cuda", "cuda:0", torch.device("cuda")))
  307. def test_decode_jpeg_cuda_device_param(cuda_device):
  308. """Make sure we can pass a string or a torch.device as device param"""
  309. path = next(path for path in get_images(IMAGE_ROOT, ".jpg") if "cmyk" not in path)
  310. data = read_file(path)
  311. decode_jpeg(data, device=cuda_device)
  312. @needs_cuda
  313. def test_decode_jpeg_cuda_errors():
  314. data = read_file(next(get_images(IMAGE_ROOT, ".jpg")))
  315. with pytest.raises(RuntimeError, match="Expected a non empty 1-dimensional tensor"):
  316. decode_jpeg(data.reshape(-1, 1), device="cuda")
  317. with pytest.raises(RuntimeError, match="input tensor must be on CPU"):
  318. decode_jpeg(data.to("cuda"), device="cuda")
  319. with pytest.raises(RuntimeError, match="Expected a torch.uint8 tensor"):
  320. decode_jpeg(data.to(torch.float), device="cuda")
  321. with pytest.raises(RuntimeError, match="Expected a cuda device"):
  322. torch.ops.image.decode_jpeg_cuda(data, ImageReadMode.UNCHANGED.value, "cpu")
  323. def test_encode_jpeg_errors():
  324. with pytest.raises(RuntimeError, match="Input tensor dtype should be uint8"):
  325. encode_jpeg(torch.empty((3, 100, 100), dtype=torch.float32))
  326. with pytest.raises(ValueError, match="Image quality should be a positive number between 1 and 100"):
  327. encode_jpeg(torch.empty((3, 100, 100), dtype=torch.uint8), quality=-1)
  328. with pytest.raises(ValueError, match="Image quality should be a positive number between 1 and 100"):
  329. encode_jpeg(torch.empty((3, 100, 100), dtype=torch.uint8), quality=101)
  330. with pytest.raises(RuntimeError, match="The number of channels should be 1 or 3, got: 5"):
  331. encode_jpeg(torch.empty((5, 100, 100), dtype=torch.uint8))
  332. with pytest.raises(RuntimeError, match="Input data should be a 3-dimensional tensor"):
  333. encode_jpeg(torch.empty((1, 3, 100, 100), dtype=torch.uint8))
  334. with pytest.raises(RuntimeError, match="Input data should be a 3-dimensional tensor"):
  335. encode_jpeg(torch.empty((100, 100), dtype=torch.uint8))
  336. @pytest.mark.parametrize(
  337. "img_path",
  338. [pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path)) for jpeg_path in get_images(ENCODE_JPEG, ".jpg")],
  339. )
  340. def test_encode_jpeg(img_path):
  341. img = read_image(img_path)
  342. pil_img = F.to_pil_image(img)
  343. buf = io.BytesIO()
  344. pil_img.save(buf, format="JPEG", quality=75)
  345. encoded_jpeg_pil = torch.frombuffer(buf.getvalue(), dtype=torch.uint8)
  346. for src_img in [img, img.contiguous()]:
  347. encoded_jpeg_torch = encode_jpeg(src_img, quality=75)
  348. assert_equal(encoded_jpeg_torch, encoded_jpeg_pil)
  349. @pytest.mark.parametrize(
  350. "img_path",
  351. [pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path)) for jpeg_path in get_images(ENCODE_JPEG, ".jpg")],
  352. )
  353. def test_write_jpeg(img_path, tmpdir):
  354. tmpdir = Path(tmpdir)
  355. img = read_image(img_path)
  356. pil_img = F.to_pil_image(img)
  357. torch_jpeg = str(tmpdir / "torch.jpg")
  358. pil_jpeg = str(tmpdir / "pil.jpg")
  359. write_jpeg(img, torch_jpeg, quality=75)
  360. pil_img.save(pil_jpeg, quality=75)
  361. with open(torch_jpeg, "rb") as f:
  362. torch_bytes = f.read()
  363. with open(pil_jpeg, "rb") as f:
  364. pil_bytes = f.read()
  365. assert_equal(torch_bytes, pil_bytes)
  366. if __name__ == "__main__":
  367. pytest.main([__file__])