_geometry.py 85 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131213221332134213521362137213821392140214121422143214421452146214721482149215021512152215321542155215621572158215921602161216221632164216521662167216821692170217121722173217421752176217721782179218021812182218321842185218621872188218921902191219221932194219521962197219821992200220122022203220422052206220722082209221022112212221322142215221622172218221922202221222222232224222522262227222822292230223122322233223422352236223722382239224022412242224322442245224622472248224922502251225222532254225522562257225822592260226122622263226422652266226722682269227022712272227322742275227622772278227922802281228222832284228522862287228822892290229122922293229422952296229722982299230023012302230323042305230623072308230923102311231223132314231523162317231823192320232123222323232423252326232723282329233023312332233323342335233623372338233923402341234223432344234523462347234823492350235123522353235423552356235723582359236023612362236323642365236623672368236923702371237223732374237523762377237823792380
  1. import math
  2. import numbers
  3. import warnings
  4. from typing import Any, List, Optional, Sequence, Tuple, Union
  5. import PIL.Image
  6. import torch
  7. from torch.nn.functional import grid_sample, interpolate, pad as torch_pad
  8. from torchvision import tv_tensors
  9. from torchvision.transforms import _functional_pil as _FP
  10. from torchvision.transforms._functional_tensor import _pad_symmetric
  11. from torchvision.transforms.functional import (
  12. _check_antialias,
  13. _compute_resized_output_size as __compute_resized_output_size,
  14. _get_perspective_coeffs,
  15. _interpolation_modes_from_int,
  16. InterpolationMode,
  17. pil_modes_mapping,
  18. pil_to_tensor,
  19. to_pil_image,
  20. )
  21. from torchvision.utils import _log_api_usage_once
  22. from ._meta import _get_size_image_pil, clamp_bounding_boxes, convert_bounding_box_format
  23. from ._utils import _FillTypeJIT, _get_kernel, _register_five_ten_crop_kernel_internal, _register_kernel_internal
  24. def _check_interpolation(interpolation: Union[InterpolationMode, int]) -> InterpolationMode:
  25. if isinstance(interpolation, int):
  26. interpolation = _interpolation_modes_from_int(interpolation)
  27. elif not isinstance(interpolation, InterpolationMode):
  28. raise ValueError(
  29. f"Argument interpolation should be an `InterpolationMode` or a corresponding Pillow integer constant, "
  30. f"but got {interpolation}."
  31. )
  32. return interpolation
  33. def horizontal_flip(inpt: torch.Tensor) -> torch.Tensor:
  34. """[BETA] See :class:`~torchvision.transforms.v2.RandomHorizontalFlip` for details."""
  35. if torch.jit.is_scripting():
  36. return horizontal_flip_image(inpt)
  37. _log_api_usage_once(horizontal_flip)
  38. kernel = _get_kernel(horizontal_flip, type(inpt))
  39. return kernel(inpt)
  40. @_register_kernel_internal(horizontal_flip, torch.Tensor)
  41. @_register_kernel_internal(horizontal_flip, tv_tensors.Image)
  42. def horizontal_flip_image(image: torch.Tensor) -> torch.Tensor:
  43. return image.flip(-1)
  44. @_register_kernel_internal(horizontal_flip, PIL.Image.Image)
  45. def _horizontal_flip_image_pil(image: PIL.Image.Image) -> PIL.Image.Image:
  46. return _FP.hflip(image)
  47. @_register_kernel_internal(horizontal_flip, tv_tensors.Mask)
  48. def horizontal_flip_mask(mask: torch.Tensor) -> torch.Tensor:
  49. return horizontal_flip_image(mask)
  50. def horizontal_flip_bounding_boxes(
  51. bounding_boxes: torch.Tensor, format: tv_tensors.BoundingBoxFormat, canvas_size: Tuple[int, int]
  52. ) -> torch.Tensor:
  53. shape = bounding_boxes.shape
  54. bounding_boxes = bounding_boxes.clone().reshape(-1, 4)
  55. if format == tv_tensors.BoundingBoxFormat.XYXY:
  56. bounding_boxes[:, [2, 0]] = bounding_boxes[:, [0, 2]].sub_(canvas_size[1]).neg_()
  57. elif format == tv_tensors.BoundingBoxFormat.XYWH:
  58. bounding_boxes[:, 0].add_(bounding_boxes[:, 2]).sub_(canvas_size[1]).neg_()
  59. else: # format == tv_tensors.BoundingBoxFormat.CXCYWH:
  60. bounding_boxes[:, 0].sub_(canvas_size[1]).neg_()
  61. return bounding_boxes.reshape(shape)
  62. @_register_kernel_internal(horizontal_flip, tv_tensors.BoundingBoxes, tv_tensor_wrapper=False)
  63. def _horizontal_flip_bounding_boxes_dispatch(inpt: tv_tensors.BoundingBoxes) -> tv_tensors.BoundingBoxes:
  64. output = horizontal_flip_bounding_boxes(
  65. inpt.as_subclass(torch.Tensor), format=inpt.format, canvas_size=inpt.canvas_size
  66. )
  67. return tv_tensors.wrap(output, like=inpt)
  68. @_register_kernel_internal(horizontal_flip, tv_tensors.Video)
  69. def horizontal_flip_video(video: torch.Tensor) -> torch.Tensor:
  70. return horizontal_flip_image(video)
  71. def vertical_flip(inpt: torch.Tensor) -> torch.Tensor:
  72. """[BETA] See :class:`~torchvision.transforms.v2.RandomVerticalFlip` for details."""
  73. if torch.jit.is_scripting():
  74. return vertical_flip_image(inpt)
  75. _log_api_usage_once(vertical_flip)
  76. kernel = _get_kernel(vertical_flip, type(inpt))
  77. return kernel(inpt)
  78. @_register_kernel_internal(vertical_flip, torch.Tensor)
  79. @_register_kernel_internal(vertical_flip, tv_tensors.Image)
  80. def vertical_flip_image(image: torch.Tensor) -> torch.Tensor:
  81. return image.flip(-2)
  82. @_register_kernel_internal(vertical_flip, PIL.Image.Image)
  83. def _vertical_flip_image_pil(image: PIL.Image) -> PIL.Image:
  84. return _FP.vflip(image)
  85. @_register_kernel_internal(vertical_flip, tv_tensors.Mask)
  86. def vertical_flip_mask(mask: torch.Tensor) -> torch.Tensor:
  87. return vertical_flip_image(mask)
  88. def vertical_flip_bounding_boxes(
  89. bounding_boxes: torch.Tensor, format: tv_tensors.BoundingBoxFormat, canvas_size: Tuple[int, int]
  90. ) -> torch.Tensor:
  91. shape = bounding_boxes.shape
  92. bounding_boxes = bounding_boxes.clone().reshape(-1, 4)
  93. if format == tv_tensors.BoundingBoxFormat.XYXY:
  94. bounding_boxes[:, [1, 3]] = bounding_boxes[:, [3, 1]].sub_(canvas_size[0]).neg_()
  95. elif format == tv_tensors.BoundingBoxFormat.XYWH:
  96. bounding_boxes[:, 1].add_(bounding_boxes[:, 3]).sub_(canvas_size[0]).neg_()
  97. else: # format == tv_tensors.BoundingBoxFormat.CXCYWH:
  98. bounding_boxes[:, 1].sub_(canvas_size[0]).neg_()
  99. return bounding_boxes.reshape(shape)
  100. @_register_kernel_internal(vertical_flip, tv_tensors.BoundingBoxes, tv_tensor_wrapper=False)
  101. def _vertical_flip_bounding_boxes_dispatch(inpt: tv_tensors.BoundingBoxes) -> tv_tensors.BoundingBoxes:
  102. output = vertical_flip_bounding_boxes(
  103. inpt.as_subclass(torch.Tensor), format=inpt.format, canvas_size=inpt.canvas_size
  104. )
  105. return tv_tensors.wrap(output, like=inpt)
  106. @_register_kernel_internal(vertical_flip, tv_tensors.Video)
  107. def vertical_flip_video(video: torch.Tensor) -> torch.Tensor:
  108. return vertical_flip_image(video)
  109. # We changed the names to align them with the transforms, i.e. `RandomHorizontalFlip`. Still, `hflip` and `vflip` are
  110. # prevalent and well understood. Thus, we just alias them without deprecating the old names.
  111. hflip = horizontal_flip
  112. vflip = vertical_flip
  113. def _compute_resized_output_size(
  114. canvas_size: Tuple[int, int], size: List[int], max_size: Optional[int] = None
  115. ) -> List[int]:
  116. if isinstance(size, int):
  117. size = [size]
  118. elif max_size is not None and len(size) != 1:
  119. raise ValueError(
  120. "max_size should only be passed if size specifies the length of the smaller edge, "
  121. "i.e. size should be an int or a sequence of length 1 in torchscript mode."
  122. )
  123. return __compute_resized_output_size(canvas_size, size=size, max_size=max_size)
  124. def resize(
  125. inpt: torch.Tensor,
  126. size: List[int],
  127. interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
  128. max_size: Optional[int] = None,
  129. antialias: Optional[Union[str, bool]] = "warn",
  130. ) -> torch.Tensor:
  131. """[BETA] See :class:`~torchvision.transforms.v2.Resize` for details."""
  132. if torch.jit.is_scripting():
  133. return resize_image(inpt, size=size, interpolation=interpolation, max_size=max_size, antialias=antialias)
  134. _log_api_usage_once(resize)
  135. kernel = _get_kernel(resize, type(inpt))
  136. return kernel(inpt, size=size, interpolation=interpolation, max_size=max_size, antialias=antialias)
  137. @_register_kernel_internal(resize, torch.Tensor)
  138. @_register_kernel_internal(resize, tv_tensors.Image)
  139. def resize_image(
  140. image: torch.Tensor,
  141. size: List[int],
  142. interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
  143. max_size: Optional[int] = None,
  144. antialias: Optional[Union[str, bool]] = "warn",
  145. ) -> torch.Tensor:
  146. interpolation = _check_interpolation(interpolation)
  147. antialias = _check_antialias(img=image, antialias=antialias, interpolation=interpolation)
  148. assert not isinstance(antialias, str)
  149. antialias = False if antialias is None else antialias
  150. align_corners: Optional[bool] = None
  151. if interpolation == InterpolationMode.BILINEAR or interpolation == InterpolationMode.BICUBIC:
  152. align_corners = False
  153. else:
  154. # The default of antialias should be True from 0.17, so we don't warn or
  155. # error if other interpolation modes are used. This is documented.
  156. antialias = False
  157. shape = image.shape
  158. numel = image.numel()
  159. num_channels, old_height, old_width = shape[-3:]
  160. new_height, new_width = _compute_resized_output_size((old_height, old_width), size=size, max_size=max_size)
  161. if (new_height, new_width) == (old_height, old_width):
  162. return image
  163. elif numel > 0:
  164. image = image.reshape(-1, num_channels, old_height, old_width)
  165. dtype = image.dtype
  166. acceptable_dtypes = [torch.float32, torch.float64]
  167. if interpolation == InterpolationMode.NEAREST or interpolation == InterpolationMode.NEAREST_EXACT:
  168. # uint8 dtype can be included for cpu and cuda input if nearest mode
  169. acceptable_dtypes.append(torch.uint8)
  170. elif image.device.type == "cpu":
  171. # uint8 dtype support for bilinear and bicubic is limited to cpu and
  172. # according to our benchmarks, non-AVX CPUs should still prefer u8->f32->interpolate->u8 path for bilinear
  173. if (interpolation == InterpolationMode.BILINEAR and "AVX2" in torch.backends.cpu.get_cpu_capability()) or (
  174. interpolation == InterpolationMode.BICUBIC
  175. ):
  176. acceptable_dtypes.append(torch.uint8)
  177. strides = image.stride()
  178. if image.is_contiguous(memory_format=torch.channels_last) and image.shape[0] == 1 and numel != strides[0]:
  179. # There is a weird behaviour in torch core where the output tensor of `interpolate()` can be allocated as
  180. # contiguous even though the input is un-ambiguously channels_last (https://github.com/pytorch/pytorch/issues/68430).
  181. # In particular this happens for the typical torchvision use-case of single CHW images where we fake the batch dim
  182. # to become 1CHW. Below, we restride those tensors to trick torch core into properly allocating the output as
  183. # channels_last, thus preserving the memory format of the input. This is not just for format consistency:
  184. # for uint8 bilinear images, this also avoids an extra copy (re-packing) of the output and saves time.
  185. # TODO: when https://github.com/pytorch/pytorch/issues/68430 is fixed (possibly by https://github.com/pytorch/pytorch/pull/100373),
  186. # we should be able to remove this hack.
  187. new_strides = list(strides)
  188. new_strides[0] = numel
  189. image = image.as_strided((1, num_channels, old_height, old_width), new_strides)
  190. need_cast = dtype not in acceptable_dtypes
  191. if need_cast:
  192. image = image.to(dtype=torch.float32)
  193. image = interpolate(
  194. image,
  195. size=[new_height, new_width],
  196. mode=interpolation.value,
  197. align_corners=align_corners,
  198. antialias=antialias,
  199. )
  200. if need_cast:
  201. if interpolation == InterpolationMode.BICUBIC and dtype == torch.uint8:
  202. # This path is hit on non-AVX archs, or on GPU.
  203. image = image.clamp_(min=0, max=255)
  204. if dtype in (torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64):
  205. image = image.round_()
  206. image = image.to(dtype=dtype)
  207. return image.reshape(shape[:-3] + (num_channels, new_height, new_width))
  208. def _resize_image_pil(
  209. image: PIL.Image.Image,
  210. size: Union[Sequence[int], int],
  211. interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
  212. max_size: Optional[int] = None,
  213. ) -> PIL.Image.Image:
  214. old_height, old_width = image.height, image.width
  215. new_height, new_width = _compute_resized_output_size(
  216. (old_height, old_width),
  217. size=size, # type: ignore[arg-type]
  218. max_size=max_size,
  219. )
  220. interpolation = _check_interpolation(interpolation)
  221. if (new_height, new_width) == (old_height, old_width):
  222. return image
  223. return image.resize((new_width, new_height), resample=pil_modes_mapping[interpolation])
  224. @_register_kernel_internal(resize, PIL.Image.Image)
  225. def __resize_image_pil_dispatch(
  226. image: PIL.Image.Image,
  227. size: Union[Sequence[int], int],
  228. interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
  229. max_size: Optional[int] = None,
  230. antialias: Optional[Union[str, bool]] = "warn",
  231. ) -> PIL.Image.Image:
  232. if antialias is False:
  233. warnings.warn("Anti-alias option is always applied for PIL Image input. Argument antialias is ignored.")
  234. return _resize_image_pil(image, size=size, interpolation=interpolation, max_size=max_size)
  235. def resize_mask(mask: torch.Tensor, size: List[int], max_size: Optional[int] = None) -> torch.Tensor:
  236. if mask.ndim < 3:
  237. mask = mask.unsqueeze(0)
  238. needs_squeeze = True
  239. else:
  240. needs_squeeze = False
  241. output = resize_image(mask, size=size, interpolation=InterpolationMode.NEAREST, max_size=max_size)
  242. if needs_squeeze:
  243. output = output.squeeze(0)
  244. return output
  245. @_register_kernel_internal(resize, tv_tensors.Mask, tv_tensor_wrapper=False)
  246. def _resize_mask_dispatch(
  247. inpt: tv_tensors.Mask, size: List[int], max_size: Optional[int] = None, **kwargs: Any
  248. ) -> tv_tensors.Mask:
  249. output = resize_mask(inpt.as_subclass(torch.Tensor), size, max_size=max_size)
  250. return tv_tensors.wrap(output, like=inpt)
  251. def resize_bounding_boxes(
  252. bounding_boxes: torch.Tensor, canvas_size: Tuple[int, int], size: List[int], max_size: Optional[int] = None
  253. ) -> Tuple[torch.Tensor, Tuple[int, int]]:
  254. old_height, old_width = canvas_size
  255. new_height, new_width = _compute_resized_output_size(canvas_size, size=size, max_size=max_size)
  256. if (new_height, new_width) == (old_height, old_width):
  257. return bounding_boxes, canvas_size
  258. w_ratio = new_width / old_width
  259. h_ratio = new_height / old_height
  260. ratios = torch.tensor([w_ratio, h_ratio, w_ratio, h_ratio], device=bounding_boxes.device)
  261. return (
  262. bounding_boxes.mul(ratios).to(bounding_boxes.dtype),
  263. (new_height, new_width),
  264. )
  265. @_register_kernel_internal(resize, tv_tensors.BoundingBoxes, tv_tensor_wrapper=False)
  266. def _resize_bounding_boxes_dispatch(
  267. inpt: tv_tensors.BoundingBoxes, size: List[int], max_size: Optional[int] = None, **kwargs: Any
  268. ) -> tv_tensors.BoundingBoxes:
  269. output, canvas_size = resize_bounding_boxes(
  270. inpt.as_subclass(torch.Tensor), inpt.canvas_size, size, max_size=max_size
  271. )
  272. return tv_tensors.wrap(output, like=inpt, canvas_size=canvas_size)
  273. @_register_kernel_internal(resize, tv_tensors.Video)
  274. def resize_video(
  275. video: torch.Tensor,
  276. size: List[int],
  277. interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
  278. max_size: Optional[int] = None,
  279. antialias: Optional[Union[str, bool]] = "warn",
  280. ) -> torch.Tensor:
  281. return resize_image(video, size=size, interpolation=interpolation, max_size=max_size, antialias=antialias)
  282. def affine(
  283. inpt: torch.Tensor,
  284. angle: Union[int, float],
  285. translate: List[float],
  286. scale: float,
  287. shear: List[float],
  288. interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
  289. fill: _FillTypeJIT = None,
  290. center: Optional[List[float]] = None,
  291. ) -> torch.Tensor:
  292. """[BETA] See :class:`~torchvision.transforms.v2.RandomAffine` for details."""
  293. if torch.jit.is_scripting():
  294. return affine_image(
  295. inpt,
  296. angle=angle,
  297. translate=translate,
  298. scale=scale,
  299. shear=shear,
  300. interpolation=interpolation,
  301. fill=fill,
  302. center=center,
  303. )
  304. _log_api_usage_once(affine)
  305. kernel = _get_kernel(affine, type(inpt))
  306. return kernel(
  307. inpt,
  308. angle=angle,
  309. translate=translate,
  310. scale=scale,
  311. shear=shear,
  312. interpolation=interpolation,
  313. fill=fill,
  314. center=center,
  315. )
  316. def _affine_parse_args(
  317. angle: Union[int, float],
  318. translate: List[float],
  319. scale: float,
  320. shear: List[float],
  321. interpolation: InterpolationMode = InterpolationMode.NEAREST,
  322. center: Optional[List[float]] = None,
  323. ) -> Tuple[float, List[float], List[float], Optional[List[float]]]:
  324. if not isinstance(angle, (int, float)):
  325. raise TypeError("Argument angle should be int or float")
  326. if not isinstance(translate, (list, tuple)):
  327. raise TypeError("Argument translate should be a sequence")
  328. if len(translate) != 2:
  329. raise ValueError("Argument translate should be a sequence of length 2")
  330. if scale <= 0.0:
  331. raise ValueError("Argument scale should be positive")
  332. if not isinstance(shear, (numbers.Number, (list, tuple))):
  333. raise TypeError("Shear should be either a single value or a sequence of two values")
  334. if not isinstance(interpolation, InterpolationMode):
  335. raise TypeError("Argument interpolation should be a InterpolationMode")
  336. if isinstance(angle, int):
  337. angle = float(angle)
  338. if isinstance(translate, tuple):
  339. translate = list(translate)
  340. if isinstance(shear, numbers.Number):
  341. shear = [shear, 0.0]
  342. if isinstance(shear, tuple):
  343. shear = list(shear)
  344. if len(shear) == 1:
  345. shear = [shear[0], shear[0]]
  346. if len(shear) != 2:
  347. raise ValueError(f"Shear should be a sequence containing two values. Got {shear}")
  348. if center is not None:
  349. if not isinstance(center, (list, tuple)):
  350. raise TypeError("Argument center should be a sequence")
  351. else:
  352. center = [float(c) for c in center]
  353. return angle, translate, shear, center
  354. def _get_inverse_affine_matrix(
  355. center: List[float], angle: float, translate: List[float], scale: float, shear: List[float], inverted: bool = True
  356. ) -> List[float]:
  357. # Helper method to compute inverse matrix for affine transformation
  358. # Pillow requires inverse affine transformation matrix:
  359. # Affine matrix is : M = T * C * RotateScaleShear * C^-1
  360. #
  361. # where T is translation matrix: [1, 0, tx | 0, 1, ty | 0, 0, 1]
  362. # C is translation matrix to keep center: [1, 0, cx | 0, 1, cy | 0, 0, 1]
  363. # RotateScaleShear is rotation with scale and shear matrix
  364. #
  365. # RotateScaleShear(a, s, (sx, sy)) =
  366. # = R(a) * S(s) * SHy(sy) * SHx(sx)
  367. # = [ s*cos(a - sy)/cos(sy), s*(-cos(a - sy)*tan(sx)/cos(sy) - sin(a)), 0 ]
  368. # [ s*sin(a - sy)/cos(sy), s*(-sin(a - sy)*tan(sx)/cos(sy) + cos(a)), 0 ]
  369. # [ 0 , 0 , 1 ]
  370. # where R is a rotation matrix, S is a scaling matrix, and SHx and SHy are the shears:
  371. # SHx(s) = [1, -tan(s)] and SHy(s) = [1 , 0]
  372. # [0, 1 ] [-tan(s), 1]
  373. #
  374. # Thus, the inverse is M^-1 = C * RotateScaleShear^-1 * C^-1 * T^-1
  375. rot = math.radians(angle)
  376. sx = math.radians(shear[0])
  377. sy = math.radians(shear[1])
  378. cx, cy = center
  379. tx, ty = translate
  380. # Cached results
  381. cos_sy = math.cos(sy)
  382. tan_sx = math.tan(sx)
  383. rot_minus_sy = rot - sy
  384. cx_plus_tx = cx + tx
  385. cy_plus_ty = cy + ty
  386. # Rotate Scale Shear (RSS) without scaling
  387. a = math.cos(rot_minus_sy) / cos_sy
  388. b = -(a * tan_sx + math.sin(rot))
  389. c = math.sin(rot_minus_sy) / cos_sy
  390. d = math.cos(rot) - c * tan_sx
  391. if inverted:
  392. # Inverted rotation matrix with scale and shear
  393. # det([[a, b], [c, d]]) == 1, since det(rotation) = 1 and det(shear) = 1
  394. matrix = [d / scale, -b / scale, 0.0, -c / scale, a / scale, 0.0]
  395. # Apply inverse of translation and of center translation: RSS^-1 * C^-1 * T^-1
  396. # and then apply center translation: C * RSS^-1 * C^-1 * T^-1
  397. matrix[2] += cx - matrix[0] * cx_plus_tx - matrix[1] * cy_plus_ty
  398. matrix[5] += cy - matrix[3] * cx_plus_tx - matrix[4] * cy_plus_ty
  399. else:
  400. matrix = [a * scale, b * scale, 0.0, c * scale, d * scale, 0.0]
  401. # Apply inverse of center translation: RSS * C^-1
  402. # and then apply translation and center : T * C * RSS * C^-1
  403. matrix[2] += cx_plus_tx - matrix[0] * cx - matrix[1] * cy
  404. matrix[5] += cy_plus_ty - matrix[3] * cx - matrix[4] * cy
  405. return matrix
  406. def _compute_affine_output_size(matrix: List[float], w: int, h: int) -> Tuple[int, int]:
  407. # Inspired of PIL implementation:
  408. # https://github.com/python-pillow/Pillow/blob/11de3318867e4398057373ee9f12dcb33db7335c/src/PIL/Image.py#L2054
  409. # pts are Top-Left, Top-Right, Bottom-Left, Bottom-Right points.
  410. # Points are shifted due to affine matrix torch convention about
  411. # the center point. Center is (0, 0) for image center pivot point (w * 0.5, h * 0.5)
  412. half_w = 0.5 * w
  413. half_h = 0.5 * h
  414. pts = torch.tensor(
  415. [
  416. [-half_w, -half_h, 1.0],
  417. [-half_w, half_h, 1.0],
  418. [half_w, half_h, 1.0],
  419. [half_w, -half_h, 1.0],
  420. ]
  421. )
  422. theta = torch.tensor(matrix, dtype=torch.float).view(2, 3)
  423. new_pts = torch.matmul(pts, theta.T)
  424. min_vals, max_vals = new_pts.aminmax(dim=0)
  425. # shift points to [0, w] and [0, h] interval to match PIL results
  426. halfs = torch.tensor((half_w, half_h))
  427. min_vals.add_(halfs)
  428. max_vals.add_(halfs)
  429. # Truncate precision to 1e-4 to avoid ceil of Xe-15 to 1.0
  430. tol = 1e-4
  431. inv_tol = 1.0 / tol
  432. cmax = max_vals.mul_(inv_tol).trunc_().mul_(tol).ceil_()
  433. cmin = min_vals.mul_(inv_tol).trunc_().mul_(tol).floor_()
  434. size = cmax.sub_(cmin)
  435. return int(size[0]), int(size[1]) # w, h
  436. def _apply_grid_transform(img: torch.Tensor, grid: torch.Tensor, mode: str, fill: _FillTypeJIT) -> torch.Tensor:
  437. # We are using context knowledge that grid should have float dtype
  438. fp = img.dtype == grid.dtype
  439. float_img = img if fp else img.to(grid.dtype)
  440. shape = float_img.shape
  441. if shape[0] > 1:
  442. # Apply same grid to a batch of images
  443. grid = grid.expand(shape[0], -1, -1, -1)
  444. # Append a dummy mask for customized fill colors, should be faster than grid_sample() twice
  445. if fill is not None:
  446. mask = torch.ones((shape[0], 1, shape[2], shape[3]), dtype=float_img.dtype, device=float_img.device)
  447. float_img = torch.cat((float_img, mask), dim=1)
  448. float_img = grid_sample(float_img, grid, mode=mode, padding_mode="zeros", align_corners=False)
  449. # Fill with required color
  450. if fill is not None:
  451. float_img, mask = torch.tensor_split(float_img, indices=(-1,), dim=-3)
  452. mask = mask.expand_as(float_img)
  453. fill_list = fill if isinstance(fill, (tuple, list)) else [float(fill)] # type: ignore[arg-type]
  454. fill_img = torch.tensor(fill_list, dtype=float_img.dtype, device=float_img.device).view(1, -1, 1, 1)
  455. if mode == "nearest":
  456. bool_mask = mask < 0.5
  457. float_img[bool_mask] = fill_img.expand_as(float_img)[bool_mask]
  458. else: # 'bilinear'
  459. # The following is mathematically equivalent to:
  460. # img * mask + (1.0 - mask) * fill = img * mask - fill * mask + fill = mask * (img - fill) + fill
  461. float_img = float_img.sub_(fill_img).mul_(mask).add_(fill_img)
  462. img = float_img.round_().to(img.dtype) if not fp else float_img
  463. return img
  464. def _assert_grid_transform_inputs(
  465. image: torch.Tensor,
  466. matrix: Optional[List[float]],
  467. interpolation: str,
  468. fill: _FillTypeJIT,
  469. supported_interpolation_modes: List[str],
  470. coeffs: Optional[List[float]] = None,
  471. ) -> None:
  472. if matrix is not None:
  473. if not isinstance(matrix, list):
  474. raise TypeError("Argument matrix should be a list")
  475. elif len(matrix) != 6:
  476. raise ValueError("Argument matrix should have 6 float values")
  477. if coeffs is not None and len(coeffs) != 8:
  478. raise ValueError("Argument coeffs should have 8 float values")
  479. if fill is not None:
  480. if isinstance(fill, (tuple, list)):
  481. length = len(fill)
  482. num_channels = image.shape[-3]
  483. if length > 1 and length != num_channels:
  484. raise ValueError(
  485. "The number of elements in 'fill' cannot broadcast to match the number of "
  486. f"channels of the image ({length} != {num_channels})"
  487. )
  488. elif not isinstance(fill, (int, float)):
  489. raise ValueError("Argument fill should be either int, float, tuple or list")
  490. if interpolation not in supported_interpolation_modes:
  491. raise ValueError(f"Interpolation mode '{interpolation}' is unsupported with Tensor input")
  492. def _affine_grid(
  493. theta: torch.Tensor,
  494. w: int,
  495. h: int,
  496. ow: int,
  497. oh: int,
  498. ) -> torch.Tensor:
  499. # https://github.com/pytorch/pytorch/blob/74b65c32be68b15dc7c9e8bb62459efbfbde33d8/aten/src/ATen/native/
  500. # AffineGridGenerator.cpp#L18
  501. # Difference with AffineGridGenerator is that:
  502. # 1) we normalize grid values after applying theta
  503. # 2) we can normalize by other image size, such that it covers "extend" option like in PIL.Image.rotate
  504. dtype = theta.dtype
  505. device = theta.device
  506. base_grid = torch.empty(1, oh, ow, 3, dtype=dtype, device=device)
  507. x_grid = torch.linspace((1.0 - ow) * 0.5, (ow - 1.0) * 0.5, steps=ow, device=device)
  508. base_grid[..., 0].copy_(x_grid)
  509. y_grid = torch.linspace((1.0 - oh) * 0.5, (oh - 1.0) * 0.5, steps=oh, device=device).unsqueeze_(-1)
  510. base_grid[..., 1].copy_(y_grid)
  511. base_grid[..., 2].fill_(1)
  512. rescaled_theta = theta.transpose(1, 2).div_(torch.tensor([0.5 * w, 0.5 * h], dtype=dtype, device=device))
  513. output_grid = base_grid.view(1, oh * ow, 3).bmm(rescaled_theta)
  514. return output_grid.view(1, oh, ow, 2)
  515. @_register_kernel_internal(affine, torch.Tensor)
  516. @_register_kernel_internal(affine, tv_tensors.Image)
  517. def affine_image(
  518. image: torch.Tensor,
  519. angle: Union[int, float],
  520. translate: List[float],
  521. scale: float,
  522. shear: List[float],
  523. interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
  524. fill: _FillTypeJIT = None,
  525. center: Optional[List[float]] = None,
  526. ) -> torch.Tensor:
  527. interpolation = _check_interpolation(interpolation)
  528. if image.numel() == 0:
  529. return image
  530. shape = image.shape
  531. ndim = image.ndim
  532. if ndim > 4:
  533. image = image.reshape((-1,) + shape[-3:])
  534. needs_unsquash = True
  535. elif ndim == 3:
  536. image = image.unsqueeze(0)
  537. needs_unsquash = True
  538. else:
  539. needs_unsquash = False
  540. height, width = shape[-2:]
  541. angle, translate, shear, center = _affine_parse_args(angle, translate, scale, shear, interpolation, center)
  542. center_f = [0.0, 0.0]
  543. if center is not None:
  544. # Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center.
  545. center_f = [(c - s * 0.5) for c, s in zip(center, [width, height])]
  546. translate_f = [float(t) for t in translate]
  547. matrix = _get_inverse_affine_matrix(center_f, angle, translate_f, scale, shear)
  548. _assert_grid_transform_inputs(image, matrix, interpolation.value, fill, ["nearest", "bilinear"])
  549. dtype = image.dtype if torch.is_floating_point(image) else torch.float32
  550. theta = torch.tensor(matrix, dtype=dtype, device=image.device).reshape(1, 2, 3)
  551. grid = _affine_grid(theta, w=width, h=height, ow=width, oh=height)
  552. output = _apply_grid_transform(image, grid, interpolation.value, fill=fill)
  553. if needs_unsquash:
  554. output = output.reshape(shape)
  555. return output
  556. @_register_kernel_internal(affine, PIL.Image.Image)
  557. def _affine_image_pil(
  558. image: PIL.Image.Image,
  559. angle: Union[int, float],
  560. translate: List[float],
  561. scale: float,
  562. shear: List[float],
  563. interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
  564. fill: _FillTypeJIT = None,
  565. center: Optional[List[float]] = None,
  566. ) -> PIL.Image.Image:
  567. interpolation = _check_interpolation(interpolation)
  568. angle, translate, shear, center = _affine_parse_args(angle, translate, scale, shear, interpolation, center)
  569. # center = (img_size[0] * 0.5 + 0.5, img_size[1] * 0.5 + 0.5)
  570. # it is visually better to estimate the center without 0.5 offset
  571. # otherwise image rotated by 90 degrees is shifted vs output image of torch.rot90 or F_t.affine
  572. if center is None:
  573. height, width = _get_size_image_pil(image)
  574. center = [width * 0.5, height * 0.5]
  575. matrix = _get_inverse_affine_matrix(center, angle, translate, scale, shear)
  576. return _FP.affine(image, matrix, interpolation=pil_modes_mapping[interpolation], fill=fill)
  577. def _affine_bounding_boxes_with_expand(
  578. bounding_boxes: torch.Tensor,
  579. format: tv_tensors.BoundingBoxFormat,
  580. canvas_size: Tuple[int, int],
  581. angle: Union[int, float],
  582. translate: List[float],
  583. scale: float,
  584. shear: List[float],
  585. center: Optional[List[float]] = None,
  586. expand: bool = False,
  587. ) -> Tuple[torch.Tensor, Tuple[int, int]]:
  588. if bounding_boxes.numel() == 0:
  589. return bounding_boxes, canvas_size
  590. original_shape = bounding_boxes.shape
  591. original_dtype = bounding_boxes.dtype
  592. bounding_boxes = bounding_boxes.clone() if bounding_boxes.is_floating_point() else bounding_boxes.float()
  593. dtype = bounding_boxes.dtype
  594. device = bounding_boxes.device
  595. bounding_boxes = (
  596. convert_bounding_box_format(
  597. bounding_boxes, old_format=format, new_format=tv_tensors.BoundingBoxFormat.XYXY, inplace=True
  598. )
  599. ).reshape(-1, 4)
  600. angle, translate, shear, center = _affine_parse_args(
  601. angle, translate, scale, shear, InterpolationMode.NEAREST, center
  602. )
  603. if center is None:
  604. height, width = canvas_size
  605. center = [width * 0.5, height * 0.5]
  606. affine_vector = _get_inverse_affine_matrix(center, angle, translate, scale, shear, inverted=False)
  607. transposed_affine_matrix = (
  608. torch.tensor(
  609. affine_vector,
  610. dtype=dtype,
  611. device=device,
  612. )
  613. .reshape(2, 3)
  614. .T
  615. )
  616. # 1) Let's transform bboxes into a tensor of 4 points (top-left, top-right, bottom-left, bottom-right corners).
  617. # Tensor of points has shape (N * 4, 3), where N is the number of bboxes
  618. # Single point structure is similar to
  619. # [(xmin, ymin, 1), (xmax, ymin, 1), (xmax, ymax, 1), (xmin, ymax, 1)]
  620. points = bounding_boxes[:, [[0, 1], [2, 1], [2, 3], [0, 3]]].reshape(-1, 2)
  621. points = torch.cat([points, torch.ones(points.shape[0], 1, device=device, dtype=dtype)], dim=-1)
  622. # 2) Now let's transform the points using affine matrix
  623. transformed_points = torch.matmul(points, transposed_affine_matrix)
  624. # 3) Reshape transformed points to [N boxes, 4 points, x/y coords]
  625. # and compute bounding box from 4 transformed points:
  626. transformed_points = transformed_points.reshape(-1, 4, 2)
  627. out_bbox_mins, out_bbox_maxs = torch.aminmax(transformed_points, dim=1)
  628. out_bboxes = torch.cat([out_bbox_mins, out_bbox_maxs], dim=1)
  629. if expand:
  630. # Compute minimum point for transformed image frame:
  631. # Points are Top-Left, Top-Right, Bottom-Left, Bottom-Right points.
  632. height, width = canvas_size
  633. points = torch.tensor(
  634. [
  635. [0.0, 0.0, 1.0],
  636. [0.0, float(height), 1.0],
  637. [float(width), float(height), 1.0],
  638. [float(width), 0.0, 1.0],
  639. ],
  640. dtype=dtype,
  641. device=device,
  642. )
  643. new_points = torch.matmul(points, transposed_affine_matrix)
  644. tr = torch.amin(new_points, dim=0, keepdim=True)
  645. # Translate bounding boxes
  646. out_bboxes.sub_(tr.repeat((1, 2)))
  647. # Estimate meta-data for image with inverted=True
  648. affine_vector = _get_inverse_affine_matrix(center, angle, translate, scale, shear)
  649. new_width, new_height = _compute_affine_output_size(affine_vector, width, height)
  650. canvas_size = (new_height, new_width)
  651. out_bboxes = clamp_bounding_boxes(out_bboxes, format=tv_tensors.BoundingBoxFormat.XYXY, canvas_size=canvas_size)
  652. out_bboxes = convert_bounding_box_format(
  653. out_bboxes, old_format=tv_tensors.BoundingBoxFormat.XYXY, new_format=format, inplace=True
  654. ).reshape(original_shape)
  655. out_bboxes = out_bboxes.to(original_dtype)
  656. return out_bboxes, canvas_size
  657. def affine_bounding_boxes(
  658. bounding_boxes: torch.Tensor,
  659. format: tv_tensors.BoundingBoxFormat,
  660. canvas_size: Tuple[int, int],
  661. angle: Union[int, float],
  662. translate: List[float],
  663. scale: float,
  664. shear: List[float],
  665. center: Optional[List[float]] = None,
  666. ) -> torch.Tensor:
  667. out_box, _ = _affine_bounding_boxes_with_expand(
  668. bounding_boxes,
  669. format=format,
  670. canvas_size=canvas_size,
  671. angle=angle,
  672. translate=translate,
  673. scale=scale,
  674. shear=shear,
  675. center=center,
  676. expand=False,
  677. )
  678. return out_box
  679. @_register_kernel_internal(affine, tv_tensors.BoundingBoxes, tv_tensor_wrapper=False)
  680. def _affine_bounding_boxes_dispatch(
  681. inpt: tv_tensors.BoundingBoxes,
  682. angle: Union[int, float],
  683. translate: List[float],
  684. scale: float,
  685. shear: List[float],
  686. center: Optional[List[float]] = None,
  687. **kwargs,
  688. ) -> tv_tensors.BoundingBoxes:
  689. output = affine_bounding_boxes(
  690. inpt.as_subclass(torch.Tensor),
  691. format=inpt.format,
  692. canvas_size=inpt.canvas_size,
  693. angle=angle,
  694. translate=translate,
  695. scale=scale,
  696. shear=shear,
  697. center=center,
  698. )
  699. return tv_tensors.wrap(output, like=inpt)
  700. def affine_mask(
  701. mask: torch.Tensor,
  702. angle: Union[int, float],
  703. translate: List[float],
  704. scale: float,
  705. shear: List[float],
  706. fill: _FillTypeJIT = None,
  707. center: Optional[List[float]] = None,
  708. ) -> torch.Tensor:
  709. if mask.ndim < 3:
  710. mask = mask.unsqueeze(0)
  711. needs_squeeze = True
  712. else:
  713. needs_squeeze = False
  714. output = affine_image(
  715. mask,
  716. angle=angle,
  717. translate=translate,
  718. scale=scale,
  719. shear=shear,
  720. interpolation=InterpolationMode.NEAREST,
  721. fill=fill,
  722. center=center,
  723. )
  724. if needs_squeeze:
  725. output = output.squeeze(0)
  726. return output
  727. @_register_kernel_internal(affine, tv_tensors.Mask, tv_tensor_wrapper=False)
  728. def _affine_mask_dispatch(
  729. inpt: tv_tensors.Mask,
  730. angle: Union[int, float],
  731. translate: List[float],
  732. scale: float,
  733. shear: List[float],
  734. fill: _FillTypeJIT = None,
  735. center: Optional[List[float]] = None,
  736. **kwargs,
  737. ) -> tv_tensors.Mask:
  738. output = affine_mask(
  739. inpt.as_subclass(torch.Tensor),
  740. angle=angle,
  741. translate=translate,
  742. scale=scale,
  743. shear=shear,
  744. fill=fill,
  745. center=center,
  746. )
  747. return tv_tensors.wrap(output, like=inpt)
  748. @_register_kernel_internal(affine, tv_tensors.Video)
  749. def affine_video(
  750. video: torch.Tensor,
  751. angle: Union[int, float],
  752. translate: List[float],
  753. scale: float,
  754. shear: List[float],
  755. interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
  756. fill: _FillTypeJIT = None,
  757. center: Optional[List[float]] = None,
  758. ) -> torch.Tensor:
  759. return affine_image(
  760. video,
  761. angle=angle,
  762. translate=translate,
  763. scale=scale,
  764. shear=shear,
  765. interpolation=interpolation,
  766. fill=fill,
  767. center=center,
  768. )
  769. def rotate(
  770. inpt: torch.Tensor,
  771. angle: float,
  772. interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
  773. expand: bool = False,
  774. center: Optional[List[float]] = None,
  775. fill: _FillTypeJIT = None,
  776. ) -> torch.Tensor:
  777. """[BETA] See :class:`~torchvision.transforms.v2.RandomRotation` for details."""
  778. if torch.jit.is_scripting():
  779. return rotate_image(inpt, angle=angle, interpolation=interpolation, expand=expand, fill=fill, center=center)
  780. _log_api_usage_once(rotate)
  781. kernel = _get_kernel(rotate, type(inpt))
  782. return kernel(inpt, angle=angle, interpolation=interpolation, expand=expand, fill=fill, center=center)
  783. @_register_kernel_internal(rotate, torch.Tensor)
  784. @_register_kernel_internal(rotate, tv_tensors.Image)
  785. def rotate_image(
  786. image: torch.Tensor,
  787. angle: float,
  788. interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
  789. expand: bool = False,
  790. center: Optional[List[float]] = None,
  791. fill: _FillTypeJIT = None,
  792. ) -> torch.Tensor:
  793. interpolation = _check_interpolation(interpolation)
  794. shape = image.shape
  795. num_channels, height, width = shape[-3:]
  796. center_f = [0.0, 0.0]
  797. if center is not None:
  798. # Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center.
  799. center_f = [(c - s * 0.5) for c, s in zip(center, [width, height])]
  800. # due to current incoherence of rotation angle direction between affine and rotate implementations
  801. # we need to set -angle.
  802. matrix = _get_inverse_affine_matrix(center_f, -angle, [0.0, 0.0], 1.0, [0.0, 0.0])
  803. if image.numel() > 0:
  804. image = image.reshape(-1, num_channels, height, width)
  805. _assert_grid_transform_inputs(image, matrix, interpolation.value, fill, ["nearest", "bilinear"])
  806. ow, oh = _compute_affine_output_size(matrix, width, height) if expand else (width, height)
  807. dtype = image.dtype if torch.is_floating_point(image) else torch.float32
  808. theta = torch.tensor(matrix, dtype=dtype, device=image.device).reshape(1, 2, 3)
  809. grid = _affine_grid(theta, w=width, h=height, ow=ow, oh=oh)
  810. output = _apply_grid_transform(image, grid, interpolation.value, fill=fill)
  811. new_height, new_width = output.shape[-2:]
  812. else:
  813. output = image
  814. new_width, new_height = _compute_affine_output_size(matrix, width, height) if expand else (width, height)
  815. return output.reshape(shape[:-3] + (num_channels, new_height, new_width))
  816. @_register_kernel_internal(rotate, PIL.Image.Image)
  817. def _rotate_image_pil(
  818. image: PIL.Image.Image,
  819. angle: float,
  820. interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
  821. expand: bool = False,
  822. center: Optional[List[float]] = None,
  823. fill: _FillTypeJIT = None,
  824. ) -> PIL.Image.Image:
  825. interpolation = _check_interpolation(interpolation)
  826. return _FP.rotate(
  827. image, angle, interpolation=pil_modes_mapping[interpolation], expand=expand, fill=fill, center=center
  828. )
  829. def rotate_bounding_boxes(
  830. bounding_boxes: torch.Tensor,
  831. format: tv_tensors.BoundingBoxFormat,
  832. canvas_size: Tuple[int, int],
  833. angle: float,
  834. expand: bool = False,
  835. center: Optional[List[float]] = None,
  836. ) -> Tuple[torch.Tensor, Tuple[int, int]]:
  837. return _affine_bounding_boxes_with_expand(
  838. bounding_boxes,
  839. format=format,
  840. canvas_size=canvas_size,
  841. angle=-angle,
  842. translate=[0.0, 0.0],
  843. scale=1.0,
  844. shear=[0.0, 0.0],
  845. center=center,
  846. expand=expand,
  847. )
  848. @_register_kernel_internal(rotate, tv_tensors.BoundingBoxes, tv_tensor_wrapper=False)
  849. def _rotate_bounding_boxes_dispatch(
  850. inpt: tv_tensors.BoundingBoxes, angle: float, expand: bool = False, center: Optional[List[float]] = None, **kwargs
  851. ) -> tv_tensors.BoundingBoxes:
  852. output, canvas_size = rotate_bounding_boxes(
  853. inpt.as_subclass(torch.Tensor),
  854. format=inpt.format,
  855. canvas_size=inpt.canvas_size,
  856. angle=angle,
  857. expand=expand,
  858. center=center,
  859. )
  860. return tv_tensors.wrap(output, like=inpt, canvas_size=canvas_size)
  861. def rotate_mask(
  862. mask: torch.Tensor,
  863. angle: float,
  864. expand: bool = False,
  865. center: Optional[List[float]] = None,
  866. fill: _FillTypeJIT = None,
  867. ) -> torch.Tensor:
  868. if mask.ndim < 3:
  869. mask = mask.unsqueeze(0)
  870. needs_squeeze = True
  871. else:
  872. needs_squeeze = False
  873. output = rotate_image(
  874. mask,
  875. angle=angle,
  876. expand=expand,
  877. interpolation=InterpolationMode.NEAREST,
  878. fill=fill,
  879. center=center,
  880. )
  881. if needs_squeeze:
  882. output = output.squeeze(0)
  883. return output
  884. @_register_kernel_internal(rotate, tv_tensors.Mask, tv_tensor_wrapper=False)
  885. def _rotate_mask_dispatch(
  886. inpt: tv_tensors.Mask,
  887. angle: float,
  888. expand: bool = False,
  889. center: Optional[List[float]] = None,
  890. fill: _FillTypeJIT = None,
  891. **kwargs,
  892. ) -> tv_tensors.Mask:
  893. output = rotate_mask(inpt.as_subclass(torch.Tensor), angle=angle, expand=expand, fill=fill, center=center)
  894. return tv_tensors.wrap(output, like=inpt)
  895. @_register_kernel_internal(rotate, tv_tensors.Video)
  896. def rotate_video(
  897. video: torch.Tensor,
  898. angle: float,
  899. interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
  900. expand: bool = False,
  901. center: Optional[List[float]] = None,
  902. fill: _FillTypeJIT = None,
  903. ) -> torch.Tensor:
  904. return rotate_image(video, angle, interpolation=interpolation, expand=expand, fill=fill, center=center)
  905. def pad(
  906. inpt: torch.Tensor,
  907. padding: List[int],
  908. fill: Optional[Union[int, float, List[float]]] = None,
  909. padding_mode: str = "constant",
  910. ) -> torch.Tensor:
  911. """[BETA] See :class:`~torchvision.transforms.v2.Pad` for details."""
  912. if torch.jit.is_scripting():
  913. return pad_image(inpt, padding=padding, fill=fill, padding_mode=padding_mode)
  914. _log_api_usage_once(pad)
  915. kernel = _get_kernel(pad, type(inpt))
  916. return kernel(inpt, padding=padding, fill=fill, padding_mode=padding_mode)
  917. def _parse_pad_padding(padding: Union[int, List[int]]) -> List[int]:
  918. if isinstance(padding, int):
  919. pad_left = pad_right = pad_top = pad_bottom = padding
  920. elif isinstance(padding, (tuple, list)):
  921. if len(padding) == 1:
  922. pad_left = pad_right = pad_top = pad_bottom = padding[0]
  923. elif len(padding) == 2:
  924. pad_left = pad_right = padding[0]
  925. pad_top = pad_bottom = padding[1]
  926. elif len(padding) == 4:
  927. pad_left = padding[0]
  928. pad_top = padding[1]
  929. pad_right = padding[2]
  930. pad_bottom = padding[3]
  931. else:
  932. raise ValueError(
  933. f"Padding must be an int or a 1, 2, or 4 element tuple, not a {len(padding)} element tuple"
  934. )
  935. else:
  936. raise TypeError(f"`padding` should be an integer or tuple or list of integers, but got {padding}")
  937. return [pad_left, pad_right, pad_top, pad_bottom]
  938. @_register_kernel_internal(pad, torch.Tensor)
  939. @_register_kernel_internal(pad, tv_tensors.Image)
  940. def pad_image(
  941. image: torch.Tensor,
  942. padding: List[int],
  943. fill: Optional[Union[int, float, List[float]]] = None,
  944. padding_mode: str = "constant",
  945. ) -> torch.Tensor:
  946. # Be aware that while `padding` has order `[left, top, right, bottom]`, `torch_padding` uses
  947. # `[left, right, top, bottom]`. This stems from the fact that we align our API with PIL, but need to use `torch_pad`
  948. # internally.
  949. torch_padding = _parse_pad_padding(padding)
  950. if padding_mode not in ("constant", "edge", "reflect", "symmetric"):
  951. raise ValueError(
  952. f"`padding_mode` should be either `'constant'`, `'edge'`, `'reflect'` or `'symmetric'`, "
  953. f"but got `'{padding_mode}'`."
  954. )
  955. if fill is None:
  956. fill = 0
  957. if isinstance(fill, (int, float)):
  958. return _pad_with_scalar_fill(image, torch_padding, fill=fill, padding_mode=padding_mode)
  959. elif len(fill) == 1:
  960. return _pad_with_scalar_fill(image, torch_padding, fill=fill[0], padding_mode=padding_mode)
  961. else:
  962. return _pad_with_vector_fill(image, torch_padding, fill=fill, padding_mode=padding_mode)
  963. def _pad_with_scalar_fill(
  964. image: torch.Tensor,
  965. torch_padding: List[int],
  966. fill: Union[int, float],
  967. padding_mode: str,
  968. ) -> torch.Tensor:
  969. shape = image.shape
  970. num_channels, height, width = shape[-3:]
  971. batch_size = 1
  972. for s in shape[:-3]:
  973. batch_size *= s
  974. image = image.reshape(batch_size, num_channels, height, width)
  975. if padding_mode == "edge":
  976. # Similar to the padding order, `torch_pad`'s PIL's padding modes don't have the same names. Thus, we map
  977. # the PIL name for the padding mode, which we are also using for our API, to the corresponding `torch_pad`
  978. # name.
  979. padding_mode = "replicate"
  980. if padding_mode == "constant":
  981. image = torch_pad(image, torch_padding, mode=padding_mode, value=float(fill))
  982. elif padding_mode in ("reflect", "replicate"):
  983. # `torch_pad` only supports `"reflect"` or `"replicate"` padding for floating point inputs.
  984. # TODO: See https://github.com/pytorch/pytorch/issues/40763
  985. dtype = image.dtype
  986. if not image.is_floating_point():
  987. needs_cast = True
  988. image = image.to(torch.float32)
  989. else:
  990. needs_cast = False
  991. image = torch_pad(image, torch_padding, mode=padding_mode)
  992. if needs_cast:
  993. image = image.to(dtype)
  994. else: # padding_mode == "symmetric"
  995. image = _pad_symmetric(image, torch_padding)
  996. new_height, new_width = image.shape[-2:]
  997. return image.reshape(shape[:-3] + (num_channels, new_height, new_width))
  998. # TODO: This should be removed once torch_pad supports non-scalar padding values
  999. def _pad_with_vector_fill(
  1000. image: torch.Tensor,
  1001. torch_padding: List[int],
  1002. fill: List[float],
  1003. padding_mode: str,
  1004. ) -> torch.Tensor:
  1005. if padding_mode != "constant":
  1006. raise ValueError(f"Padding mode '{padding_mode}' is not supported if fill is not scalar")
  1007. output = _pad_with_scalar_fill(image, torch_padding, fill=0, padding_mode="constant")
  1008. left, right, top, bottom = torch_padding
  1009. # We are creating the tensor in the autodetected dtype first and convert to the right one after to avoid an implicit
  1010. # float -> int conversion. That happens for example for the valid input of a uint8 image with floating point fill
  1011. # value.
  1012. fill = torch.tensor(fill, device=image.device).to(dtype=image.dtype).reshape(-1, 1, 1)
  1013. if top > 0:
  1014. output[..., :top, :] = fill
  1015. if left > 0:
  1016. output[..., :, :left] = fill
  1017. if bottom > 0:
  1018. output[..., -bottom:, :] = fill
  1019. if right > 0:
  1020. output[..., :, -right:] = fill
  1021. return output
  1022. _pad_image_pil = _register_kernel_internal(pad, PIL.Image.Image)(_FP.pad)
  1023. @_register_kernel_internal(pad, tv_tensors.Mask)
  1024. def pad_mask(
  1025. mask: torch.Tensor,
  1026. padding: List[int],
  1027. fill: Optional[Union[int, float, List[float]]] = None,
  1028. padding_mode: str = "constant",
  1029. ) -> torch.Tensor:
  1030. if fill is None:
  1031. fill = 0
  1032. if isinstance(fill, (tuple, list)):
  1033. raise ValueError("Non-scalar fill value is not supported")
  1034. if mask.ndim < 3:
  1035. mask = mask.unsqueeze(0)
  1036. needs_squeeze = True
  1037. else:
  1038. needs_squeeze = False
  1039. output = pad_image(mask, padding=padding, fill=fill, padding_mode=padding_mode)
  1040. if needs_squeeze:
  1041. output = output.squeeze(0)
  1042. return output
  1043. def pad_bounding_boxes(
  1044. bounding_boxes: torch.Tensor,
  1045. format: tv_tensors.BoundingBoxFormat,
  1046. canvas_size: Tuple[int, int],
  1047. padding: List[int],
  1048. padding_mode: str = "constant",
  1049. ) -> Tuple[torch.Tensor, Tuple[int, int]]:
  1050. if padding_mode not in ["constant"]:
  1051. # TODO: add support of other padding modes
  1052. raise ValueError(f"Padding mode '{padding_mode}' is not supported with bounding boxes")
  1053. left, right, top, bottom = _parse_pad_padding(padding)
  1054. if format == tv_tensors.BoundingBoxFormat.XYXY:
  1055. pad = [left, top, left, top]
  1056. else:
  1057. pad = [left, top, 0, 0]
  1058. bounding_boxes = bounding_boxes + torch.tensor(pad, dtype=bounding_boxes.dtype, device=bounding_boxes.device)
  1059. height, width = canvas_size
  1060. height += top + bottom
  1061. width += left + right
  1062. canvas_size = (height, width)
  1063. return clamp_bounding_boxes(bounding_boxes, format=format, canvas_size=canvas_size), canvas_size
  1064. @_register_kernel_internal(pad, tv_tensors.BoundingBoxes, tv_tensor_wrapper=False)
  1065. def _pad_bounding_boxes_dispatch(
  1066. inpt: tv_tensors.BoundingBoxes, padding: List[int], padding_mode: str = "constant", **kwargs
  1067. ) -> tv_tensors.BoundingBoxes:
  1068. output, canvas_size = pad_bounding_boxes(
  1069. inpt.as_subclass(torch.Tensor),
  1070. format=inpt.format,
  1071. canvas_size=inpt.canvas_size,
  1072. padding=padding,
  1073. padding_mode=padding_mode,
  1074. )
  1075. return tv_tensors.wrap(output, like=inpt, canvas_size=canvas_size)
  1076. @_register_kernel_internal(pad, tv_tensors.Video)
  1077. def pad_video(
  1078. video: torch.Tensor,
  1079. padding: List[int],
  1080. fill: Optional[Union[int, float, List[float]]] = None,
  1081. padding_mode: str = "constant",
  1082. ) -> torch.Tensor:
  1083. return pad_image(video, padding, fill=fill, padding_mode=padding_mode)
  1084. def crop(inpt: torch.Tensor, top: int, left: int, height: int, width: int) -> torch.Tensor:
  1085. """[BETA] See :class:`~torchvision.transforms.v2.RandomCrop` for details."""
  1086. if torch.jit.is_scripting():
  1087. return crop_image(inpt, top=top, left=left, height=height, width=width)
  1088. _log_api_usage_once(crop)
  1089. kernel = _get_kernel(crop, type(inpt))
  1090. return kernel(inpt, top=top, left=left, height=height, width=width)
  1091. @_register_kernel_internal(crop, torch.Tensor)
  1092. @_register_kernel_internal(crop, tv_tensors.Image)
  1093. def crop_image(image: torch.Tensor, top: int, left: int, height: int, width: int) -> torch.Tensor:
  1094. h, w = image.shape[-2:]
  1095. right = left + width
  1096. bottom = top + height
  1097. if left < 0 or top < 0 or right > w or bottom > h:
  1098. image = image[..., max(top, 0) : bottom, max(left, 0) : right]
  1099. torch_padding = [
  1100. max(min(right, 0) - left, 0),
  1101. max(right - max(w, left), 0),
  1102. max(min(bottom, 0) - top, 0),
  1103. max(bottom - max(h, top), 0),
  1104. ]
  1105. return _pad_with_scalar_fill(image, torch_padding, fill=0, padding_mode="constant")
  1106. return image[..., top:bottom, left:right]
  1107. _crop_image_pil = _FP.crop
  1108. _register_kernel_internal(crop, PIL.Image.Image)(_crop_image_pil)
  1109. def crop_bounding_boxes(
  1110. bounding_boxes: torch.Tensor,
  1111. format: tv_tensors.BoundingBoxFormat,
  1112. top: int,
  1113. left: int,
  1114. height: int,
  1115. width: int,
  1116. ) -> Tuple[torch.Tensor, Tuple[int, int]]:
  1117. # Crop or implicit pad if left and/or top have negative values:
  1118. if format == tv_tensors.BoundingBoxFormat.XYXY:
  1119. sub = [left, top, left, top]
  1120. else:
  1121. sub = [left, top, 0, 0]
  1122. bounding_boxes = bounding_boxes - torch.tensor(sub, dtype=bounding_boxes.dtype, device=bounding_boxes.device)
  1123. canvas_size = (height, width)
  1124. return clamp_bounding_boxes(bounding_boxes, format=format, canvas_size=canvas_size), canvas_size
  1125. @_register_kernel_internal(crop, tv_tensors.BoundingBoxes, tv_tensor_wrapper=False)
  1126. def _crop_bounding_boxes_dispatch(
  1127. inpt: tv_tensors.BoundingBoxes, top: int, left: int, height: int, width: int
  1128. ) -> tv_tensors.BoundingBoxes:
  1129. output, canvas_size = crop_bounding_boxes(
  1130. inpt.as_subclass(torch.Tensor), format=inpt.format, top=top, left=left, height=height, width=width
  1131. )
  1132. return tv_tensors.wrap(output, like=inpt, canvas_size=canvas_size)
  1133. @_register_kernel_internal(crop, tv_tensors.Mask)
  1134. def crop_mask(mask: torch.Tensor, top: int, left: int, height: int, width: int) -> torch.Tensor:
  1135. if mask.ndim < 3:
  1136. mask = mask.unsqueeze(0)
  1137. needs_squeeze = True
  1138. else:
  1139. needs_squeeze = False
  1140. output = crop_image(mask, top, left, height, width)
  1141. if needs_squeeze:
  1142. output = output.squeeze(0)
  1143. return output
  1144. @_register_kernel_internal(crop, tv_tensors.Video)
  1145. def crop_video(video: torch.Tensor, top: int, left: int, height: int, width: int) -> torch.Tensor:
  1146. return crop_image(video, top, left, height, width)
  1147. def perspective(
  1148. inpt: torch.Tensor,
  1149. startpoints: Optional[List[List[int]]],
  1150. endpoints: Optional[List[List[int]]],
  1151. interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
  1152. fill: _FillTypeJIT = None,
  1153. coefficients: Optional[List[float]] = None,
  1154. ) -> torch.Tensor:
  1155. """[BETA] See :class:`~torchvision.transforms.v2.RandomPerspective` for details."""
  1156. if torch.jit.is_scripting():
  1157. return perspective_image(
  1158. inpt,
  1159. startpoints=startpoints,
  1160. endpoints=endpoints,
  1161. interpolation=interpolation,
  1162. fill=fill,
  1163. coefficients=coefficients,
  1164. )
  1165. _log_api_usage_once(perspective)
  1166. kernel = _get_kernel(perspective, type(inpt))
  1167. return kernel(
  1168. inpt,
  1169. startpoints=startpoints,
  1170. endpoints=endpoints,
  1171. interpolation=interpolation,
  1172. fill=fill,
  1173. coefficients=coefficients,
  1174. )
  1175. def _perspective_grid(coeffs: List[float], ow: int, oh: int, dtype: torch.dtype, device: torch.device) -> torch.Tensor:
  1176. # https://github.com/python-pillow/Pillow/blob/4634eafe3c695a014267eefdce830b4a825beed7/
  1177. # src/libImaging/Geometry.c#L394
  1178. #
  1179. # x_out = (coeffs[0] * x + coeffs[1] * y + coeffs[2]) / (coeffs[6] * x + coeffs[7] * y + 1)
  1180. # y_out = (coeffs[3] * x + coeffs[4] * y + coeffs[5]) / (coeffs[6] * x + coeffs[7] * y + 1)
  1181. #
  1182. theta1 = torch.tensor(
  1183. [[[coeffs[0], coeffs[1], coeffs[2]], [coeffs[3], coeffs[4], coeffs[5]]]], dtype=dtype, device=device
  1184. )
  1185. theta2 = torch.tensor([[[coeffs[6], coeffs[7], 1.0], [coeffs[6], coeffs[7], 1.0]]], dtype=dtype, device=device)
  1186. d = 0.5
  1187. base_grid = torch.empty(1, oh, ow, 3, dtype=dtype, device=device)
  1188. x_grid = torch.linspace(d, ow + d - 1.0, steps=ow, device=device, dtype=dtype)
  1189. base_grid[..., 0].copy_(x_grid)
  1190. y_grid = torch.linspace(d, oh + d - 1.0, steps=oh, device=device, dtype=dtype).unsqueeze_(-1)
  1191. base_grid[..., 1].copy_(y_grid)
  1192. base_grid[..., 2].fill_(1)
  1193. rescaled_theta1 = theta1.transpose(1, 2).div_(torch.tensor([0.5 * ow, 0.5 * oh], dtype=dtype, device=device))
  1194. shape = (1, oh * ow, 3)
  1195. output_grid1 = base_grid.view(shape).bmm(rescaled_theta1)
  1196. output_grid2 = base_grid.view(shape).bmm(theta2.transpose(1, 2))
  1197. output_grid = output_grid1.div_(output_grid2).sub_(1.0)
  1198. return output_grid.view(1, oh, ow, 2)
  1199. def _perspective_coefficients(
  1200. startpoints: Optional[List[List[int]]],
  1201. endpoints: Optional[List[List[int]]],
  1202. coefficients: Optional[List[float]],
  1203. ) -> List[float]:
  1204. if coefficients is not None:
  1205. if startpoints is not None and endpoints is not None:
  1206. raise ValueError("The startpoints/endpoints and the coefficients shouldn't be defined concurrently.")
  1207. elif len(coefficients) != 8:
  1208. raise ValueError("Argument coefficients should have 8 float values")
  1209. return coefficients
  1210. elif startpoints is not None and endpoints is not None:
  1211. return _get_perspective_coeffs(startpoints, endpoints)
  1212. else:
  1213. raise ValueError("Either the startpoints/endpoints or the coefficients must have non `None` values.")
  1214. @_register_kernel_internal(perspective, torch.Tensor)
  1215. @_register_kernel_internal(perspective, tv_tensors.Image)
  1216. def perspective_image(
  1217. image: torch.Tensor,
  1218. startpoints: Optional[List[List[int]]],
  1219. endpoints: Optional[List[List[int]]],
  1220. interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
  1221. fill: _FillTypeJIT = None,
  1222. coefficients: Optional[List[float]] = None,
  1223. ) -> torch.Tensor:
  1224. perspective_coeffs = _perspective_coefficients(startpoints, endpoints, coefficients)
  1225. interpolation = _check_interpolation(interpolation)
  1226. if image.numel() == 0:
  1227. return image
  1228. shape = image.shape
  1229. ndim = image.ndim
  1230. if ndim > 4:
  1231. image = image.reshape((-1,) + shape[-3:])
  1232. needs_unsquash = True
  1233. elif ndim == 3:
  1234. image = image.unsqueeze(0)
  1235. needs_unsquash = True
  1236. else:
  1237. needs_unsquash = False
  1238. _assert_grid_transform_inputs(
  1239. image,
  1240. matrix=None,
  1241. interpolation=interpolation.value,
  1242. fill=fill,
  1243. supported_interpolation_modes=["nearest", "bilinear"],
  1244. coeffs=perspective_coeffs,
  1245. )
  1246. oh, ow = shape[-2:]
  1247. dtype = image.dtype if torch.is_floating_point(image) else torch.float32
  1248. grid = _perspective_grid(perspective_coeffs, ow=ow, oh=oh, dtype=dtype, device=image.device)
  1249. output = _apply_grid_transform(image, grid, interpolation.value, fill=fill)
  1250. if needs_unsquash:
  1251. output = output.reshape(shape)
  1252. return output
  1253. @_register_kernel_internal(perspective, PIL.Image.Image)
  1254. def _perspective_image_pil(
  1255. image: PIL.Image.Image,
  1256. startpoints: Optional[List[List[int]]],
  1257. endpoints: Optional[List[List[int]]],
  1258. interpolation: Union[InterpolationMode, int] = InterpolationMode.BICUBIC,
  1259. fill: _FillTypeJIT = None,
  1260. coefficients: Optional[List[float]] = None,
  1261. ) -> PIL.Image.Image:
  1262. perspective_coeffs = _perspective_coefficients(startpoints, endpoints, coefficients)
  1263. interpolation = _check_interpolation(interpolation)
  1264. return _FP.perspective(image, perspective_coeffs, interpolation=pil_modes_mapping[interpolation], fill=fill)
  1265. def perspective_bounding_boxes(
  1266. bounding_boxes: torch.Tensor,
  1267. format: tv_tensors.BoundingBoxFormat,
  1268. canvas_size: Tuple[int, int],
  1269. startpoints: Optional[List[List[int]]],
  1270. endpoints: Optional[List[List[int]]],
  1271. coefficients: Optional[List[float]] = None,
  1272. ) -> torch.Tensor:
  1273. if bounding_boxes.numel() == 0:
  1274. return bounding_boxes
  1275. perspective_coeffs = _perspective_coefficients(startpoints, endpoints, coefficients)
  1276. original_shape = bounding_boxes.shape
  1277. # TODO: first cast to float if bbox is int64 before convert_bounding_box_format
  1278. bounding_boxes = (
  1279. convert_bounding_box_format(bounding_boxes, old_format=format, new_format=tv_tensors.BoundingBoxFormat.XYXY)
  1280. ).reshape(-1, 4)
  1281. dtype = bounding_boxes.dtype if torch.is_floating_point(bounding_boxes) else torch.float32
  1282. device = bounding_boxes.device
  1283. # perspective_coeffs are computed as endpoint -> start point
  1284. # We have to invert perspective_coeffs for bboxes:
  1285. # (x, y) - end point and (x_out, y_out) - start point
  1286. # x_out = (coeffs[0] * x + coeffs[1] * y + coeffs[2]) / (coeffs[6] * x + coeffs[7] * y + 1)
  1287. # y_out = (coeffs[3] * x + coeffs[4] * y + coeffs[5]) / (coeffs[6] * x + coeffs[7] * y + 1)
  1288. # and we would like to get:
  1289. # x = (inv_coeffs[0] * x_out + inv_coeffs[1] * y_out + inv_coeffs[2])
  1290. # / (inv_coeffs[6] * x_out + inv_coeffs[7] * y_out + 1)
  1291. # y = (inv_coeffs[3] * x_out + inv_coeffs[4] * y_out + inv_coeffs[5])
  1292. # / (inv_coeffs[6] * x_out + inv_coeffs[7] * y_out + 1)
  1293. # and compute inv_coeffs in terms of coeffs
  1294. denom = perspective_coeffs[0] * perspective_coeffs[4] - perspective_coeffs[1] * perspective_coeffs[3]
  1295. if denom == 0:
  1296. raise RuntimeError(
  1297. f"Provided perspective_coeffs {perspective_coeffs} can not be inverted to transform bounding boxes. "
  1298. f"Denominator is zero, denom={denom}"
  1299. )
  1300. inv_coeffs = [
  1301. (perspective_coeffs[4] - perspective_coeffs[5] * perspective_coeffs[7]) / denom,
  1302. (-perspective_coeffs[1] + perspective_coeffs[2] * perspective_coeffs[7]) / denom,
  1303. (perspective_coeffs[1] * perspective_coeffs[5] - perspective_coeffs[2] * perspective_coeffs[4]) / denom,
  1304. (-perspective_coeffs[3] + perspective_coeffs[5] * perspective_coeffs[6]) / denom,
  1305. (perspective_coeffs[0] - perspective_coeffs[2] * perspective_coeffs[6]) / denom,
  1306. (-perspective_coeffs[0] * perspective_coeffs[5] + perspective_coeffs[2] * perspective_coeffs[3]) / denom,
  1307. (-perspective_coeffs[4] * perspective_coeffs[6] + perspective_coeffs[3] * perspective_coeffs[7]) / denom,
  1308. (-perspective_coeffs[0] * perspective_coeffs[7] + perspective_coeffs[1] * perspective_coeffs[6]) / denom,
  1309. ]
  1310. theta1 = torch.tensor(
  1311. [[inv_coeffs[0], inv_coeffs[1], inv_coeffs[2]], [inv_coeffs[3], inv_coeffs[4], inv_coeffs[5]]],
  1312. dtype=dtype,
  1313. device=device,
  1314. )
  1315. theta2 = torch.tensor(
  1316. [[inv_coeffs[6], inv_coeffs[7], 1.0], [inv_coeffs[6], inv_coeffs[7], 1.0]], dtype=dtype, device=device
  1317. )
  1318. # 1) Let's transform bboxes into a tensor of 4 points (top-left, top-right, bottom-left, bottom-right corners).
  1319. # Tensor of points has shape (N * 4, 3), where N is the number of bboxes
  1320. # Single point structure is similar to
  1321. # [(xmin, ymin, 1), (xmax, ymin, 1), (xmax, ymax, 1), (xmin, ymax, 1)]
  1322. points = bounding_boxes[:, [[0, 1], [2, 1], [2, 3], [0, 3]]].reshape(-1, 2)
  1323. points = torch.cat([points, torch.ones(points.shape[0], 1, device=points.device)], dim=-1)
  1324. # 2) Now let's transform the points using perspective matrices
  1325. # x_out = (coeffs[0] * x + coeffs[1] * y + coeffs[2]) / (coeffs[6] * x + coeffs[7] * y + 1)
  1326. # y_out = (coeffs[3] * x + coeffs[4] * y + coeffs[5]) / (coeffs[6] * x + coeffs[7] * y + 1)
  1327. numer_points = torch.matmul(points, theta1.T)
  1328. denom_points = torch.matmul(points, theta2.T)
  1329. transformed_points = numer_points.div_(denom_points)
  1330. # 3) Reshape transformed points to [N boxes, 4 points, x/y coords]
  1331. # and compute bounding box from 4 transformed points:
  1332. transformed_points = transformed_points.reshape(-1, 4, 2)
  1333. out_bbox_mins, out_bbox_maxs = torch.aminmax(transformed_points, dim=1)
  1334. out_bboxes = clamp_bounding_boxes(
  1335. torch.cat([out_bbox_mins, out_bbox_maxs], dim=1).to(bounding_boxes.dtype),
  1336. format=tv_tensors.BoundingBoxFormat.XYXY,
  1337. canvas_size=canvas_size,
  1338. )
  1339. # out_bboxes should be of shape [N boxes, 4]
  1340. return convert_bounding_box_format(
  1341. out_bboxes, old_format=tv_tensors.BoundingBoxFormat.XYXY, new_format=format, inplace=True
  1342. ).reshape(original_shape)
  1343. @_register_kernel_internal(perspective, tv_tensors.BoundingBoxes, tv_tensor_wrapper=False)
  1344. def _perspective_bounding_boxes_dispatch(
  1345. inpt: tv_tensors.BoundingBoxes,
  1346. startpoints: Optional[List[List[int]]],
  1347. endpoints: Optional[List[List[int]]],
  1348. coefficients: Optional[List[float]] = None,
  1349. **kwargs,
  1350. ) -> tv_tensors.BoundingBoxes:
  1351. output = perspective_bounding_boxes(
  1352. inpt.as_subclass(torch.Tensor),
  1353. format=inpt.format,
  1354. canvas_size=inpt.canvas_size,
  1355. startpoints=startpoints,
  1356. endpoints=endpoints,
  1357. coefficients=coefficients,
  1358. )
  1359. return tv_tensors.wrap(output, like=inpt)
  1360. def perspective_mask(
  1361. mask: torch.Tensor,
  1362. startpoints: Optional[List[List[int]]],
  1363. endpoints: Optional[List[List[int]]],
  1364. fill: _FillTypeJIT = None,
  1365. coefficients: Optional[List[float]] = None,
  1366. ) -> torch.Tensor:
  1367. if mask.ndim < 3:
  1368. mask = mask.unsqueeze(0)
  1369. needs_squeeze = True
  1370. else:
  1371. needs_squeeze = False
  1372. output = perspective_image(
  1373. mask, startpoints, endpoints, interpolation=InterpolationMode.NEAREST, fill=fill, coefficients=coefficients
  1374. )
  1375. if needs_squeeze:
  1376. output = output.squeeze(0)
  1377. return output
  1378. @_register_kernel_internal(perspective, tv_tensors.Mask, tv_tensor_wrapper=False)
  1379. def _perspective_mask_dispatch(
  1380. inpt: tv_tensors.Mask,
  1381. startpoints: Optional[List[List[int]]],
  1382. endpoints: Optional[List[List[int]]],
  1383. fill: _FillTypeJIT = None,
  1384. coefficients: Optional[List[float]] = None,
  1385. **kwargs,
  1386. ) -> tv_tensors.Mask:
  1387. output = perspective_mask(
  1388. inpt.as_subclass(torch.Tensor),
  1389. startpoints=startpoints,
  1390. endpoints=endpoints,
  1391. fill=fill,
  1392. coefficients=coefficients,
  1393. )
  1394. return tv_tensors.wrap(output, like=inpt)
  1395. @_register_kernel_internal(perspective, tv_tensors.Video)
  1396. def perspective_video(
  1397. video: torch.Tensor,
  1398. startpoints: Optional[List[List[int]]],
  1399. endpoints: Optional[List[List[int]]],
  1400. interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
  1401. fill: _FillTypeJIT = None,
  1402. coefficients: Optional[List[float]] = None,
  1403. ) -> torch.Tensor:
  1404. return perspective_image(
  1405. video, startpoints, endpoints, interpolation=interpolation, fill=fill, coefficients=coefficients
  1406. )
  1407. def elastic(
  1408. inpt: torch.Tensor,
  1409. displacement: torch.Tensor,
  1410. interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
  1411. fill: _FillTypeJIT = None,
  1412. ) -> torch.Tensor:
  1413. """[BETA] See :class:`~torchvision.transforms.v2.ElasticTransform` for details."""
  1414. if torch.jit.is_scripting():
  1415. return elastic_image(inpt, displacement=displacement, interpolation=interpolation, fill=fill)
  1416. _log_api_usage_once(elastic)
  1417. kernel = _get_kernel(elastic, type(inpt))
  1418. return kernel(inpt, displacement=displacement, interpolation=interpolation, fill=fill)
  1419. elastic_transform = elastic
  1420. @_register_kernel_internal(elastic, torch.Tensor)
  1421. @_register_kernel_internal(elastic, tv_tensors.Image)
  1422. def elastic_image(
  1423. image: torch.Tensor,
  1424. displacement: torch.Tensor,
  1425. interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
  1426. fill: _FillTypeJIT = None,
  1427. ) -> torch.Tensor:
  1428. if not isinstance(displacement, torch.Tensor):
  1429. raise TypeError("Argument displacement should be a Tensor")
  1430. interpolation = _check_interpolation(interpolation)
  1431. if image.numel() == 0:
  1432. return image
  1433. shape = image.shape
  1434. ndim = image.ndim
  1435. device = image.device
  1436. dtype = image.dtype if torch.is_floating_point(image) else torch.float32
  1437. # Patch: elastic transform should support (cpu,f16) input
  1438. is_cpu_half = device.type == "cpu" and dtype == torch.float16
  1439. if is_cpu_half:
  1440. image = image.to(torch.float32)
  1441. dtype = torch.float32
  1442. # We are aware that if input image dtype is uint8 and displacement is float64 then
  1443. # displacement will be casted to float32 and all computations will be done with float32
  1444. # We can fix this later if needed
  1445. expected_shape = (1,) + shape[-2:] + (2,)
  1446. if expected_shape != displacement.shape:
  1447. raise ValueError(f"Argument displacement shape should be {expected_shape}, but given {displacement.shape}")
  1448. if ndim > 4:
  1449. image = image.reshape((-1,) + shape[-3:])
  1450. needs_unsquash = True
  1451. elif ndim == 3:
  1452. image = image.unsqueeze(0)
  1453. needs_unsquash = True
  1454. else:
  1455. needs_unsquash = False
  1456. if displacement.dtype != dtype or displacement.device != device:
  1457. displacement = displacement.to(dtype=dtype, device=device)
  1458. image_height, image_width = shape[-2:]
  1459. grid = _create_identity_grid((image_height, image_width), device=device, dtype=dtype).add_(displacement)
  1460. output = _apply_grid_transform(image, grid, interpolation.value, fill=fill)
  1461. if needs_unsquash:
  1462. output = output.reshape(shape)
  1463. if is_cpu_half:
  1464. output = output.to(torch.float16)
  1465. return output
  1466. @_register_kernel_internal(elastic, PIL.Image.Image)
  1467. def _elastic_image_pil(
  1468. image: PIL.Image.Image,
  1469. displacement: torch.Tensor,
  1470. interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
  1471. fill: _FillTypeJIT = None,
  1472. ) -> PIL.Image.Image:
  1473. t_img = pil_to_tensor(image)
  1474. output = elastic_image(t_img, displacement, interpolation=interpolation, fill=fill)
  1475. return to_pil_image(output, mode=image.mode)
  1476. def _create_identity_grid(size: Tuple[int, int], device: torch.device, dtype: torch.dtype) -> torch.Tensor:
  1477. sy, sx = size
  1478. base_grid = torch.empty(1, sy, sx, 2, device=device, dtype=dtype)
  1479. x_grid = torch.linspace((-sx + 1) / sx, (sx - 1) / sx, sx, device=device, dtype=dtype)
  1480. base_grid[..., 0].copy_(x_grid)
  1481. y_grid = torch.linspace((-sy + 1) / sy, (sy - 1) / sy, sy, device=device, dtype=dtype).unsqueeze_(-1)
  1482. base_grid[..., 1].copy_(y_grid)
  1483. return base_grid
  1484. def elastic_bounding_boxes(
  1485. bounding_boxes: torch.Tensor,
  1486. format: tv_tensors.BoundingBoxFormat,
  1487. canvas_size: Tuple[int, int],
  1488. displacement: torch.Tensor,
  1489. ) -> torch.Tensor:
  1490. expected_shape = (1, canvas_size[0], canvas_size[1], 2)
  1491. if not isinstance(displacement, torch.Tensor):
  1492. raise TypeError("Argument displacement should be a Tensor")
  1493. elif displacement.shape != expected_shape:
  1494. raise ValueError(f"Argument displacement shape should be {expected_shape}, but given {displacement.shape}")
  1495. if bounding_boxes.numel() == 0:
  1496. return bounding_boxes
  1497. # TODO: add in docstring about approximation we are doing for grid inversion
  1498. device = bounding_boxes.device
  1499. dtype = bounding_boxes.dtype if torch.is_floating_point(bounding_boxes) else torch.float32
  1500. if displacement.dtype != dtype or displacement.device != device:
  1501. displacement = displacement.to(dtype=dtype, device=device)
  1502. original_shape = bounding_boxes.shape
  1503. # TODO: first cast to float if bbox is int64 before convert_bounding_box_format
  1504. bounding_boxes = (
  1505. convert_bounding_box_format(bounding_boxes, old_format=format, new_format=tv_tensors.BoundingBoxFormat.XYXY)
  1506. ).reshape(-1, 4)
  1507. id_grid = _create_identity_grid(canvas_size, device=device, dtype=dtype)
  1508. # We construct an approximation of inverse grid as inv_grid = id_grid - displacement
  1509. # This is not an exact inverse of the grid
  1510. inv_grid = id_grid.sub_(displacement)
  1511. # Get points from bboxes
  1512. points = bounding_boxes[:, [[0, 1], [2, 1], [2, 3], [0, 3]]].reshape(-1, 2)
  1513. if points.is_floating_point():
  1514. points = points.ceil_()
  1515. index_xy = points.to(dtype=torch.long)
  1516. index_x, index_y = index_xy[:, 0], index_xy[:, 1]
  1517. # Transform points:
  1518. t_size = torch.tensor(canvas_size[::-1], device=displacement.device, dtype=displacement.dtype)
  1519. transformed_points = inv_grid[0, index_y, index_x, :].add_(1).mul_(0.5 * t_size).sub_(0.5)
  1520. transformed_points = transformed_points.reshape(-1, 4, 2)
  1521. out_bbox_mins, out_bbox_maxs = torch.aminmax(transformed_points, dim=1)
  1522. out_bboxes = clamp_bounding_boxes(
  1523. torch.cat([out_bbox_mins, out_bbox_maxs], dim=1).to(bounding_boxes.dtype),
  1524. format=tv_tensors.BoundingBoxFormat.XYXY,
  1525. canvas_size=canvas_size,
  1526. )
  1527. return convert_bounding_box_format(
  1528. out_bboxes, old_format=tv_tensors.BoundingBoxFormat.XYXY, new_format=format, inplace=True
  1529. ).reshape(original_shape)
  1530. @_register_kernel_internal(elastic, tv_tensors.BoundingBoxes, tv_tensor_wrapper=False)
  1531. def _elastic_bounding_boxes_dispatch(
  1532. inpt: tv_tensors.BoundingBoxes, displacement: torch.Tensor, **kwargs
  1533. ) -> tv_tensors.BoundingBoxes:
  1534. output = elastic_bounding_boxes(
  1535. inpt.as_subclass(torch.Tensor), format=inpt.format, canvas_size=inpt.canvas_size, displacement=displacement
  1536. )
  1537. return tv_tensors.wrap(output, like=inpt)
  1538. def elastic_mask(
  1539. mask: torch.Tensor,
  1540. displacement: torch.Tensor,
  1541. fill: _FillTypeJIT = None,
  1542. ) -> torch.Tensor:
  1543. if mask.ndim < 3:
  1544. mask = mask.unsqueeze(0)
  1545. needs_squeeze = True
  1546. else:
  1547. needs_squeeze = False
  1548. output = elastic_image(mask, displacement=displacement, interpolation=InterpolationMode.NEAREST, fill=fill)
  1549. if needs_squeeze:
  1550. output = output.squeeze(0)
  1551. return output
  1552. @_register_kernel_internal(elastic, tv_tensors.Mask, tv_tensor_wrapper=False)
  1553. def _elastic_mask_dispatch(
  1554. inpt: tv_tensors.Mask, displacement: torch.Tensor, fill: _FillTypeJIT = None, **kwargs
  1555. ) -> tv_tensors.Mask:
  1556. output = elastic_mask(inpt.as_subclass(torch.Tensor), displacement=displacement, fill=fill)
  1557. return tv_tensors.wrap(output, like=inpt)
  1558. @_register_kernel_internal(elastic, tv_tensors.Video)
  1559. def elastic_video(
  1560. video: torch.Tensor,
  1561. displacement: torch.Tensor,
  1562. interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
  1563. fill: _FillTypeJIT = None,
  1564. ) -> torch.Tensor:
  1565. return elastic_image(video, displacement, interpolation=interpolation, fill=fill)
  1566. def center_crop(inpt: torch.Tensor, output_size: List[int]) -> torch.Tensor:
  1567. """[BETA] See :class:`~torchvision.transforms.v2.RandomCrop` for details."""
  1568. if torch.jit.is_scripting():
  1569. return center_crop_image(inpt, output_size=output_size)
  1570. _log_api_usage_once(center_crop)
  1571. kernel = _get_kernel(center_crop, type(inpt))
  1572. return kernel(inpt, output_size=output_size)
  1573. def _center_crop_parse_output_size(output_size: List[int]) -> List[int]:
  1574. if isinstance(output_size, numbers.Number):
  1575. s = int(output_size)
  1576. return [s, s]
  1577. elif isinstance(output_size, (tuple, list)) and len(output_size) == 1:
  1578. return [output_size[0], output_size[0]]
  1579. else:
  1580. return list(output_size)
  1581. def _center_crop_compute_padding(crop_height: int, crop_width: int, image_height: int, image_width: int) -> List[int]:
  1582. return [
  1583. (crop_width - image_width) // 2 if crop_width > image_width else 0,
  1584. (crop_height - image_height) // 2 if crop_height > image_height else 0,
  1585. (crop_width - image_width + 1) // 2 if crop_width > image_width else 0,
  1586. (crop_height - image_height + 1) // 2 if crop_height > image_height else 0,
  1587. ]
  1588. def _center_crop_compute_crop_anchor(
  1589. crop_height: int, crop_width: int, image_height: int, image_width: int
  1590. ) -> Tuple[int, int]:
  1591. crop_top = int(round((image_height - crop_height) / 2.0))
  1592. crop_left = int(round((image_width - crop_width) / 2.0))
  1593. return crop_top, crop_left
  1594. @_register_kernel_internal(center_crop, torch.Tensor)
  1595. @_register_kernel_internal(center_crop, tv_tensors.Image)
  1596. def center_crop_image(image: torch.Tensor, output_size: List[int]) -> torch.Tensor:
  1597. crop_height, crop_width = _center_crop_parse_output_size(output_size)
  1598. shape = image.shape
  1599. if image.numel() == 0:
  1600. return image.reshape(shape[:-2] + (crop_height, crop_width))
  1601. image_height, image_width = shape[-2:]
  1602. if crop_height > image_height or crop_width > image_width:
  1603. padding_ltrb = _center_crop_compute_padding(crop_height, crop_width, image_height, image_width)
  1604. image = torch_pad(image, _parse_pad_padding(padding_ltrb), value=0.0)
  1605. image_height, image_width = image.shape[-2:]
  1606. if crop_width == image_width and crop_height == image_height:
  1607. return image
  1608. crop_top, crop_left = _center_crop_compute_crop_anchor(crop_height, crop_width, image_height, image_width)
  1609. return image[..., crop_top : (crop_top + crop_height), crop_left : (crop_left + crop_width)]
  1610. @_register_kernel_internal(center_crop, PIL.Image.Image)
  1611. def _center_crop_image_pil(image: PIL.Image.Image, output_size: List[int]) -> PIL.Image.Image:
  1612. crop_height, crop_width = _center_crop_parse_output_size(output_size)
  1613. image_height, image_width = _get_size_image_pil(image)
  1614. if crop_height > image_height or crop_width > image_width:
  1615. padding_ltrb = _center_crop_compute_padding(crop_height, crop_width, image_height, image_width)
  1616. image = _pad_image_pil(image, padding_ltrb, fill=0)
  1617. image_height, image_width = _get_size_image_pil(image)
  1618. if crop_width == image_width and crop_height == image_height:
  1619. return image
  1620. crop_top, crop_left = _center_crop_compute_crop_anchor(crop_height, crop_width, image_height, image_width)
  1621. return _crop_image_pil(image, crop_top, crop_left, crop_height, crop_width)
  1622. def center_crop_bounding_boxes(
  1623. bounding_boxes: torch.Tensor,
  1624. format: tv_tensors.BoundingBoxFormat,
  1625. canvas_size: Tuple[int, int],
  1626. output_size: List[int],
  1627. ) -> Tuple[torch.Tensor, Tuple[int, int]]:
  1628. crop_height, crop_width = _center_crop_parse_output_size(output_size)
  1629. crop_top, crop_left = _center_crop_compute_crop_anchor(crop_height, crop_width, *canvas_size)
  1630. return crop_bounding_boxes(
  1631. bounding_boxes, format, top=crop_top, left=crop_left, height=crop_height, width=crop_width
  1632. )
  1633. @_register_kernel_internal(center_crop, tv_tensors.BoundingBoxes, tv_tensor_wrapper=False)
  1634. def _center_crop_bounding_boxes_dispatch(
  1635. inpt: tv_tensors.BoundingBoxes, output_size: List[int]
  1636. ) -> tv_tensors.BoundingBoxes:
  1637. output, canvas_size = center_crop_bounding_boxes(
  1638. inpt.as_subclass(torch.Tensor), format=inpt.format, canvas_size=inpt.canvas_size, output_size=output_size
  1639. )
  1640. return tv_tensors.wrap(output, like=inpt, canvas_size=canvas_size)
  1641. @_register_kernel_internal(center_crop, tv_tensors.Mask)
  1642. def center_crop_mask(mask: torch.Tensor, output_size: List[int]) -> torch.Tensor:
  1643. if mask.ndim < 3:
  1644. mask = mask.unsqueeze(0)
  1645. needs_squeeze = True
  1646. else:
  1647. needs_squeeze = False
  1648. output = center_crop_image(image=mask, output_size=output_size)
  1649. if needs_squeeze:
  1650. output = output.squeeze(0)
  1651. return output
  1652. @_register_kernel_internal(center_crop, tv_tensors.Video)
  1653. def center_crop_video(video: torch.Tensor, output_size: List[int]) -> torch.Tensor:
  1654. return center_crop_image(video, output_size)
  1655. def resized_crop(
  1656. inpt: torch.Tensor,
  1657. top: int,
  1658. left: int,
  1659. height: int,
  1660. width: int,
  1661. size: List[int],
  1662. interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
  1663. antialias: Optional[Union[str, bool]] = "warn",
  1664. ) -> torch.Tensor:
  1665. """[BETA] See :class:`~torchvision.transforms.v2.RandomResizedCrop` for details."""
  1666. if torch.jit.is_scripting():
  1667. return resized_crop_image(
  1668. inpt,
  1669. top=top,
  1670. left=left,
  1671. height=height,
  1672. width=width,
  1673. size=size,
  1674. interpolation=interpolation,
  1675. antialias=antialias,
  1676. )
  1677. _log_api_usage_once(resized_crop)
  1678. kernel = _get_kernel(resized_crop, type(inpt))
  1679. return kernel(
  1680. inpt,
  1681. top=top,
  1682. left=left,
  1683. height=height,
  1684. width=width,
  1685. size=size,
  1686. interpolation=interpolation,
  1687. antialias=antialias,
  1688. )
  1689. @_register_kernel_internal(resized_crop, torch.Tensor)
  1690. @_register_kernel_internal(resized_crop, tv_tensors.Image)
  1691. def resized_crop_image(
  1692. image: torch.Tensor,
  1693. top: int,
  1694. left: int,
  1695. height: int,
  1696. width: int,
  1697. size: List[int],
  1698. interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
  1699. antialias: Optional[Union[str, bool]] = "warn",
  1700. ) -> torch.Tensor:
  1701. image = crop_image(image, top, left, height, width)
  1702. return resize_image(image, size, interpolation=interpolation, antialias=antialias)
  1703. def _resized_crop_image_pil(
  1704. image: PIL.Image.Image,
  1705. top: int,
  1706. left: int,
  1707. height: int,
  1708. width: int,
  1709. size: List[int],
  1710. interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
  1711. ) -> PIL.Image.Image:
  1712. image = _crop_image_pil(image, top, left, height, width)
  1713. return _resize_image_pil(image, size, interpolation=interpolation)
  1714. @_register_kernel_internal(resized_crop, PIL.Image.Image)
  1715. def _resized_crop_image_pil_dispatch(
  1716. image: PIL.Image.Image,
  1717. top: int,
  1718. left: int,
  1719. height: int,
  1720. width: int,
  1721. size: List[int],
  1722. interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
  1723. antialias: Optional[Union[str, bool]] = "warn",
  1724. ) -> PIL.Image.Image:
  1725. if antialias is False:
  1726. warnings.warn("Anti-alias option is always applied for PIL Image input. Argument antialias is ignored.")
  1727. return _resized_crop_image_pil(
  1728. image,
  1729. top=top,
  1730. left=left,
  1731. height=height,
  1732. width=width,
  1733. size=size,
  1734. interpolation=interpolation,
  1735. )
  1736. def resized_crop_bounding_boxes(
  1737. bounding_boxes: torch.Tensor,
  1738. format: tv_tensors.BoundingBoxFormat,
  1739. top: int,
  1740. left: int,
  1741. height: int,
  1742. width: int,
  1743. size: List[int],
  1744. ) -> Tuple[torch.Tensor, Tuple[int, int]]:
  1745. bounding_boxes, canvas_size = crop_bounding_boxes(bounding_boxes, format, top, left, height, width)
  1746. return resize_bounding_boxes(bounding_boxes, canvas_size=canvas_size, size=size)
  1747. @_register_kernel_internal(resized_crop, tv_tensors.BoundingBoxes, tv_tensor_wrapper=False)
  1748. def _resized_crop_bounding_boxes_dispatch(
  1749. inpt: tv_tensors.BoundingBoxes, top: int, left: int, height: int, width: int, size: List[int], **kwargs
  1750. ) -> tv_tensors.BoundingBoxes:
  1751. output, canvas_size = resized_crop_bounding_boxes(
  1752. inpt.as_subclass(torch.Tensor), format=inpt.format, top=top, left=left, height=height, width=width, size=size
  1753. )
  1754. return tv_tensors.wrap(output, like=inpt, canvas_size=canvas_size)
  1755. def resized_crop_mask(
  1756. mask: torch.Tensor,
  1757. top: int,
  1758. left: int,
  1759. height: int,
  1760. width: int,
  1761. size: List[int],
  1762. ) -> torch.Tensor:
  1763. mask = crop_mask(mask, top, left, height, width)
  1764. return resize_mask(mask, size)
  1765. @_register_kernel_internal(resized_crop, tv_tensors.Mask, tv_tensor_wrapper=False)
  1766. def _resized_crop_mask_dispatch(
  1767. inpt: tv_tensors.Mask, top: int, left: int, height: int, width: int, size: List[int], **kwargs
  1768. ) -> tv_tensors.Mask:
  1769. output = resized_crop_mask(
  1770. inpt.as_subclass(torch.Tensor), top=top, left=left, height=height, width=width, size=size
  1771. )
  1772. return tv_tensors.wrap(output, like=inpt)
  1773. @_register_kernel_internal(resized_crop, tv_tensors.Video)
  1774. def resized_crop_video(
  1775. video: torch.Tensor,
  1776. top: int,
  1777. left: int,
  1778. height: int,
  1779. width: int,
  1780. size: List[int],
  1781. interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
  1782. antialias: Optional[Union[str, bool]] = "warn",
  1783. ) -> torch.Tensor:
  1784. return resized_crop_image(
  1785. video, top, left, height, width, antialias=antialias, size=size, interpolation=interpolation
  1786. )
  1787. def five_crop(
  1788. inpt: torch.Tensor, size: List[int]
  1789. ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
  1790. """[BETA] See :class:`~torchvision.transforms.v2.FiveCrop` for details."""
  1791. if torch.jit.is_scripting():
  1792. return five_crop_image(inpt, size=size)
  1793. _log_api_usage_once(five_crop)
  1794. kernel = _get_kernel(five_crop, type(inpt))
  1795. return kernel(inpt, size=size)
  1796. def _parse_five_crop_size(size: List[int]) -> List[int]:
  1797. if isinstance(size, numbers.Number):
  1798. s = int(size)
  1799. size = [s, s]
  1800. elif isinstance(size, (tuple, list)) and len(size) == 1:
  1801. s = size[0]
  1802. size = [s, s]
  1803. if len(size) != 2:
  1804. raise ValueError("Please provide only two dimensions (h, w) for size.")
  1805. return size
  1806. @_register_five_ten_crop_kernel_internal(five_crop, torch.Tensor)
  1807. @_register_five_ten_crop_kernel_internal(five_crop, tv_tensors.Image)
  1808. def five_crop_image(
  1809. image: torch.Tensor, size: List[int]
  1810. ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
  1811. crop_height, crop_width = _parse_five_crop_size(size)
  1812. image_height, image_width = image.shape[-2:]
  1813. if crop_width > image_width or crop_height > image_height:
  1814. raise ValueError(f"Requested crop size {size} is bigger than input size {(image_height, image_width)}")
  1815. tl = crop_image(image, 0, 0, crop_height, crop_width)
  1816. tr = crop_image(image, 0, image_width - crop_width, crop_height, crop_width)
  1817. bl = crop_image(image, image_height - crop_height, 0, crop_height, crop_width)
  1818. br = crop_image(image, image_height - crop_height, image_width - crop_width, crop_height, crop_width)
  1819. center = center_crop_image(image, [crop_height, crop_width])
  1820. return tl, tr, bl, br, center
  1821. @_register_five_ten_crop_kernel_internal(five_crop, PIL.Image.Image)
  1822. def _five_crop_image_pil(
  1823. image: PIL.Image.Image, size: List[int]
  1824. ) -> Tuple[PIL.Image.Image, PIL.Image.Image, PIL.Image.Image, PIL.Image.Image, PIL.Image.Image]:
  1825. crop_height, crop_width = _parse_five_crop_size(size)
  1826. image_height, image_width = _get_size_image_pil(image)
  1827. if crop_width > image_width or crop_height > image_height:
  1828. raise ValueError(f"Requested crop size {size} is bigger than input size {(image_height, image_width)}")
  1829. tl = _crop_image_pil(image, 0, 0, crop_height, crop_width)
  1830. tr = _crop_image_pil(image, 0, image_width - crop_width, crop_height, crop_width)
  1831. bl = _crop_image_pil(image, image_height - crop_height, 0, crop_height, crop_width)
  1832. br = _crop_image_pil(image, image_height - crop_height, image_width - crop_width, crop_height, crop_width)
  1833. center = _center_crop_image_pil(image, [crop_height, crop_width])
  1834. return tl, tr, bl, br, center
  1835. @_register_five_ten_crop_kernel_internal(five_crop, tv_tensors.Video)
  1836. def five_crop_video(
  1837. video: torch.Tensor, size: List[int]
  1838. ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
  1839. return five_crop_image(video, size)
  1840. def ten_crop(
  1841. inpt: torch.Tensor, size: List[int], vertical_flip: bool = False
  1842. ) -> Tuple[
  1843. torch.Tensor,
  1844. torch.Tensor,
  1845. torch.Tensor,
  1846. torch.Tensor,
  1847. torch.Tensor,
  1848. torch.Tensor,
  1849. torch.Tensor,
  1850. torch.Tensor,
  1851. torch.Tensor,
  1852. torch.Tensor,
  1853. ]:
  1854. """[BETA] See :class:`~torchvision.transforms.v2.TenCrop` for details."""
  1855. if torch.jit.is_scripting():
  1856. return ten_crop_image(inpt, size=size, vertical_flip=vertical_flip)
  1857. _log_api_usage_once(ten_crop)
  1858. kernel = _get_kernel(ten_crop, type(inpt))
  1859. return kernel(inpt, size=size, vertical_flip=vertical_flip)
  1860. @_register_five_ten_crop_kernel_internal(ten_crop, torch.Tensor)
  1861. @_register_five_ten_crop_kernel_internal(ten_crop, tv_tensors.Image)
  1862. def ten_crop_image(
  1863. image: torch.Tensor, size: List[int], vertical_flip: bool = False
  1864. ) -> Tuple[
  1865. torch.Tensor,
  1866. torch.Tensor,
  1867. torch.Tensor,
  1868. torch.Tensor,
  1869. torch.Tensor,
  1870. torch.Tensor,
  1871. torch.Tensor,
  1872. torch.Tensor,
  1873. torch.Tensor,
  1874. torch.Tensor,
  1875. ]:
  1876. non_flipped = five_crop_image(image, size)
  1877. if vertical_flip:
  1878. image = vertical_flip_image(image)
  1879. else:
  1880. image = horizontal_flip_image(image)
  1881. flipped = five_crop_image(image, size)
  1882. return non_flipped + flipped
  1883. @_register_five_ten_crop_kernel_internal(ten_crop, PIL.Image.Image)
  1884. def _ten_crop_image_pil(
  1885. image: PIL.Image.Image, size: List[int], vertical_flip: bool = False
  1886. ) -> Tuple[
  1887. PIL.Image.Image,
  1888. PIL.Image.Image,
  1889. PIL.Image.Image,
  1890. PIL.Image.Image,
  1891. PIL.Image.Image,
  1892. PIL.Image.Image,
  1893. PIL.Image.Image,
  1894. PIL.Image.Image,
  1895. PIL.Image.Image,
  1896. PIL.Image.Image,
  1897. ]:
  1898. non_flipped = _five_crop_image_pil(image, size)
  1899. if vertical_flip:
  1900. image = _vertical_flip_image_pil(image)
  1901. else:
  1902. image = _horizontal_flip_image_pil(image)
  1903. flipped = _five_crop_image_pil(image, size)
  1904. return non_flipped + flipped
  1905. @_register_five_ten_crop_kernel_internal(ten_crop, tv_tensors.Video)
  1906. def ten_crop_video(
  1907. video: torch.Tensor, size: List[int], vertical_flip: bool = False
  1908. ) -> Tuple[
  1909. torch.Tensor,
  1910. torch.Tensor,
  1911. torch.Tensor,
  1912. torch.Tensor,
  1913. torch.Tensor,
  1914. torch.Tensor,
  1915. torch.Tensor,
  1916. torch.Tensor,
  1917. torch.Tensor,
  1918. torch.Tensor,
  1919. ]:
  1920. return ten_crop_image(video, size, vertical_flip=vertical_flip)