_functional_tensor.py 33 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962
  1. import warnings
  2. from typing import List, Optional, Tuple, Union
  3. import torch
  4. from torch import Tensor
  5. from torch.nn.functional import conv2d, grid_sample, interpolate, pad as torch_pad
  6. def _is_tensor_a_torch_image(x: Tensor) -> bool:
  7. return x.ndim >= 2
  8. def _assert_image_tensor(img: Tensor) -> None:
  9. if not _is_tensor_a_torch_image(img):
  10. raise TypeError("Tensor is not a torch image.")
  11. def get_dimensions(img: Tensor) -> List[int]:
  12. _assert_image_tensor(img)
  13. channels = 1 if img.ndim == 2 else img.shape[-3]
  14. height, width = img.shape[-2:]
  15. return [channels, height, width]
  16. def get_image_size(img: Tensor) -> List[int]:
  17. # Returns (w, h) of tensor image
  18. _assert_image_tensor(img)
  19. return [img.shape[-1], img.shape[-2]]
  20. def get_image_num_channels(img: Tensor) -> int:
  21. _assert_image_tensor(img)
  22. if img.ndim == 2:
  23. return 1
  24. elif img.ndim > 2:
  25. return img.shape[-3]
  26. raise TypeError(f"Input ndim should be 2 or more. Got {img.ndim}")
  27. def _max_value(dtype: torch.dtype) -> int:
  28. if dtype == torch.uint8:
  29. return 255
  30. elif dtype == torch.int8:
  31. return 127
  32. elif dtype == torch.int16:
  33. return 32767
  34. elif dtype == torch.int32:
  35. return 2147483647
  36. elif dtype == torch.int64:
  37. return 9223372036854775807
  38. else:
  39. # This is only here for completeness. This value is implicitly assumed in a lot of places so changing it is not
  40. # easy.
  41. return 1
  42. def _assert_channels(img: Tensor, permitted: List[int]) -> None:
  43. c = get_dimensions(img)[0]
  44. if c not in permitted:
  45. raise TypeError(f"Input image tensor permitted channel values are {permitted}, but found {c}")
  46. def convert_image_dtype(image: torch.Tensor, dtype: torch.dtype = torch.float) -> torch.Tensor:
  47. if image.dtype == dtype:
  48. return image
  49. if image.is_floating_point():
  50. # TODO: replace with dtype.is_floating_point when torchscript supports it
  51. if torch.tensor(0, dtype=dtype).is_floating_point():
  52. return image.to(dtype)
  53. # float to int
  54. if (image.dtype == torch.float32 and dtype in (torch.int32, torch.int64)) or (
  55. image.dtype == torch.float64 and dtype == torch.int64
  56. ):
  57. msg = f"The cast from {image.dtype} to {dtype} cannot be performed safely."
  58. raise RuntimeError(msg)
  59. # https://github.com/pytorch/vision/pull/2078#issuecomment-612045321
  60. # For data in the range 0-1, (float * 255).to(uint) is only 255
  61. # when float is exactly 1.0.
  62. # `max + 1 - epsilon` provides more evenly distributed mapping of
  63. # ranges of floats to ints.
  64. eps = 1e-3
  65. max_val = float(_max_value(dtype))
  66. result = image.mul(max_val + 1.0 - eps)
  67. return result.to(dtype)
  68. else:
  69. input_max = float(_max_value(image.dtype))
  70. # int to float
  71. # TODO: replace with dtype.is_floating_point when torchscript supports it
  72. if torch.tensor(0, dtype=dtype).is_floating_point():
  73. image = image.to(dtype)
  74. return image / input_max
  75. output_max = float(_max_value(dtype))
  76. # int to int
  77. if input_max > output_max:
  78. # factor should be forced to int for torch jit script
  79. # otherwise factor is a float and image // factor can produce different results
  80. factor = int((input_max + 1) // (output_max + 1))
  81. image = torch.div(image, factor, rounding_mode="floor")
  82. return image.to(dtype)
  83. else:
  84. # factor should be forced to int for torch jit script
  85. # otherwise factor is a float and image * factor can produce different results
  86. factor = int((output_max + 1) // (input_max + 1))
  87. image = image.to(dtype)
  88. return image * factor
  89. def vflip(img: Tensor) -> Tensor:
  90. _assert_image_tensor(img)
  91. return img.flip(-2)
  92. def hflip(img: Tensor) -> Tensor:
  93. _assert_image_tensor(img)
  94. return img.flip(-1)
  95. def crop(img: Tensor, top: int, left: int, height: int, width: int) -> Tensor:
  96. _assert_image_tensor(img)
  97. _, h, w = get_dimensions(img)
  98. right = left + width
  99. bottom = top + height
  100. if left < 0 or top < 0 or right > w or bottom > h:
  101. padding_ltrb = [
  102. max(-left + min(0, right), 0),
  103. max(-top + min(0, bottom), 0),
  104. max(right - max(w, left), 0),
  105. max(bottom - max(h, top), 0),
  106. ]
  107. return pad(img[..., max(top, 0) : bottom, max(left, 0) : right], padding_ltrb, fill=0)
  108. return img[..., top:bottom, left:right]
  109. def rgb_to_grayscale(img: Tensor, num_output_channels: int = 1) -> Tensor:
  110. if img.ndim < 3:
  111. raise TypeError(f"Input image tensor should have at least 3 dimensions, but found {img.ndim}")
  112. _assert_channels(img, [1, 3])
  113. if num_output_channels not in (1, 3):
  114. raise ValueError("num_output_channels should be either 1 or 3")
  115. if img.shape[-3] == 3:
  116. r, g, b = img.unbind(dim=-3)
  117. # This implementation closely follows the TF one:
  118. # https://github.com/tensorflow/tensorflow/blob/v2.3.0/tensorflow/python/ops/image_ops_impl.py#L2105-L2138
  119. l_img = (0.2989 * r + 0.587 * g + 0.114 * b).to(img.dtype)
  120. l_img = l_img.unsqueeze(dim=-3)
  121. else:
  122. l_img = img.clone()
  123. if num_output_channels == 3:
  124. return l_img.expand(img.shape)
  125. return l_img
  126. def adjust_brightness(img: Tensor, brightness_factor: float) -> Tensor:
  127. if brightness_factor < 0:
  128. raise ValueError(f"brightness_factor ({brightness_factor}) is not non-negative.")
  129. _assert_image_tensor(img)
  130. _assert_channels(img, [1, 3])
  131. return _blend(img, torch.zeros_like(img), brightness_factor)
  132. def adjust_contrast(img: Tensor, contrast_factor: float) -> Tensor:
  133. if contrast_factor < 0:
  134. raise ValueError(f"contrast_factor ({contrast_factor}) is not non-negative.")
  135. _assert_image_tensor(img)
  136. _assert_channels(img, [3, 1])
  137. c = get_dimensions(img)[0]
  138. dtype = img.dtype if torch.is_floating_point(img) else torch.float32
  139. if c == 3:
  140. mean = torch.mean(rgb_to_grayscale(img).to(dtype), dim=(-3, -2, -1), keepdim=True)
  141. else:
  142. mean = torch.mean(img.to(dtype), dim=(-3, -2, -1), keepdim=True)
  143. return _blend(img, mean, contrast_factor)
  144. def adjust_hue(img: Tensor, hue_factor: float) -> Tensor:
  145. if not (-0.5 <= hue_factor <= 0.5):
  146. raise ValueError(f"hue_factor ({hue_factor}) is not in [-0.5, 0.5].")
  147. if not (isinstance(img, torch.Tensor)):
  148. raise TypeError("Input img should be Tensor image")
  149. _assert_image_tensor(img)
  150. _assert_channels(img, [1, 3])
  151. if get_dimensions(img)[0] == 1: # Match PIL behaviour
  152. return img
  153. orig_dtype = img.dtype
  154. img = convert_image_dtype(img, torch.float32)
  155. img = _rgb2hsv(img)
  156. h, s, v = img.unbind(dim=-3)
  157. h = (h + hue_factor) % 1.0
  158. img = torch.stack((h, s, v), dim=-3)
  159. img_hue_adj = _hsv2rgb(img)
  160. return convert_image_dtype(img_hue_adj, orig_dtype)
  161. def adjust_saturation(img: Tensor, saturation_factor: float) -> Tensor:
  162. if saturation_factor < 0:
  163. raise ValueError(f"saturation_factor ({saturation_factor}) is not non-negative.")
  164. _assert_image_tensor(img)
  165. _assert_channels(img, [1, 3])
  166. if get_dimensions(img)[0] == 1: # Match PIL behaviour
  167. return img
  168. return _blend(img, rgb_to_grayscale(img), saturation_factor)
  169. def adjust_gamma(img: Tensor, gamma: float, gain: float = 1) -> Tensor:
  170. if not isinstance(img, torch.Tensor):
  171. raise TypeError("Input img should be a Tensor.")
  172. _assert_channels(img, [1, 3])
  173. if gamma < 0:
  174. raise ValueError("Gamma should be a non-negative real number")
  175. result = img
  176. dtype = img.dtype
  177. if not torch.is_floating_point(img):
  178. result = convert_image_dtype(result, torch.float32)
  179. result = (gain * result**gamma).clamp(0, 1)
  180. result = convert_image_dtype(result, dtype)
  181. return result
  182. def _blend(img1: Tensor, img2: Tensor, ratio: float) -> Tensor:
  183. ratio = float(ratio)
  184. bound = _max_value(img1.dtype)
  185. return (ratio * img1 + (1.0 - ratio) * img2).clamp(0, bound).to(img1.dtype)
  186. def _rgb2hsv(img: Tensor) -> Tensor:
  187. r, g, b = img.unbind(dim=-3)
  188. # Implementation is based on https://github.com/python-pillow/Pillow/blob/4174d4267616897df3746d315d5a2d0f82c656ee/
  189. # src/libImaging/Convert.c#L330
  190. maxc = torch.max(img, dim=-3).values
  191. minc = torch.min(img, dim=-3).values
  192. # The algorithm erases S and H channel where `maxc = minc`. This avoids NaN
  193. # from happening in the results, because
  194. # + S channel has division by `maxc`, which is zero only if `maxc = minc`
  195. # + H channel has division by `(maxc - minc)`.
  196. #
  197. # Instead of overwriting NaN afterwards, we just prevent it from occurring, so
  198. # we don't need to deal with it in case we save the NaN in a buffer in
  199. # backprop, if it is ever supported, but it doesn't hurt to do so.
  200. eqc = maxc == minc
  201. cr = maxc - minc
  202. # Since `eqc => cr = 0`, replacing denominator with 1 when `eqc` is fine.
  203. ones = torch.ones_like(maxc)
  204. s = cr / torch.where(eqc, ones, maxc)
  205. # Note that `eqc => maxc = minc = r = g = b`. So the following calculation
  206. # of `h` would reduce to `bc - gc + 2 + rc - bc + 4 + rc - bc = 6` so it
  207. # would not matter what values `rc`, `gc`, and `bc` have here, and thus
  208. # replacing denominator with 1 when `eqc` is fine.
  209. cr_divisor = torch.where(eqc, ones, cr)
  210. rc = (maxc - r) / cr_divisor
  211. gc = (maxc - g) / cr_divisor
  212. bc = (maxc - b) / cr_divisor
  213. hr = (maxc == r) * (bc - gc)
  214. hg = ((maxc == g) & (maxc != r)) * (2.0 + rc - bc)
  215. hb = ((maxc != g) & (maxc != r)) * (4.0 + gc - rc)
  216. h = hr + hg + hb
  217. h = torch.fmod((h / 6.0 + 1.0), 1.0)
  218. return torch.stack((h, s, maxc), dim=-3)
  219. def _hsv2rgb(img: Tensor) -> Tensor:
  220. h, s, v = img.unbind(dim=-3)
  221. i = torch.floor(h * 6.0)
  222. f = (h * 6.0) - i
  223. i = i.to(dtype=torch.int32)
  224. p = torch.clamp((v * (1.0 - s)), 0.0, 1.0)
  225. q = torch.clamp((v * (1.0 - s * f)), 0.0, 1.0)
  226. t = torch.clamp((v * (1.0 - s * (1.0 - f))), 0.0, 1.0)
  227. i = i % 6
  228. mask = i.unsqueeze(dim=-3) == torch.arange(6, device=i.device).view(-1, 1, 1)
  229. a1 = torch.stack((v, q, p, p, t, v), dim=-3)
  230. a2 = torch.stack((t, v, v, q, p, p), dim=-3)
  231. a3 = torch.stack((p, p, t, v, v, q), dim=-3)
  232. a4 = torch.stack((a1, a2, a3), dim=-4)
  233. return torch.einsum("...ijk, ...xijk -> ...xjk", mask.to(dtype=img.dtype), a4)
  234. def _pad_symmetric(img: Tensor, padding: List[int]) -> Tensor:
  235. # padding is left, right, top, bottom
  236. # crop if needed
  237. if padding[0] < 0 or padding[1] < 0 or padding[2] < 0 or padding[3] < 0:
  238. neg_min_padding = [-min(x, 0) for x in padding]
  239. crop_left, crop_right, crop_top, crop_bottom = neg_min_padding
  240. img = img[..., crop_top : img.shape[-2] - crop_bottom, crop_left : img.shape[-1] - crop_right]
  241. padding = [max(x, 0) for x in padding]
  242. in_sizes = img.size()
  243. _x_indices = [i for i in range(in_sizes[-1])] # [0, 1, 2, 3, ...]
  244. left_indices = [i for i in range(padding[0] - 1, -1, -1)] # e.g. [3, 2, 1, 0]
  245. right_indices = [-(i + 1) for i in range(padding[1])] # e.g. [-1, -2, -3]
  246. x_indices = torch.tensor(left_indices + _x_indices + right_indices, device=img.device)
  247. _y_indices = [i for i in range(in_sizes[-2])]
  248. top_indices = [i for i in range(padding[2] - 1, -1, -1)]
  249. bottom_indices = [-(i + 1) for i in range(padding[3])]
  250. y_indices = torch.tensor(top_indices + _y_indices + bottom_indices, device=img.device)
  251. ndim = img.ndim
  252. if ndim == 3:
  253. return img[:, y_indices[:, None], x_indices[None, :]]
  254. elif ndim == 4:
  255. return img[:, :, y_indices[:, None], x_indices[None, :]]
  256. else:
  257. raise RuntimeError("Symmetric padding of N-D tensors are not supported yet")
  258. def _parse_pad_padding(padding: Union[int, List[int]]) -> List[int]:
  259. if isinstance(padding, int):
  260. if torch.jit.is_scripting():
  261. # This maybe unreachable
  262. raise ValueError("padding can't be an int while torchscripting, set it as a list [value, ]")
  263. pad_left = pad_right = pad_top = pad_bottom = padding
  264. elif len(padding) == 1:
  265. pad_left = pad_right = pad_top = pad_bottom = padding[0]
  266. elif len(padding) == 2:
  267. pad_left = pad_right = padding[0]
  268. pad_top = pad_bottom = padding[1]
  269. else:
  270. pad_left = padding[0]
  271. pad_top = padding[1]
  272. pad_right = padding[2]
  273. pad_bottom = padding[3]
  274. return [pad_left, pad_right, pad_top, pad_bottom]
  275. def pad(
  276. img: Tensor, padding: Union[int, List[int]], fill: Optional[Union[int, float]] = 0, padding_mode: str = "constant"
  277. ) -> Tensor:
  278. _assert_image_tensor(img)
  279. if fill is None:
  280. fill = 0
  281. if not isinstance(padding, (int, tuple, list)):
  282. raise TypeError("Got inappropriate padding arg")
  283. if not isinstance(fill, (int, float)):
  284. raise TypeError("Got inappropriate fill arg")
  285. if not isinstance(padding_mode, str):
  286. raise TypeError("Got inappropriate padding_mode arg")
  287. if isinstance(padding, tuple):
  288. padding = list(padding)
  289. if isinstance(padding, list):
  290. # TODO: Jit is failing on loading this op when scripted and saved
  291. # https://github.com/pytorch/pytorch/issues/81100
  292. if len(padding) not in [1, 2, 4]:
  293. raise ValueError(
  294. f"Padding must be an int or a 1, 2, or 4 element tuple, not a {len(padding)} element tuple"
  295. )
  296. if padding_mode not in ["constant", "edge", "reflect", "symmetric"]:
  297. raise ValueError("Padding mode should be either constant, edge, reflect or symmetric")
  298. p = _parse_pad_padding(padding)
  299. if padding_mode == "edge":
  300. # remap padding_mode str
  301. padding_mode = "replicate"
  302. elif padding_mode == "symmetric":
  303. # route to another implementation
  304. return _pad_symmetric(img, p)
  305. need_squeeze = False
  306. if img.ndim < 4:
  307. img = img.unsqueeze(dim=0)
  308. need_squeeze = True
  309. out_dtype = img.dtype
  310. need_cast = False
  311. if (padding_mode != "constant") and img.dtype not in (torch.float32, torch.float64):
  312. # Here we temporarily cast input tensor to float
  313. # until pytorch issue is resolved :
  314. # https://github.com/pytorch/pytorch/issues/40763
  315. need_cast = True
  316. img = img.to(torch.float32)
  317. if padding_mode in ("reflect", "replicate"):
  318. img = torch_pad(img, p, mode=padding_mode)
  319. else:
  320. img = torch_pad(img, p, mode=padding_mode, value=float(fill))
  321. if need_squeeze:
  322. img = img.squeeze(dim=0)
  323. if need_cast:
  324. img = img.to(out_dtype)
  325. return img
  326. def resize(
  327. img: Tensor,
  328. size: List[int],
  329. interpolation: str = "bilinear",
  330. # TODO: in v0.17, change the default to True. This will a private function
  331. # by then, so we don't care about warning here.
  332. antialias: Optional[bool] = None,
  333. ) -> Tensor:
  334. _assert_image_tensor(img)
  335. if isinstance(size, tuple):
  336. size = list(size)
  337. if antialias is None:
  338. antialias = False
  339. if antialias and interpolation not in ["bilinear", "bicubic"]:
  340. # We manually set it to False to avoid an error downstream in interpolate()
  341. # This behaviour is documented: the parameter is irrelevant for modes
  342. # that are not bilinear or bicubic. We used to raise an error here, but
  343. # now we don't as True is the default.
  344. antialias = False
  345. img, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(img, [torch.float32, torch.float64])
  346. # Define align_corners to avoid warnings
  347. align_corners = False if interpolation in ["bilinear", "bicubic"] else None
  348. img = interpolate(img, size=size, mode=interpolation, align_corners=align_corners, antialias=antialias)
  349. if interpolation == "bicubic" and out_dtype == torch.uint8:
  350. img = img.clamp(min=0, max=255)
  351. img = _cast_squeeze_out(img, need_cast=need_cast, need_squeeze=need_squeeze, out_dtype=out_dtype)
  352. return img
  353. def _assert_grid_transform_inputs(
  354. img: Tensor,
  355. matrix: Optional[List[float]],
  356. interpolation: str,
  357. fill: Optional[Union[int, float, List[float]]],
  358. supported_interpolation_modes: List[str],
  359. coeffs: Optional[List[float]] = None,
  360. ) -> None:
  361. if not (isinstance(img, torch.Tensor)):
  362. raise TypeError("Input img should be Tensor")
  363. _assert_image_tensor(img)
  364. if matrix is not None and not isinstance(matrix, list):
  365. raise TypeError("Argument matrix should be a list")
  366. if matrix is not None and len(matrix) != 6:
  367. raise ValueError("Argument matrix should have 6 float values")
  368. if coeffs is not None and len(coeffs) != 8:
  369. raise ValueError("Argument coeffs should have 8 float values")
  370. if fill is not None and not isinstance(fill, (int, float, tuple, list)):
  371. warnings.warn("Argument fill should be either int, float, tuple or list")
  372. # Check fill
  373. num_channels = get_dimensions(img)[0]
  374. if fill is not None and isinstance(fill, (tuple, list)) and len(fill) > 1 and len(fill) != num_channels:
  375. msg = (
  376. "The number of elements in 'fill' cannot broadcast to match the number of "
  377. "channels of the image ({} != {})"
  378. )
  379. raise ValueError(msg.format(len(fill), num_channels))
  380. if interpolation not in supported_interpolation_modes:
  381. raise ValueError(f"Interpolation mode '{interpolation}' is unsupported with Tensor input")
  382. def _cast_squeeze_in(img: Tensor, req_dtypes: List[torch.dtype]) -> Tuple[Tensor, bool, bool, torch.dtype]:
  383. need_squeeze = False
  384. # make image NCHW
  385. if img.ndim < 4:
  386. img = img.unsqueeze(dim=0)
  387. need_squeeze = True
  388. out_dtype = img.dtype
  389. need_cast = False
  390. if out_dtype not in req_dtypes:
  391. need_cast = True
  392. req_dtype = req_dtypes[0]
  393. img = img.to(req_dtype)
  394. return img, need_cast, need_squeeze, out_dtype
  395. def _cast_squeeze_out(img: Tensor, need_cast: bool, need_squeeze: bool, out_dtype: torch.dtype) -> Tensor:
  396. if need_squeeze:
  397. img = img.squeeze(dim=0)
  398. if need_cast:
  399. if out_dtype in (torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64):
  400. # it is better to round before cast
  401. img = torch.round(img)
  402. img = img.to(out_dtype)
  403. return img
  404. def _apply_grid_transform(
  405. img: Tensor, grid: Tensor, mode: str, fill: Optional[Union[int, float, List[float]]]
  406. ) -> Tensor:
  407. img, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(img, [grid.dtype])
  408. if img.shape[0] > 1:
  409. # Apply same grid to a batch of images
  410. grid = grid.expand(img.shape[0], grid.shape[1], grid.shape[2], grid.shape[3])
  411. # Append a dummy mask for customized fill colors, should be faster than grid_sample() twice
  412. if fill is not None:
  413. mask = torch.ones((img.shape[0], 1, img.shape[2], img.shape[3]), dtype=img.dtype, device=img.device)
  414. img = torch.cat((img, mask), dim=1)
  415. img = grid_sample(img, grid, mode=mode, padding_mode="zeros", align_corners=False)
  416. # Fill with required color
  417. if fill is not None:
  418. mask = img[:, -1:, :, :] # N * 1 * H * W
  419. img = img[:, :-1, :, :] # N * C * H * W
  420. mask = mask.expand_as(img)
  421. fill_list, len_fill = (fill, len(fill)) if isinstance(fill, (tuple, list)) else ([float(fill)], 1)
  422. fill_img = torch.tensor(fill_list, dtype=img.dtype, device=img.device).view(1, len_fill, 1, 1).expand_as(img)
  423. if mode == "nearest":
  424. mask = mask < 0.5
  425. img[mask] = fill_img[mask]
  426. else: # 'bilinear'
  427. img = img * mask + (1.0 - mask) * fill_img
  428. img = _cast_squeeze_out(img, need_cast, need_squeeze, out_dtype)
  429. return img
  430. def _gen_affine_grid(
  431. theta: Tensor,
  432. w: int,
  433. h: int,
  434. ow: int,
  435. oh: int,
  436. ) -> Tensor:
  437. # https://github.com/pytorch/pytorch/blob/74b65c32be68b15dc7c9e8bb62459efbfbde33d8/aten/src/ATen/native/
  438. # AffineGridGenerator.cpp#L18
  439. # Difference with AffineGridGenerator is that:
  440. # 1) we normalize grid values after applying theta
  441. # 2) we can normalize by other image size, such that it covers "extend" option like in PIL.Image.rotate
  442. d = 0.5
  443. base_grid = torch.empty(1, oh, ow, 3, dtype=theta.dtype, device=theta.device)
  444. x_grid = torch.linspace(-ow * 0.5 + d, ow * 0.5 + d - 1, steps=ow, device=theta.device)
  445. base_grid[..., 0].copy_(x_grid)
  446. y_grid = torch.linspace(-oh * 0.5 + d, oh * 0.5 + d - 1, steps=oh, device=theta.device).unsqueeze_(-1)
  447. base_grid[..., 1].copy_(y_grid)
  448. base_grid[..., 2].fill_(1)
  449. rescaled_theta = theta.transpose(1, 2) / torch.tensor([0.5 * w, 0.5 * h], dtype=theta.dtype, device=theta.device)
  450. output_grid = base_grid.view(1, oh * ow, 3).bmm(rescaled_theta)
  451. return output_grid.view(1, oh, ow, 2)
  452. def affine(
  453. img: Tensor,
  454. matrix: List[float],
  455. interpolation: str = "nearest",
  456. fill: Optional[Union[int, float, List[float]]] = None,
  457. ) -> Tensor:
  458. _assert_grid_transform_inputs(img, matrix, interpolation, fill, ["nearest", "bilinear"])
  459. dtype = img.dtype if torch.is_floating_point(img) else torch.float32
  460. theta = torch.tensor(matrix, dtype=dtype, device=img.device).reshape(1, 2, 3)
  461. shape = img.shape
  462. # grid will be generated on the same device as theta and img
  463. grid = _gen_affine_grid(theta, w=shape[-1], h=shape[-2], ow=shape[-1], oh=shape[-2])
  464. return _apply_grid_transform(img, grid, interpolation, fill=fill)
  465. def _compute_affine_output_size(matrix: List[float], w: int, h: int) -> Tuple[int, int]:
  466. # Inspired of PIL implementation:
  467. # https://github.com/python-pillow/Pillow/blob/11de3318867e4398057373ee9f12dcb33db7335c/src/PIL/Image.py#L2054
  468. # pts are Top-Left, Top-Right, Bottom-Left, Bottom-Right points.
  469. # Points are shifted due to affine matrix torch convention about
  470. # the center point. Center is (0, 0) for image center pivot point (w * 0.5, h * 0.5)
  471. pts = torch.tensor(
  472. [
  473. [-0.5 * w, -0.5 * h, 1.0],
  474. [-0.5 * w, 0.5 * h, 1.0],
  475. [0.5 * w, 0.5 * h, 1.0],
  476. [0.5 * w, -0.5 * h, 1.0],
  477. ]
  478. )
  479. theta = torch.tensor(matrix, dtype=torch.float).view(2, 3)
  480. new_pts = torch.matmul(pts, theta.T)
  481. min_vals, _ = new_pts.min(dim=0)
  482. max_vals, _ = new_pts.max(dim=0)
  483. # shift points to [0, w] and [0, h] interval to match PIL results
  484. min_vals += torch.tensor((w * 0.5, h * 0.5))
  485. max_vals += torch.tensor((w * 0.5, h * 0.5))
  486. # Truncate precision to 1e-4 to avoid ceil of Xe-15 to 1.0
  487. tol = 1e-4
  488. cmax = torch.ceil((max_vals / tol).trunc_() * tol)
  489. cmin = torch.floor((min_vals / tol).trunc_() * tol)
  490. size = cmax - cmin
  491. return int(size[0]), int(size[1]) # w, h
  492. def rotate(
  493. img: Tensor,
  494. matrix: List[float],
  495. interpolation: str = "nearest",
  496. expand: bool = False,
  497. fill: Optional[Union[int, float, List[float]]] = None,
  498. ) -> Tensor:
  499. _assert_grid_transform_inputs(img, matrix, interpolation, fill, ["nearest", "bilinear"])
  500. w, h = img.shape[-1], img.shape[-2]
  501. ow, oh = _compute_affine_output_size(matrix, w, h) if expand else (w, h)
  502. dtype = img.dtype if torch.is_floating_point(img) else torch.float32
  503. theta = torch.tensor(matrix, dtype=dtype, device=img.device).reshape(1, 2, 3)
  504. # grid will be generated on the same device as theta and img
  505. grid = _gen_affine_grid(theta, w=w, h=h, ow=ow, oh=oh)
  506. return _apply_grid_transform(img, grid, interpolation, fill=fill)
  507. def _perspective_grid(coeffs: List[float], ow: int, oh: int, dtype: torch.dtype, device: torch.device) -> Tensor:
  508. # https://github.com/python-pillow/Pillow/blob/4634eafe3c695a014267eefdce830b4a825beed7/
  509. # src/libImaging/Geometry.c#L394
  510. #
  511. # x_out = (coeffs[0] * x + coeffs[1] * y + coeffs[2]) / (coeffs[6] * x + coeffs[7] * y + 1)
  512. # y_out = (coeffs[3] * x + coeffs[4] * y + coeffs[5]) / (coeffs[6] * x + coeffs[7] * y + 1)
  513. #
  514. theta1 = torch.tensor(
  515. [[[coeffs[0], coeffs[1], coeffs[2]], [coeffs[3], coeffs[4], coeffs[5]]]], dtype=dtype, device=device
  516. )
  517. theta2 = torch.tensor([[[coeffs[6], coeffs[7], 1.0], [coeffs[6], coeffs[7], 1.0]]], dtype=dtype, device=device)
  518. d = 0.5
  519. base_grid = torch.empty(1, oh, ow, 3, dtype=dtype, device=device)
  520. x_grid = torch.linspace(d, ow * 1.0 + d - 1.0, steps=ow, device=device)
  521. base_grid[..., 0].copy_(x_grid)
  522. y_grid = torch.linspace(d, oh * 1.0 + d - 1.0, steps=oh, device=device).unsqueeze_(-1)
  523. base_grid[..., 1].copy_(y_grid)
  524. base_grid[..., 2].fill_(1)
  525. rescaled_theta1 = theta1.transpose(1, 2) / torch.tensor([0.5 * ow, 0.5 * oh], dtype=dtype, device=device)
  526. output_grid1 = base_grid.view(1, oh * ow, 3).bmm(rescaled_theta1)
  527. output_grid2 = base_grid.view(1, oh * ow, 3).bmm(theta2.transpose(1, 2))
  528. output_grid = output_grid1 / output_grid2 - 1.0
  529. return output_grid.view(1, oh, ow, 2)
  530. def perspective(
  531. img: Tensor,
  532. perspective_coeffs: List[float],
  533. interpolation: str = "bilinear",
  534. fill: Optional[Union[int, float, List[float]]] = None,
  535. ) -> Tensor:
  536. if not (isinstance(img, torch.Tensor)):
  537. raise TypeError("Input img should be Tensor.")
  538. _assert_image_tensor(img)
  539. _assert_grid_transform_inputs(
  540. img,
  541. matrix=None,
  542. interpolation=interpolation,
  543. fill=fill,
  544. supported_interpolation_modes=["nearest", "bilinear"],
  545. coeffs=perspective_coeffs,
  546. )
  547. ow, oh = img.shape[-1], img.shape[-2]
  548. dtype = img.dtype if torch.is_floating_point(img) else torch.float32
  549. grid = _perspective_grid(perspective_coeffs, ow=ow, oh=oh, dtype=dtype, device=img.device)
  550. return _apply_grid_transform(img, grid, interpolation, fill=fill)
  551. def _get_gaussian_kernel1d(kernel_size: int, sigma: float) -> Tensor:
  552. ksize_half = (kernel_size - 1) * 0.5
  553. x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size)
  554. pdf = torch.exp(-0.5 * (x / sigma).pow(2))
  555. kernel1d = pdf / pdf.sum()
  556. return kernel1d
  557. def _get_gaussian_kernel2d(
  558. kernel_size: List[int], sigma: List[float], dtype: torch.dtype, device: torch.device
  559. ) -> Tensor:
  560. kernel1d_x = _get_gaussian_kernel1d(kernel_size[0], sigma[0]).to(device, dtype=dtype)
  561. kernel1d_y = _get_gaussian_kernel1d(kernel_size[1], sigma[1]).to(device, dtype=dtype)
  562. kernel2d = torch.mm(kernel1d_y[:, None], kernel1d_x[None, :])
  563. return kernel2d
  564. def gaussian_blur(img: Tensor, kernel_size: List[int], sigma: List[float]) -> Tensor:
  565. if not (isinstance(img, torch.Tensor)):
  566. raise TypeError(f"img should be Tensor. Got {type(img)}")
  567. _assert_image_tensor(img)
  568. dtype = img.dtype if torch.is_floating_point(img) else torch.float32
  569. kernel = _get_gaussian_kernel2d(kernel_size, sigma, dtype=dtype, device=img.device)
  570. kernel = kernel.expand(img.shape[-3], 1, kernel.shape[0], kernel.shape[1])
  571. img, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(img, [kernel.dtype])
  572. # padding = (left, right, top, bottom)
  573. padding = [kernel_size[0] // 2, kernel_size[0] // 2, kernel_size[1] // 2, kernel_size[1] // 2]
  574. img = torch_pad(img, padding, mode="reflect")
  575. img = conv2d(img, kernel, groups=img.shape[-3])
  576. img = _cast_squeeze_out(img, need_cast, need_squeeze, out_dtype)
  577. return img
  578. def invert(img: Tensor) -> Tensor:
  579. _assert_image_tensor(img)
  580. if img.ndim < 3:
  581. raise TypeError(f"Input image tensor should have at least 3 dimensions, but found {img.ndim}")
  582. _assert_channels(img, [1, 3])
  583. return _max_value(img.dtype) - img
  584. def posterize(img: Tensor, bits: int) -> Tensor:
  585. _assert_image_tensor(img)
  586. if img.ndim < 3:
  587. raise TypeError(f"Input image tensor should have at least 3 dimensions, but found {img.ndim}")
  588. if img.dtype != torch.uint8:
  589. raise TypeError(f"Only torch.uint8 image tensors are supported, but found {img.dtype}")
  590. _assert_channels(img, [1, 3])
  591. mask = -int(2 ** (8 - bits)) # JIT-friendly for: ~(2 ** (8 - bits) - 1)
  592. return img & mask
  593. def solarize(img: Tensor, threshold: float) -> Tensor:
  594. _assert_image_tensor(img)
  595. if img.ndim < 3:
  596. raise TypeError(f"Input image tensor should have at least 3 dimensions, but found {img.ndim}")
  597. _assert_channels(img, [1, 3])
  598. if threshold > _max_value(img.dtype):
  599. raise TypeError("Threshold should be less than bound of img.")
  600. inverted_img = invert(img)
  601. return torch.where(img >= threshold, inverted_img, img)
  602. def _blurred_degenerate_image(img: Tensor) -> Tensor:
  603. dtype = img.dtype if torch.is_floating_point(img) else torch.float32
  604. kernel = torch.ones((3, 3), dtype=dtype, device=img.device)
  605. kernel[1, 1] = 5.0
  606. kernel /= kernel.sum()
  607. kernel = kernel.expand(img.shape[-3], 1, kernel.shape[0], kernel.shape[1])
  608. result_tmp, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(img, [kernel.dtype])
  609. result_tmp = conv2d(result_tmp, kernel, groups=result_tmp.shape[-3])
  610. result_tmp = _cast_squeeze_out(result_tmp, need_cast, need_squeeze, out_dtype)
  611. result = img.clone()
  612. result[..., 1:-1, 1:-1] = result_tmp
  613. return result
  614. def adjust_sharpness(img: Tensor, sharpness_factor: float) -> Tensor:
  615. if sharpness_factor < 0:
  616. raise ValueError(f"sharpness_factor ({sharpness_factor}) is not non-negative.")
  617. _assert_image_tensor(img)
  618. _assert_channels(img, [1, 3])
  619. if img.size(-1) <= 2 or img.size(-2) <= 2:
  620. return img
  621. return _blend(img, _blurred_degenerate_image(img), sharpness_factor)
  622. def autocontrast(img: Tensor) -> Tensor:
  623. _assert_image_tensor(img)
  624. if img.ndim < 3:
  625. raise TypeError(f"Input image tensor should have at least 3 dimensions, but found {img.ndim}")
  626. _assert_channels(img, [1, 3])
  627. bound = _max_value(img.dtype)
  628. dtype = img.dtype if torch.is_floating_point(img) else torch.float32
  629. minimum = img.amin(dim=(-2, -1), keepdim=True).to(dtype)
  630. maximum = img.amax(dim=(-2, -1), keepdim=True).to(dtype)
  631. scale = bound / (maximum - minimum)
  632. eq_idxs = torch.isfinite(scale).logical_not()
  633. minimum[eq_idxs] = 0
  634. scale[eq_idxs] = 1
  635. return ((img - minimum) * scale).clamp(0, bound).to(img.dtype)
  636. def _scale_channel(img_chan: Tensor) -> Tensor:
  637. # TODO: we should expect bincount to always be faster than histc, but this
  638. # isn't always the case. Once
  639. # https://github.com/pytorch/pytorch/issues/53194 is fixed, remove the if
  640. # block and only use bincount.
  641. if img_chan.is_cuda:
  642. hist = torch.histc(img_chan.to(torch.float32), bins=256, min=0, max=255)
  643. else:
  644. hist = torch.bincount(img_chan.reshape(-1), minlength=256)
  645. nonzero_hist = hist[hist != 0]
  646. step = torch.div(nonzero_hist[:-1].sum(), 255, rounding_mode="floor")
  647. if step == 0:
  648. return img_chan
  649. lut = torch.div(torch.cumsum(hist, 0) + torch.div(step, 2, rounding_mode="floor"), step, rounding_mode="floor")
  650. lut = torch.nn.functional.pad(lut, [1, 0])[:-1].clamp(0, 255)
  651. return lut[img_chan.to(torch.int64)].to(torch.uint8)
  652. def _equalize_single_image(img: Tensor) -> Tensor:
  653. return torch.stack([_scale_channel(img[c]) for c in range(img.size(0))])
  654. def equalize(img: Tensor) -> Tensor:
  655. _assert_image_tensor(img)
  656. if not (3 <= img.ndim <= 4):
  657. raise TypeError(f"Input image tensor should have 3 or 4 dimensions, but found {img.ndim}")
  658. if img.dtype != torch.uint8:
  659. raise TypeError(f"Only torch.uint8 image tensors are supported, but found {img.dtype}")
  660. _assert_channels(img, [1, 3])
  661. if img.ndim == 3:
  662. return _equalize_single_image(img)
  663. return torch.stack([_equalize_single_image(x) for x in img])
  664. def normalize(tensor: Tensor, mean: List[float], std: List[float], inplace: bool = False) -> Tensor:
  665. _assert_image_tensor(tensor)
  666. if not tensor.is_floating_point():
  667. raise TypeError(f"Input tensor should be a float tensor. Got {tensor.dtype}.")
  668. if tensor.ndim < 3:
  669. raise ValueError(
  670. f"Expected tensor to be a tensor image of size (..., C, H, W). Got tensor.size() = {tensor.size()}"
  671. )
  672. if not inplace:
  673. tensor = tensor.clone()
  674. dtype = tensor.dtype
  675. mean = torch.as_tensor(mean, dtype=dtype, device=tensor.device)
  676. std = torch.as_tensor(std, dtype=dtype, device=tensor.device)
  677. if (std == 0).any():
  678. raise ValueError(f"std evaluated to zero after conversion to {dtype}, leading to division by zero.")
  679. if mean.ndim == 1:
  680. mean = mean.view(-1, 1, 1)
  681. if std.ndim == 1:
  682. std = std.view(-1, 1, 1)
  683. return tensor.sub_(mean).div_(std)
  684. def erase(img: Tensor, i: int, j: int, h: int, w: int, v: Tensor, inplace: bool = False) -> Tensor:
  685. _assert_image_tensor(img)
  686. if not inplace:
  687. img = img.clone()
  688. img[..., i : i + h, j : j + w] = v
  689. return img
  690. def _create_identity_grid(size: List[int]) -> Tensor:
  691. hw_space = [torch.linspace((-s + 1) / s, (s - 1) / s, s) for s in size]
  692. grid_y, grid_x = torch.meshgrid(hw_space, indexing="ij")
  693. return torch.stack([grid_x, grid_y], -1).unsqueeze(0) # 1 x H x W x 2
  694. def elastic_transform(
  695. img: Tensor,
  696. displacement: Tensor,
  697. interpolation: str = "bilinear",
  698. fill: Optional[Union[int, float, List[float]]] = None,
  699. ) -> Tensor:
  700. if not (isinstance(img, torch.Tensor)):
  701. raise TypeError(f"img should be Tensor. Got {type(img)}")
  702. size = list(img.shape[-2:])
  703. displacement = displacement.to(img.device)
  704. identity_grid = _create_identity_grid(size)
  705. grid = identity_grid.to(img.device) + displacement
  706. return _apply_grid_transform(img, grid, interpolation, fill)