123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264 |
- from enum import Enum
- from warnings import warn
- import torch
- from ..extension import _load_library
- from ..utils import _log_api_usage_once
- try:
- _load_library("image")
- except (ImportError, OSError) as e:
- warn(
- f"Failed to load image Python extension: '{e}'"
- f"If you don't plan on using image functionality from `torchvision.io`, you can ignore this warning. "
- f"Otherwise, there might be something wrong with your environment. "
- f"Did you have `libjpeg` or `libpng` installed before building `torchvision` from source?"
- )
- class ImageReadMode(Enum):
- """
- Support for various modes while reading images.
- Use ``ImageReadMode.UNCHANGED`` for loading the image as-is,
- ``ImageReadMode.GRAY`` for converting to grayscale,
- ``ImageReadMode.GRAY_ALPHA`` for grayscale with transparency,
- ``ImageReadMode.RGB`` for RGB and ``ImageReadMode.RGB_ALPHA`` for
- RGB with transparency.
- """
- UNCHANGED = 0
- GRAY = 1
- GRAY_ALPHA = 2
- RGB = 3
- RGB_ALPHA = 4
- def read_file(path: str) -> torch.Tensor:
- """
- Reads and outputs the bytes contents of a file as a uint8 Tensor
- with one dimension.
- Args:
- path (str): the path to the file to be read
- Returns:
- data (Tensor)
- """
- if not torch.jit.is_scripting() and not torch.jit.is_tracing():
- _log_api_usage_once(read_file)
- data = torch.ops.image.read_file(path)
- return data
- def write_file(filename: str, data: torch.Tensor) -> None:
- """
- Writes the contents of an uint8 tensor with one dimension to a
- file.
- Args:
- filename (str): the path to the file to be written
- data (Tensor): the contents to be written to the output file
- """
- if not torch.jit.is_scripting() and not torch.jit.is_tracing():
- _log_api_usage_once(write_file)
- torch.ops.image.write_file(filename, data)
- def decode_png(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANGED) -> torch.Tensor:
- """
- Decodes a PNG image into a 3 dimensional RGB or grayscale Tensor.
- Optionally converts the image to the desired format.
- The values of the output tensor are uint8 in [0, 255].
- Args:
- input (Tensor[1]): a one dimensional uint8 tensor containing
- the raw bytes of the PNG image.
- mode (ImageReadMode): the read mode used for optionally
- converting the image. Default: ``ImageReadMode.UNCHANGED``.
- See `ImageReadMode` class for more information on various
- available modes.
- Returns:
- output (Tensor[image_channels, image_height, image_width])
- """
- if not torch.jit.is_scripting() and not torch.jit.is_tracing():
- _log_api_usage_once(decode_png)
- output = torch.ops.image.decode_png(input, mode.value, False)
- return output
- def encode_png(input: torch.Tensor, compression_level: int = 6) -> torch.Tensor:
- """
- Takes an input tensor in CHW layout and returns a buffer with the contents
- of its corresponding PNG file.
- Args:
- input (Tensor[channels, image_height, image_width]): int8 image tensor of
- ``c`` channels, where ``c`` must 3 or 1.
- compression_level (int): Compression factor for the resulting file, it must be a number
- between 0 and 9. Default: 6
- Returns:
- Tensor[1]: A one dimensional int8 tensor that contains the raw bytes of the
- PNG file.
- """
- if not torch.jit.is_scripting() and not torch.jit.is_tracing():
- _log_api_usage_once(encode_png)
- output = torch.ops.image.encode_png(input, compression_level)
- return output
- def write_png(input: torch.Tensor, filename: str, compression_level: int = 6):
- """
- Takes an input tensor in CHW layout (or HW in the case of grayscale images)
- and saves it in a PNG file.
- Args:
- input (Tensor[channels, image_height, image_width]): int8 image tensor of
- ``c`` channels, where ``c`` must be 1 or 3.
- filename (str): Path to save the image.
- compression_level (int): Compression factor for the resulting file, it must be a number
- between 0 and 9. Default: 6
- """
- if not torch.jit.is_scripting() and not torch.jit.is_tracing():
- _log_api_usage_once(write_png)
- output = encode_png(input, compression_level)
- write_file(filename, output)
- def decode_jpeg(
- input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANGED, device: str = "cpu"
- ) -> torch.Tensor:
- """
- Decodes a JPEG image into a 3 dimensional RGB or grayscale Tensor.
- Optionally converts the image to the desired format.
- The values of the output tensor are uint8 between 0 and 255.
- Args:
- input (Tensor[1]): a one dimensional uint8 tensor containing
- the raw bytes of the JPEG image. This tensor must be on CPU,
- regardless of the ``device`` parameter.
- mode (ImageReadMode): the read mode used for optionally
- converting the image. The supported modes are: ``ImageReadMode.UNCHANGED``,
- ``ImageReadMode.GRAY`` and ``ImageReadMode.RGB``
- Default: ``ImageReadMode.UNCHANGED``.
- See ``ImageReadMode`` class for more information on various
- available modes.
- device (str or torch.device): The device on which the decoded image will
- be stored. If a cuda device is specified, the image will be decoded
- with `nvjpeg <https://developer.nvidia.com/nvjpeg>`_. This is only
- supported for CUDA version >= 10.1
- .. betastatus:: device parameter
- .. warning::
- There is a memory leak in the nvjpeg library for CUDA versions < 11.6.
- Make sure to rely on CUDA 11.6 or above before using ``device="cuda"``.
- Returns:
- output (Tensor[image_channels, image_height, image_width])
- """
- if not torch.jit.is_scripting() and not torch.jit.is_tracing():
- _log_api_usage_once(decode_jpeg)
- device = torch.device(device)
- if device.type == "cuda":
- output = torch.ops.image.decode_jpeg_cuda(input, mode.value, device)
- else:
- output = torch.ops.image.decode_jpeg(input, mode.value)
- return output
- def encode_jpeg(input: torch.Tensor, quality: int = 75) -> torch.Tensor:
- """
- Takes an input tensor in CHW layout and returns a buffer with the contents
- of its corresponding JPEG file.
- Args:
- input (Tensor[channels, image_height, image_width])): int8 image tensor of
- ``c`` channels, where ``c`` must be 1 or 3.
- quality (int): Quality of the resulting JPEG file, it must be a number between
- 1 and 100. Default: 75
- Returns:
- output (Tensor[1]): A one dimensional int8 tensor that contains the raw bytes of the
- JPEG file.
- """
- if not torch.jit.is_scripting() and not torch.jit.is_tracing():
- _log_api_usage_once(encode_jpeg)
- if quality < 1 or quality > 100:
- raise ValueError("Image quality should be a positive number between 1 and 100")
- output = torch.ops.image.encode_jpeg(input, quality)
- return output
- def write_jpeg(input: torch.Tensor, filename: str, quality: int = 75):
- """
- Takes an input tensor in CHW layout and saves it in a JPEG file.
- Args:
- input (Tensor[channels, image_height, image_width]): int8 image tensor of ``c``
- channels, where ``c`` must be 1 or 3.
- filename (str): Path to save the image.
- quality (int): Quality of the resulting JPEG file, it must be a number
- between 1 and 100. Default: 75
- """
- if not torch.jit.is_scripting() and not torch.jit.is_tracing():
- _log_api_usage_once(write_jpeg)
- output = encode_jpeg(input, quality)
- write_file(filename, output)
- def decode_image(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANGED) -> torch.Tensor:
- """
- Detects whether an image is a JPEG or PNG and performs the appropriate
- operation to decode the image into a 3 dimensional RGB or grayscale Tensor.
- Optionally converts the image to the desired format.
- The values of the output tensor are uint8 in [0, 255].
- Args:
- input (Tensor): a one dimensional uint8 tensor containing the raw bytes of the
- PNG or JPEG image.
- mode (ImageReadMode): the read mode used for optionally converting the image.
- Default: ``ImageReadMode.UNCHANGED``.
- See ``ImageReadMode`` class for more information on various
- available modes.
- Returns:
- output (Tensor[image_channels, image_height, image_width])
- """
- if not torch.jit.is_scripting() and not torch.jit.is_tracing():
- _log_api_usage_once(decode_image)
- output = torch.ops.image.decode_image(input, mode.value)
- return output
- def read_image(path: str, mode: ImageReadMode = ImageReadMode.UNCHANGED) -> torch.Tensor:
- """
- Reads a JPEG or PNG image into a 3 dimensional RGB or grayscale Tensor.
- Optionally converts the image to the desired format.
- The values of the output tensor are uint8 in [0, 255].
- Args:
- path (str): path of the JPEG or PNG image.
- mode (ImageReadMode): the read mode used for optionally converting the image.
- Default: ``ImageReadMode.UNCHANGED``.
- See ``ImageReadMode`` class for more information on various
- available modes.
- Returns:
- output (Tensor[image_channels, image_height, image_width])
- """
- if not torch.jit.is_scripting() and not torch.jit.is_tracing():
- _log_api_usage_once(read_image)
- data = read_file(path)
- return decode_image(data, mode)
- def _read_png_16(path: str, mode: ImageReadMode = ImageReadMode.UNCHANGED) -> torch.Tensor:
- data = read_file(path)
- return torch.ops.image.decode_png(data, mode.value, True)
|