functional.py 68 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617
  1. import math
  2. import numbers
  3. import warnings
  4. from enum import Enum
  5. from typing import Any, List, Optional, Tuple, Union
  6. import numpy as np
  7. import torch
  8. from PIL import Image
  9. from torch import Tensor
  10. try:
  11. import accimage
  12. except ImportError:
  13. accimage = None
  14. from ..utils import _log_api_usage_once
  15. from . import _functional_pil as F_pil, _functional_tensor as F_t
  16. class InterpolationMode(Enum):
  17. """Interpolation modes
  18. Available interpolation methods are ``nearest``, ``nearest-exact``, ``bilinear``, ``bicubic``, ``box``, ``hamming``,
  19. and ``lanczos``.
  20. """
  21. NEAREST = "nearest"
  22. NEAREST_EXACT = "nearest-exact"
  23. BILINEAR = "bilinear"
  24. BICUBIC = "bicubic"
  25. # For PIL compatibility
  26. BOX = "box"
  27. HAMMING = "hamming"
  28. LANCZOS = "lanczos"
  29. # TODO: Once torchscript supports Enums with staticmethod
  30. # this can be put into InterpolationMode as staticmethod
  31. def _interpolation_modes_from_int(i: int) -> InterpolationMode:
  32. inverse_modes_mapping = {
  33. 0: InterpolationMode.NEAREST,
  34. 2: InterpolationMode.BILINEAR,
  35. 3: InterpolationMode.BICUBIC,
  36. 4: InterpolationMode.BOX,
  37. 5: InterpolationMode.HAMMING,
  38. 1: InterpolationMode.LANCZOS,
  39. }
  40. return inverse_modes_mapping[i]
  41. pil_modes_mapping = {
  42. InterpolationMode.NEAREST: 0,
  43. InterpolationMode.BILINEAR: 2,
  44. InterpolationMode.BICUBIC: 3,
  45. InterpolationMode.NEAREST_EXACT: 0,
  46. InterpolationMode.BOX: 4,
  47. InterpolationMode.HAMMING: 5,
  48. InterpolationMode.LANCZOS: 1,
  49. }
  50. _is_pil_image = F_pil._is_pil_image
  51. def get_dimensions(img: Tensor) -> List[int]:
  52. """Returns the dimensions of an image as [channels, height, width].
  53. Args:
  54. img (PIL Image or Tensor): The image to be checked.
  55. Returns:
  56. List[int]: The image dimensions.
  57. """
  58. if not torch.jit.is_scripting() and not torch.jit.is_tracing():
  59. _log_api_usage_once(get_dimensions)
  60. if isinstance(img, torch.Tensor):
  61. return F_t.get_dimensions(img)
  62. return F_pil.get_dimensions(img)
  63. def get_image_size(img: Tensor) -> List[int]:
  64. """Returns the size of an image as [width, height].
  65. Args:
  66. img (PIL Image or Tensor): The image to be checked.
  67. Returns:
  68. List[int]: The image size.
  69. """
  70. if not torch.jit.is_scripting() and not torch.jit.is_tracing():
  71. _log_api_usage_once(get_image_size)
  72. if isinstance(img, torch.Tensor):
  73. return F_t.get_image_size(img)
  74. return F_pil.get_image_size(img)
  75. def get_image_num_channels(img: Tensor) -> int:
  76. """Returns the number of channels of an image.
  77. Args:
  78. img (PIL Image or Tensor): The image to be checked.
  79. Returns:
  80. int: The number of channels.
  81. """
  82. if not torch.jit.is_scripting() and not torch.jit.is_tracing():
  83. _log_api_usage_once(get_image_num_channels)
  84. if isinstance(img, torch.Tensor):
  85. return F_t.get_image_num_channels(img)
  86. return F_pil.get_image_num_channels(img)
  87. @torch.jit.unused
  88. def _is_numpy(img: Any) -> bool:
  89. return isinstance(img, np.ndarray)
  90. @torch.jit.unused
  91. def _is_numpy_image(img: Any) -> bool:
  92. return img.ndim in {2, 3}
  93. def to_tensor(pic) -> Tensor:
  94. """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.
  95. This function does not support torchscript.
  96. See :class:`~torchvision.transforms.ToTensor` for more details.
  97. Args:
  98. pic (PIL Image or numpy.ndarray): Image to be converted to tensor.
  99. Returns:
  100. Tensor: Converted image.
  101. """
  102. if not torch.jit.is_scripting() and not torch.jit.is_tracing():
  103. _log_api_usage_once(to_tensor)
  104. if not (F_pil._is_pil_image(pic) or _is_numpy(pic)):
  105. raise TypeError(f"pic should be PIL Image or ndarray. Got {type(pic)}")
  106. if _is_numpy(pic) and not _is_numpy_image(pic):
  107. raise ValueError(f"pic should be 2/3 dimensional. Got {pic.ndim} dimensions.")
  108. default_float_dtype = torch.get_default_dtype()
  109. if isinstance(pic, np.ndarray):
  110. # handle numpy array
  111. if pic.ndim == 2:
  112. pic = pic[:, :, None]
  113. img = torch.from_numpy(pic.transpose((2, 0, 1))).contiguous()
  114. # backward compatibility
  115. if isinstance(img, torch.ByteTensor):
  116. return img.to(dtype=default_float_dtype).div(255)
  117. else:
  118. return img
  119. if accimage is not None and isinstance(pic, accimage.Image):
  120. nppic = np.zeros([pic.channels, pic.height, pic.width], dtype=np.float32)
  121. pic.copyto(nppic)
  122. return torch.from_numpy(nppic).to(dtype=default_float_dtype)
  123. # handle PIL Image
  124. mode_to_nptype = {"I": np.int32, "I;16": np.int16, "F": np.float32}
  125. img = torch.from_numpy(np.array(pic, mode_to_nptype.get(pic.mode, np.uint8), copy=True))
  126. if pic.mode == "1":
  127. img = 255 * img
  128. img = img.view(pic.size[1], pic.size[0], F_pil.get_image_num_channels(pic))
  129. # put it from HWC to CHW format
  130. img = img.permute((2, 0, 1)).contiguous()
  131. if isinstance(img, torch.ByteTensor):
  132. return img.to(dtype=default_float_dtype).div(255)
  133. else:
  134. return img
  135. def pil_to_tensor(pic: Any) -> Tensor:
  136. """Convert a ``PIL Image`` to a tensor of the same type.
  137. This function does not support torchscript.
  138. See :class:`~torchvision.transforms.PILToTensor` for more details.
  139. .. note::
  140. A deep copy of the underlying array is performed.
  141. Args:
  142. pic (PIL Image): Image to be converted to tensor.
  143. Returns:
  144. Tensor: Converted image.
  145. """
  146. if not torch.jit.is_scripting() and not torch.jit.is_tracing():
  147. _log_api_usage_once(pil_to_tensor)
  148. if not F_pil._is_pil_image(pic):
  149. raise TypeError(f"pic should be PIL Image. Got {type(pic)}")
  150. if accimage is not None and isinstance(pic, accimage.Image):
  151. # accimage format is always uint8 internally, so always return uint8 here
  152. nppic = np.zeros([pic.channels, pic.height, pic.width], dtype=np.uint8)
  153. pic.copyto(nppic)
  154. return torch.as_tensor(nppic)
  155. # handle PIL Image
  156. img = torch.as_tensor(np.array(pic, copy=True))
  157. img = img.view(pic.size[1], pic.size[0], F_pil.get_image_num_channels(pic))
  158. # put it from HWC to CHW format
  159. img = img.permute((2, 0, 1))
  160. return img
  161. def convert_image_dtype(image: torch.Tensor, dtype: torch.dtype = torch.float) -> torch.Tensor:
  162. """Convert a tensor image to the given ``dtype`` and scale the values accordingly
  163. This function does not support PIL Image.
  164. Args:
  165. image (torch.Tensor): Image to be converted
  166. dtype (torch.dtype): Desired data type of the output
  167. Returns:
  168. Tensor: Converted image
  169. .. note::
  170. When converting from a smaller to a larger integer ``dtype`` the maximum values are **not** mapped exactly.
  171. If converted back and forth, this mismatch has no effect.
  172. Raises:
  173. RuntimeError: When trying to cast :class:`torch.float32` to :class:`torch.int32` or :class:`torch.int64` as
  174. well as for trying to cast :class:`torch.float64` to :class:`torch.int64`. These conversions might lead to
  175. overflow errors since the floating point ``dtype`` cannot store consecutive integers over the whole range
  176. of the integer ``dtype``.
  177. """
  178. if not torch.jit.is_scripting() and not torch.jit.is_tracing():
  179. _log_api_usage_once(convert_image_dtype)
  180. if not isinstance(image, torch.Tensor):
  181. raise TypeError("Input img should be Tensor Image")
  182. return F_t.convert_image_dtype(image, dtype)
  183. def to_pil_image(pic, mode=None):
  184. """Convert a tensor or an ndarray to PIL Image. This function does not support torchscript.
  185. See :class:`~torchvision.transforms.ToPILImage` for more details.
  186. Args:
  187. pic (Tensor or numpy.ndarray): Image to be converted to PIL Image.
  188. mode (`PIL.Image mode`_): color space and pixel depth of input data (optional).
  189. .. _PIL.Image mode: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#concept-modes
  190. Returns:
  191. PIL Image: Image converted to PIL Image.
  192. """
  193. if not torch.jit.is_scripting() and not torch.jit.is_tracing():
  194. _log_api_usage_once(to_pil_image)
  195. if not (isinstance(pic, torch.Tensor) or isinstance(pic, np.ndarray)):
  196. raise TypeError(f"pic should be Tensor or ndarray. Got {type(pic)}.")
  197. elif isinstance(pic, torch.Tensor):
  198. if pic.ndimension() not in {2, 3}:
  199. raise ValueError(f"pic should be 2/3 dimensional. Got {pic.ndimension()} dimensions.")
  200. elif pic.ndimension() == 2:
  201. # if 2D image, add channel dimension (CHW)
  202. pic = pic.unsqueeze(0)
  203. # check number of channels
  204. if pic.shape[-3] > 4:
  205. raise ValueError(f"pic should not have > 4 channels. Got {pic.shape[-3]} channels.")
  206. elif isinstance(pic, np.ndarray):
  207. if pic.ndim not in {2, 3}:
  208. raise ValueError(f"pic should be 2/3 dimensional. Got {pic.ndim} dimensions.")
  209. elif pic.ndim == 2:
  210. # if 2D image, add channel dimension (HWC)
  211. pic = np.expand_dims(pic, 2)
  212. # check number of channels
  213. if pic.shape[-1] > 4:
  214. raise ValueError(f"pic should not have > 4 channels. Got {pic.shape[-1]} channels.")
  215. npimg = pic
  216. if isinstance(pic, torch.Tensor):
  217. if pic.is_floating_point() and mode != "F":
  218. pic = pic.mul(255).byte()
  219. npimg = np.transpose(pic.cpu().numpy(), (1, 2, 0))
  220. if not isinstance(npimg, np.ndarray):
  221. raise TypeError("Input pic must be a torch.Tensor or NumPy ndarray, not {type(npimg)}")
  222. if npimg.shape[2] == 1:
  223. expected_mode = None
  224. npimg = npimg[:, :, 0]
  225. if npimg.dtype == np.uint8:
  226. expected_mode = "L"
  227. elif npimg.dtype == np.int16:
  228. expected_mode = "I;16"
  229. elif npimg.dtype == np.int32:
  230. expected_mode = "I"
  231. elif npimg.dtype == np.float32:
  232. expected_mode = "F"
  233. if mode is not None and mode != expected_mode:
  234. raise ValueError(f"Incorrect mode ({mode}) supplied for input type {np.dtype}. Should be {expected_mode}")
  235. mode = expected_mode
  236. elif npimg.shape[2] == 2:
  237. permitted_2_channel_modes = ["LA"]
  238. if mode is not None and mode not in permitted_2_channel_modes:
  239. raise ValueError(f"Only modes {permitted_2_channel_modes} are supported for 2D inputs")
  240. if mode is None and npimg.dtype == np.uint8:
  241. mode = "LA"
  242. elif npimg.shape[2] == 4:
  243. permitted_4_channel_modes = ["RGBA", "CMYK", "RGBX"]
  244. if mode is not None and mode not in permitted_4_channel_modes:
  245. raise ValueError(f"Only modes {permitted_4_channel_modes} are supported for 4D inputs")
  246. if mode is None and npimg.dtype == np.uint8:
  247. mode = "RGBA"
  248. else:
  249. permitted_3_channel_modes = ["RGB", "YCbCr", "HSV"]
  250. if mode is not None and mode not in permitted_3_channel_modes:
  251. raise ValueError(f"Only modes {permitted_3_channel_modes} are supported for 3D inputs")
  252. if mode is None and npimg.dtype == np.uint8:
  253. mode = "RGB"
  254. if mode is None:
  255. raise TypeError(f"Input type {npimg.dtype} is not supported")
  256. return Image.fromarray(npimg, mode=mode)
  257. def normalize(tensor: Tensor, mean: List[float], std: List[float], inplace: bool = False) -> Tensor:
  258. """Normalize a float tensor image with mean and standard deviation.
  259. This transform does not support PIL Image.
  260. .. note::
  261. This transform acts out of place by default, i.e., it does not mutates the input tensor.
  262. See :class:`~torchvision.transforms.Normalize` for more details.
  263. Args:
  264. tensor (Tensor): Float tensor image of size (C, H, W) or (B, C, H, W) to be normalized.
  265. mean (sequence): Sequence of means for each channel.
  266. std (sequence): Sequence of standard deviations for each channel.
  267. inplace(bool,optional): Bool to make this operation inplace.
  268. Returns:
  269. Tensor: Normalized Tensor image.
  270. """
  271. if not torch.jit.is_scripting() and not torch.jit.is_tracing():
  272. _log_api_usage_once(normalize)
  273. if not isinstance(tensor, torch.Tensor):
  274. raise TypeError(f"img should be Tensor Image. Got {type(tensor)}")
  275. return F_t.normalize(tensor, mean=mean, std=std, inplace=inplace)
  276. def _compute_resized_output_size(
  277. image_size: Tuple[int, int], size: List[int], max_size: Optional[int] = None
  278. ) -> List[int]:
  279. if len(size) == 1: # specified size only for the smallest edge
  280. h, w = image_size
  281. short, long = (w, h) if w <= h else (h, w)
  282. requested_new_short = size if isinstance(size, int) else size[0]
  283. new_short, new_long = requested_new_short, int(requested_new_short * long / short)
  284. if max_size is not None:
  285. if max_size <= requested_new_short:
  286. raise ValueError(
  287. f"max_size = {max_size} must be strictly greater than the requested "
  288. f"size for the smaller edge size = {size}"
  289. )
  290. if new_long > max_size:
  291. new_short, new_long = int(max_size * new_short / new_long), max_size
  292. new_w, new_h = (new_short, new_long) if w <= h else (new_long, new_short)
  293. else: # specified both h and w
  294. new_w, new_h = size[1], size[0]
  295. return [new_h, new_w]
  296. def resize(
  297. img: Tensor,
  298. size: List[int],
  299. interpolation: InterpolationMode = InterpolationMode.BILINEAR,
  300. max_size: Optional[int] = None,
  301. antialias: Optional[Union[str, bool]] = "warn",
  302. ) -> Tensor:
  303. r"""Resize the input image to the given size.
  304. If the image is torch Tensor, it is expected
  305. to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions
  306. .. warning::
  307. The output image might be different depending on its type: when downsampling, the interpolation of PIL images
  308. and tensors is slightly different, because PIL applies antialiasing. This may lead to significant differences
  309. in the performance of a network. Therefore, it is preferable to train and serve a model with the same input
  310. types. See also below the ``antialias`` parameter, which can help making the output of PIL images and tensors
  311. closer.
  312. Args:
  313. img (PIL Image or Tensor): Image to be resized.
  314. size (sequence or int): Desired output size. If size is a sequence like
  315. (h, w), the output size will be matched to this. If size is an int,
  316. the smaller edge of the image will be matched to this number maintaining
  317. the aspect ratio. i.e, if height > width, then image will be rescaled to
  318. :math:`\left(\text{size} \times \frac{\text{height}}{\text{width}}, \text{size}\right)`.
  319. .. note::
  320. In torchscript mode size as single int is not supported, use a sequence of length 1: ``[size, ]``.
  321. interpolation (InterpolationMode): Desired interpolation enum defined by
  322. :class:`torchvision.transforms.InterpolationMode`.
  323. Default is ``InterpolationMode.BILINEAR``. If input is Tensor, only ``InterpolationMode.NEAREST``,
  324. ``InterpolationMode.NEAREST_EXACT``, ``InterpolationMode.BILINEAR`` and ``InterpolationMode.BICUBIC`` are
  325. supported.
  326. The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well.
  327. max_size (int, optional): The maximum allowed for the longer edge of
  328. the resized image. If the longer edge of the image is greater
  329. than ``max_size`` after being resized according to ``size``,
  330. ``size`` will be overruled so that the longer edge is equal to
  331. ``max_size``.
  332. As a result, the smaller edge may be shorter than ``size``. This
  333. is only supported if ``size`` is an int (or a sequence of length
  334. 1 in torchscript mode).
  335. antialias (bool, optional): Whether to apply antialiasing.
  336. It only affects **tensors** with bilinear or bicubic modes and it is
  337. ignored otherwise: on PIL images, antialiasing is always applied on
  338. bilinear or bicubic modes; on other modes (for PIL images and
  339. tensors), antialiasing makes no sense and this parameter is ignored.
  340. Possible values are:
  341. - ``True``: will apply antialiasing for bilinear or bicubic modes.
  342. Other mode aren't affected. This is probably what you want to use.
  343. - ``False``: will not apply antialiasing for tensors on any mode. PIL
  344. images are still antialiased on bilinear or bicubic modes, because
  345. PIL doesn't support no antialias.
  346. - ``None``: equivalent to ``False`` for tensors and ``True`` for
  347. PIL images. This value exists for legacy reasons and you probably
  348. don't want to use it unless you really know what you are doing.
  349. The current default is ``None`` **but will change to** ``True`` **in
  350. v0.17** for the PIL and Tensor backends to be consistent.
  351. Returns:
  352. PIL Image or Tensor: Resized image.
  353. """
  354. if not torch.jit.is_scripting() and not torch.jit.is_tracing():
  355. _log_api_usage_once(resize)
  356. if isinstance(interpolation, int):
  357. interpolation = _interpolation_modes_from_int(interpolation)
  358. elif not isinstance(interpolation, InterpolationMode):
  359. raise TypeError(
  360. "Argument interpolation should be a InterpolationMode or a corresponding Pillow integer constant"
  361. )
  362. if isinstance(size, (list, tuple)):
  363. if len(size) not in [1, 2]:
  364. raise ValueError(
  365. f"Size must be an int or a 1 or 2 element tuple/list, not a {len(size)} element tuple/list"
  366. )
  367. if max_size is not None and len(size) != 1:
  368. raise ValueError(
  369. "max_size should only be passed if size specifies the length of the smaller edge, "
  370. "i.e. size should be an int or a sequence of length 1 in torchscript mode."
  371. )
  372. _, image_height, image_width = get_dimensions(img)
  373. if isinstance(size, int):
  374. size = [size]
  375. output_size = _compute_resized_output_size((image_height, image_width), size, max_size)
  376. if [image_height, image_width] == output_size:
  377. return img
  378. antialias = _check_antialias(img, antialias, interpolation)
  379. if not isinstance(img, torch.Tensor):
  380. if antialias is False:
  381. warnings.warn("Anti-alias option is always applied for PIL Image input. Argument antialias is ignored.")
  382. pil_interpolation = pil_modes_mapping[interpolation]
  383. return F_pil.resize(img, size=output_size, interpolation=pil_interpolation)
  384. return F_t.resize(img, size=output_size, interpolation=interpolation.value, antialias=antialias)
  385. def pad(img: Tensor, padding: List[int], fill: Union[int, float] = 0, padding_mode: str = "constant") -> Tensor:
  386. r"""Pad the given image on all sides with the given "pad" value.
  387. If the image is torch Tensor, it is expected
  388. to have [..., H, W] shape, where ... means at most 2 leading dimensions for mode reflect and symmetric,
  389. at most 3 leading dimensions for mode edge,
  390. and an arbitrary number of leading dimensions for mode constant
  391. Args:
  392. img (PIL Image or Tensor): Image to be padded.
  393. padding (int or sequence): Padding on each border. If a single int is provided this
  394. is used to pad all borders. If sequence of length 2 is provided this is the padding
  395. on left/right and top/bottom respectively. If a sequence of length 4 is provided
  396. this is the padding for the left, top, right and bottom borders respectively.
  397. .. note::
  398. In torchscript mode padding as single int is not supported, use a sequence of
  399. length 1: ``[padding, ]``.
  400. fill (number or tuple): Pixel fill value for constant fill. Default is 0.
  401. If a tuple of length 3, it is used to fill R, G, B channels respectively.
  402. This value is only used when the padding_mode is constant.
  403. Only number is supported for torch Tensor.
  404. Only int or tuple value is supported for PIL Image.
  405. padding_mode (str): Type of padding. Should be: constant, edge, reflect or symmetric.
  406. Default is constant.
  407. - constant: pads with a constant value, this value is specified with fill
  408. - edge: pads with the last value at the edge of the image.
  409. If input a 5D torch Tensor, the last 3 dimensions will be padded instead of the last 2
  410. - reflect: pads with reflection of image without repeating the last value on the edge.
  411. For example, padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode
  412. will result in [3, 2, 1, 2, 3, 4, 3, 2]
  413. - symmetric: pads with reflection of image repeating the last value on the edge.
  414. For example, padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode
  415. will result in [2, 1, 1, 2, 3, 4, 4, 3]
  416. Returns:
  417. PIL Image or Tensor: Padded image.
  418. """
  419. if not torch.jit.is_scripting() and not torch.jit.is_tracing():
  420. _log_api_usage_once(pad)
  421. if not isinstance(img, torch.Tensor):
  422. return F_pil.pad(img, padding=padding, fill=fill, padding_mode=padding_mode)
  423. return F_t.pad(img, padding=padding, fill=fill, padding_mode=padding_mode)
  424. def crop(img: Tensor, top: int, left: int, height: int, width: int) -> Tensor:
  425. """Crop the given image at specified location and output size.
  426. If the image is torch Tensor, it is expected
  427. to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
  428. If image size is smaller than output size along any edge, image is padded with 0 and then cropped.
  429. Args:
  430. img (PIL Image or Tensor): Image to be cropped. (0,0) denotes the top left corner of the image.
  431. top (int): Vertical component of the top left corner of the crop box.
  432. left (int): Horizontal component of the top left corner of the crop box.
  433. height (int): Height of the crop box.
  434. width (int): Width of the crop box.
  435. Returns:
  436. PIL Image or Tensor: Cropped image.
  437. """
  438. if not torch.jit.is_scripting() and not torch.jit.is_tracing():
  439. _log_api_usage_once(crop)
  440. if not isinstance(img, torch.Tensor):
  441. return F_pil.crop(img, top, left, height, width)
  442. return F_t.crop(img, top, left, height, width)
  443. def center_crop(img: Tensor, output_size: List[int]) -> Tensor:
  444. """Crops the given image at the center.
  445. If the image is torch Tensor, it is expected
  446. to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
  447. If image size is smaller than output size along any edge, image is padded with 0 and then center cropped.
  448. Args:
  449. img (PIL Image or Tensor): Image to be cropped.
  450. output_size (sequence or int): (height, width) of the crop box. If int or sequence with single int,
  451. it is used for both directions.
  452. Returns:
  453. PIL Image or Tensor: Cropped image.
  454. """
  455. if not torch.jit.is_scripting() and not torch.jit.is_tracing():
  456. _log_api_usage_once(center_crop)
  457. if isinstance(output_size, numbers.Number):
  458. output_size = (int(output_size), int(output_size))
  459. elif isinstance(output_size, (tuple, list)) and len(output_size) == 1:
  460. output_size = (output_size[0], output_size[0])
  461. _, image_height, image_width = get_dimensions(img)
  462. crop_height, crop_width = output_size
  463. if crop_width > image_width or crop_height > image_height:
  464. padding_ltrb = [
  465. (crop_width - image_width) // 2 if crop_width > image_width else 0,
  466. (crop_height - image_height) // 2 if crop_height > image_height else 0,
  467. (crop_width - image_width + 1) // 2 if crop_width > image_width else 0,
  468. (crop_height - image_height + 1) // 2 if crop_height > image_height else 0,
  469. ]
  470. img = pad(img, padding_ltrb, fill=0) # PIL uses fill value 0
  471. _, image_height, image_width = get_dimensions(img)
  472. if crop_width == image_width and crop_height == image_height:
  473. return img
  474. crop_top = int(round((image_height - crop_height) / 2.0))
  475. crop_left = int(round((image_width - crop_width) / 2.0))
  476. return crop(img, crop_top, crop_left, crop_height, crop_width)
  477. def resized_crop(
  478. img: Tensor,
  479. top: int,
  480. left: int,
  481. height: int,
  482. width: int,
  483. size: List[int],
  484. interpolation: InterpolationMode = InterpolationMode.BILINEAR,
  485. antialias: Optional[Union[str, bool]] = "warn",
  486. ) -> Tensor:
  487. """Crop the given image and resize it to desired size.
  488. If the image is torch Tensor, it is expected
  489. to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions
  490. Notably used in :class:`~torchvision.transforms.RandomResizedCrop`.
  491. Args:
  492. img (PIL Image or Tensor): Image to be cropped. (0,0) denotes the top left corner of the image.
  493. top (int): Vertical component of the top left corner of the crop box.
  494. left (int): Horizontal component of the top left corner of the crop box.
  495. height (int): Height of the crop box.
  496. width (int): Width of the crop box.
  497. size (sequence or int): Desired output size. Same semantics as ``resize``.
  498. interpolation (InterpolationMode): Desired interpolation enum defined by
  499. :class:`torchvision.transforms.InterpolationMode`.
  500. Default is ``InterpolationMode.BILINEAR``. If input is Tensor, only ``InterpolationMode.NEAREST``,
  501. ``InterpolationMode.NEAREST_EXACT``, ``InterpolationMode.BILINEAR`` and ``InterpolationMode.BICUBIC`` are
  502. supported.
  503. The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well.
  504. antialias (bool, optional): Whether to apply antialiasing.
  505. It only affects **tensors** with bilinear or bicubic modes and it is
  506. ignored otherwise: on PIL images, antialiasing is always applied on
  507. bilinear or bicubic modes; on other modes (for PIL images and
  508. tensors), antialiasing makes no sense and this parameter is ignored.
  509. Possible values are:
  510. - ``True``: will apply antialiasing for bilinear or bicubic modes.
  511. Other mode aren't affected. This is probably what you want to use.
  512. - ``False``: will not apply antialiasing for tensors on any mode. PIL
  513. images are still antialiased on bilinear or bicubic modes, because
  514. PIL doesn't support no antialias.
  515. - ``None``: equivalent to ``False`` for tensors and ``True`` for
  516. PIL images. This value exists for legacy reasons and you probably
  517. don't want to use it unless you really know what you are doing.
  518. The current default is ``None`` **but will change to** ``True`` **in
  519. v0.17** for the PIL and Tensor backends to be consistent.
  520. Returns:
  521. PIL Image or Tensor: Cropped image.
  522. """
  523. if not torch.jit.is_scripting() and not torch.jit.is_tracing():
  524. _log_api_usage_once(resized_crop)
  525. img = crop(img, top, left, height, width)
  526. img = resize(img, size, interpolation, antialias=antialias)
  527. return img
  528. def hflip(img: Tensor) -> Tensor:
  529. """Horizontally flip the given image.
  530. Args:
  531. img (PIL Image or Tensor): Image to be flipped. If img
  532. is a Tensor, it is expected to be in [..., H, W] format,
  533. where ... means it can have an arbitrary number of leading
  534. dimensions.
  535. Returns:
  536. PIL Image or Tensor: Horizontally flipped image.
  537. """
  538. if not torch.jit.is_scripting() and not torch.jit.is_tracing():
  539. _log_api_usage_once(hflip)
  540. if not isinstance(img, torch.Tensor):
  541. return F_pil.hflip(img)
  542. return F_t.hflip(img)
  543. def _get_perspective_coeffs(startpoints: List[List[int]], endpoints: List[List[int]]) -> List[float]:
  544. """Helper function to get the coefficients (a, b, c, d, e, f, g, h) for the perspective transforms.
  545. In Perspective Transform each pixel (x, y) in the original image gets transformed as,
  546. (x, y) -> ( (ax + by + c) / (gx + hy + 1), (dx + ey + f) / (gx + hy + 1) )
  547. Args:
  548. startpoints (list of list of ints): List containing four lists of two integers corresponding to four corners
  549. ``[top-left, top-right, bottom-right, bottom-left]`` of the original image.
  550. endpoints (list of list of ints): List containing four lists of two integers corresponding to four corners
  551. ``[top-left, top-right, bottom-right, bottom-left]`` of the transformed image.
  552. Returns:
  553. octuple (a, b, c, d, e, f, g, h) for transforming each pixel.
  554. """
  555. a_matrix = torch.zeros(2 * len(startpoints), 8, dtype=torch.float)
  556. for i, (p1, p2) in enumerate(zip(endpoints, startpoints)):
  557. a_matrix[2 * i, :] = torch.tensor([p1[0], p1[1], 1, 0, 0, 0, -p2[0] * p1[0], -p2[0] * p1[1]])
  558. a_matrix[2 * i + 1, :] = torch.tensor([0, 0, 0, p1[0], p1[1], 1, -p2[1] * p1[0], -p2[1] * p1[1]])
  559. b_matrix = torch.tensor(startpoints, dtype=torch.float).view(8)
  560. res = torch.linalg.lstsq(a_matrix, b_matrix, driver="gels").solution
  561. output: List[float] = res.tolist()
  562. return output
  563. def perspective(
  564. img: Tensor,
  565. startpoints: List[List[int]],
  566. endpoints: List[List[int]],
  567. interpolation: InterpolationMode = InterpolationMode.BILINEAR,
  568. fill: Optional[List[float]] = None,
  569. ) -> Tensor:
  570. """Perform perspective transform of the given image.
  571. If the image is torch Tensor, it is expected
  572. to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
  573. Args:
  574. img (PIL Image or Tensor): Image to be transformed.
  575. startpoints (list of list of ints): List containing four lists of two integers corresponding to four corners
  576. ``[top-left, top-right, bottom-right, bottom-left]`` of the original image.
  577. endpoints (list of list of ints): List containing four lists of two integers corresponding to four corners
  578. ``[top-left, top-right, bottom-right, bottom-left]`` of the transformed image.
  579. interpolation (InterpolationMode): Desired interpolation enum defined by
  580. :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``.
  581. If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
  582. The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well.
  583. fill (sequence or number, optional): Pixel fill value for the area outside the transformed
  584. image. If given a number, the value is used for all bands respectively.
  585. .. note::
  586. In torchscript mode single int/float value is not supported, please use a sequence
  587. of length 1: ``[value, ]``.
  588. Returns:
  589. PIL Image or Tensor: transformed Image.
  590. """
  591. if not torch.jit.is_scripting() and not torch.jit.is_tracing():
  592. _log_api_usage_once(perspective)
  593. coeffs = _get_perspective_coeffs(startpoints, endpoints)
  594. if isinstance(interpolation, int):
  595. interpolation = _interpolation_modes_from_int(interpolation)
  596. elif not isinstance(interpolation, InterpolationMode):
  597. raise TypeError(
  598. "Argument interpolation should be a InterpolationMode or a corresponding Pillow integer constant"
  599. )
  600. if not isinstance(img, torch.Tensor):
  601. pil_interpolation = pil_modes_mapping[interpolation]
  602. return F_pil.perspective(img, coeffs, interpolation=pil_interpolation, fill=fill)
  603. return F_t.perspective(img, coeffs, interpolation=interpolation.value, fill=fill)
  604. def vflip(img: Tensor) -> Tensor:
  605. """Vertically flip the given image.
  606. Args:
  607. img (PIL Image or Tensor): Image to be flipped. If img
  608. is a Tensor, it is expected to be in [..., H, W] format,
  609. where ... means it can have an arbitrary number of leading
  610. dimensions.
  611. Returns:
  612. PIL Image or Tensor: Vertically flipped image.
  613. """
  614. if not torch.jit.is_scripting() and not torch.jit.is_tracing():
  615. _log_api_usage_once(vflip)
  616. if not isinstance(img, torch.Tensor):
  617. return F_pil.vflip(img)
  618. return F_t.vflip(img)
  619. def five_crop(img: Tensor, size: List[int]) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
  620. """Crop the given image into four corners and the central crop.
  621. If the image is torch Tensor, it is expected
  622. to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions
  623. .. Note::
  624. This transform returns a tuple of images and there may be a
  625. mismatch in the number of inputs and targets your ``Dataset`` returns.
  626. Args:
  627. img (PIL Image or Tensor): Image to be cropped.
  628. size (sequence or int): Desired output size of the crop. If size is an
  629. int instead of sequence like (h, w), a square crop (size, size) is
  630. made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]).
  631. Returns:
  632. tuple: tuple (tl, tr, bl, br, center)
  633. Corresponding top left, top right, bottom left, bottom right and center crop.
  634. """
  635. if not torch.jit.is_scripting() and not torch.jit.is_tracing():
  636. _log_api_usage_once(five_crop)
  637. if isinstance(size, numbers.Number):
  638. size = (int(size), int(size))
  639. elif isinstance(size, (tuple, list)) and len(size) == 1:
  640. size = (size[0], size[0])
  641. if len(size) != 2:
  642. raise ValueError("Please provide only two dimensions (h, w) for size.")
  643. _, image_height, image_width = get_dimensions(img)
  644. crop_height, crop_width = size
  645. if crop_width > image_width or crop_height > image_height:
  646. msg = "Requested crop size {} is bigger than input size {}"
  647. raise ValueError(msg.format(size, (image_height, image_width)))
  648. tl = crop(img, 0, 0, crop_height, crop_width)
  649. tr = crop(img, 0, image_width - crop_width, crop_height, crop_width)
  650. bl = crop(img, image_height - crop_height, 0, crop_height, crop_width)
  651. br = crop(img, image_height - crop_height, image_width - crop_width, crop_height, crop_width)
  652. center = center_crop(img, [crop_height, crop_width])
  653. return tl, tr, bl, br, center
  654. def ten_crop(
  655. img: Tensor, size: List[int], vertical_flip: bool = False
  656. ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]:
  657. """Generate ten cropped images from the given image.
  658. Crop the given image into four corners and the central crop plus the
  659. flipped version of these (horizontal flipping is used by default).
  660. If the image is torch Tensor, it is expected
  661. to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions
  662. .. Note::
  663. This transform returns a tuple of images and there may be a
  664. mismatch in the number of inputs and targets your ``Dataset`` returns.
  665. Args:
  666. img (PIL Image or Tensor): Image to be cropped.
  667. size (sequence or int): Desired output size of the crop. If size is an
  668. int instead of sequence like (h, w), a square crop (size, size) is
  669. made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]).
  670. vertical_flip (bool): Use vertical flipping instead of horizontal
  671. Returns:
  672. tuple: tuple (tl, tr, bl, br, center, tl_flip, tr_flip, bl_flip, br_flip, center_flip)
  673. Corresponding top left, top right, bottom left, bottom right and
  674. center crop and same for the flipped image.
  675. """
  676. if not torch.jit.is_scripting() and not torch.jit.is_tracing():
  677. _log_api_usage_once(ten_crop)
  678. if isinstance(size, numbers.Number):
  679. size = (int(size), int(size))
  680. elif isinstance(size, (tuple, list)) and len(size) == 1:
  681. size = (size[0], size[0])
  682. if len(size) != 2:
  683. raise ValueError("Please provide only two dimensions (h, w) for size.")
  684. first_five = five_crop(img, size)
  685. if vertical_flip:
  686. img = vflip(img)
  687. else:
  688. img = hflip(img)
  689. second_five = five_crop(img, size)
  690. return first_five + second_five
  691. def adjust_brightness(img: Tensor, brightness_factor: float) -> Tensor:
  692. """Adjust brightness of an image.
  693. Args:
  694. img (PIL Image or Tensor): Image to be adjusted.
  695. If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format,
  696. where ... means it can have an arbitrary number of leading dimensions.
  697. brightness_factor (float): How much to adjust the brightness. Can be
  698. any non-negative number. 0 gives a black image, 1 gives the
  699. original image while 2 increases the brightness by a factor of 2.
  700. Returns:
  701. PIL Image or Tensor: Brightness adjusted image.
  702. """
  703. if not torch.jit.is_scripting() and not torch.jit.is_tracing():
  704. _log_api_usage_once(adjust_brightness)
  705. if not isinstance(img, torch.Tensor):
  706. return F_pil.adjust_brightness(img, brightness_factor)
  707. return F_t.adjust_brightness(img, brightness_factor)
  708. def adjust_contrast(img: Tensor, contrast_factor: float) -> Tensor:
  709. """Adjust contrast of an image.
  710. Args:
  711. img (PIL Image or Tensor): Image to be adjusted.
  712. If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format,
  713. where ... means it can have an arbitrary number of leading dimensions.
  714. contrast_factor (float): How much to adjust the contrast. Can be any
  715. non-negative number. 0 gives a solid gray image, 1 gives the
  716. original image while 2 increases the contrast by a factor of 2.
  717. Returns:
  718. PIL Image or Tensor: Contrast adjusted image.
  719. """
  720. if not torch.jit.is_scripting() and not torch.jit.is_tracing():
  721. _log_api_usage_once(adjust_contrast)
  722. if not isinstance(img, torch.Tensor):
  723. return F_pil.adjust_contrast(img, contrast_factor)
  724. return F_t.adjust_contrast(img, contrast_factor)
  725. def adjust_saturation(img: Tensor, saturation_factor: float) -> Tensor:
  726. """Adjust color saturation of an image.
  727. Args:
  728. img (PIL Image or Tensor): Image to be adjusted.
  729. If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format,
  730. where ... means it can have an arbitrary number of leading dimensions.
  731. saturation_factor (float): How much to adjust the saturation. 0 will
  732. give a black and white image, 1 will give the original image while
  733. 2 will enhance the saturation by a factor of 2.
  734. Returns:
  735. PIL Image or Tensor: Saturation adjusted image.
  736. """
  737. if not torch.jit.is_scripting() and not torch.jit.is_tracing():
  738. _log_api_usage_once(adjust_saturation)
  739. if not isinstance(img, torch.Tensor):
  740. return F_pil.adjust_saturation(img, saturation_factor)
  741. return F_t.adjust_saturation(img, saturation_factor)
  742. def adjust_hue(img: Tensor, hue_factor: float) -> Tensor:
  743. """Adjust hue of an image.
  744. The image hue is adjusted by converting the image to HSV and
  745. cyclically shifting the intensities in the hue channel (H).
  746. The image is then converted back to original image mode.
  747. `hue_factor` is the amount of shift in H channel and must be in the
  748. interval `[-0.5, 0.5]`.
  749. See `Hue`_ for more details.
  750. .. _Hue: https://en.wikipedia.org/wiki/Hue
  751. Args:
  752. img (PIL Image or Tensor): Image to be adjusted.
  753. If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format,
  754. where ... means it can have an arbitrary number of leading dimensions.
  755. If img is PIL Image mode "1", "I", "F" and modes with transparency (alpha channel) are not supported.
  756. Note: the pixel values of the input image has to be non-negative for conversion to HSV space;
  757. thus it does not work if you normalize your image to an interval with negative values,
  758. or use an interpolation that generates negative values before using this function.
  759. hue_factor (float): How much to shift the hue channel. Should be in
  760. [-0.5, 0.5]. 0.5 and -0.5 give complete reversal of hue channel in
  761. HSV space in positive and negative direction respectively.
  762. 0 means no shift. Therefore, both -0.5 and 0.5 will give an image
  763. with complementary colors while 0 gives the original image.
  764. Returns:
  765. PIL Image or Tensor: Hue adjusted image.
  766. """
  767. if not torch.jit.is_scripting() and not torch.jit.is_tracing():
  768. _log_api_usage_once(adjust_hue)
  769. if not isinstance(img, torch.Tensor):
  770. return F_pil.adjust_hue(img, hue_factor)
  771. return F_t.adjust_hue(img, hue_factor)
  772. def adjust_gamma(img: Tensor, gamma: float, gain: float = 1) -> Tensor:
  773. r"""Perform gamma correction on an image.
  774. Also known as Power Law Transform. Intensities in RGB mode are adjusted
  775. based on the following equation:
  776. .. math::
  777. I_{\text{out}} = 255 \times \text{gain} \times \left(\frac{I_{\text{in}}}{255}\right)^{\gamma}
  778. See `Gamma Correction`_ for more details.
  779. .. _Gamma Correction: https://en.wikipedia.org/wiki/Gamma_correction
  780. Args:
  781. img (PIL Image or Tensor): PIL Image to be adjusted.
  782. If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format,
  783. where ... means it can have an arbitrary number of leading dimensions.
  784. If img is PIL Image, modes with transparency (alpha channel) are not supported.
  785. gamma (float): Non negative real number, same as :math:`\gamma` in the equation.
  786. gamma larger than 1 make the shadows darker,
  787. while gamma smaller than 1 make dark regions lighter.
  788. gain (float): The constant multiplier.
  789. Returns:
  790. PIL Image or Tensor: Gamma correction adjusted image.
  791. """
  792. if not torch.jit.is_scripting() and not torch.jit.is_tracing():
  793. _log_api_usage_once(adjust_gamma)
  794. if not isinstance(img, torch.Tensor):
  795. return F_pil.adjust_gamma(img, gamma, gain)
  796. return F_t.adjust_gamma(img, gamma, gain)
  797. def _get_inverse_affine_matrix(
  798. center: List[float], angle: float, translate: List[float], scale: float, shear: List[float], inverted: bool = True
  799. ) -> List[float]:
  800. # Helper method to compute inverse matrix for affine transformation
  801. # Pillow requires inverse affine transformation matrix:
  802. # Affine matrix is : M = T * C * RotateScaleShear * C^-1
  803. #
  804. # where T is translation matrix: [1, 0, tx | 0, 1, ty | 0, 0, 1]
  805. # C is translation matrix to keep center: [1, 0, cx | 0, 1, cy | 0, 0, 1]
  806. # RotateScaleShear is rotation with scale and shear matrix
  807. #
  808. # RotateScaleShear(a, s, (sx, sy)) =
  809. # = R(a) * S(s) * SHy(sy) * SHx(sx)
  810. # = [ s*cos(a - sy)/cos(sy), s*(-cos(a - sy)*tan(sx)/cos(sy) - sin(a)), 0 ]
  811. # [ s*sin(a - sy)/cos(sy), s*(-sin(a - sy)*tan(sx)/cos(sy) + cos(a)), 0 ]
  812. # [ 0 , 0 , 1 ]
  813. # where R is a rotation matrix, S is a scaling matrix, and SHx and SHy are the shears:
  814. # SHx(s) = [1, -tan(s)] and SHy(s) = [1 , 0]
  815. # [0, 1 ] [-tan(s), 1]
  816. #
  817. # Thus, the inverse is M^-1 = C * RotateScaleShear^-1 * C^-1 * T^-1
  818. rot = math.radians(angle)
  819. sx = math.radians(shear[0])
  820. sy = math.radians(shear[1])
  821. cx, cy = center
  822. tx, ty = translate
  823. # RSS without scaling
  824. a = math.cos(rot - sy) / math.cos(sy)
  825. b = -math.cos(rot - sy) * math.tan(sx) / math.cos(sy) - math.sin(rot)
  826. c = math.sin(rot - sy) / math.cos(sy)
  827. d = -math.sin(rot - sy) * math.tan(sx) / math.cos(sy) + math.cos(rot)
  828. if inverted:
  829. # Inverted rotation matrix with scale and shear
  830. # det([[a, b], [c, d]]) == 1, since det(rotation) = 1 and det(shear) = 1
  831. matrix = [d, -b, 0.0, -c, a, 0.0]
  832. matrix = [x / scale for x in matrix]
  833. # Apply inverse of translation and of center translation: RSS^-1 * C^-1 * T^-1
  834. matrix[2] += matrix[0] * (-cx - tx) + matrix[1] * (-cy - ty)
  835. matrix[5] += matrix[3] * (-cx - tx) + matrix[4] * (-cy - ty)
  836. # Apply center translation: C * RSS^-1 * C^-1 * T^-1
  837. matrix[2] += cx
  838. matrix[5] += cy
  839. else:
  840. matrix = [a, b, 0.0, c, d, 0.0]
  841. matrix = [x * scale for x in matrix]
  842. # Apply inverse of center translation: RSS * C^-1
  843. matrix[2] += matrix[0] * (-cx) + matrix[1] * (-cy)
  844. matrix[5] += matrix[3] * (-cx) + matrix[4] * (-cy)
  845. # Apply translation and center : T * C * RSS * C^-1
  846. matrix[2] += cx + tx
  847. matrix[5] += cy + ty
  848. return matrix
  849. def rotate(
  850. img: Tensor,
  851. angle: float,
  852. interpolation: InterpolationMode = InterpolationMode.NEAREST,
  853. expand: bool = False,
  854. center: Optional[List[int]] = None,
  855. fill: Optional[List[float]] = None,
  856. ) -> Tensor:
  857. """Rotate the image by angle.
  858. If the image is torch Tensor, it is expected
  859. to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
  860. Args:
  861. img (PIL Image or Tensor): image to be rotated.
  862. angle (number): rotation angle value in degrees, counter-clockwise.
  863. interpolation (InterpolationMode): Desired interpolation enum defined by
  864. :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
  865. If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
  866. The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well.
  867. expand (bool, optional): Optional expansion flag.
  868. If true, expands the output image to make it large enough to hold the entire rotated image.
  869. If false or omitted, make the output image the same size as the input image.
  870. Note that the expand flag assumes rotation around the center and no translation.
  871. center (sequence, optional): Optional center of rotation. Origin is the upper left corner.
  872. Default is the center of the image.
  873. fill (sequence or number, optional): Pixel fill value for the area outside the transformed
  874. image. If given a number, the value is used for all bands respectively.
  875. .. note::
  876. In torchscript mode single int/float value is not supported, please use a sequence
  877. of length 1: ``[value, ]``.
  878. Returns:
  879. PIL Image or Tensor: Rotated image.
  880. .. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters
  881. """
  882. if not torch.jit.is_scripting() and not torch.jit.is_tracing():
  883. _log_api_usage_once(rotate)
  884. if isinstance(interpolation, int):
  885. interpolation = _interpolation_modes_from_int(interpolation)
  886. elif not isinstance(interpolation, InterpolationMode):
  887. raise TypeError(
  888. "Argument interpolation should be a InterpolationMode or a corresponding Pillow integer constant"
  889. )
  890. if not isinstance(angle, (int, float)):
  891. raise TypeError("Argument angle should be int or float")
  892. if center is not None and not isinstance(center, (list, tuple)):
  893. raise TypeError("Argument center should be a sequence")
  894. if not isinstance(img, torch.Tensor):
  895. pil_interpolation = pil_modes_mapping[interpolation]
  896. return F_pil.rotate(img, angle=angle, interpolation=pil_interpolation, expand=expand, center=center, fill=fill)
  897. center_f = [0.0, 0.0]
  898. if center is not None:
  899. _, height, width = get_dimensions(img)
  900. # Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center.
  901. center_f = [1.0 * (c - s * 0.5) for c, s in zip(center, [width, height])]
  902. # due to current incoherence of rotation angle direction between affine and rotate implementations
  903. # we need to set -angle.
  904. matrix = _get_inverse_affine_matrix(center_f, -angle, [0.0, 0.0], 1.0, [0.0, 0.0])
  905. return F_t.rotate(img, matrix=matrix, interpolation=interpolation.value, expand=expand, fill=fill)
  906. def affine(
  907. img: Tensor,
  908. angle: float,
  909. translate: List[int],
  910. scale: float,
  911. shear: List[float],
  912. interpolation: InterpolationMode = InterpolationMode.NEAREST,
  913. fill: Optional[List[float]] = None,
  914. center: Optional[List[int]] = None,
  915. ) -> Tensor:
  916. """Apply affine transformation on the image keeping image center invariant.
  917. If the image is torch Tensor, it is expected
  918. to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
  919. Args:
  920. img (PIL Image or Tensor): image to transform.
  921. angle (number): rotation angle in degrees between -180 and 180, clockwise direction.
  922. translate (sequence of integers): horizontal and vertical translations (post-rotation translation)
  923. scale (float): overall scale
  924. shear (float or sequence): shear angle value in degrees between -180 to 180, clockwise direction.
  925. If a sequence is specified, the first value corresponds to a shear parallel to the x-axis, while
  926. the second value corresponds to a shear parallel to the y-axis.
  927. interpolation (InterpolationMode): Desired interpolation enum defined by
  928. :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
  929. If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
  930. The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well.
  931. fill (sequence or number, optional): Pixel fill value for the area outside the transformed
  932. image. If given a number, the value is used for all bands respectively.
  933. .. note::
  934. In torchscript mode single int/float value is not supported, please use a sequence
  935. of length 1: ``[value, ]``.
  936. center (sequence, optional): Optional center of rotation. Origin is the upper left corner.
  937. Default is the center of the image.
  938. Returns:
  939. PIL Image or Tensor: Transformed image.
  940. """
  941. if not torch.jit.is_scripting() and not torch.jit.is_tracing():
  942. _log_api_usage_once(affine)
  943. if isinstance(interpolation, int):
  944. interpolation = _interpolation_modes_from_int(interpolation)
  945. elif not isinstance(interpolation, InterpolationMode):
  946. raise TypeError(
  947. "Argument interpolation should be a InterpolationMode or a corresponding Pillow integer constant"
  948. )
  949. if not isinstance(angle, (int, float)):
  950. raise TypeError("Argument angle should be int or float")
  951. if not isinstance(translate, (list, tuple)):
  952. raise TypeError("Argument translate should be a sequence")
  953. if len(translate) != 2:
  954. raise ValueError("Argument translate should be a sequence of length 2")
  955. if scale <= 0.0:
  956. raise ValueError("Argument scale should be positive")
  957. if not isinstance(shear, (numbers.Number, (list, tuple))):
  958. raise TypeError("Shear should be either a single value or a sequence of two values")
  959. if isinstance(angle, int):
  960. angle = float(angle)
  961. if isinstance(translate, tuple):
  962. translate = list(translate)
  963. if isinstance(shear, numbers.Number):
  964. shear = [shear, 0.0]
  965. if isinstance(shear, tuple):
  966. shear = list(shear)
  967. if len(shear) == 1:
  968. shear = [shear[0], shear[0]]
  969. if len(shear) != 2:
  970. raise ValueError(f"Shear should be a sequence containing two values. Got {shear}")
  971. if center is not None and not isinstance(center, (list, tuple)):
  972. raise TypeError("Argument center should be a sequence")
  973. _, height, width = get_dimensions(img)
  974. if not isinstance(img, torch.Tensor):
  975. # center = (width * 0.5 + 0.5, height * 0.5 + 0.5)
  976. # it is visually better to estimate the center without 0.5 offset
  977. # otherwise image rotated by 90 degrees is shifted vs output image of torch.rot90 or F_t.affine
  978. if center is None:
  979. center = [width * 0.5, height * 0.5]
  980. matrix = _get_inverse_affine_matrix(center, angle, translate, scale, shear)
  981. pil_interpolation = pil_modes_mapping[interpolation]
  982. return F_pil.affine(img, matrix=matrix, interpolation=pil_interpolation, fill=fill)
  983. center_f = [0.0, 0.0]
  984. if center is not None:
  985. _, height, width = get_dimensions(img)
  986. # Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center.
  987. center_f = [1.0 * (c - s * 0.5) for c, s in zip(center, [width, height])]
  988. translate_f = [1.0 * t for t in translate]
  989. matrix = _get_inverse_affine_matrix(center_f, angle, translate_f, scale, shear)
  990. return F_t.affine(img, matrix=matrix, interpolation=interpolation.value, fill=fill)
  991. # Looks like to_grayscale() is a stand-alone functional that is never called
  992. # from the transform classes. Perhaps it's still here for BC? I can't be
  993. # bothered to dig.
  994. @torch.jit.unused
  995. def to_grayscale(img, num_output_channels=1):
  996. """Convert PIL image of any mode (RGB, HSV, LAB, etc) to grayscale version of image.
  997. This transform does not support torch Tensor.
  998. Args:
  999. img (PIL Image): PIL Image to be converted to grayscale.
  1000. num_output_channels (int): number of channels of the output image. Value can be 1 or 3. Default is 1.
  1001. Returns:
  1002. PIL Image: Grayscale version of the image.
  1003. - if num_output_channels = 1 : returned image is single channel
  1004. - if num_output_channels = 3 : returned image is 3 channel with r = g = b
  1005. """
  1006. if not torch.jit.is_scripting() and not torch.jit.is_tracing():
  1007. _log_api_usage_once(to_grayscale)
  1008. if isinstance(img, Image.Image):
  1009. return F_pil.to_grayscale(img, num_output_channels)
  1010. raise TypeError("Input should be PIL Image")
  1011. def rgb_to_grayscale(img: Tensor, num_output_channels: int = 1) -> Tensor:
  1012. """Convert RGB image to grayscale version of image.
  1013. If the image is torch Tensor, it is expected
  1014. to have [..., 3, H, W] shape, where ... means an arbitrary number of leading dimensions
  1015. Note:
  1016. Please, note that this method supports only RGB images as input. For inputs in other color spaces,
  1017. please, consider using meth:`~torchvision.transforms.functional.to_grayscale` with PIL Image.
  1018. Args:
  1019. img (PIL Image or Tensor): RGB Image to be converted to grayscale.
  1020. num_output_channels (int): number of channels of the output image. Value can be 1 or 3. Default, 1.
  1021. Returns:
  1022. PIL Image or Tensor: Grayscale version of the image.
  1023. - if num_output_channels = 1 : returned image is single channel
  1024. - if num_output_channels = 3 : returned image is 3 channel with r = g = b
  1025. """
  1026. if not torch.jit.is_scripting() and not torch.jit.is_tracing():
  1027. _log_api_usage_once(rgb_to_grayscale)
  1028. if not isinstance(img, torch.Tensor):
  1029. return F_pil.to_grayscale(img, num_output_channels)
  1030. return F_t.rgb_to_grayscale(img, num_output_channels)
  1031. def erase(img: Tensor, i: int, j: int, h: int, w: int, v: Tensor, inplace: bool = False) -> Tensor:
  1032. """Erase the input Tensor Image with given value.
  1033. This transform does not support PIL Image.
  1034. Args:
  1035. img (Tensor Image): Tensor image of size (C, H, W) to be erased
  1036. i (int): i in (i,j) i.e coordinates of the upper left corner.
  1037. j (int): j in (i,j) i.e coordinates of the upper left corner.
  1038. h (int): Height of the erased region.
  1039. w (int): Width of the erased region.
  1040. v: Erasing value.
  1041. inplace(bool, optional): For in-place operations. By default, is set False.
  1042. Returns:
  1043. Tensor Image: Erased image.
  1044. """
  1045. if not torch.jit.is_scripting() and not torch.jit.is_tracing():
  1046. _log_api_usage_once(erase)
  1047. if not isinstance(img, torch.Tensor):
  1048. raise TypeError(f"img should be Tensor Image. Got {type(img)}")
  1049. return F_t.erase(img, i, j, h, w, v, inplace=inplace)
  1050. def gaussian_blur(img: Tensor, kernel_size: List[int], sigma: Optional[List[float]] = None) -> Tensor:
  1051. """Performs Gaussian blurring on the image by given kernel.
  1052. If the image is torch Tensor, it is expected
  1053. to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
  1054. Args:
  1055. img (PIL Image or Tensor): Image to be blurred
  1056. kernel_size (sequence of ints or int): Gaussian kernel size. Can be a sequence of integers
  1057. like ``(kx, ky)`` or a single integer for square kernels.
  1058. .. note::
  1059. In torchscript mode kernel_size as single int is not supported, use a sequence of
  1060. length 1: ``[ksize, ]``.
  1061. sigma (sequence of floats or float, optional): Gaussian kernel standard deviation. Can be a
  1062. sequence of floats like ``(sigma_x, sigma_y)`` or a single float to define the
  1063. same sigma in both X/Y directions. If None, then it is computed using
  1064. ``kernel_size`` as ``sigma = 0.3 * ((kernel_size - 1) * 0.5 - 1) + 0.8``.
  1065. Default, None.
  1066. .. note::
  1067. In torchscript mode sigma as single float is
  1068. not supported, use a sequence of length 1: ``[sigma, ]``.
  1069. Returns:
  1070. PIL Image or Tensor: Gaussian Blurred version of the image.
  1071. """
  1072. if not torch.jit.is_scripting() and not torch.jit.is_tracing():
  1073. _log_api_usage_once(gaussian_blur)
  1074. if not isinstance(kernel_size, (int, list, tuple)):
  1075. raise TypeError(f"kernel_size should be int or a sequence of integers. Got {type(kernel_size)}")
  1076. if isinstance(kernel_size, int):
  1077. kernel_size = [kernel_size, kernel_size]
  1078. if len(kernel_size) != 2:
  1079. raise ValueError(f"If kernel_size is a sequence its length should be 2. Got {len(kernel_size)}")
  1080. for ksize in kernel_size:
  1081. if ksize % 2 == 0 or ksize < 0:
  1082. raise ValueError(f"kernel_size should have odd and positive integers. Got {kernel_size}")
  1083. if sigma is None:
  1084. sigma = [ksize * 0.15 + 0.35 for ksize in kernel_size]
  1085. if sigma is not None and not isinstance(sigma, (int, float, list, tuple)):
  1086. raise TypeError(f"sigma should be either float or sequence of floats. Got {type(sigma)}")
  1087. if isinstance(sigma, (int, float)):
  1088. sigma = [float(sigma), float(sigma)]
  1089. if isinstance(sigma, (list, tuple)) and len(sigma) == 1:
  1090. sigma = [sigma[0], sigma[0]]
  1091. if len(sigma) != 2:
  1092. raise ValueError(f"If sigma is a sequence, its length should be 2. Got {len(sigma)}")
  1093. for s in sigma:
  1094. if s <= 0.0:
  1095. raise ValueError(f"sigma should have positive values. Got {sigma}")
  1096. t_img = img
  1097. if not isinstance(img, torch.Tensor):
  1098. if not F_pil._is_pil_image(img):
  1099. raise TypeError(f"img should be PIL Image or Tensor. Got {type(img)}")
  1100. t_img = pil_to_tensor(img)
  1101. output = F_t.gaussian_blur(t_img, kernel_size, sigma)
  1102. if not isinstance(img, torch.Tensor):
  1103. output = to_pil_image(output, mode=img.mode)
  1104. return output
  1105. def invert(img: Tensor) -> Tensor:
  1106. """Invert the colors of an RGB/grayscale image.
  1107. Args:
  1108. img (PIL Image or Tensor): Image to have its colors inverted.
  1109. If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format,
  1110. where ... means it can have an arbitrary number of leading dimensions.
  1111. If img is PIL Image, it is expected to be in mode "L" or "RGB".
  1112. Returns:
  1113. PIL Image or Tensor: Color inverted image.
  1114. """
  1115. if not torch.jit.is_scripting() and not torch.jit.is_tracing():
  1116. _log_api_usage_once(invert)
  1117. if not isinstance(img, torch.Tensor):
  1118. return F_pil.invert(img)
  1119. return F_t.invert(img)
  1120. def posterize(img: Tensor, bits: int) -> Tensor:
  1121. """Posterize an image by reducing the number of bits for each color channel.
  1122. Args:
  1123. img (PIL Image or Tensor): Image to have its colors posterized.
  1124. If img is torch Tensor, it should be of type torch.uint8, and
  1125. it is expected to be in [..., 1 or 3, H, W] format, where ... means
  1126. it can have an arbitrary number of leading dimensions.
  1127. If img is PIL Image, it is expected to be in mode "L" or "RGB".
  1128. bits (int): The number of bits to keep for each channel (0-8).
  1129. Returns:
  1130. PIL Image or Tensor: Posterized image.
  1131. """
  1132. if not torch.jit.is_scripting() and not torch.jit.is_tracing():
  1133. _log_api_usage_once(posterize)
  1134. if not (0 <= bits <= 8):
  1135. raise ValueError(f"The number if bits should be between 0 and 8. Got {bits}")
  1136. if not isinstance(img, torch.Tensor):
  1137. return F_pil.posterize(img, bits)
  1138. return F_t.posterize(img, bits)
  1139. def solarize(img: Tensor, threshold: float) -> Tensor:
  1140. """Solarize an RGB/grayscale image by inverting all pixel values above a threshold.
  1141. Args:
  1142. img (PIL Image or Tensor): Image to have its colors inverted.
  1143. If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format,
  1144. where ... means it can have an arbitrary number of leading dimensions.
  1145. If img is PIL Image, it is expected to be in mode "L" or "RGB".
  1146. threshold (float): All pixels equal or above this value are inverted.
  1147. Returns:
  1148. PIL Image or Tensor: Solarized image.
  1149. """
  1150. if not torch.jit.is_scripting() and not torch.jit.is_tracing():
  1151. _log_api_usage_once(solarize)
  1152. if not isinstance(img, torch.Tensor):
  1153. return F_pil.solarize(img, threshold)
  1154. return F_t.solarize(img, threshold)
  1155. def adjust_sharpness(img: Tensor, sharpness_factor: float) -> Tensor:
  1156. """Adjust the sharpness of an image.
  1157. Args:
  1158. img (PIL Image or Tensor): Image to be adjusted.
  1159. If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format,
  1160. where ... means it can have an arbitrary number of leading dimensions.
  1161. sharpness_factor (float): How much to adjust the sharpness. Can be
  1162. any non-negative number. 0 gives a blurred image, 1 gives the
  1163. original image while 2 increases the sharpness by a factor of 2.
  1164. Returns:
  1165. PIL Image or Tensor: Sharpness adjusted image.
  1166. """
  1167. if not torch.jit.is_scripting() and not torch.jit.is_tracing():
  1168. _log_api_usage_once(adjust_sharpness)
  1169. if not isinstance(img, torch.Tensor):
  1170. return F_pil.adjust_sharpness(img, sharpness_factor)
  1171. return F_t.adjust_sharpness(img, sharpness_factor)
  1172. def autocontrast(img: Tensor) -> Tensor:
  1173. """Maximize contrast of an image by remapping its
  1174. pixels per channel so that the lowest becomes black and the lightest
  1175. becomes white.
  1176. Args:
  1177. img (PIL Image or Tensor): Image on which autocontrast is applied.
  1178. If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format,
  1179. where ... means it can have an arbitrary number of leading dimensions.
  1180. If img is PIL Image, it is expected to be in mode "L" or "RGB".
  1181. Returns:
  1182. PIL Image or Tensor: An image that was autocontrasted.
  1183. """
  1184. if not torch.jit.is_scripting() and not torch.jit.is_tracing():
  1185. _log_api_usage_once(autocontrast)
  1186. if not isinstance(img, torch.Tensor):
  1187. return F_pil.autocontrast(img)
  1188. return F_t.autocontrast(img)
  1189. def equalize(img: Tensor) -> Tensor:
  1190. """Equalize the histogram of an image by applying
  1191. a non-linear mapping to the input in order to create a uniform
  1192. distribution of grayscale values in the output.
  1193. Args:
  1194. img (PIL Image or Tensor): Image on which equalize is applied.
  1195. If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format,
  1196. where ... means it can have an arbitrary number of leading dimensions.
  1197. The tensor dtype must be ``torch.uint8`` and values are expected to be in ``[0, 255]``.
  1198. If img is PIL Image, it is expected to be in mode "P", "L" or "RGB".
  1199. Returns:
  1200. PIL Image or Tensor: An image that was equalized.
  1201. """
  1202. if not torch.jit.is_scripting() and not torch.jit.is_tracing():
  1203. _log_api_usage_once(equalize)
  1204. if not isinstance(img, torch.Tensor):
  1205. return F_pil.equalize(img)
  1206. return F_t.equalize(img)
  1207. def elastic_transform(
  1208. img: Tensor,
  1209. displacement: Tensor,
  1210. interpolation: InterpolationMode = InterpolationMode.BILINEAR,
  1211. fill: Optional[List[float]] = None,
  1212. ) -> Tensor:
  1213. """Transform a tensor image with elastic transformations.
  1214. Given alpha and sigma, it will generate displacement
  1215. vectors for all pixels based on random offsets. Alpha controls the strength
  1216. and sigma controls the smoothness of the displacements.
  1217. The displacements are added to an identity grid and the resulting grid is
  1218. used to grid_sample from the image.
  1219. Applications:
  1220. Randomly transforms the morphology of objects in images and produces a
  1221. see-through-water-like effect.
  1222. Args:
  1223. img (PIL Image or Tensor): Image on which elastic_transform is applied.
  1224. If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format,
  1225. where ... means it can have an arbitrary number of leading dimensions.
  1226. If img is PIL Image, it is expected to be in mode "P", "L" or "RGB".
  1227. displacement (Tensor): The displacement field. Expected shape is [1, H, W, 2].
  1228. interpolation (InterpolationMode): Desired interpolation enum defined by
  1229. :class:`torchvision.transforms.InterpolationMode`.
  1230. Default is ``InterpolationMode.BILINEAR``.
  1231. The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well.
  1232. fill (number or str or tuple): Pixel fill value for constant fill. Default is 0.
  1233. If a tuple of length 3, it is used to fill R, G, B channels respectively.
  1234. This value is only used when the padding_mode is constant.
  1235. """
  1236. if not torch.jit.is_scripting() and not torch.jit.is_tracing():
  1237. _log_api_usage_once(elastic_transform)
  1238. # Backward compatibility with integer value
  1239. if isinstance(interpolation, int):
  1240. warnings.warn(
  1241. "Argument interpolation should be of type InterpolationMode instead of int. "
  1242. "Please, use InterpolationMode enum."
  1243. )
  1244. interpolation = _interpolation_modes_from_int(interpolation)
  1245. if not isinstance(displacement, torch.Tensor):
  1246. raise TypeError("Argument displacement should be a Tensor")
  1247. t_img = img
  1248. if not isinstance(img, torch.Tensor):
  1249. if not F_pil._is_pil_image(img):
  1250. raise TypeError(f"img should be PIL Image or Tensor. Got {type(img)}")
  1251. t_img = pil_to_tensor(img)
  1252. shape = t_img.shape
  1253. shape = (1,) + shape[-2:] + (2,)
  1254. if shape != displacement.shape:
  1255. raise ValueError(f"Argument displacement shape should be {shape}, but given {displacement.shape}")
  1256. # TODO: if image shape is [N1, N2, ..., C, H, W] and
  1257. # displacement is [1, H, W, 2] we need to reshape input image
  1258. # such grid_sampler takes internal code for 4D input
  1259. output = F_t.elastic_transform(
  1260. t_img,
  1261. displacement,
  1262. interpolation=interpolation.value,
  1263. fill=fill,
  1264. )
  1265. if not isinstance(img, torch.Tensor):
  1266. output = to_pil_image(output, mode=img.mode)
  1267. return output
  1268. # TODO in v0.17: remove this helper and change default of antialias to True everywhere
  1269. def _check_antialias(
  1270. img: Tensor, antialias: Optional[Union[str, bool]], interpolation: InterpolationMode
  1271. ) -> Optional[bool]:
  1272. if isinstance(antialias, str): # it should be "warn", but we don't bother checking against that
  1273. if isinstance(img, Tensor) and (
  1274. interpolation == InterpolationMode.BILINEAR or interpolation == InterpolationMode.BICUBIC
  1275. ):
  1276. warnings.warn(
  1277. "The default value of the antialias parameter of all the resizing transforms "
  1278. "(Resize(), RandomResizedCrop(), etc.) "
  1279. "will change from None to True in v0.17, "
  1280. "in order to be consistent across the PIL and Tensor backends. "
  1281. "To suppress this warning, directly pass "
  1282. "antialias=True (recommended, future default), antialias=None (current default, "
  1283. "which means False for Tensors and True for PIL), "
  1284. "or antialias=False (only works on Tensors - PIL will still use antialiasing). "
  1285. "This also applies if you are using the inference transforms from the models weights: "
  1286. "update the call to weights.transforms(antialias=True)."
  1287. )
  1288. antialias = None
  1289. return antialias