_functional_pil.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393
  1. import numbers
  2. from typing import Any, Dict, List, Literal, Optional, Sequence, Tuple, Union
  3. import numpy as np
  4. import torch
  5. from PIL import Image, ImageEnhance, ImageOps
  6. try:
  7. import accimage
  8. except ImportError:
  9. accimage = None
  10. @torch.jit.unused
  11. def _is_pil_image(img: Any) -> bool:
  12. if accimage is not None:
  13. return isinstance(img, (Image.Image, accimage.Image))
  14. else:
  15. return isinstance(img, Image.Image)
  16. @torch.jit.unused
  17. def get_dimensions(img: Any) -> List[int]:
  18. if _is_pil_image(img):
  19. if hasattr(img, "getbands"):
  20. channels = len(img.getbands())
  21. else:
  22. channels = img.channels
  23. width, height = img.size
  24. return [channels, height, width]
  25. raise TypeError(f"Unexpected type {type(img)}")
  26. @torch.jit.unused
  27. def get_image_size(img: Any) -> List[int]:
  28. if _is_pil_image(img):
  29. return list(img.size)
  30. raise TypeError(f"Unexpected type {type(img)}")
  31. @torch.jit.unused
  32. def get_image_num_channels(img: Any) -> int:
  33. if _is_pil_image(img):
  34. if hasattr(img, "getbands"):
  35. return len(img.getbands())
  36. else:
  37. return img.channels
  38. raise TypeError(f"Unexpected type {type(img)}")
  39. @torch.jit.unused
  40. def hflip(img: Image.Image) -> Image.Image:
  41. if not _is_pil_image(img):
  42. raise TypeError(f"img should be PIL Image. Got {type(img)}")
  43. return img.transpose(Image.FLIP_LEFT_RIGHT)
  44. @torch.jit.unused
  45. def vflip(img: Image.Image) -> Image.Image:
  46. if not _is_pil_image(img):
  47. raise TypeError(f"img should be PIL Image. Got {type(img)}")
  48. return img.transpose(Image.FLIP_TOP_BOTTOM)
  49. @torch.jit.unused
  50. def adjust_brightness(img: Image.Image, brightness_factor: float) -> Image.Image:
  51. if not _is_pil_image(img):
  52. raise TypeError(f"img should be PIL Image. Got {type(img)}")
  53. enhancer = ImageEnhance.Brightness(img)
  54. img = enhancer.enhance(brightness_factor)
  55. return img
  56. @torch.jit.unused
  57. def adjust_contrast(img: Image.Image, contrast_factor: float) -> Image.Image:
  58. if not _is_pil_image(img):
  59. raise TypeError(f"img should be PIL Image. Got {type(img)}")
  60. enhancer = ImageEnhance.Contrast(img)
  61. img = enhancer.enhance(contrast_factor)
  62. return img
  63. @torch.jit.unused
  64. def adjust_saturation(img: Image.Image, saturation_factor: float) -> Image.Image:
  65. if not _is_pil_image(img):
  66. raise TypeError(f"img should be PIL Image. Got {type(img)}")
  67. enhancer = ImageEnhance.Color(img)
  68. img = enhancer.enhance(saturation_factor)
  69. return img
  70. @torch.jit.unused
  71. def adjust_hue(img: Image.Image, hue_factor: float) -> Image.Image:
  72. if not (-0.5 <= hue_factor <= 0.5):
  73. raise ValueError(f"hue_factor ({hue_factor}) is not in [-0.5, 0.5].")
  74. if not _is_pil_image(img):
  75. raise TypeError(f"img should be PIL Image. Got {type(img)}")
  76. input_mode = img.mode
  77. if input_mode in {"L", "1", "I", "F"}:
  78. return img
  79. h, s, v = img.convert("HSV").split()
  80. np_h = np.array(h, dtype=np.uint8)
  81. # uint8 addition take cares of rotation across boundaries
  82. with np.errstate(over="ignore"):
  83. np_h += np.uint8(hue_factor * 255)
  84. h = Image.fromarray(np_h, "L")
  85. img = Image.merge("HSV", (h, s, v)).convert(input_mode)
  86. return img
  87. @torch.jit.unused
  88. def adjust_gamma(
  89. img: Image.Image,
  90. gamma: float,
  91. gain: float = 1.0,
  92. ) -> Image.Image:
  93. if not _is_pil_image(img):
  94. raise TypeError(f"img should be PIL Image. Got {type(img)}")
  95. if gamma < 0:
  96. raise ValueError("Gamma should be a non-negative real number")
  97. input_mode = img.mode
  98. img = img.convert("RGB")
  99. gamma_map = [int((255 + 1 - 1e-3) * gain * pow(ele / 255.0, gamma)) for ele in range(256)] * 3
  100. img = img.point(gamma_map) # use PIL's point-function to accelerate this part
  101. img = img.convert(input_mode)
  102. return img
  103. @torch.jit.unused
  104. def pad(
  105. img: Image.Image,
  106. padding: Union[int, List[int], Tuple[int, ...]],
  107. fill: Optional[Union[float, List[float], Tuple[float, ...]]] = 0,
  108. padding_mode: Literal["constant", "edge", "reflect", "symmetric"] = "constant",
  109. ) -> Image.Image:
  110. if not _is_pil_image(img):
  111. raise TypeError(f"img should be PIL Image. Got {type(img)}")
  112. if not isinstance(padding, (numbers.Number, tuple, list)):
  113. raise TypeError("Got inappropriate padding arg")
  114. if fill is not None and not isinstance(fill, (numbers.Number, tuple, list)):
  115. raise TypeError("Got inappropriate fill arg")
  116. if not isinstance(padding_mode, str):
  117. raise TypeError("Got inappropriate padding_mode arg")
  118. if isinstance(padding, list):
  119. padding = tuple(padding)
  120. if isinstance(padding, tuple) and len(padding) not in [1, 2, 4]:
  121. raise ValueError(f"Padding must be an int or a 1, 2, or 4 element tuple, not a {len(padding)} element tuple")
  122. if isinstance(padding, tuple) and len(padding) == 1:
  123. # Compatibility with `functional_tensor.pad`
  124. padding = padding[0]
  125. if padding_mode not in ["constant", "edge", "reflect", "symmetric"]:
  126. raise ValueError("Padding mode should be either constant, edge, reflect or symmetric")
  127. if padding_mode == "constant":
  128. opts = _parse_fill(fill, img, name="fill")
  129. if img.mode == "P":
  130. palette = img.getpalette()
  131. image = ImageOps.expand(img, border=padding, **opts)
  132. image.putpalette(palette)
  133. return image
  134. return ImageOps.expand(img, border=padding, **opts)
  135. else:
  136. if isinstance(padding, int):
  137. pad_left = pad_right = pad_top = pad_bottom = padding
  138. if isinstance(padding, tuple) and len(padding) == 2:
  139. pad_left = pad_right = padding[0]
  140. pad_top = pad_bottom = padding[1]
  141. if isinstance(padding, tuple) and len(padding) == 4:
  142. pad_left = padding[0]
  143. pad_top = padding[1]
  144. pad_right = padding[2]
  145. pad_bottom = padding[3]
  146. p = [pad_left, pad_top, pad_right, pad_bottom]
  147. cropping = -np.minimum(p, 0)
  148. if cropping.any():
  149. crop_left, crop_top, crop_right, crop_bottom = cropping
  150. img = img.crop((crop_left, crop_top, img.width - crop_right, img.height - crop_bottom))
  151. pad_left, pad_top, pad_right, pad_bottom = np.maximum(p, 0)
  152. if img.mode == "P":
  153. palette = img.getpalette()
  154. img = np.asarray(img)
  155. img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right)), mode=padding_mode)
  156. img = Image.fromarray(img)
  157. img.putpalette(palette)
  158. return img
  159. img = np.asarray(img)
  160. # RGB image
  161. if len(img.shape) == 3:
  162. img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right), (0, 0)), padding_mode)
  163. # Grayscale image
  164. if len(img.shape) == 2:
  165. img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right)), padding_mode)
  166. return Image.fromarray(img)
  167. @torch.jit.unused
  168. def crop(
  169. img: Image.Image,
  170. top: int,
  171. left: int,
  172. height: int,
  173. width: int,
  174. ) -> Image.Image:
  175. if not _is_pil_image(img):
  176. raise TypeError(f"img should be PIL Image. Got {type(img)}")
  177. return img.crop((left, top, left + width, top + height))
  178. @torch.jit.unused
  179. def resize(
  180. img: Image.Image,
  181. size: Union[List[int], int],
  182. interpolation: int = Image.BILINEAR,
  183. ) -> Image.Image:
  184. if not _is_pil_image(img):
  185. raise TypeError(f"img should be PIL Image. Got {type(img)}")
  186. if not (isinstance(size, list) and len(size) == 2):
  187. raise TypeError(f"Got inappropriate size arg: {size}")
  188. return img.resize(tuple(size[::-1]), interpolation)
  189. @torch.jit.unused
  190. def _parse_fill(
  191. fill: Optional[Union[float, List[float], Tuple[float, ...]]],
  192. img: Image.Image,
  193. name: str = "fillcolor",
  194. ) -> Dict[str, Optional[Union[float, List[float], Tuple[float, ...]]]]:
  195. # Process fill color for affine transforms
  196. num_channels = get_image_num_channels(img)
  197. if fill is None:
  198. fill = 0
  199. if isinstance(fill, (int, float)) and num_channels > 1:
  200. fill = tuple([fill] * num_channels)
  201. if isinstance(fill, (list, tuple)):
  202. if len(fill) == 1:
  203. fill = fill * num_channels
  204. elif len(fill) != num_channels:
  205. msg = "The number of elements in 'fill' does not match the number of channels of the image ({} != {})"
  206. raise ValueError(msg.format(len(fill), num_channels))
  207. fill = tuple(fill) # type: ignore[arg-type]
  208. if img.mode != "F":
  209. if isinstance(fill, (list, tuple)):
  210. fill = tuple(int(x) for x in fill)
  211. else:
  212. fill = int(fill)
  213. return {name: fill}
  214. @torch.jit.unused
  215. def affine(
  216. img: Image.Image,
  217. matrix: List[float],
  218. interpolation: int = Image.NEAREST,
  219. fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
  220. ) -> Image.Image:
  221. if not _is_pil_image(img):
  222. raise TypeError(f"img should be PIL Image. Got {type(img)}")
  223. output_size = img.size
  224. opts = _parse_fill(fill, img)
  225. return img.transform(output_size, Image.AFFINE, matrix, interpolation, **opts)
  226. @torch.jit.unused
  227. def rotate(
  228. img: Image.Image,
  229. angle: float,
  230. interpolation: int = Image.NEAREST,
  231. expand: bool = False,
  232. center: Optional[Tuple[int, int]] = None,
  233. fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
  234. ) -> Image.Image:
  235. if not _is_pil_image(img):
  236. raise TypeError(f"img should be PIL Image. Got {type(img)}")
  237. opts = _parse_fill(fill, img)
  238. return img.rotate(angle, interpolation, expand, center, **opts)
  239. @torch.jit.unused
  240. def perspective(
  241. img: Image.Image,
  242. perspective_coeffs: List[float],
  243. interpolation: int = Image.BICUBIC,
  244. fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
  245. ) -> Image.Image:
  246. if not _is_pil_image(img):
  247. raise TypeError(f"img should be PIL Image. Got {type(img)}")
  248. opts = _parse_fill(fill, img)
  249. return img.transform(img.size, Image.PERSPECTIVE, perspective_coeffs, interpolation, **opts)
  250. @torch.jit.unused
  251. def to_grayscale(img: Image.Image, num_output_channels: int) -> Image.Image:
  252. if not _is_pil_image(img):
  253. raise TypeError(f"img should be PIL Image. Got {type(img)}")
  254. if num_output_channels == 1:
  255. img = img.convert("L")
  256. elif num_output_channels == 3:
  257. img = img.convert("L")
  258. np_img = np.array(img, dtype=np.uint8)
  259. np_img = np.dstack([np_img, np_img, np_img])
  260. img = Image.fromarray(np_img, "RGB")
  261. else:
  262. raise ValueError("num_output_channels should be either 1 or 3")
  263. return img
  264. @torch.jit.unused
  265. def invert(img: Image.Image) -> Image.Image:
  266. if not _is_pil_image(img):
  267. raise TypeError(f"img should be PIL Image. Got {type(img)}")
  268. return ImageOps.invert(img)
  269. @torch.jit.unused
  270. def posterize(img: Image.Image, bits: int) -> Image.Image:
  271. if not _is_pil_image(img):
  272. raise TypeError(f"img should be PIL Image. Got {type(img)}")
  273. return ImageOps.posterize(img, bits)
  274. @torch.jit.unused
  275. def solarize(img: Image.Image, threshold: int) -> Image.Image:
  276. if not _is_pil_image(img):
  277. raise TypeError(f"img should be PIL Image. Got {type(img)}")
  278. return ImageOps.solarize(img, threshold)
  279. @torch.jit.unused
  280. def adjust_sharpness(img: Image.Image, sharpness_factor: float) -> Image.Image:
  281. if not _is_pil_image(img):
  282. raise TypeError(f"img should be PIL Image. Got {type(img)}")
  283. enhancer = ImageEnhance.Sharpness(img)
  284. img = enhancer.enhance(sharpness_factor)
  285. return img
  286. @torch.jit.unused
  287. def autocontrast(img: Image.Image) -> Image.Image:
  288. if not _is_pil_image(img):
  289. raise TypeError(f"img should be PIL Image. Got {type(img)}")
  290. return ImageOps.autocontrast(img)
  291. @torch.jit.unused
  292. def equalize(img: Image.Image) -> Image.Image:
  293. if not _is_pil_image(img):
  294. raise TypeError(f"img should be PIL Image. Got {type(img)}")
  295. return ImageOps.equalize(img)