|
- import glob
- import io
- import os
- import sys
- from pathlib import Path
- import numpy as np
- import pytest
- import torch
- import torchvision.transforms.functional as F
- from common_utils import assert_equal, needs_cuda
- from PIL import __version__ as PILLOW_VERSION, Image
- from torchvision.io.image import (
- _read_png_16,
- decode_image,
- decode_jpeg,
- decode_png,
- encode_jpeg,
- encode_png,
- ImageReadMode,
- read_file,
- read_image,
- write_file,
- write_jpeg,
- write_png,
- )
- IMAGE_ROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets")
- FAKEDATA_DIR = os.path.join(IMAGE_ROOT, "fakedata")
- IMAGE_DIR = os.path.join(FAKEDATA_DIR, "imagefolder")
- DAMAGED_JPEG = os.path.join(IMAGE_ROOT, "damaged_jpeg")
- DAMAGED_PNG = os.path.join(IMAGE_ROOT, "damaged_png")
- ENCODE_JPEG = os.path.join(IMAGE_ROOT, "encode_jpeg")
- INTERLACED_PNG = os.path.join(IMAGE_ROOT, "interlaced_png")
- TOOSMALL_PNG = os.path.join(IMAGE_ROOT, "toosmall_png")
- IS_WINDOWS = sys.platform in ("win32", "cygwin")
- PILLOW_VERSION = tuple(int(x) for x in PILLOW_VERSION.split("."))
- def _get_safe_image_name(name):
- # Used when we need to change the pytest "id" for an "image path" parameter.
- # If we don't, the test id (i.e. its name) will contain the whole path to the image, which is machine-specific,
- # and this creates issues when the test is running in a different machine than where it was collected
- # (typically, in fb internal infra)
- return name.split(os.path.sep)[-1]
- def get_images(directory, img_ext):
- assert os.path.isdir(directory)
- image_paths = glob.glob(directory + f"/**/*{img_ext}", recursive=True)
- for path in image_paths:
- if path.split(os.sep)[-2] not in ["damaged_jpeg", "jpeg_write"]:
- yield path
- def pil_read_image(img_path):
- with Image.open(img_path) as img:
- return torch.from_numpy(np.array(img))
- def normalize_dimensions(img_pil):
- if len(img_pil.shape) == 3:
- img_pil = img_pil.permute(2, 0, 1)
- else:
- img_pil = img_pil.unsqueeze(0)
- return img_pil
- @pytest.mark.parametrize(
- "img_path",
- [pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path)) for jpeg_path in get_images(IMAGE_ROOT, ".jpg")],
- )
- @pytest.mark.parametrize(
- "pil_mode, mode",
- [
- (None, ImageReadMode.UNCHANGED),
- ("L", ImageReadMode.GRAY),
- ("RGB", ImageReadMode.RGB),
- ],
- )
- def test_decode_jpeg(img_path, pil_mode, mode):
- with Image.open(img_path) as img:
- is_cmyk = img.mode == "CMYK"
- if pil_mode is not None:
- img = img.convert(pil_mode)
- img_pil = torch.from_numpy(np.array(img))
- if is_cmyk and mode == ImageReadMode.UNCHANGED:
- # flip the colors to match libjpeg
- img_pil = 255 - img_pil
- img_pil = normalize_dimensions(img_pil)
- data = read_file(img_path)
- img_ljpeg = decode_image(data, mode=mode)
- # Permit a small variation on pixel values to account for implementation
- # differences between Pillow and LibJPEG.
- abs_mean_diff = (img_ljpeg.type(torch.float32) - img_pil).abs().mean().item()
- assert abs_mean_diff < 2
- def test_decode_jpeg_errors():
- with pytest.raises(RuntimeError, match="Expected a non empty 1-dimensional tensor"):
- decode_jpeg(torch.empty((100, 1), dtype=torch.uint8))
- with pytest.raises(RuntimeError, match="Expected a torch.uint8 tensor"):
- decode_jpeg(torch.empty((100,), dtype=torch.float16))
- with pytest.raises(RuntimeError, match="Not a JPEG file"):
- decode_jpeg(torch.empty((100), dtype=torch.uint8))
- def test_decode_bad_huffman_images():
- # sanity check: make sure we can decode the bad Huffman encoding
- bad_huff = read_file(os.path.join(DAMAGED_JPEG, "bad_huffman.jpg"))
- decode_jpeg(bad_huff)
- @pytest.mark.parametrize(
- "img_path",
- [
- pytest.param(truncated_image, id=_get_safe_image_name(truncated_image))
- for truncated_image in glob.glob(os.path.join(DAMAGED_JPEG, "corrupt*.jpg"))
- ],
- )
- def test_damaged_corrupt_images(img_path):
- # Truncated images should raise an exception
- data = read_file(img_path)
- if "corrupt34" in img_path:
- match_message = "Image is incomplete or truncated"
- else:
- match_message = "Unsupported marker type"
- with pytest.raises(RuntimeError, match=match_message):
- decode_jpeg(data)
- @pytest.mark.parametrize(
- "img_path",
- [pytest.param(png_path, id=_get_safe_image_name(png_path)) for png_path in get_images(FAKEDATA_DIR, ".png")],
- )
- @pytest.mark.parametrize(
- "pil_mode, mode",
- [
- (None, ImageReadMode.UNCHANGED),
- ("L", ImageReadMode.GRAY),
- ("LA", ImageReadMode.GRAY_ALPHA),
- ("RGB", ImageReadMode.RGB),
- ("RGBA", ImageReadMode.RGB_ALPHA),
- ],
- )
- def test_decode_png(img_path, pil_mode, mode):
- with Image.open(img_path) as img:
- if pil_mode is not None:
- img = img.convert(pil_mode)
- img_pil = torch.from_numpy(np.array(img))
- img_pil = normalize_dimensions(img_pil)
- if img_path.endswith("16.png"):
- # 16 bits image decoding is supported, but only as a private API
- # FIXME: see https://github.com/pytorch/vision/issues/4731 for potential solutions to making it public
- with pytest.raises(RuntimeError, match="At most 8-bit PNG images are supported"):
- data = read_file(img_path)
- img_lpng = decode_image(data, mode=mode)
- img_lpng = _read_png_16(img_path, mode=mode)
- assert img_lpng.dtype == torch.int32
- # PIL converts 16 bits pngs in uint8
- img_lpng = torch.round(img_lpng / (2**16 - 1) * 255).to(torch.uint8)
- else:
- data = read_file(img_path)
- img_lpng = decode_image(data, mode=mode)
- tol = 0 if pil_mode is None else 1
- if PILLOW_VERSION >= (8, 3) and pil_mode == "LA":
- # Avoid checking the transparency channel until
- # https://github.com/python-pillow/Pillow/issues/5593#issuecomment-878244910
- # is fixed.
- # TODO: remove once fix is released in PIL. Should be > 8.3.1.
- img_lpng, img_pil = img_lpng[0], img_pil[0]
- torch.testing.assert_close(img_lpng, img_pil, atol=tol, rtol=0)
- def test_decode_png_errors():
- with pytest.raises(RuntimeError, match="Expected a non empty 1-dimensional tensor"):
- decode_png(torch.empty((), dtype=torch.uint8))
- with pytest.raises(RuntimeError, match="Content is not png"):
- decode_png(torch.randint(3, 5, (300,), dtype=torch.uint8))
- with pytest.raises(RuntimeError, match="Out of bound read in decode_png"):
- decode_png(read_file(os.path.join(DAMAGED_PNG, "sigsegv.png")))
- with pytest.raises(RuntimeError, match="Content is too small for png"):
- decode_png(read_file(os.path.join(TOOSMALL_PNG, "heapbof.png")))
- @pytest.mark.parametrize(
- "img_path",
- [pytest.param(png_path, id=_get_safe_image_name(png_path)) for png_path in get_images(IMAGE_DIR, ".png")],
- )
- def test_encode_png(img_path):
- pil_image = Image.open(img_path)
- img_pil = torch.from_numpy(np.array(pil_image))
- img_pil = img_pil.permute(2, 0, 1)
- png_buf = encode_png(img_pil, compression_level=6)
- rec_img = Image.open(io.BytesIO(bytes(png_buf.tolist())))
- rec_img = torch.from_numpy(np.array(rec_img))
- rec_img = rec_img.permute(2, 0, 1)
- assert_equal(img_pil, rec_img)
- def test_encode_png_errors():
- with pytest.raises(RuntimeError, match="Input tensor dtype should be uint8"):
- encode_png(torch.empty((3, 100, 100), dtype=torch.float32))
- with pytest.raises(RuntimeError, match="Compression level should be between 0 and 9"):
- encode_png(torch.empty((3, 100, 100), dtype=torch.uint8), compression_level=-1)
- with pytest.raises(RuntimeError, match="Compression level should be between 0 and 9"):
- encode_png(torch.empty((3, 100, 100), dtype=torch.uint8), compression_level=10)
- with pytest.raises(RuntimeError, match="The number of channels should be 1 or 3, got: 5"):
- encode_png(torch.empty((5, 100, 100), dtype=torch.uint8))
- @pytest.mark.parametrize(
- "img_path",
- [pytest.param(png_path, id=_get_safe_image_name(png_path)) for png_path in get_images(IMAGE_DIR, ".png")],
- )
- def test_write_png(img_path, tmpdir):
- pil_image = Image.open(img_path)
- img_pil = torch.from_numpy(np.array(pil_image))
- img_pil = img_pil.permute(2, 0, 1)
- filename, _ = os.path.splitext(os.path.basename(img_path))
- torch_png = os.path.join(tmpdir, f"{filename}_torch.png")
- write_png(img_pil, torch_png, compression_level=6)
- saved_image = torch.from_numpy(np.array(Image.open(torch_png)))
- saved_image = saved_image.permute(2, 0, 1)
- assert_equal(img_pil, saved_image)
- def test_read_file(tmpdir):
- fname, content = "test1.bin", b"TorchVision\211\n"
- fpath = os.path.join(tmpdir, fname)
- with open(fpath, "wb") as f:
- f.write(content)
- data = read_file(fpath)
- expected = torch.tensor(list(content), dtype=torch.uint8)
- os.unlink(fpath)
- assert_equal(data, expected)
- with pytest.raises(RuntimeError, match="No such file or directory: 'tst'"):
- read_file("tst")
- def test_read_file_non_ascii(tmpdir):
- fname, content = "日本語(Japanese).bin", b"TorchVision\211\n"
- fpath = os.path.join(tmpdir, fname)
- with open(fpath, "wb") as f:
- f.write(content)
- data = read_file(fpath)
- expected = torch.tensor(list(content), dtype=torch.uint8)
- os.unlink(fpath)
- assert_equal(data, expected)
- def test_write_file(tmpdir):
- fname, content = "test1.bin", b"TorchVision\211\n"
- fpath = os.path.join(tmpdir, fname)
- content_tensor = torch.tensor(list(content), dtype=torch.uint8)
- write_file(fpath, content_tensor)
- with open(fpath, "rb") as f:
- saved_content = f.read()
- os.unlink(fpath)
- assert content == saved_content
- def test_write_file_non_ascii(tmpdir):
- fname, content = "日本語(Japanese).bin", b"TorchVision\211\n"
- fpath = os.path.join(tmpdir, fname)
- content_tensor = torch.tensor(list(content), dtype=torch.uint8)
- write_file(fpath, content_tensor)
- with open(fpath, "rb") as f:
- saved_content = f.read()
- os.unlink(fpath)
- assert content == saved_content
- @pytest.mark.parametrize(
- "shape",
- [
- (27, 27),
- (60, 60),
- (105, 105),
- ],
- )
- def test_read_1_bit_png(shape, tmpdir):
- np_rng = np.random.RandomState(0)
- image_path = os.path.join(tmpdir, f"test_{shape}.png")
- pixels = np_rng.rand(*shape) > 0.5
- img = Image.fromarray(pixels)
- img.save(image_path)
- img1 = read_image(image_path)
- img2 = normalize_dimensions(torch.as_tensor(pixels * 255, dtype=torch.uint8))
- assert_equal(img1, img2)
- @pytest.mark.parametrize(
- "shape",
- [
- (27, 27),
- (60, 60),
- (105, 105),
- ],
- )
- @pytest.mark.parametrize(
- "mode",
- [
- ImageReadMode.UNCHANGED,
- ImageReadMode.GRAY,
- ],
- )
- def test_read_1_bit_png_consistency(shape, mode, tmpdir):
- np_rng = np.random.RandomState(0)
- image_path = os.path.join(tmpdir, f"test_{shape}.png")
- pixels = np_rng.rand(*shape) > 0.5
- img = Image.fromarray(pixels)
- img.save(image_path)
- img1 = read_image(image_path, mode)
- img2 = read_image(image_path, mode)
- assert_equal(img1, img2)
- def test_read_interlaced_png():
- imgs = list(get_images(INTERLACED_PNG, ".png"))
- with Image.open(imgs[0]) as im1, Image.open(imgs[1]) as im2:
- assert not (im1.info.get("interlace") is im2.info.get("interlace"))
- img1 = read_image(imgs[0])
- img2 = read_image(imgs[1])
- assert_equal(img1, img2)
- @needs_cuda
- @pytest.mark.parametrize(
- "img_path",
- [pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path)) for jpeg_path in get_images(IMAGE_ROOT, ".jpg")],
- )
- @pytest.mark.parametrize("mode", [ImageReadMode.UNCHANGED, ImageReadMode.GRAY, ImageReadMode.RGB])
- @pytest.mark.parametrize("scripted", (False, True))
- def test_decode_jpeg_cuda(mode, img_path, scripted):
- if "cmyk" in img_path:
- pytest.xfail("Decoding a CMYK jpeg isn't supported")
- data = read_file(img_path)
- img = decode_image(data, mode=mode)
- f = torch.jit.script(decode_jpeg) if scripted else decode_jpeg
- img_nvjpeg = f(data, mode=mode, device="cuda")
- # Some difference expected between jpeg implementations
- assert (img.float() - img_nvjpeg.cpu().float()).abs().mean() < 2
- @needs_cuda
- def test_decode_image_cuda_raises():
- data = torch.randint(0, 127, size=(255,), device="cuda", dtype=torch.uint8)
- with pytest.raises(RuntimeError):
- decode_image(data)
- @needs_cuda
- @pytest.mark.parametrize("cuda_device", ("cuda", "cuda:0", torch.device("cuda")))
- def test_decode_jpeg_cuda_device_param(cuda_device):
- """Make sure we can pass a string or a torch.device as device param"""
- path = next(path for path in get_images(IMAGE_ROOT, ".jpg") if "cmyk" not in path)
- data = read_file(path)
- decode_jpeg(data, device=cuda_device)
- @needs_cuda
- def test_decode_jpeg_cuda_errors():
- data = read_file(next(get_images(IMAGE_ROOT, ".jpg")))
- with pytest.raises(RuntimeError, match="Expected a non empty 1-dimensional tensor"):
- decode_jpeg(data.reshape(-1, 1), device="cuda")
- with pytest.raises(RuntimeError, match="input tensor must be on CPU"):
- decode_jpeg(data.to("cuda"), device="cuda")
- with pytest.raises(RuntimeError, match="Expected a torch.uint8 tensor"):
- decode_jpeg(data.to(torch.float), device="cuda")
- with pytest.raises(RuntimeError, match="Expected a cuda device"):
- torch.ops.image.decode_jpeg_cuda(data, ImageReadMode.UNCHANGED.value, "cpu")
- def test_encode_jpeg_errors():
- with pytest.raises(RuntimeError, match="Input tensor dtype should be uint8"):
- encode_jpeg(torch.empty((3, 100, 100), dtype=torch.float32))
- with pytest.raises(ValueError, match="Image quality should be a positive number between 1 and 100"):
- encode_jpeg(torch.empty((3, 100, 100), dtype=torch.uint8), quality=-1)
- with pytest.raises(ValueError, match="Image quality should be a positive number between 1 and 100"):
- encode_jpeg(torch.empty((3, 100, 100), dtype=torch.uint8), quality=101)
- with pytest.raises(RuntimeError, match="The number of channels should be 1 or 3, got: 5"):
- encode_jpeg(torch.empty((5, 100, 100), dtype=torch.uint8))
- with pytest.raises(RuntimeError, match="Input data should be a 3-dimensional tensor"):
- encode_jpeg(torch.empty((1, 3, 100, 100), dtype=torch.uint8))
- with pytest.raises(RuntimeError, match="Input data should be a 3-dimensional tensor"):
- encode_jpeg(torch.empty((100, 100), dtype=torch.uint8))
- @pytest.mark.parametrize(
- "img_path",
- [pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path)) for jpeg_path in get_images(ENCODE_JPEG, ".jpg")],
- )
- def test_encode_jpeg(img_path):
- img = read_image(img_path)
- pil_img = F.to_pil_image(img)
- buf = io.BytesIO()
- pil_img.save(buf, format="JPEG", quality=75)
- encoded_jpeg_pil = torch.frombuffer(buf.getvalue(), dtype=torch.uint8)
- for src_img in [img, img.contiguous()]:
- encoded_jpeg_torch = encode_jpeg(src_img, quality=75)
- assert_equal(encoded_jpeg_torch, encoded_jpeg_pil)
- @pytest.mark.parametrize(
- "img_path",
- [pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path)) for jpeg_path in get_images(ENCODE_JPEG, ".jpg")],
- )
- def test_write_jpeg(img_path, tmpdir):
- tmpdir = Path(tmpdir)
- img = read_image(img_path)
- pil_img = F.to_pil_image(img)
- torch_jpeg = str(tmpdir / "torch.jpg")
- pil_jpeg = str(tmpdir / "pil.jpg")
- write_jpeg(img, torch_jpeg, quality=75)
- pil_img.save(pil_jpeg, quality=75)
- with open(torch_jpeg, "rb") as f:
- torch_bytes = f.read()
- with open(pil_jpeg, "rb") as f:
- pil_bytes = f.read()
- assert_equal(torch_bytes, pil_bytes)
- if __name__ == "__main__":
- pytest.main([__file__])
|