test_transforms_v2_refactored.py 116 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228222922302231223222332234223522362237223822392240224122422243224422452246224722482249225022512252225322542255225622572258225922602261226222632264226522662267226822692270227122722273227422752276227722782279228022812282228322842285228622872288228922902291229222932294229522962297229822992300230123022303230423052306230723082309231023112312231323142315231623172318231923202321232223232324232523262327232823292330233123322333233423352336233723382339234023412342234323442345234623472348234923502351235223532354235523562357235823592360236123622363236423652366236723682369237023712372237323742375237623772378237923802381238223832384238523862387238823892390239123922393239423952396239723982399240024012402240324042405240624072408240924102411241224132414241524162417241824192420242124222423242424252426242724282429243024312432243324342435243624372438243924402441244224432444244524462447244824492450245124522453245424552456245724582459246024612462246324642465246624672468246924702471247224732474247524762477247824792480248124822483248424852486248724882489249024912492249324942495249624972498249925002501250225032504250525062507250825092510251125122513251425152516251725182519252025212522252325242525252625272528252925302531253225332534253525362537253825392540254125422543254425452546254725482549255025512552255325542555255625572558255925602561256225632564256525662567256825692570257125722573257425752576257725782579258025812582258325842585258625872588258925902591259225932594259525962597259825992600260126022603260426052606260726082609261026112612261326142615261626172618261926202621262226232624262526262627262826292630263126322633263426352636263726382639264026412642264326442645264626472648264926502651265226532654265526562657265826592660266126622663266426652666266726682669267026712672267326742675267626772678267926802681268226832684268526862687268826892690269126922693269426952696269726982699270027012702270327042705270627072708270927102711271227132714271527162717271827192720272127222723272427252726272727282729273027312732273327342735273627372738273927402741274227432744274527462747274827492750275127522753275427552756275727582759276027612762276327642765276627672768276927702771277227732774277527762777277827792780278127822783278427852786278727882789279027912792279327942795279627972798279928002801280228032804280528062807280828092810281128122813281428152816281728182819282028212822282328242825282628272828282928302831283228332834283528362837283828392840284128422843284428452846284728482849285028512852285328542855285628572858285928602861286228632864286528662867286828692870287128722873287428752876287728782879288028812882288328842885288628872888288928902891289228932894289528962897289828992900290129022903290429052906290729082909
  1. import contextlib
  2. import decimal
  3. import inspect
  4. import math
  5. import pickle
  6. import re
  7. from pathlib import Path
  8. from unittest import mock
  9. import numpy as np
  10. import PIL.Image
  11. import pytest
  12. import torch
  13. import torchvision.transforms.v2 as transforms
  14. from common_utils import (
  15. assert_equal,
  16. assert_no_warnings,
  17. cache,
  18. cpu_and_cuda,
  19. freeze_rng_state,
  20. ignore_jit_no_profile_information_warning,
  21. make_bounding_boxes,
  22. make_detection_mask,
  23. make_image,
  24. make_image_pil,
  25. make_image_tensor,
  26. make_segmentation_mask,
  27. make_video,
  28. make_video_tensor,
  29. needs_cuda,
  30. set_rng_seed,
  31. )
  32. from torch import nn
  33. from torch.testing import assert_close
  34. from torch.utils._pytree import tree_map
  35. from torch.utils.data import DataLoader, default_collate
  36. from torchvision import tv_tensors
  37. from torchvision.transforms._functional_tensor import _max_value as get_max_value
  38. from torchvision.transforms.functional import pil_modes_mapping
  39. from torchvision.transforms.v2 import functional as F
  40. from torchvision.transforms.v2.functional._utils import _get_kernel, _register_kernel_internal
  41. @pytest.fixture(autouse=True)
  42. def fix_rng_seed():
  43. set_rng_seed(0)
  44. yield
  45. def _to_tolerances(maybe_tolerance_dict):
  46. if not isinstance(maybe_tolerance_dict, dict):
  47. return dict(rtol=None, atol=None)
  48. tolerances = dict(rtol=0, atol=0)
  49. tolerances.update(maybe_tolerance_dict)
  50. return tolerances
  51. def _check_kernel_cuda_vs_cpu(kernel, input, *args, rtol, atol, **kwargs):
  52. """Checks if the kernel produces closes results for inputs on GPU and CPU."""
  53. if input.device.type != "cuda":
  54. return
  55. input_cuda = input.as_subclass(torch.Tensor)
  56. input_cpu = input_cuda.to("cpu")
  57. with freeze_rng_state():
  58. actual = kernel(input_cuda, *args, **kwargs)
  59. with freeze_rng_state():
  60. expected = kernel(input_cpu, *args, **kwargs)
  61. assert_close(actual, expected, check_device=False, rtol=rtol, atol=atol)
  62. @cache
  63. def _script(obj):
  64. try:
  65. return torch.jit.script(obj)
  66. except Exception as error:
  67. name = getattr(obj, "__name__", obj.__class__.__name__)
  68. raise AssertionError(f"Trying to `torch.jit.script` '{name}' raised the error above.") from error
  69. def _check_kernel_scripted_vs_eager(kernel, input, *args, rtol, atol, **kwargs):
  70. """Checks if the kernel is scriptable and if the scripted output is close to the eager one."""
  71. if input.device.type != "cpu":
  72. return
  73. kernel_scripted = _script(kernel)
  74. input = input.as_subclass(torch.Tensor)
  75. with ignore_jit_no_profile_information_warning():
  76. actual = kernel_scripted(input, *args, **kwargs)
  77. expected = kernel(input, *args, **kwargs)
  78. assert_close(actual, expected, rtol=rtol, atol=atol)
  79. def _check_kernel_batched_vs_unbatched(kernel, input, *args, rtol, atol, **kwargs):
  80. """Checks if the kernel produces close results for batched and unbatched inputs."""
  81. unbatched_input = input.as_subclass(torch.Tensor)
  82. for batch_dims in [(2,), (2, 1)]:
  83. repeats = [*batch_dims, *[1] * input.ndim]
  84. actual = kernel(unbatched_input.repeat(repeats), *args, **kwargs)
  85. expected = kernel(unbatched_input, *args, **kwargs)
  86. # We can't directly call `.repeat()` on the output, since some kernel also return some additional metadata
  87. if isinstance(expected, torch.Tensor):
  88. expected = expected.repeat(repeats)
  89. else:
  90. tensor, *metadata = expected
  91. expected = (tensor.repeat(repeats), *metadata)
  92. assert_close(actual, expected, rtol=rtol, atol=atol)
  93. for degenerate_batch_dims in [(0,), (5, 0), (0, 5)]:
  94. degenerate_batched_input = torch.empty(
  95. degenerate_batch_dims + input.shape, dtype=input.dtype, device=input.device
  96. )
  97. output = kernel(degenerate_batched_input, *args, **kwargs)
  98. # Most kernels just return a tensor, but some also return some additional metadata
  99. if not isinstance(output, torch.Tensor):
  100. output, *_ = output
  101. assert output.shape[: -input.ndim] == degenerate_batch_dims
  102. def check_kernel(
  103. kernel,
  104. input,
  105. *args,
  106. check_cuda_vs_cpu=True,
  107. check_scripted_vs_eager=True,
  108. check_batched_vs_unbatched=True,
  109. expect_same_dtype=True,
  110. **kwargs,
  111. ):
  112. initial_input_version = input._version
  113. output = kernel(input.as_subclass(torch.Tensor), *args, **kwargs)
  114. # Most kernels just return a tensor, but some also return some additional metadata
  115. if not isinstance(output, torch.Tensor):
  116. output, *_ = output
  117. # check that no inplace operation happened
  118. assert input._version == initial_input_version
  119. if expect_same_dtype:
  120. assert output.dtype == input.dtype
  121. assert output.device == input.device
  122. if check_cuda_vs_cpu:
  123. _check_kernel_cuda_vs_cpu(kernel, input, *args, **kwargs, **_to_tolerances(check_cuda_vs_cpu))
  124. if check_scripted_vs_eager:
  125. _check_kernel_scripted_vs_eager(kernel, input, *args, **kwargs, **_to_tolerances(check_scripted_vs_eager))
  126. if check_batched_vs_unbatched:
  127. _check_kernel_batched_vs_unbatched(kernel, input, *args, **kwargs, **_to_tolerances(check_batched_vs_unbatched))
  128. def _check_functional_scripted_smoke(functional, input, *args, **kwargs):
  129. """Checks if the functional can be scripted and the scripted version can be called without error."""
  130. if not isinstance(input, tv_tensors.Image):
  131. return
  132. functional_scripted = _script(functional)
  133. with ignore_jit_no_profile_information_warning():
  134. functional_scripted(input.as_subclass(torch.Tensor), *args, **kwargs)
  135. def check_functional(functional, input, *args, check_scripted_smoke=True, **kwargs):
  136. unknown_input = object()
  137. with pytest.raises(TypeError, match=re.escape(str(type(unknown_input)))):
  138. functional(unknown_input, *args, **kwargs)
  139. with mock.patch("torch._C._log_api_usage_once", wraps=torch._C._log_api_usage_once) as spy:
  140. output = functional(input, *args, **kwargs)
  141. spy.assert_any_call(f"{functional.__module__}.{functional.__name__}")
  142. assert isinstance(output, type(input))
  143. if isinstance(input, tv_tensors.BoundingBoxes):
  144. assert output.format == input.format
  145. if check_scripted_smoke:
  146. _check_functional_scripted_smoke(functional, input, *args, **kwargs)
  147. def check_functional_kernel_signature_match(functional, *, kernel, input_type):
  148. """Checks if the signature of the functional matches the kernel signature."""
  149. functional_params = list(inspect.signature(functional).parameters.values())[1:]
  150. kernel_params = list(inspect.signature(kernel).parameters.values())[1:]
  151. if issubclass(input_type, tv_tensors.TVTensor):
  152. # We filter out metadata that is implicitly passed to the functional through the input tv_tensor, but has to be
  153. # explicitly passed to the kernel.
  154. explicit_metadata = {
  155. tv_tensors.BoundingBoxes: {"format", "canvas_size"},
  156. }
  157. kernel_params = [param for param in kernel_params if param.name not in explicit_metadata.get(input_type, set())]
  158. functional_params = iter(functional_params)
  159. for functional_param, kernel_param in zip(functional_params, kernel_params):
  160. try:
  161. # In general, the functional parameters are a superset of the kernel parameters. Thus, we filter out
  162. # functional parameters that have no kernel equivalent while keeping the order intact.
  163. while functional_param.name != kernel_param.name:
  164. functional_param = next(functional_params)
  165. except StopIteration:
  166. raise AssertionError(
  167. f"Parameter `{kernel_param.name}` of kernel `{kernel.__name__}` "
  168. f"has no corresponding parameter on the functional `{functional.__name__}`."
  169. ) from None
  170. if issubclass(input_type, PIL.Image.Image):
  171. # PIL kernels often have more correct annotations, since they are not limited by JIT. Thus, we don't check
  172. # them in the first place.
  173. functional_param._annotation = kernel_param._annotation = inspect.Parameter.empty
  174. assert functional_param == kernel_param
  175. def _check_transform_v1_compatibility(transform, input, *, rtol, atol):
  176. """If the transform defines the ``_v1_transform_cls`` attribute, checks if the transform has a public, static
  177. ``get_params`` method that is the v1 equivalent, the output is close to v1, is scriptable, and the scripted version
  178. can be called without error."""
  179. if type(input) is not torch.Tensor or isinstance(input, PIL.Image.Image):
  180. return
  181. v1_transform_cls = transform._v1_transform_cls
  182. if v1_transform_cls is None:
  183. return
  184. if hasattr(v1_transform_cls, "get_params"):
  185. assert type(transform).get_params is v1_transform_cls.get_params
  186. v1_transform = v1_transform_cls(**transform._extract_params_for_v1_transform())
  187. with freeze_rng_state():
  188. output_v2 = transform(input)
  189. with freeze_rng_state():
  190. output_v1 = v1_transform(input)
  191. assert_close(output_v2, output_v1, rtol=rtol, atol=atol)
  192. if isinstance(input, PIL.Image.Image):
  193. return
  194. _script(v1_transform)(input)
  195. def check_transform(transform, input, check_v1_compatibility=True):
  196. pickle.loads(pickle.dumps(transform))
  197. output = transform(input)
  198. assert isinstance(output, type(input))
  199. if isinstance(input, tv_tensors.BoundingBoxes):
  200. assert output.format == input.format
  201. if check_v1_compatibility:
  202. _check_transform_v1_compatibility(transform, input, **_to_tolerances(check_v1_compatibility))
  203. def transform_cls_to_functional(transform_cls, **transform_specific_kwargs):
  204. def wrapper(input, *args, **kwargs):
  205. transform = transform_cls(*args, **transform_specific_kwargs, **kwargs)
  206. return transform(input)
  207. wrapper.__name__ = transform_cls.__name__
  208. return wrapper
  209. def param_value_parametrization(**kwargs):
  210. """Helper function to turn
  211. @pytest.mark.parametrize(
  212. ("param", "value"),
  213. ("a", 1),
  214. ("a", 2),
  215. ("a", 3),
  216. ("b", -1.0)
  217. ("b", 1.0)
  218. )
  219. into
  220. @param_value_parametrization(a=[1, 2, 3], b=[-1.0, 1.0])
  221. """
  222. return pytest.mark.parametrize(
  223. ("param", "value"),
  224. [(param, value) for param, values in kwargs.items() for value in values],
  225. )
  226. def adapt_fill(value, *, dtype):
  227. """Adapt fill values in the range [0.0, 1.0] to the value range of the dtype"""
  228. if value is None:
  229. return value
  230. max_value = get_max_value(dtype)
  231. value_type = float if dtype.is_floating_point else int
  232. if isinstance(value, (int, float)):
  233. return value_type(value * max_value)
  234. elif isinstance(value, (list, tuple)):
  235. return type(value)(value_type(v * max_value) for v in value)
  236. else:
  237. raise ValueError(f"fill should be an int or float, or a list or tuple of the former, but got '{value}'.")
  238. EXHAUSTIVE_TYPE_FILLS = [
  239. None,
  240. 1,
  241. 0.5,
  242. [1],
  243. [0.2],
  244. (0,),
  245. (0.7,),
  246. [1, 0, 1],
  247. [0.1, 0.2, 0.3],
  248. (0, 1, 0),
  249. (0.9, 0.234, 0.314),
  250. ]
  251. CORRECTNESS_FILLS = [
  252. v for v in EXHAUSTIVE_TYPE_FILLS if v is None or isinstance(v, float) or (isinstance(v, list) and len(v) > 1)
  253. ]
  254. # We cannot use `list(transforms.InterpolationMode)` here, since it includes some PIL-only ones as well
  255. INTERPOLATION_MODES = [
  256. transforms.InterpolationMode.NEAREST,
  257. transforms.InterpolationMode.NEAREST_EXACT,
  258. transforms.InterpolationMode.BILINEAR,
  259. transforms.InterpolationMode.BICUBIC,
  260. ]
  261. @contextlib.contextmanager
  262. def assert_warns_antialias_default_value():
  263. with pytest.warns(UserWarning, match="The default value of the antialias parameter of all the resizing transforms"):
  264. yield
  265. def reference_affine_bounding_boxes_helper(bounding_boxes, *, affine_matrix, new_canvas_size=None, clamp=True):
  266. format = bounding_boxes.format
  267. canvas_size = new_canvas_size or bounding_boxes.canvas_size
  268. def affine_bounding_boxes(bounding_boxes):
  269. dtype = bounding_boxes.dtype
  270. device = bounding_boxes.device
  271. # Go to float before converting to prevent precision loss in case of CXCYWH -> XYXY and W or H is 1
  272. input_xyxy = F.convert_bounding_box_format(
  273. bounding_boxes.to(dtype=torch.float64, device="cpu", copy=True),
  274. old_format=format,
  275. new_format=tv_tensors.BoundingBoxFormat.XYXY,
  276. inplace=True,
  277. )
  278. x1, y1, x2, y2 = input_xyxy.squeeze(0).tolist()
  279. points = np.array(
  280. [
  281. [x1, y1, 1.0],
  282. [x2, y1, 1.0],
  283. [x1, y2, 1.0],
  284. [x2, y2, 1.0],
  285. ]
  286. )
  287. transformed_points = np.matmul(points, affine_matrix.astype(points.dtype).T)
  288. output_xyxy = torch.Tensor(
  289. [
  290. float(np.min(transformed_points[:, 0])),
  291. float(np.min(transformed_points[:, 1])),
  292. float(np.max(transformed_points[:, 0])),
  293. float(np.max(transformed_points[:, 1])),
  294. ]
  295. )
  296. output = F.convert_bounding_box_format(
  297. output_xyxy, old_format=tv_tensors.BoundingBoxFormat.XYXY, new_format=format
  298. )
  299. if clamp:
  300. # It is important to clamp before casting, especially for CXCYWH format, dtype=int64
  301. output = F.clamp_bounding_boxes(
  302. output,
  303. format=format,
  304. canvas_size=canvas_size,
  305. )
  306. else:
  307. # We leave the bounding box as float64 so the caller gets the full precision to perform any additional
  308. # operation
  309. dtype = output.dtype
  310. return output.to(dtype=dtype, device=device)
  311. return tv_tensors.BoundingBoxes(
  312. torch.cat([affine_bounding_boxes(b) for b in bounding_boxes.reshape(-1, 4).unbind()], dim=0).reshape(
  313. bounding_boxes.shape
  314. ),
  315. format=format,
  316. canvas_size=canvas_size,
  317. )
  318. # turns all warnings into errors for this module
  319. pytestmark = pytest.mark.filterwarnings("error")
  320. class TestResize:
  321. INPUT_SIZE = (17, 11)
  322. OUTPUT_SIZES = [17, [17], (17,), [12, 13], (12, 13)]
  323. def _make_max_size_kwarg(self, *, use_max_size, size):
  324. if use_max_size:
  325. if not (isinstance(size, int) or len(size) == 1):
  326. # This would result in an `ValueError`
  327. return None
  328. max_size = (size if isinstance(size, int) else size[0]) + 1
  329. else:
  330. max_size = None
  331. return dict(max_size=max_size)
  332. def _compute_output_size(self, *, input_size, size, max_size):
  333. if not (isinstance(size, int) or len(size) == 1):
  334. return tuple(size)
  335. if not isinstance(size, int):
  336. size = size[0]
  337. old_height, old_width = input_size
  338. ratio = old_width / old_height
  339. if ratio > 1:
  340. new_height = size
  341. new_width = int(ratio * new_height)
  342. else:
  343. new_width = size
  344. new_height = int(new_width / ratio)
  345. if max_size is not None and max(new_height, new_width) > max_size:
  346. # Need to recompute the aspect ratio, since it might have changed due to rounding
  347. ratio = new_width / new_height
  348. if ratio > 1:
  349. new_width = max_size
  350. new_height = int(new_width / ratio)
  351. else:
  352. new_height = max_size
  353. new_width = int(new_height * ratio)
  354. return new_height, new_width
  355. @pytest.mark.parametrize("size", OUTPUT_SIZES)
  356. @pytest.mark.parametrize("interpolation", INTERPOLATION_MODES)
  357. @pytest.mark.parametrize("use_max_size", [True, False])
  358. @pytest.mark.parametrize("antialias", [True, False])
  359. @pytest.mark.parametrize("dtype", [torch.float32, torch.uint8])
  360. @pytest.mark.parametrize("device", cpu_and_cuda())
  361. def test_kernel_image(self, size, interpolation, use_max_size, antialias, dtype, device):
  362. if not (max_size_kwarg := self._make_max_size_kwarg(use_max_size=use_max_size, size=size)):
  363. return
  364. # In contrast to CPU, there is no native `InterpolationMode.BICUBIC` implementation for uint8 images on CUDA.
  365. # Internally, it uses the float path. Thus, we need to test with an enormous tolerance here to account for that.
  366. atol = 30 if transforms.InterpolationMode.BICUBIC and dtype is torch.uint8 else 1
  367. check_cuda_vs_cpu_tolerances = dict(rtol=0, atol=atol / 255 if dtype.is_floating_point else atol)
  368. check_kernel(
  369. F.resize_image,
  370. make_image(self.INPUT_SIZE, dtype=dtype, device=device),
  371. size=size,
  372. interpolation=interpolation,
  373. **max_size_kwarg,
  374. antialias=antialias,
  375. check_cuda_vs_cpu=check_cuda_vs_cpu_tolerances,
  376. check_scripted_vs_eager=not isinstance(size, int),
  377. )
  378. @pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
  379. @pytest.mark.parametrize("size", OUTPUT_SIZES)
  380. @pytest.mark.parametrize("use_max_size", [True, False])
  381. @pytest.mark.parametrize("dtype", [torch.float32, torch.int64])
  382. @pytest.mark.parametrize("device", cpu_and_cuda())
  383. def test_kernel_bounding_boxes(self, format, size, use_max_size, dtype, device):
  384. if not (max_size_kwarg := self._make_max_size_kwarg(use_max_size=use_max_size, size=size)):
  385. return
  386. bounding_boxes = make_bounding_boxes(
  387. format=format,
  388. canvas_size=self.INPUT_SIZE,
  389. dtype=dtype,
  390. device=device,
  391. )
  392. check_kernel(
  393. F.resize_bounding_boxes,
  394. bounding_boxes,
  395. canvas_size=bounding_boxes.canvas_size,
  396. size=size,
  397. **max_size_kwarg,
  398. check_scripted_vs_eager=not isinstance(size, int),
  399. )
  400. @pytest.mark.parametrize("make_mask", [make_segmentation_mask, make_detection_mask])
  401. def test_kernel_mask(self, make_mask):
  402. check_kernel(F.resize_mask, make_mask(self.INPUT_SIZE), size=self.OUTPUT_SIZES[-1])
  403. def test_kernel_video(self):
  404. check_kernel(F.resize_video, make_video(self.INPUT_SIZE), size=self.OUTPUT_SIZES[-1], antialias=True)
  405. @pytest.mark.parametrize("size", OUTPUT_SIZES)
  406. @pytest.mark.parametrize(
  407. "make_input",
  408. [make_image_tensor, make_image_pil, make_image, make_bounding_boxes, make_segmentation_mask, make_video],
  409. )
  410. def test_functional(self, size, make_input):
  411. check_functional(
  412. F.resize,
  413. make_input(self.INPUT_SIZE),
  414. size=size,
  415. antialias=True,
  416. check_scripted_smoke=not isinstance(size, int),
  417. )
  418. @pytest.mark.parametrize(
  419. ("kernel", "input_type"),
  420. [
  421. (F.resize_image, torch.Tensor),
  422. (F._resize_image_pil, PIL.Image.Image),
  423. (F.resize_image, tv_tensors.Image),
  424. (F.resize_bounding_boxes, tv_tensors.BoundingBoxes),
  425. (F.resize_mask, tv_tensors.Mask),
  426. (F.resize_video, tv_tensors.Video),
  427. ],
  428. )
  429. def test_functional_signature(self, kernel, input_type):
  430. check_functional_kernel_signature_match(F.resize, kernel=kernel, input_type=input_type)
  431. @pytest.mark.parametrize("size", OUTPUT_SIZES)
  432. @pytest.mark.parametrize("device", cpu_and_cuda())
  433. @pytest.mark.parametrize(
  434. "make_input",
  435. [
  436. make_image_tensor,
  437. make_image_pil,
  438. make_image,
  439. make_bounding_boxes,
  440. make_segmentation_mask,
  441. make_detection_mask,
  442. make_video,
  443. ],
  444. )
  445. def test_transform(self, size, device, make_input):
  446. check_transform(
  447. transforms.Resize(size=size, antialias=True),
  448. make_input(self.INPUT_SIZE, device=device),
  449. # atol=1 due to Resize v2 is using native uint8 interpolate path for bilinear and nearest modes
  450. check_v1_compatibility=dict(rtol=0, atol=1),
  451. )
  452. def _check_output_size(self, input, output, *, size, max_size):
  453. assert tuple(F.get_size(output)) == self._compute_output_size(
  454. input_size=F.get_size(input), size=size, max_size=max_size
  455. )
  456. @pytest.mark.parametrize("size", OUTPUT_SIZES)
  457. # `InterpolationMode.NEAREST` is modeled after the buggy `INTER_NEAREST` interpolation of CV2.
  458. # The PIL equivalent of `InterpolationMode.NEAREST` is `InterpolationMode.NEAREST_EXACT`
  459. @pytest.mark.parametrize("interpolation", set(INTERPOLATION_MODES) - {transforms.InterpolationMode.NEAREST})
  460. @pytest.mark.parametrize("use_max_size", [True, False])
  461. @pytest.mark.parametrize("fn", [F.resize, transform_cls_to_functional(transforms.Resize)])
  462. def test_image_correctness(self, size, interpolation, use_max_size, fn):
  463. if not (max_size_kwarg := self._make_max_size_kwarg(use_max_size=use_max_size, size=size)):
  464. return
  465. image = make_image(self.INPUT_SIZE, dtype=torch.uint8)
  466. actual = fn(image, size=size, interpolation=interpolation, **max_size_kwarg, antialias=True)
  467. expected = F.to_image(F.resize(F.to_pil_image(image), size=size, interpolation=interpolation, **max_size_kwarg))
  468. self._check_output_size(image, actual, size=size, **max_size_kwarg)
  469. torch.testing.assert_close(actual, expected, atol=1, rtol=0)
  470. def _reference_resize_bounding_boxes(self, bounding_boxes, *, size, max_size=None):
  471. old_height, old_width = bounding_boxes.canvas_size
  472. new_height, new_width = self._compute_output_size(
  473. input_size=bounding_boxes.canvas_size, size=size, max_size=max_size
  474. )
  475. if (old_height, old_width) == (new_height, new_width):
  476. return bounding_boxes
  477. affine_matrix = np.array(
  478. [
  479. [new_width / old_width, 0, 0],
  480. [0, new_height / old_height, 0],
  481. ],
  482. )
  483. return reference_affine_bounding_boxes_helper(
  484. bounding_boxes,
  485. affine_matrix=affine_matrix,
  486. new_canvas_size=(new_height, new_width),
  487. )
  488. @pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
  489. @pytest.mark.parametrize("size", OUTPUT_SIZES)
  490. @pytest.mark.parametrize("use_max_size", [True, False])
  491. @pytest.mark.parametrize("fn", [F.resize, transform_cls_to_functional(transforms.Resize)])
  492. def test_bounding_boxes_correctness(self, format, size, use_max_size, fn):
  493. if not (max_size_kwarg := self._make_max_size_kwarg(use_max_size=use_max_size, size=size)):
  494. return
  495. bounding_boxes = make_bounding_boxes(format=format, canvas_size=self.INPUT_SIZE)
  496. actual = fn(bounding_boxes, size=size, **max_size_kwarg)
  497. expected = self._reference_resize_bounding_boxes(bounding_boxes, size=size, **max_size_kwarg)
  498. self._check_output_size(bounding_boxes, actual, size=size, **max_size_kwarg)
  499. torch.testing.assert_close(actual, expected)
  500. @pytest.mark.parametrize("interpolation", set(transforms.InterpolationMode) - set(INTERPOLATION_MODES))
  501. @pytest.mark.parametrize(
  502. "make_input",
  503. [make_image_tensor, make_image_pil, make_image, make_video],
  504. )
  505. def test_pil_interpolation_compat_smoke(self, interpolation, make_input):
  506. input = make_input(self.INPUT_SIZE)
  507. with (
  508. contextlib.nullcontext()
  509. if isinstance(input, PIL.Image.Image)
  510. # This error is triggered in PyTorch core
  511. else pytest.raises(NotImplementedError, match=f"got {interpolation.value.lower()}")
  512. ):
  513. F.resize(
  514. input,
  515. size=self.OUTPUT_SIZES[0],
  516. interpolation=interpolation,
  517. )
  518. def test_functional_pil_antialias_warning(self):
  519. with pytest.warns(UserWarning, match="Anti-alias option is always applied for PIL Image input"):
  520. F.resize(make_image_pil(self.INPUT_SIZE), size=self.OUTPUT_SIZES[0], antialias=False)
  521. @pytest.mark.parametrize("size", OUTPUT_SIZES)
  522. @pytest.mark.parametrize(
  523. "make_input",
  524. [
  525. make_image_tensor,
  526. make_image_pil,
  527. make_image,
  528. make_bounding_boxes,
  529. make_segmentation_mask,
  530. make_detection_mask,
  531. make_video,
  532. ],
  533. )
  534. def test_max_size_error(self, size, make_input):
  535. if isinstance(size, int) or len(size) == 1:
  536. max_size = (size if isinstance(size, int) else size[0]) - 1
  537. match = "must be strictly greater than the requested size"
  538. else:
  539. # value can be anything other than None
  540. max_size = -1
  541. match = "size should be an int or a sequence of length 1"
  542. with pytest.raises(ValueError, match=match):
  543. F.resize(make_input(self.INPUT_SIZE), size=size, max_size=max_size, antialias=True)
  544. @pytest.mark.parametrize("interpolation", INTERPOLATION_MODES)
  545. @pytest.mark.parametrize(
  546. "make_input",
  547. [make_image_tensor, make_image, make_video],
  548. )
  549. def test_antialias_warning(self, interpolation, make_input):
  550. with (
  551. assert_warns_antialias_default_value()
  552. if interpolation in {transforms.InterpolationMode.BILINEAR, transforms.InterpolationMode.BICUBIC}
  553. else assert_no_warnings()
  554. ):
  555. F.resize(
  556. make_input(self.INPUT_SIZE),
  557. size=self.OUTPUT_SIZES[0],
  558. interpolation=interpolation,
  559. )
  560. @pytest.mark.parametrize("interpolation", INTERPOLATION_MODES)
  561. @pytest.mark.parametrize(
  562. "make_input",
  563. [make_image_tensor, make_image_pil, make_image, make_video],
  564. )
  565. def test_interpolation_int(self, interpolation, make_input):
  566. input = make_input(self.INPUT_SIZE)
  567. # `InterpolationMode.NEAREST_EXACT` has no proper corresponding integer equivalent. Internally, we map it to
  568. # `0` to be the same as `InterpolationMode.NEAREST` for PIL. However, for the tensor backend there is a
  569. # difference and thus we don't test it here.
  570. if isinstance(input, torch.Tensor) and interpolation is transforms.InterpolationMode.NEAREST_EXACT:
  571. return
  572. expected = F.resize(input, size=self.OUTPUT_SIZES[0], interpolation=interpolation, antialias=True)
  573. actual = F.resize(
  574. input, size=self.OUTPUT_SIZES[0], interpolation=pil_modes_mapping[interpolation], antialias=True
  575. )
  576. assert_equal(actual, expected)
  577. def test_transform_unknown_size_error(self):
  578. with pytest.raises(ValueError, match="size can either be an integer or a list or tuple of one or two integers"):
  579. transforms.Resize(size=object())
  580. @pytest.mark.parametrize(
  581. "size", [min(INPUT_SIZE), [min(INPUT_SIZE)], (min(INPUT_SIZE),), list(INPUT_SIZE), tuple(INPUT_SIZE)]
  582. )
  583. @pytest.mark.parametrize(
  584. "make_input",
  585. [
  586. make_image_tensor,
  587. make_image_pil,
  588. make_image,
  589. make_bounding_boxes,
  590. make_segmentation_mask,
  591. make_detection_mask,
  592. make_video,
  593. ],
  594. )
  595. def test_noop(self, size, make_input):
  596. input = make_input(self.INPUT_SIZE)
  597. output = F.resize(input, size=F.get_size(input), antialias=True)
  598. # This identity check is not a requirement. It is here to avoid breaking the behavior by accident. If there
  599. # is a good reason to break this, feel free to downgrade to an equality check.
  600. if isinstance(input, tv_tensors.TVTensor):
  601. # We can't test identity directly, since that checks for the identity of the Python object. Since all
  602. # tv_tensors unwrap before a kernel and wrap again afterwards, the Python object changes. Thus, we check
  603. # that the underlying storage is the same
  604. assert output.data_ptr() == input.data_ptr()
  605. else:
  606. assert output is input
  607. @pytest.mark.parametrize(
  608. "make_input",
  609. [
  610. make_image_tensor,
  611. make_image_pil,
  612. make_image,
  613. make_bounding_boxes,
  614. make_segmentation_mask,
  615. make_detection_mask,
  616. make_video,
  617. ],
  618. )
  619. def test_no_regression_5405(self, make_input):
  620. # Checks that `max_size` is not ignored if `size == small_edge_size`
  621. # See https://github.com/pytorch/vision/issues/5405
  622. input = make_input(self.INPUT_SIZE)
  623. size = min(F.get_size(input))
  624. max_size = size + 1
  625. output = F.resize(input, size=size, max_size=max_size, antialias=True)
  626. assert max(F.get_size(output)) == max_size
  627. def _make_image(self, *args, batch_dims=(), memory_format=torch.contiguous_format, **kwargs):
  628. # torch.channels_last memory_format is only available for 4D tensors, i.e. (B, C, H, W). However, images coming
  629. # from PIL or our own I/O functions do not have a batch dimensions and are thus 3D, i.e. (C, H, W). Still, the
  630. # layout of the data in memory is channels last. To emulate this when a 3D input is requested here, we create
  631. # the image as 4D and create a view with the right shape afterwards. With this the layout in memory is channels
  632. # last although PyTorch doesn't recognizes it as such.
  633. emulate_channels_last = memory_format is torch.channels_last and len(batch_dims) != 1
  634. image = make_image(
  635. *args,
  636. batch_dims=(math.prod(batch_dims),) if emulate_channels_last else batch_dims,
  637. memory_format=memory_format,
  638. **kwargs,
  639. )
  640. if emulate_channels_last:
  641. image = tv_tensors.wrap(image.view(*batch_dims, *image.shape[-3:]), like=image)
  642. return image
  643. def _check_stride(self, image, *, memory_format):
  644. C, H, W = F.get_dimensions(image)
  645. if memory_format is torch.contiguous_format:
  646. expected_stride = (H * W, W, 1)
  647. elif memory_format is torch.channels_last:
  648. expected_stride = (1, W * C, C)
  649. else:
  650. raise ValueError(f"Unknown memory_format: {memory_format}")
  651. assert image.stride() == expected_stride
  652. # TODO: We can remove this test and related torchvision workaround
  653. # once we fixed related pytorch issue: https://github.com/pytorch/pytorch/issues/68430
  654. @pytest.mark.parametrize("interpolation", INTERPOLATION_MODES)
  655. @pytest.mark.parametrize("antialias", [True, False])
  656. @pytest.mark.parametrize("memory_format", [torch.contiguous_format, torch.channels_last])
  657. @pytest.mark.parametrize("dtype", [torch.uint8, torch.float32])
  658. @pytest.mark.parametrize("device", cpu_and_cuda())
  659. def test_kernel_image_memory_format_consistency(self, interpolation, antialias, memory_format, dtype, device):
  660. size = self.OUTPUT_SIZES[0]
  661. input = self._make_image(self.INPUT_SIZE, dtype=dtype, device=device, memory_format=memory_format)
  662. # Smoke test to make sure we aren't starting with wrong assumptions
  663. self._check_stride(input, memory_format=memory_format)
  664. output = F.resize_image(input, size=size, interpolation=interpolation, antialias=antialias)
  665. self._check_stride(output, memory_format=memory_format)
  666. def test_float16_no_rounding(self):
  667. # Make sure Resize() doesn't round float16 images
  668. # Non-regression test for https://github.com/pytorch/vision/issues/7667
  669. input = make_image_tensor(self.INPUT_SIZE, dtype=torch.float16)
  670. output = F.resize_image(input, size=self.OUTPUT_SIZES[0], antialias=True)
  671. assert output.dtype is torch.float16
  672. assert (output.round() - output).abs().sum() > 0
  673. class TestHorizontalFlip:
  674. @pytest.mark.parametrize("dtype", [torch.float32, torch.uint8])
  675. @pytest.mark.parametrize("device", cpu_and_cuda())
  676. def test_kernel_image(self, dtype, device):
  677. check_kernel(F.horizontal_flip_image, make_image(dtype=dtype, device=device))
  678. @pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
  679. @pytest.mark.parametrize("dtype", [torch.float32, torch.int64])
  680. @pytest.mark.parametrize("device", cpu_and_cuda())
  681. def test_kernel_bounding_boxes(self, format, dtype, device):
  682. bounding_boxes = make_bounding_boxes(format=format, dtype=dtype, device=device)
  683. check_kernel(
  684. F.horizontal_flip_bounding_boxes,
  685. bounding_boxes,
  686. format=format,
  687. canvas_size=bounding_boxes.canvas_size,
  688. )
  689. @pytest.mark.parametrize("make_mask", [make_segmentation_mask, make_detection_mask])
  690. def test_kernel_mask(self, make_mask):
  691. check_kernel(F.horizontal_flip_mask, make_mask())
  692. def test_kernel_video(self):
  693. check_kernel(F.horizontal_flip_video, make_video())
  694. @pytest.mark.parametrize(
  695. "make_input",
  696. [make_image_tensor, make_image_pil, make_image, make_bounding_boxes, make_segmentation_mask, make_video],
  697. )
  698. def test_functional(self, make_input):
  699. check_functional(F.horizontal_flip, make_input())
  700. @pytest.mark.parametrize(
  701. ("kernel", "input_type"),
  702. [
  703. (F.horizontal_flip_image, torch.Tensor),
  704. (F._horizontal_flip_image_pil, PIL.Image.Image),
  705. (F.horizontal_flip_image, tv_tensors.Image),
  706. (F.horizontal_flip_bounding_boxes, tv_tensors.BoundingBoxes),
  707. (F.horizontal_flip_mask, tv_tensors.Mask),
  708. (F.horizontal_flip_video, tv_tensors.Video),
  709. ],
  710. )
  711. def test_functional_signature(self, kernel, input_type):
  712. check_functional_kernel_signature_match(F.horizontal_flip, kernel=kernel, input_type=input_type)
  713. @pytest.mark.parametrize(
  714. "make_input",
  715. [make_image_tensor, make_image_pil, make_image, make_bounding_boxes, make_segmentation_mask, make_video],
  716. )
  717. @pytest.mark.parametrize("device", cpu_and_cuda())
  718. def test_transform(self, make_input, device):
  719. check_transform(transforms.RandomHorizontalFlip(p=1), make_input(device=device))
  720. @pytest.mark.parametrize(
  721. "fn", [F.horizontal_flip, transform_cls_to_functional(transforms.RandomHorizontalFlip, p=1)]
  722. )
  723. def test_image_correctness(self, fn):
  724. image = make_image(dtype=torch.uint8, device="cpu")
  725. actual = fn(image)
  726. expected = F.to_image(F.horizontal_flip(F.to_pil_image(image)))
  727. torch.testing.assert_close(actual, expected)
  728. def _reference_horizontal_flip_bounding_boxes(self, bounding_boxes):
  729. affine_matrix = np.array(
  730. [
  731. [-1, 0, bounding_boxes.canvas_size[1]],
  732. [0, 1, 0],
  733. ],
  734. )
  735. return reference_affine_bounding_boxes_helper(bounding_boxes, affine_matrix=affine_matrix)
  736. @pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
  737. @pytest.mark.parametrize(
  738. "fn", [F.horizontal_flip, transform_cls_to_functional(transforms.RandomHorizontalFlip, p=1)]
  739. )
  740. def test_bounding_boxes_correctness(self, format, fn):
  741. bounding_boxes = make_bounding_boxes(format=format)
  742. actual = fn(bounding_boxes)
  743. expected = self._reference_horizontal_flip_bounding_boxes(bounding_boxes)
  744. torch.testing.assert_close(actual, expected)
  745. @pytest.mark.parametrize(
  746. "make_input",
  747. [make_image_tensor, make_image_pil, make_image, make_bounding_boxes, make_segmentation_mask, make_video],
  748. )
  749. @pytest.mark.parametrize("device", cpu_and_cuda())
  750. def test_transform_noop(self, make_input, device):
  751. input = make_input(device=device)
  752. transform = transforms.RandomHorizontalFlip(p=0)
  753. output = transform(input)
  754. assert_equal(output, input)
  755. class TestAffine:
  756. _EXHAUSTIVE_TYPE_AFFINE_KWARGS = dict(
  757. # float, int
  758. angle=[-10.9, 18],
  759. # two-list of float, two-list of int, two-tuple of float, two-tuple of int
  760. translate=[[6.3, -0.6], [1, -3], (16.6, -6.6), (-2, 4)],
  761. # float
  762. scale=[0.5],
  763. # float, int,
  764. # one-list of float, one-list of int, one-tuple of float, one-tuple of int
  765. # two-list of float, two-list of int, two-tuple of float, two-tuple of int
  766. shear=[35.6, 38, [-37.7], [-23], (5.3,), (-52,), [5.4, 21.8], [-47, 51], (-11.2, 36.7), (8, -53)],
  767. # None
  768. # two-list of float, two-list of int, two-tuple of float, two-tuple of int
  769. center=[None, [1.2, 4.9], [-3, 1], (2.5, -4.7), (3, 2)],
  770. )
  771. # The special case for shear makes sure we pick a value that is supported while JIT scripting
  772. _MINIMAL_AFFINE_KWARGS = {
  773. k: vs[0] if k != "shear" else next(v for v in vs if isinstance(v, list))
  774. for k, vs in _EXHAUSTIVE_TYPE_AFFINE_KWARGS.items()
  775. }
  776. _CORRECTNESS_AFFINE_KWARGS = {
  777. k: [v for v in vs if v is None or isinstance(v, float) or (isinstance(v, list) and len(v) > 1)]
  778. for k, vs in _EXHAUSTIVE_TYPE_AFFINE_KWARGS.items()
  779. }
  780. _EXHAUSTIVE_TYPE_TRANSFORM_AFFINE_RANGES = dict(
  781. degrees=[30, (-15, 20)],
  782. translate=[None, (0.5, 0.5)],
  783. scale=[None, (0.75, 1.25)],
  784. shear=[None, (12, 30, -17, 5), 10, (-5, 12)],
  785. )
  786. _CORRECTNESS_TRANSFORM_AFFINE_RANGES = {
  787. k: next(v for v in vs if v is not None) for k, vs in _EXHAUSTIVE_TYPE_TRANSFORM_AFFINE_RANGES.items()
  788. }
  789. def _check_kernel(self, kernel, input, *args, **kwargs):
  790. kwargs_ = self._MINIMAL_AFFINE_KWARGS.copy()
  791. kwargs_.update(kwargs)
  792. check_kernel(kernel, input, *args, **kwargs_)
  793. @param_value_parametrization(
  794. angle=_EXHAUSTIVE_TYPE_AFFINE_KWARGS["angle"],
  795. translate=_EXHAUSTIVE_TYPE_AFFINE_KWARGS["translate"],
  796. shear=_EXHAUSTIVE_TYPE_AFFINE_KWARGS["shear"],
  797. center=_EXHAUSTIVE_TYPE_AFFINE_KWARGS["center"],
  798. interpolation=[transforms.InterpolationMode.NEAREST, transforms.InterpolationMode.BILINEAR],
  799. fill=EXHAUSTIVE_TYPE_FILLS,
  800. )
  801. @pytest.mark.parametrize("dtype", [torch.float32, torch.uint8])
  802. @pytest.mark.parametrize("device", cpu_and_cuda())
  803. def test_kernel_image(self, param, value, dtype, device):
  804. if param == "fill":
  805. value = adapt_fill(value, dtype=dtype)
  806. self._check_kernel(
  807. F.affine_image,
  808. make_image(dtype=dtype, device=device),
  809. **{param: value},
  810. check_scripted_vs_eager=not (param in {"shear", "fill"} and isinstance(value, (int, float))),
  811. check_cuda_vs_cpu=dict(atol=1, rtol=0)
  812. if dtype is torch.uint8 and param == "interpolation" and value is transforms.InterpolationMode.BILINEAR
  813. else True,
  814. )
  815. @param_value_parametrization(
  816. angle=_EXHAUSTIVE_TYPE_AFFINE_KWARGS["angle"],
  817. translate=_EXHAUSTIVE_TYPE_AFFINE_KWARGS["translate"],
  818. shear=_EXHAUSTIVE_TYPE_AFFINE_KWARGS["shear"],
  819. center=_EXHAUSTIVE_TYPE_AFFINE_KWARGS["center"],
  820. )
  821. @pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
  822. @pytest.mark.parametrize("dtype", [torch.float32, torch.int64])
  823. @pytest.mark.parametrize("device", cpu_and_cuda())
  824. def test_kernel_bounding_boxes(self, param, value, format, dtype, device):
  825. bounding_boxes = make_bounding_boxes(format=format, dtype=dtype, device=device)
  826. self._check_kernel(
  827. F.affine_bounding_boxes,
  828. bounding_boxes,
  829. format=format,
  830. canvas_size=bounding_boxes.canvas_size,
  831. **{param: value},
  832. check_scripted_vs_eager=not (param == "shear" and isinstance(value, (int, float))),
  833. )
  834. @pytest.mark.parametrize("make_mask", [make_segmentation_mask, make_detection_mask])
  835. def test_kernel_mask(self, make_mask):
  836. self._check_kernel(F.affine_mask, make_mask())
  837. def test_kernel_video(self):
  838. self._check_kernel(F.affine_video, make_video())
  839. @pytest.mark.parametrize(
  840. "make_input",
  841. [make_image_tensor, make_image_pil, make_image, make_bounding_boxes, make_segmentation_mask, make_video],
  842. )
  843. def test_functional(self, make_input):
  844. check_functional(F.affine, make_input(), **self._MINIMAL_AFFINE_KWARGS)
  845. @pytest.mark.parametrize(
  846. ("kernel", "input_type"),
  847. [
  848. (F.affine_image, torch.Tensor),
  849. (F._affine_image_pil, PIL.Image.Image),
  850. (F.affine_image, tv_tensors.Image),
  851. (F.affine_bounding_boxes, tv_tensors.BoundingBoxes),
  852. (F.affine_mask, tv_tensors.Mask),
  853. (F.affine_video, tv_tensors.Video),
  854. ],
  855. )
  856. def test_functional_signature(self, kernel, input_type):
  857. check_functional_kernel_signature_match(F.affine, kernel=kernel, input_type=input_type)
  858. @pytest.mark.parametrize(
  859. "make_input",
  860. [make_image_tensor, make_image_pil, make_image, make_bounding_boxes, make_segmentation_mask, make_video],
  861. )
  862. @pytest.mark.parametrize("device", cpu_and_cuda())
  863. def test_transform(self, make_input, device):
  864. input = make_input(device=device)
  865. check_transform(transforms.RandomAffine(**self._CORRECTNESS_TRANSFORM_AFFINE_RANGES), input)
  866. @pytest.mark.parametrize("angle", _CORRECTNESS_AFFINE_KWARGS["angle"])
  867. @pytest.mark.parametrize("translate", _CORRECTNESS_AFFINE_KWARGS["translate"])
  868. @pytest.mark.parametrize("scale", _CORRECTNESS_AFFINE_KWARGS["scale"])
  869. @pytest.mark.parametrize("shear", _CORRECTNESS_AFFINE_KWARGS["shear"])
  870. @pytest.mark.parametrize("center", _CORRECTNESS_AFFINE_KWARGS["center"])
  871. @pytest.mark.parametrize(
  872. "interpolation", [transforms.InterpolationMode.NEAREST, transforms.InterpolationMode.BILINEAR]
  873. )
  874. @pytest.mark.parametrize("fill", CORRECTNESS_FILLS)
  875. def test_functional_image_correctness(self, angle, translate, scale, shear, center, interpolation, fill):
  876. image = make_image(dtype=torch.uint8, device="cpu")
  877. fill = adapt_fill(fill, dtype=torch.uint8)
  878. actual = F.affine(
  879. image,
  880. angle=angle,
  881. translate=translate,
  882. scale=scale,
  883. shear=shear,
  884. center=center,
  885. interpolation=interpolation,
  886. fill=fill,
  887. )
  888. expected = F.to_image(
  889. F.affine(
  890. F.to_pil_image(image),
  891. angle=angle,
  892. translate=translate,
  893. scale=scale,
  894. shear=shear,
  895. center=center,
  896. interpolation=interpolation,
  897. fill=fill,
  898. )
  899. )
  900. mae = (actual.float() - expected.float()).abs().mean()
  901. assert mae < 2 if interpolation is transforms.InterpolationMode.NEAREST else 8
  902. @pytest.mark.parametrize("center", _CORRECTNESS_AFFINE_KWARGS["center"])
  903. @pytest.mark.parametrize(
  904. "interpolation", [transforms.InterpolationMode.NEAREST, transforms.InterpolationMode.BILINEAR]
  905. )
  906. @pytest.mark.parametrize("fill", CORRECTNESS_FILLS)
  907. @pytest.mark.parametrize("seed", list(range(5)))
  908. def test_transform_image_correctness(self, center, interpolation, fill, seed):
  909. image = make_image(dtype=torch.uint8, device="cpu")
  910. fill = adapt_fill(fill, dtype=torch.uint8)
  911. transform = transforms.RandomAffine(
  912. **self._CORRECTNESS_TRANSFORM_AFFINE_RANGES, center=center, interpolation=interpolation, fill=fill
  913. )
  914. torch.manual_seed(seed)
  915. actual = transform(image)
  916. torch.manual_seed(seed)
  917. expected = F.to_image(transform(F.to_pil_image(image)))
  918. mae = (actual.float() - expected.float()).abs().mean()
  919. assert mae < 2 if interpolation is transforms.InterpolationMode.NEAREST else 8
  920. def _compute_affine_matrix(self, *, angle, translate, scale, shear, center):
  921. rot = math.radians(angle)
  922. cx, cy = center
  923. tx, ty = translate
  924. sx, sy = [math.radians(s) for s in ([shear, 0.0] if isinstance(shear, (int, float)) else shear)]
  925. c_matrix = np.array([[1, 0, cx], [0, 1, cy], [0, 0, 1]])
  926. t_matrix = np.array([[1, 0, tx], [0, 1, ty], [0, 0, 1]])
  927. c_matrix_inv = np.linalg.inv(c_matrix)
  928. rs_matrix = np.array(
  929. [
  930. [scale * math.cos(rot), -scale * math.sin(rot), 0],
  931. [scale * math.sin(rot), scale * math.cos(rot), 0],
  932. [0, 0, 1],
  933. ]
  934. )
  935. shear_x_matrix = np.array([[1, -math.tan(sx), 0], [0, 1, 0], [0, 0, 1]])
  936. shear_y_matrix = np.array([[1, 0, 0], [-math.tan(sy), 1, 0], [0, 0, 1]])
  937. rss_matrix = np.matmul(rs_matrix, np.matmul(shear_y_matrix, shear_x_matrix))
  938. true_matrix = np.matmul(t_matrix, np.matmul(c_matrix, np.matmul(rss_matrix, c_matrix_inv)))
  939. return true_matrix[:2, :]
  940. def _reference_affine_bounding_boxes(self, bounding_boxes, *, angle, translate, scale, shear, center):
  941. if center is None:
  942. center = [s * 0.5 for s in bounding_boxes.canvas_size[::-1]]
  943. return reference_affine_bounding_boxes_helper(
  944. bounding_boxes,
  945. affine_matrix=self._compute_affine_matrix(
  946. angle=angle, translate=translate, scale=scale, shear=shear, center=center
  947. ),
  948. )
  949. @pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
  950. @pytest.mark.parametrize("angle", _CORRECTNESS_AFFINE_KWARGS["angle"])
  951. @pytest.mark.parametrize("translate", _CORRECTNESS_AFFINE_KWARGS["translate"])
  952. @pytest.mark.parametrize("scale", _CORRECTNESS_AFFINE_KWARGS["scale"])
  953. @pytest.mark.parametrize("shear", _CORRECTNESS_AFFINE_KWARGS["shear"])
  954. @pytest.mark.parametrize("center", _CORRECTNESS_AFFINE_KWARGS["center"])
  955. def test_functional_bounding_boxes_correctness(self, format, angle, translate, scale, shear, center):
  956. bounding_boxes = make_bounding_boxes(format=format)
  957. actual = F.affine(
  958. bounding_boxes,
  959. angle=angle,
  960. translate=translate,
  961. scale=scale,
  962. shear=shear,
  963. center=center,
  964. )
  965. expected = self._reference_affine_bounding_boxes(
  966. bounding_boxes,
  967. angle=angle,
  968. translate=translate,
  969. scale=scale,
  970. shear=shear,
  971. center=center,
  972. )
  973. torch.testing.assert_close(actual, expected)
  974. @pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
  975. @pytest.mark.parametrize("center", _CORRECTNESS_AFFINE_KWARGS["center"])
  976. @pytest.mark.parametrize("seed", list(range(5)))
  977. def test_transform_bounding_boxes_correctness(self, format, center, seed):
  978. bounding_boxes = make_bounding_boxes(format=format)
  979. transform = transforms.RandomAffine(**self._CORRECTNESS_TRANSFORM_AFFINE_RANGES, center=center)
  980. torch.manual_seed(seed)
  981. params = transform._get_params([bounding_boxes])
  982. torch.manual_seed(seed)
  983. actual = transform(bounding_boxes)
  984. expected = self._reference_affine_bounding_boxes(bounding_boxes, **params, center=center)
  985. torch.testing.assert_close(actual, expected)
  986. @pytest.mark.parametrize("degrees", _EXHAUSTIVE_TYPE_TRANSFORM_AFFINE_RANGES["degrees"])
  987. @pytest.mark.parametrize("translate", _EXHAUSTIVE_TYPE_TRANSFORM_AFFINE_RANGES["translate"])
  988. @pytest.mark.parametrize("scale", _EXHAUSTIVE_TYPE_TRANSFORM_AFFINE_RANGES["scale"])
  989. @pytest.mark.parametrize("shear", _EXHAUSTIVE_TYPE_TRANSFORM_AFFINE_RANGES["shear"])
  990. @pytest.mark.parametrize("seed", list(range(10)))
  991. def test_transform_get_params_bounds(self, degrees, translate, scale, shear, seed):
  992. image = make_image()
  993. height, width = F.get_size(image)
  994. transform = transforms.RandomAffine(degrees=degrees, translate=translate, scale=scale, shear=shear)
  995. torch.manual_seed(seed)
  996. params = transform._get_params([image])
  997. if isinstance(degrees, (int, float)):
  998. assert -degrees <= params["angle"] <= degrees
  999. else:
  1000. assert degrees[0] <= params["angle"] <= degrees[1]
  1001. if translate is not None:
  1002. width_max = int(round(translate[0] * width))
  1003. height_max = int(round(translate[1] * height))
  1004. assert -width_max <= params["translate"][0] <= width_max
  1005. assert -height_max <= params["translate"][1] <= height_max
  1006. else:
  1007. assert params["translate"] == (0, 0)
  1008. if scale is not None:
  1009. assert scale[0] <= params["scale"] <= scale[1]
  1010. else:
  1011. assert params["scale"] == 1.0
  1012. if shear is not None:
  1013. if isinstance(shear, (int, float)):
  1014. assert -shear <= params["shear"][0] <= shear
  1015. assert params["shear"][1] == 0.0
  1016. elif len(shear) == 2:
  1017. assert shear[0] <= params["shear"][0] <= shear[1]
  1018. assert params["shear"][1] == 0.0
  1019. elif len(shear) == 4:
  1020. assert shear[0] <= params["shear"][0] <= shear[1]
  1021. assert shear[2] <= params["shear"][1] <= shear[3]
  1022. else:
  1023. assert params["shear"] == (0, 0)
  1024. @pytest.mark.parametrize("param", ["degrees", "translate", "scale", "shear", "center"])
  1025. @pytest.mark.parametrize("value", [0, [0], [0, 0, 0]])
  1026. def test_transform_sequence_len_errors(self, param, value):
  1027. if param in {"degrees", "shear"} and not isinstance(value, list):
  1028. return
  1029. kwargs = {param: value}
  1030. if param != "degrees":
  1031. kwargs["degrees"] = 0
  1032. with pytest.raises(
  1033. ValueError if isinstance(value, list) else TypeError, match=f"{param} should be a sequence of length 2"
  1034. ):
  1035. transforms.RandomAffine(**kwargs)
  1036. def test_transform_negative_degrees_error(self):
  1037. with pytest.raises(ValueError, match="If degrees is a single number, it must be positive"):
  1038. transforms.RandomAffine(degrees=-1)
  1039. @pytest.mark.parametrize("translate", [[-1, 0], [2, 0], [-1, 2]])
  1040. def test_transform_translate_range_error(self, translate):
  1041. with pytest.raises(ValueError, match="translation values should be between 0 and 1"):
  1042. transforms.RandomAffine(degrees=0, translate=translate)
  1043. @pytest.mark.parametrize("scale", [[-1, 0], [0, -1], [-1, -1]])
  1044. def test_transform_scale_range_error(self, scale):
  1045. with pytest.raises(ValueError, match="scale values should be positive"):
  1046. transforms.RandomAffine(degrees=0, scale=scale)
  1047. def test_transform_negative_shear_error(self):
  1048. with pytest.raises(ValueError, match="If shear is a single number, it must be positive"):
  1049. transforms.RandomAffine(degrees=0, shear=-1)
  1050. def test_transform_unknown_fill_error(self):
  1051. with pytest.raises(TypeError, match="Got inappropriate fill arg"):
  1052. transforms.RandomAffine(degrees=0, fill="fill")
  1053. class TestVerticalFlip:
  1054. @pytest.mark.parametrize("dtype", [torch.float32, torch.uint8])
  1055. @pytest.mark.parametrize("device", cpu_and_cuda())
  1056. def test_kernel_image(self, dtype, device):
  1057. check_kernel(F.vertical_flip_image, make_image(dtype=dtype, device=device))
  1058. @pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
  1059. @pytest.mark.parametrize("dtype", [torch.float32, torch.int64])
  1060. @pytest.mark.parametrize("device", cpu_and_cuda())
  1061. def test_kernel_bounding_boxes(self, format, dtype, device):
  1062. bounding_boxes = make_bounding_boxes(format=format, dtype=dtype, device=device)
  1063. check_kernel(
  1064. F.vertical_flip_bounding_boxes,
  1065. bounding_boxes,
  1066. format=format,
  1067. canvas_size=bounding_boxes.canvas_size,
  1068. )
  1069. @pytest.mark.parametrize("make_mask", [make_segmentation_mask, make_detection_mask])
  1070. def test_kernel_mask(self, make_mask):
  1071. check_kernel(F.vertical_flip_mask, make_mask())
  1072. def test_kernel_video(self):
  1073. check_kernel(F.vertical_flip_video, make_video())
  1074. @pytest.mark.parametrize(
  1075. "make_input",
  1076. [make_image_tensor, make_image_pil, make_image, make_bounding_boxes, make_segmentation_mask, make_video],
  1077. )
  1078. def test_functional(self, make_input):
  1079. check_functional(F.vertical_flip, make_input())
  1080. @pytest.mark.parametrize(
  1081. ("kernel", "input_type"),
  1082. [
  1083. (F.vertical_flip_image, torch.Tensor),
  1084. (F._vertical_flip_image_pil, PIL.Image.Image),
  1085. (F.vertical_flip_image, tv_tensors.Image),
  1086. (F.vertical_flip_bounding_boxes, tv_tensors.BoundingBoxes),
  1087. (F.vertical_flip_mask, tv_tensors.Mask),
  1088. (F.vertical_flip_video, tv_tensors.Video),
  1089. ],
  1090. )
  1091. def test_functional_signature(self, kernel, input_type):
  1092. check_functional_kernel_signature_match(F.vertical_flip, kernel=kernel, input_type=input_type)
  1093. @pytest.mark.parametrize(
  1094. "make_input",
  1095. [make_image_tensor, make_image_pil, make_image, make_bounding_boxes, make_segmentation_mask, make_video],
  1096. )
  1097. @pytest.mark.parametrize("device", cpu_and_cuda())
  1098. def test_transform(self, make_input, device):
  1099. check_transform(transforms.RandomVerticalFlip(p=1), make_input(device=device))
  1100. @pytest.mark.parametrize("fn", [F.vertical_flip, transform_cls_to_functional(transforms.RandomVerticalFlip, p=1)])
  1101. def test_image_correctness(self, fn):
  1102. image = make_image(dtype=torch.uint8, device="cpu")
  1103. actual = fn(image)
  1104. expected = F.to_image(F.vertical_flip(F.to_pil_image(image)))
  1105. torch.testing.assert_close(actual, expected)
  1106. def _reference_vertical_flip_bounding_boxes(self, bounding_boxes):
  1107. affine_matrix = np.array(
  1108. [
  1109. [1, 0, 0],
  1110. [0, -1, bounding_boxes.canvas_size[0]],
  1111. ],
  1112. )
  1113. return reference_affine_bounding_boxes_helper(bounding_boxes, affine_matrix=affine_matrix)
  1114. @pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
  1115. @pytest.mark.parametrize("fn", [F.vertical_flip, transform_cls_to_functional(transforms.RandomVerticalFlip, p=1)])
  1116. def test_bounding_boxes_correctness(self, format, fn):
  1117. bounding_boxes = make_bounding_boxes(format=format)
  1118. actual = fn(bounding_boxes)
  1119. expected = self._reference_vertical_flip_bounding_boxes(bounding_boxes)
  1120. torch.testing.assert_close(actual, expected)
  1121. @pytest.mark.parametrize(
  1122. "make_input",
  1123. [make_image_tensor, make_image_pil, make_image, make_bounding_boxes, make_segmentation_mask, make_video],
  1124. )
  1125. @pytest.mark.parametrize("device", cpu_and_cuda())
  1126. def test_transform_noop(self, make_input, device):
  1127. input = make_input(device=device)
  1128. transform = transforms.RandomVerticalFlip(p=0)
  1129. output = transform(input)
  1130. assert_equal(output, input)
  1131. class TestRotate:
  1132. _EXHAUSTIVE_TYPE_AFFINE_KWARGS = dict(
  1133. # float, int
  1134. angle=[-10.9, 18],
  1135. # None
  1136. # two-list of float, two-list of int, two-tuple of float, two-tuple of int
  1137. center=[None, [1.2, 4.9], [-3, 1], (2.5, -4.7), (3, 2)],
  1138. )
  1139. _MINIMAL_AFFINE_KWARGS = {k: vs[0] for k, vs in _EXHAUSTIVE_TYPE_AFFINE_KWARGS.items()}
  1140. _CORRECTNESS_AFFINE_KWARGS = {
  1141. k: [v for v in vs if v is None or isinstance(v, float) or isinstance(v, list)]
  1142. for k, vs in _EXHAUSTIVE_TYPE_AFFINE_KWARGS.items()
  1143. }
  1144. _EXHAUSTIVE_TYPE_TRANSFORM_AFFINE_RANGES = dict(
  1145. degrees=[30, (-15, 20)],
  1146. )
  1147. _CORRECTNESS_TRANSFORM_AFFINE_RANGES = {k: vs[0] for k, vs in _EXHAUSTIVE_TYPE_TRANSFORM_AFFINE_RANGES.items()}
  1148. @param_value_parametrization(
  1149. angle=_EXHAUSTIVE_TYPE_AFFINE_KWARGS["angle"],
  1150. interpolation=[transforms.InterpolationMode.NEAREST, transforms.InterpolationMode.BILINEAR],
  1151. expand=[False, True],
  1152. center=_EXHAUSTIVE_TYPE_AFFINE_KWARGS["center"],
  1153. fill=EXHAUSTIVE_TYPE_FILLS,
  1154. )
  1155. @pytest.mark.parametrize("dtype", [torch.float32, torch.uint8])
  1156. @pytest.mark.parametrize("device", cpu_and_cuda())
  1157. def test_kernel_image(self, param, value, dtype, device):
  1158. kwargs = {param: value}
  1159. if param != "angle":
  1160. kwargs["angle"] = self._MINIMAL_AFFINE_KWARGS["angle"]
  1161. check_kernel(
  1162. F.rotate_image,
  1163. make_image(dtype=dtype, device=device),
  1164. **kwargs,
  1165. check_scripted_vs_eager=not (param == "fill" and isinstance(value, (int, float))),
  1166. )
  1167. @param_value_parametrization(
  1168. angle=_EXHAUSTIVE_TYPE_AFFINE_KWARGS["angle"],
  1169. expand=[False, True],
  1170. center=_EXHAUSTIVE_TYPE_AFFINE_KWARGS["center"],
  1171. )
  1172. @pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
  1173. @pytest.mark.parametrize("dtype", [torch.float32, torch.uint8])
  1174. @pytest.mark.parametrize("device", cpu_and_cuda())
  1175. def test_kernel_bounding_boxes(self, param, value, format, dtype, device):
  1176. kwargs = {param: value}
  1177. if param != "angle":
  1178. kwargs["angle"] = self._MINIMAL_AFFINE_KWARGS["angle"]
  1179. bounding_boxes = make_bounding_boxes(format=format, dtype=dtype, device=device)
  1180. check_kernel(
  1181. F.rotate_bounding_boxes,
  1182. bounding_boxes,
  1183. format=format,
  1184. canvas_size=bounding_boxes.canvas_size,
  1185. **kwargs,
  1186. )
  1187. @pytest.mark.parametrize("make_mask", [make_segmentation_mask, make_detection_mask])
  1188. def test_kernel_mask(self, make_mask):
  1189. check_kernel(F.rotate_mask, make_mask(), **self._MINIMAL_AFFINE_KWARGS)
  1190. def test_kernel_video(self):
  1191. check_kernel(F.rotate_video, make_video(), **self._MINIMAL_AFFINE_KWARGS)
  1192. @pytest.mark.parametrize(
  1193. "make_input",
  1194. [make_image_tensor, make_image_pil, make_image, make_bounding_boxes, make_segmentation_mask, make_video],
  1195. )
  1196. def test_functional(self, make_input):
  1197. check_functional(F.rotate, make_input(), **self._MINIMAL_AFFINE_KWARGS)
  1198. @pytest.mark.parametrize(
  1199. ("kernel", "input_type"),
  1200. [
  1201. (F.rotate_image, torch.Tensor),
  1202. (F._rotate_image_pil, PIL.Image.Image),
  1203. (F.rotate_image, tv_tensors.Image),
  1204. (F.rotate_bounding_boxes, tv_tensors.BoundingBoxes),
  1205. (F.rotate_mask, tv_tensors.Mask),
  1206. (F.rotate_video, tv_tensors.Video),
  1207. ],
  1208. )
  1209. def test_functional_signature(self, kernel, input_type):
  1210. check_functional_kernel_signature_match(F.rotate, kernel=kernel, input_type=input_type)
  1211. @pytest.mark.parametrize(
  1212. "make_input",
  1213. [make_image_tensor, make_image_pil, make_image, make_bounding_boxes, make_segmentation_mask, make_video],
  1214. )
  1215. @pytest.mark.parametrize("device", cpu_and_cuda())
  1216. def test_transform(self, make_input, device):
  1217. check_transform(
  1218. transforms.RandomRotation(**self._CORRECTNESS_TRANSFORM_AFFINE_RANGES), make_input(device=device)
  1219. )
  1220. @pytest.mark.parametrize("angle", _CORRECTNESS_AFFINE_KWARGS["angle"])
  1221. @pytest.mark.parametrize("center", _CORRECTNESS_AFFINE_KWARGS["center"])
  1222. @pytest.mark.parametrize(
  1223. "interpolation", [transforms.InterpolationMode.NEAREST, transforms.InterpolationMode.BILINEAR]
  1224. )
  1225. @pytest.mark.parametrize("expand", [False, True])
  1226. @pytest.mark.parametrize("fill", CORRECTNESS_FILLS)
  1227. def test_functional_image_correctness(self, angle, center, interpolation, expand, fill):
  1228. image = make_image(dtype=torch.uint8, device="cpu")
  1229. fill = adapt_fill(fill, dtype=torch.uint8)
  1230. actual = F.rotate(image, angle=angle, center=center, interpolation=interpolation, expand=expand, fill=fill)
  1231. expected = F.to_image(
  1232. F.rotate(
  1233. F.to_pil_image(image), angle=angle, center=center, interpolation=interpolation, expand=expand, fill=fill
  1234. )
  1235. )
  1236. mae = (actual.float() - expected.float()).abs().mean()
  1237. assert mae < 1 if interpolation is transforms.InterpolationMode.NEAREST else 6
  1238. @pytest.mark.parametrize("center", _CORRECTNESS_AFFINE_KWARGS["center"])
  1239. @pytest.mark.parametrize(
  1240. "interpolation", [transforms.InterpolationMode.NEAREST, transforms.InterpolationMode.BILINEAR]
  1241. )
  1242. @pytest.mark.parametrize("expand", [False, True])
  1243. @pytest.mark.parametrize("fill", CORRECTNESS_FILLS)
  1244. @pytest.mark.parametrize("seed", list(range(5)))
  1245. def test_transform_image_correctness(self, center, interpolation, expand, fill, seed):
  1246. image = make_image(dtype=torch.uint8, device="cpu")
  1247. fill = adapt_fill(fill, dtype=torch.uint8)
  1248. transform = transforms.RandomRotation(
  1249. **self._CORRECTNESS_TRANSFORM_AFFINE_RANGES,
  1250. center=center,
  1251. interpolation=interpolation,
  1252. expand=expand,
  1253. fill=fill,
  1254. )
  1255. torch.manual_seed(seed)
  1256. actual = transform(image)
  1257. torch.manual_seed(seed)
  1258. expected = F.to_image(transform(F.to_pil_image(image)))
  1259. mae = (actual.float() - expected.float()).abs().mean()
  1260. assert mae < 1 if interpolation is transforms.InterpolationMode.NEAREST else 6
  1261. def _compute_output_canvas_size(self, *, expand, canvas_size, affine_matrix):
  1262. if not expand:
  1263. return canvas_size, (0.0, 0.0)
  1264. input_height, input_width = canvas_size
  1265. input_image_frame = np.array(
  1266. [
  1267. [0.0, 0.0, 1.0],
  1268. [0.0, input_height, 1.0],
  1269. [input_width, input_height, 1.0],
  1270. [input_width, 0.0, 1.0],
  1271. ],
  1272. dtype=np.float64,
  1273. )
  1274. output_image_frame = np.matmul(input_image_frame, affine_matrix.astype(input_image_frame.dtype).T)
  1275. recenter_x = float(np.min(output_image_frame[:, 0]))
  1276. recenter_y = float(np.min(output_image_frame[:, 1]))
  1277. output_width = int(np.max(output_image_frame[:, 0]) - recenter_x)
  1278. output_height = int(np.max(output_image_frame[:, 1]) - recenter_y)
  1279. return (output_height, output_width), (recenter_x, recenter_y)
  1280. def _recenter_bounding_boxes_after_expand(self, bounding_boxes, *, recenter_xy):
  1281. x, y = recenter_xy
  1282. if bounding_boxes.format is tv_tensors.BoundingBoxFormat.XYXY:
  1283. translate = [x, y, x, y]
  1284. else:
  1285. translate = [x, y, 0.0, 0.0]
  1286. return tv_tensors.wrap(
  1287. (bounding_boxes.to(torch.float64) - torch.tensor(translate)).to(bounding_boxes.dtype), like=bounding_boxes
  1288. )
  1289. def _reference_rotate_bounding_boxes(self, bounding_boxes, *, angle, expand, center):
  1290. if center is None:
  1291. center = [s * 0.5 for s in bounding_boxes.canvas_size[::-1]]
  1292. cx, cy = center
  1293. a = np.cos(angle * np.pi / 180.0)
  1294. b = np.sin(angle * np.pi / 180.0)
  1295. affine_matrix = np.array(
  1296. [
  1297. [a, b, cx - cx * a - b * cy],
  1298. [-b, a, cy + cx * b - a * cy],
  1299. ],
  1300. )
  1301. new_canvas_size, recenter_xy = self._compute_output_canvas_size(
  1302. expand=expand, canvas_size=bounding_boxes.canvas_size, affine_matrix=affine_matrix
  1303. )
  1304. output = reference_affine_bounding_boxes_helper(
  1305. bounding_boxes,
  1306. affine_matrix=affine_matrix,
  1307. new_canvas_size=new_canvas_size,
  1308. clamp=False,
  1309. )
  1310. return F.clamp_bounding_boxes(self._recenter_bounding_boxes_after_expand(output, recenter_xy=recenter_xy)).to(
  1311. bounding_boxes
  1312. )
  1313. @pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
  1314. @pytest.mark.parametrize("angle", _CORRECTNESS_AFFINE_KWARGS["angle"])
  1315. @pytest.mark.parametrize("expand", [False, True])
  1316. @pytest.mark.parametrize("center", _CORRECTNESS_AFFINE_KWARGS["center"])
  1317. def test_functional_bounding_boxes_correctness(self, format, angle, expand, center):
  1318. bounding_boxes = make_bounding_boxes(format=format)
  1319. actual = F.rotate(bounding_boxes, angle=angle, expand=expand, center=center)
  1320. expected = self._reference_rotate_bounding_boxes(bounding_boxes, angle=angle, expand=expand, center=center)
  1321. torch.testing.assert_close(actual, expected)
  1322. torch.testing.assert_close(F.get_size(actual), F.get_size(expected), atol=2 if expand else 0, rtol=0)
  1323. @pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
  1324. @pytest.mark.parametrize("expand", [False, True])
  1325. @pytest.mark.parametrize("center", _CORRECTNESS_AFFINE_KWARGS["center"])
  1326. @pytest.mark.parametrize("seed", list(range(5)))
  1327. def test_transform_bounding_boxes_correctness(self, format, expand, center, seed):
  1328. bounding_boxes = make_bounding_boxes(format=format)
  1329. transform = transforms.RandomRotation(**self._CORRECTNESS_TRANSFORM_AFFINE_RANGES, expand=expand, center=center)
  1330. torch.manual_seed(seed)
  1331. params = transform._get_params([bounding_boxes])
  1332. torch.manual_seed(seed)
  1333. actual = transform(bounding_boxes)
  1334. expected = self._reference_rotate_bounding_boxes(bounding_boxes, **params, expand=expand, center=center)
  1335. torch.testing.assert_close(actual, expected)
  1336. torch.testing.assert_close(F.get_size(actual), F.get_size(expected), atol=2 if expand else 0, rtol=0)
  1337. @pytest.mark.parametrize("degrees", _EXHAUSTIVE_TYPE_TRANSFORM_AFFINE_RANGES["degrees"])
  1338. @pytest.mark.parametrize("seed", list(range(10)))
  1339. def test_transform_get_params_bounds(self, degrees, seed):
  1340. transform = transforms.RandomRotation(degrees=degrees)
  1341. torch.manual_seed(seed)
  1342. params = transform._get_params([])
  1343. if isinstance(degrees, (int, float)):
  1344. assert -degrees <= params["angle"] <= degrees
  1345. else:
  1346. assert degrees[0] <= params["angle"] <= degrees[1]
  1347. @pytest.mark.parametrize("param", ["degrees", "center"])
  1348. @pytest.mark.parametrize("value", [0, [0], [0, 0, 0]])
  1349. def test_transform_sequence_len_errors(self, param, value):
  1350. if param == "degrees" and not isinstance(value, list):
  1351. return
  1352. kwargs = {param: value}
  1353. if param != "degrees":
  1354. kwargs["degrees"] = 0
  1355. with pytest.raises(
  1356. ValueError if isinstance(value, list) else TypeError, match=f"{param} should be a sequence of length 2"
  1357. ):
  1358. transforms.RandomRotation(**kwargs)
  1359. def test_transform_negative_degrees_error(self):
  1360. with pytest.raises(ValueError, match="If degrees is a single number, it must be positive"):
  1361. transforms.RandomAffine(degrees=-1)
  1362. def test_transform_unknown_fill_error(self):
  1363. with pytest.raises(TypeError, match="Got inappropriate fill arg"):
  1364. transforms.RandomAffine(degrees=0, fill="fill")
  1365. class TestCompose:
  1366. class BuiltinTransform(transforms.Transform):
  1367. def _transform(self, inpt, params):
  1368. return inpt
  1369. class PackedInputTransform(nn.Module):
  1370. def forward(self, sample):
  1371. assert len(sample) == 2
  1372. return sample
  1373. class UnpackedInputTransform(nn.Module):
  1374. def forward(self, image, label):
  1375. return image, label
  1376. @pytest.mark.parametrize(
  1377. "transform_clss",
  1378. [
  1379. [BuiltinTransform],
  1380. [PackedInputTransform],
  1381. [UnpackedInputTransform],
  1382. [BuiltinTransform, BuiltinTransform],
  1383. [PackedInputTransform, PackedInputTransform],
  1384. [UnpackedInputTransform, UnpackedInputTransform],
  1385. [BuiltinTransform, PackedInputTransform, BuiltinTransform],
  1386. [BuiltinTransform, UnpackedInputTransform, BuiltinTransform],
  1387. [PackedInputTransform, BuiltinTransform, PackedInputTransform],
  1388. [UnpackedInputTransform, BuiltinTransform, UnpackedInputTransform],
  1389. ],
  1390. )
  1391. @pytest.mark.parametrize("unpack", [True, False])
  1392. def test_packed_unpacked(self, transform_clss, unpack):
  1393. needs_packed_inputs = any(issubclass(cls, self.PackedInputTransform) for cls in transform_clss)
  1394. needs_unpacked_inputs = any(issubclass(cls, self.UnpackedInputTransform) for cls in transform_clss)
  1395. assert not (needs_packed_inputs and needs_unpacked_inputs)
  1396. transform = transforms.Compose([cls() for cls in transform_clss])
  1397. image = make_image()
  1398. label = 3
  1399. packed_input = (image, label)
  1400. def call_transform():
  1401. if unpack:
  1402. return transform(*packed_input)
  1403. else:
  1404. return transform(packed_input)
  1405. if needs_unpacked_inputs and not unpack:
  1406. with pytest.raises(TypeError, match="missing 1 required positional argument"):
  1407. call_transform()
  1408. elif needs_packed_inputs and unpack:
  1409. with pytest.raises(TypeError, match="takes 2 positional arguments but 3 were given"):
  1410. call_transform()
  1411. else:
  1412. output = call_transform()
  1413. assert isinstance(output, tuple) and len(output) == 2
  1414. assert output[0] is image
  1415. assert output[1] is label
  1416. class TestToDtype:
  1417. @pytest.mark.parametrize(
  1418. ("kernel", "make_input"),
  1419. [
  1420. (F.to_dtype_image, make_image_tensor),
  1421. (F.to_dtype_image, make_image),
  1422. (F.to_dtype_video, make_video),
  1423. ],
  1424. )
  1425. @pytest.mark.parametrize("input_dtype", [torch.float32, torch.float64, torch.uint8])
  1426. @pytest.mark.parametrize("output_dtype", [torch.float32, torch.float64, torch.uint8])
  1427. @pytest.mark.parametrize("device", cpu_and_cuda())
  1428. @pytest.mark.parametrize("scale", (True, False))
  1429. def test_kernel(self, kernel, make_input, input_dtype, output_dtype, device, scale):
  1430. check_kernel(
  1431. kernel,
  1432. make_input(dtype=input_dtype, device=device),
  1433. expect_same_dtype=input_dtype is output_dtype,
  1434. dtype=output_dtype,
  1435. scale=scale,
  1436. )
  1437. @pytest.mark.parametrize("make_input", [make_image_tensor, make_image, make_video])
  1438. @pytest.mark.parametrize("input_dtype", [torch.float32, torch.float64, torch.uint8])
  1439. @pytest.mark.parametrize("output_dtype", [torch.float32, torch.float64, torch.uint8])
  1440. @pytest.mark.parametrize("device", cpu_and_cuda())
  1441. @pytest.mark.parametrize("scale", (True, False))
  1442. def test_functional(self, make_input, input_dtype, output_dtype, device, scale):
  1443. check_functional(
  1444. F.to_dtype,
  1445. make_input(dtype=input_dtype, device=device),
  1446. dtype=output_dtype,
  1447. scale=scale,
  1448. )
  1449. @pytest.mark.parametrize(
  1450. "make_input",
  1451. [make_image_tensor, make_image, make_bounding_boxes, make_segmentation_mask, make_video],
  1452. )
  1453. @pytest.mark.parametrize("input_dtype", [torch.float32, torch.float64, torch.uint8])
  1454. @pytest.mark.parametrize("output_dtype", [torch.float32, torch.float64, torch.uint8])
  1455. @pytest.mark.parametrize("device", cpu_and_cuda())
  1456. @pytest.mark.parametrize("scale", (True, False))
  1457. @pytest.mark.parametrize("as_dict", (True, False))
  1458. def test_transform(self, make_input, input_dtype, output_dtype, device, scale, as_dict):
  1459. input = make_input(dtype=input_dtype, device=device)
  1460. if as_dict:
  1461. output_dtype = {type(input): output_dtype}
  1462. check_transform(transforms.ToDtype(dtype=output_dtype, scale=scale), input)
  1463. def reference_convert_dtype_image_tensor(self, image, dtype=torch.float, scale=False):
  1464. input_dtype = image.dtype
  1465. output_dtype = dtype
  1466. if not scale:
  1467. return image.to(dtype)
  1468. if output_dtype == input_dtype:
  1469. return image
  1470. def fn(value):
  1471. if input_dtype.is_floating_point:
  1472. if output_dtype.is_floating_point:
  1473. return value
  1474. else:
  1475. return round(decimal.Decimal(value) * torch.iinfo(output_dtype).max)
  1476. else:
  1477. input_max_value = torch.iinfo(input_dtype).max
  1478. if output_dtype.is_floating_point:
  1479. return float(decimal.Decimal(value) / input_max_value)
  1480. else:
  1481. output_max_value = torch.iinfo(output_dtype).max
  1482. if input_max_value > output_max_value:
  1483. factor = (input_max_value + 1) // (output_max_value + 1)
  1484. return value / factor
  1485. else:
  1486. factor = (output_max_value + 1) // (input_max_value + 1)
  1487. return value * factor
  1488. return torch.tensor(tree_map(fn, image.tolist()), dtype=dtype, device=image.device)
  1489. @pytest.mark.parametrize("input_dtype", [torch.float32, torch.float64, torch.uint8])
  1490. @pytest.mark.parametrize("output_dtype", [torch.float32, torch.float64, torch.uint8])
  1491. @pytest.mark.parametrize("device", cpu_and_cuda())
  1492. @pytest.mark.parametrize("scale", (True, False))
  1493. def test_image_correctness(self, input_dtype, output_dtype, device, scale):
  1494. if input_dtype.is_floating_point and output_dtype == torch.int64:
  1495. pytest.xfail("float to int64 conversion is not supported")
  1496. input = make_image(dtype=input_dtype, device=device)
  1497. out = F.to_dtype(input, dtype=output_dtype, scale=scale)
  1498. expected = self.reference_convert_dtype_image_tensor(input, dtype=output_dtype, scale=scale)
  1499. if input_dtype.is_floating_point and not output_dtype.is_floating_point and scale:
  1500. torch.testing.assert_close(out, expected, atol=1, rtol=0)
  1501. else:
  1502. torch.testing.assert_close(out, expected)
  1503. def was_scaled(self, inpt):
  1504. # this assumes the target dtype is float
  1505. return inpt.max() <= 1
  1506. def make_inpt_with_bbox_and_mask(self, make_input):
  1507. H, W = 10, 10
  1508. inpt_dtype = torch.uint8
  1509. bbox_dtype = torch.float32
  1510. mask_dtype = torch.bool
  1511. sample = {
  1512. "inpt": make_input(size=(H, W), dtype=inpt_dtype),
  1513. "bbox": make_bounding_boxes(canvas_size=(H, W), dtype=bbox_dtype),
  1514. "mask": make_detection_mask(size=(H, W), dtype=mask_dtype),
  1515. }
  1516. return sample, inpt_dtype, bbox_dtype, mask_dtype
  1517. @pytest.mark.parametrize("make_input", (make_image_tensor, make_image, make_video))
  1518. @pytest.mark.parametrize("scale", (True, False))
  1519. def test_dtype_not_a_dict(self, make_input, scale):
  1520. # assert only inpt gets transformed when dtype isn't a dict
  1521. sample, inpt_dtype, bbox_dtype, mask_dtype = self.make_inpt_with_bbox_and_mask(make_input)
  1522. out = transforms.ToDtype(dtype=torch.float32, scale=scale)(sample)
  1523. assert out["inpt"].dtype != inpt_dtype
  1524. assert out["inpt"].dtype == torch.float32
  1525. if scale:
  1526. assert self.was_scaled(out["inpt"])
  1527. else:
  1528. assert not self.was_scaled(out["inpt"])
  1529. assert out["bbox"].dtype == bbox_dtype
  1530. assert out["mask"].dtype == mask_dtype
  1531. @pytest.mark.parametrize("make_input", (make_image_tensor, make_image, make_video))
  1532. def test_others_catch_all_and_none(self, make_input):
  1533. # make sure "others" works as a catch-all and that None means no conversion
  1534. sample, inpt_dtype, bbox_dtype, mask_dtype = self.make_inpt_with_bbox_and_mask(make_input)
  1535. out = transforms.ToDtype(dtype={tv_tensors.Mask: torch.int64, "others": None})(sample)
  1536. assert out["inpt"].dtype == inpt_dtype
  1537. assert out["bbox"].dtype == bbox_dtype
  1538. assert out["mask"].dtype != mask_dtype
  1539. assert out["mask"].dtype == torch.int64
  1540. @pytest.mark.parametrize("make_input", (make_image_tensor, make_image, make_video))
  1541. def test_typical_use_case(self, make_input):
  1542. # Typical use-case: want to convert dtype and scale for inpt and just dtype for masks.
  1543. # This just makes sure we now have a decent API for this
  1544. sample, inpt_dtype, bbox_dtype, mask_dtype = self.make_inpt_with_bbox_and_mask(make_input)
  1545. out = transforms.ToDtype(
  1546. dtype={type(sample["inpt"]): torch.float32, tv_tensors.Mask: torch.int64, "others": None}, scale=True
  1547. )(sample)
  1548. assert out["inpt"].dtype != inpt_dtype
  1549. assert out["inpt"].dtype == torch.float32
  1550. assert self.was_scaled(out["inpt"])
  1551. assert out["bbox"].dtype == bbox_dtype
  1552. assert out["mask"].dtype != mask_dtype
  1553. assert out["mask"].dtype == torch.int64
  1554. @pytest.mark.parametrize("make_input", (make_image_tensor, make_image, make_video))
  1555. def test_errors_warnings(self, make_input):
  1556. sample, inpt_dtype, bbox_dtype, mask_dtype = self.make_inpt_with_bbox_and_mask(make_input)
  1557. with pytest.raises(ValueError, match="No dtype was specified for"):
  1558. out = transforms.ToDtype(dtype={tv_tensors.Mask: torch.float32})(sample)
  1559. with pytest.warns(UserWarning, match=re.escape("plain `torch.Tensor` will *not* be transformed")):
  1560. transforms.ToDtype(dtype={torch.Tensor: torch.float32, tv_tensors.Image: torch.float32})
  1561. with pytest.warns(UserWarning, match="no scaling will be done"):
  1562. out = transforms.ToDtype(dtype={"others": None}, scale=True)(sample)
  1563. assert out["inpt"].dtype == inpt_dtype
  1564. assert out["bbox"].dtype == bbox_dtype
  1565. assert out["mask"].dtype == mask_dtype
  1566. class TestAdjustBrightness:
  1567. _CORRECTNESS_BRIGHTNESS_FACTORS = [0.5, 0.0, 1.0, 5.0]
  1568. _DEFAULT_BRIGHTNESS_FACTOR = _CORRECTNESS_BRIGHTNESS_FACTORS[0]
  1569. @pytest.mark.parametrize(
  1570. ("kernel", "make_input"),
  1571. [
  1572. (F.adjust_brightness_image, make_image),
  1573. (F.adjust_brightness_video, make_video),
  1574. ],
  1575. )
  1576. @pytest.mark.parametrize("dtype", [torch.float32, torch.uint8])
  1577. @pytest.mark.parametrize("device", cpu_and_cuda())
  1578. def test_kernel(self, kernel, make_input, dtype, device):
  1579. check_kernel(kernel, make_input(dtype=dtype, device=device), brightness_factor=self._DEFAULT_BRIGHTNESS_FACTOR)
  1580. @pytest.mark.parametrize("make_input", [make_image_tensor, make_image_pil, make_image, make_video])
  1581. def test_functional(self, make_input):
  1582. check_functional(F.adjust_brightness, make_input(), brightness_factor=self._DEFAULT_BRIGHTNESS_FACTOR)
  1583. @pytest.mark.parametrize(
  1584. ("kernel", "input_type"),
  1585. [
  1586. (F.adjust_brightness_image, torch.Tensor),
  1587. (F._adjust_brightness_image_pil, PIL.Image.Image),
  1588. (F.adjust_brightness_image, tv_tensors.Image),
  1589. (F.adjust_brightness_video, tv_tensors.Video),
  1590. ],
  1591. )
  1592. def test_functional_signature(self, kernel, input_type):
  1593. check_functional_kernel_signature_match(F.adjust_brightness, kernel=kernel, input_type=input_type)
  1594. @pytest.mark.parametrize("brightness_factor", _CORRECTNESS_BRIGHTNESS_FACTORS)
  1595. def test_image_correctness(self, brightness_factor):
  1596. image = make_image(dtype=torch.uint8, device="cpu")
  1597. actual = F.adjust_brightness(image, brightness_factor=brightness_factor)
  1598. expected = F.to_image(F.adjust_brightness(F.to_pil_image(image), brightness_factor=brightness_factor))
  1599. torch.testing.assert_close(actual, expected)
  1600. class TestCutMixMixUp:
  1601. class DummyDataset:
  1602. def __init__(self, size, num_classes):
  1603. self.size = size
  1604. self.num_classes = num_classes
  1605. assert size < num_classes
  1606. def __getitem__(self, idx):
  1607. img = torch.rand(3, 100, 100)
  1608. label = idx # This ensures all labels in a batch are unique and makes testing easier
  1609. return img, label
  1610. def __len__(self):
  1611. return self.size
  1612. @pytest.mark.parametrize("T", [transforms.CutMix, transforms.MixUp])
  1613. def test_supported_input_structure(self, T):
  1614. batch_size = 32
  1615. num_classes = 100
  1616. dataset = self.DummyDataset(size=batch_size, num_classes=num_classes)
  1617. cutmix_mixup = T(num_classes=num_classes)
  1618. dl = DataLoader(dataset, batch_size=batch_size)
  1619. # Input sanity checks
  1620. img, target = next(iter(dl))
  1621. input_img_size = img.shape[-3:]
  1622. assert isinstance(img, torch.Tensor) and isinstance(target, torch.Tensor)
  1623. assert target.shape == (batch_size,)
  1624. def check_output(img, target):
  1625. assert img.shape == (batch_size, *input_img_size)
  1626. assert target.shape == (batch_size, num_classes)
  1627. torch.testing.assert_close(target.sum(axis=-1), torch.ones(batch_size))
  1628. num_non_zero_labels = (target != 0).sum(axis=-1)
  1629. assert (num_non_zero_labels == 2).all()
  1630. # After Dataloader, as unpacked input
  1631. img, target = next(iter(dl))
  1632. assert target.shape == (batch_size,)
  1633. img, target = cutmix_mixup(img, target)
  1634. check_output(img, target)
  1635. # After Dataloader, as packed input
  1636. packed_from_dl = next(iter(dl))
  1637. assert isinstance(packed_from_dl, list)
  1638. img, target = cutmix_mixup(packed_from_dl)
  1639. check_output(img, target)
  1640. # As collation function. We expect default_collate to be used by users.
  1641. def collate_fn_1(batch):
  1642. return cutmix_mixup(default_collate(batch))
  1643. def collate_fn_2(batch):
  1644. return cutmix_mixup(*default_collate(batch))
  1645. for collate_fn in (collate_fn_1, collate_fn_2):
  1646. dl = DataLoader(dataset, batch_size=batch_size, collate_fn=collate_fn)
  1647. img, target = next(iter(dl))
  1648. check_output(img, target)
  1649. @needs_cuda
  1650. @pytest.mark.parametrize("T", [transforms.CutMix, transforms.MixUp])
  1651. def test_cpu_vs_gpu(self, T):
  1652. num_classes = 10
  1653. batch_size = 3
  1654. H, W = 12, 12
  1655. imgs = torch.rand(batch_size, 3, H, W)
  1656. labels = torch.randint(0, num_classes, (batch_size,))
  1657. cutmix_mixup = T(alpha=0.5, num_classes=num_classes)
  1658. _check_kernel_cuda_vs_cpu(cutmix_mixup, imgs, labels, rtol=None, atol=None)
  1659. @pytest.mark.parametrize("T", [transforms.CutMix, transforms.MixUp])
  1660. def test_error(self, T):
  1661. num_classes = 10
  1662. batch_size = 9
  1663. imgs = torch.rand(batch_size, 3, 12, 12)
  1664. cutmix_mixup = T(alpha=0.5, num_classes=num_classes)
  1665. for input_with_bad_type in (
  1666. F.to_pil_image(imgs[0]),
  1667. tv_tensors.Mask(torch.rand(12, 12)),
  1668. tv_tensors.BoundingBoxes(torch.rand(2, 4), format="XYXY", canvas_size=12),
  1669. ):
  1670. with pytest.raises(ValueError, match="does not support PIL images, "):
  1671. cutmix_mixup(input_with_bad_type)
  1672. with pytest.raises(ValueError, match="Could not infer where the labels are"):
  1673. cutmix_mixup({"img": imgs, "Nothing_else": 3})
  1674. with pytest.raises(ValueError, match="labels tensor should be of shape"):
  1675. # Note: the error message isn't ideal, but that's because the label heuristic found the img as the label
  1676. # It's OK, it's an edge-case. The important thing is that this fails loudly instead of passing silently
  1677. cutmix_mixup(imgs)
  1678. with pytest.raises(ValueError, match="When using the default labels_getter"):
  1679. cutmix_mixup(imgs, "not_a_tensor")
  1680. with pytest.raises(ValueError, match="labels tensor should be of shape"):
  1681. cutmix_mixup(imgs, torch.randint(0, 2, size=(2, 3)))
  1682. with pytest.raises(ValueError, match="Expected a batched input with 4 dims"):
  1683. cutmix_mixup(imgs[None, None], torch.randint(0, num_classes, size=(batch_size,)))
  1684. with pytest.raises(ValueError, match="does not match the batch size of the labels"):
  1685. cutmix_mixup(imgs, torch.randint(0, num_classes, size=(batch_size + 1,)))
  1686. with pytest.raises(ValueError, match="labels tensor should be of shape"):
  1687. # The purpose of this check is more about documenting the current
  1688. # behaviour of what happens on a Compose(), rather than actually
  1689. # asserting the expected behaviour. We may support Compose() in the
  1690. # future, e.g. for 2 consecutive CutMix?
  1691. labels = torch.randint(0, num_classes, size=(batch_size,))
  1692. transforms.Compose([cutmix_mixup, cutmix_mixup])(imgs, labels)
  1693. @pytest.mark.parametrize("key", ("labels", "LABELS", "LaBeL", "SOME_WEIRD_KEY_THAT_HAS_LABeL_IN_IT"))
  1694. @pytest.mark.parametrize("sample_type", (tuple, list, dict))
  1695. def test_labels_getter_default_heuristic(key, sample_type):
  1696. labels = torch.arange(10)
  1697. sample = {key: labels, "another_key": "whatever"}
  1698. if sample_type is not dict:
  1699. sample = sample_type((None, sample, "whatever_again"))
  1700. assert transforms._utils._find_labels_default_heuristic(sample) is labels
  1701. if key.lower() != "labels":
  1702. # If "labels" is in the dict (case-insensitive),
  1703. # it takes precedence over other keys which would otherwise be a match
  1704. d = {key: "something_else", "labels": labels}
  1705. assert transforms._utils._find_labels_default_heuristic(d) is labels
  1706. class TestShapeGetters:
  1707. @pytest.mark.parametrize(
  1708. ("kernel", "make_input"),
  1709. [
  1710. (F.get_dimensions_image, make_image_tensor),
  1711. (F._get_dimensions_image_pil, make_image_pil),
  1712. (F.get_dimensions_image, make_image),
  1713. (F.get_dimensions_video, make_video),
  1714. ],
  1715. )
  1716. def test_get_dimensions(self, kernel, make_input):
  1717. size = (10, 10)
  1718. color_space, num_channels = "RGB", 3
  1719. input = make_input(size, color_space=color_space)
  1720. assert kernel(input) == F.get_dimensions(input) == [num_channels, *size]
  1721. @pytest.mark.parametrize(
  1722. ("kernel", "make_input"),
  1723. [
  1724. (F.get_num_channels_image, make_image_tensor),
  1725. (F._get_num_channels_image_pil, make_image_pil),
  1726. (F.get_num_channels_image, make_image),
  1727. (F.get_num_channels_video, make_video),
  1728. ],
  1729. )
  1730. def test_get_num_channels(self, kernel, make_input):
  1731. color_space, num_channels = "RGB", 3
  1732. input = make_input(color_space=color_space)
  1733. assert kernel(input) == F.get_num_channels(input) == num_channels
  1734. @pytest.mark.parametrize(
  1735. ("kernel", "make_input"),
  1736. [
  1737. (F.get_size_image, make_image_tensor),
  1738. (F._get_size_image_pil, make_image_pil),
  1739. (F.get_size_image, make_image),
  1740. (F.get_size_bounding_boxes, make_bounding_boxes),
  1741. (F.get_size_mask, make_detection_mask),
  1742. (F.get_size_mask, make_segmentation_mask),
  1743. (F.get_size_video, make_video),
  1744. ],
  1745. )
  1746. def test_get_size(self, kernel, make_input):
  1747. size = (10, 10)
  1748. input = make_input(size)
  1749. assert kernel(input) == F.get_size(input) == list(size)
  1750. @pytest.mark.parametrize(
  1751. ("kernel", "make_input"),
  1752. [
  1753. (F.get_num_frames_video, make_video_tensor),
  1754. (F.get_num_frames_video, make_video),
  1755. ],
  1756. )
  1757. def test_get_num_frames(self, kernel, make_input):
  1758. num_frames = 4
  1759. input = make_input(num_frames=num_frames)
  1760. assert kernel(input) == F.get_num_frames(input) == num_frames
  1761. @pytest.mark.parametrize(
  1762. ("functional", "make_input"),
  1763. [
  1764. (F.get_dimensions, make_bounding_boxes),
  1765. (F.get_dimensions, make_detection_mask),
  1766. (F.get_dimensions, make_segmentation_mask),
  1767. (F.get_num_channels, make_bounding_boxes),
  1768. (F.get_num_channels, make_detection_mask),
  1769. (F.get_num_channels, make_segmentation_mask),
  1770. (F.get_num_frames, make_image_pil),
  1771. (F.get_num_frames, make_image),
  1772. (F.get_num_frames, make_bounding_boxes),
  1773. (F.get_num_frames, make_detection_mask),
  1774. (F.get_num_frames, make_segmentation_mask),
  1775. ],
  1776. )
  1777. def test_unsupported_types(self, functional, make_input):
  1778. input = make_input()
  1779. with pytest.raises(TypeError, match=re.escape(str(type(input)))):
  1780. functional(input)
  1781. class TestRegisterKernel:
  1782. @pytest.mark.parametrize("functional", (F.resize, "resize"))
  1783. def test_register_kernel(self, functional):
  1784. class CustomTVTensor(tv_tensors.TVTensor):
  1785. pass
  1786. kernel_was_called = False
  1787. @F.register_kernel(functional, CustomTVTensor)
  1788. def new_resize(dp, *args, **kwargs):
  1789. nonlocal kernel_was_called
  1790. kernel_was_called = True
  1791. return dp
  1792. t = transforms.Resize(size=(224, 224), antialias=True)
  1793. my_dp = CustomTVTensor(torch.rand(3, 10, 10))
  1794. out = t(my_dp)
  1795. assert out is my_dp
  1796. assert kernel_was_called
  1797. # Sanity check to make sure we didn't override the kernel of other types
  1798. t(torch.rand(3, 10, 10)).shape == (3, 224, 224)
  1799. t(tv_tensors.Image(torch.rand(3, 10, 10))).shape == (3, 224, 224)
  1800. def test_errors(self):
  1801. with pytest.raises(ValueError, match="Could not find functional with name"):
  1802. F.register_kernel("bad_name", tv_tensors.Image)
  1803. with pytest.raises(ValueError, match="Kernels can only be registered on functionals"):
  1804. F.register_kernel(tv_tensors.Image, F.resize)
  1805. with pytest.raises(ValueError, match="Kernels can only be registered for subclasses"):
  1806. F.register_kernel(F.resize, object)
  1807. with pytest.raises(ValueError, match="cannot be registered for the builtin tv_tensor classes"):
  1808. F.register_kernel(F.resize, tv_tensors.Image)(F.resize_image)
  1809. class CustomTVTensor(tv_tensors.TVTensor):
  1810. pass
  1811. def resize_custom_tv_tensor():
  1812. pass
  1813. F.register_kernel(F.resize, CustomTVTensor)(resize_custom_tv_tensor)
  1814. with pytest.raises(ValueError, match="already has a kernel registered for type"):
  1815. F.register_kernel(F.resize, CustomTVTensor)(resize_custom_tv_tensor)
  1816. class TestGetKernel:
  1817. # We are using F.resize as functional and the kernels below as proxy. Any other functional / kernels combination
  1818. # would also be fine
  1819. KERNELS = {
  1820. torch.Tensor: F.resize_image,
  1821. PIL.Image.Image: F._resize_image_pil,
  1822. tv_tensors.Image: F.resize_image,
  1823. tv_tensors.BoundingBoxes: F.resize_bounding_boxes,
  1824. tv_tensors.Mask: F.resize_mask,
  1825. tv_tensors.Video: F.resize_video,
  1826. }
  1827. @pytest.mark.parametrize("input_type", [str, int, object])
  1828. def test_unsupported_types(self, input_type):
  1829. with pytest.raises(TypeError, match="supports inputs of type"):
  1830. _get_kernel(F.resize, input_type)
  1831. def test_exact_match(self):
  1832. # We cannot use F.resize together with self.KERNELS mapping here directly here, since this is only the
  1833. # ideal wrapping. Practically, we have an intermediate wrapper layer. Thus, we create a new resize functional
  1834. # here, register the kernels without wrapper, and check the exact matching afterwards.
  1835. def resize_with_pure_kernels():
  1836. pass
  1837. for input_type, kernel in self.KERNELS.items():
  1838. _register_kernel_internal(resize_with_pure_kernels, input_type, tv_tensor_wrapper=False)(kernel)
  1839. assert _get_kernel(resize_with_pure_kernels, input_type) is kernel
  1840. def test_builtin_tv_tensor_subclass(self):
  1841. # We cannot use F.resize together with self.KERNELS mapping here directly here, since this is only the
  1842. # ideal wrapping. Practically, we have an intermediate wrapper layer. Thus, we create a new resize functional
  1843. # here, register the kernels without wrapper, and check if subclasses of our builtin tv_tensors get dispatched
  1844. # to the kernel of the corresponding superclass
  1845. def resize_with_pure_kernels():
  1846. pass
  1847. class MyImage(tv_tensors.Image):
  1848. pass
  1849. class MyBoundingBoxes(tv_tensors.BoundingBoxes):
  1850. pass
  1851. class MyMask(tv_tensors.Mask):
  1852. pass
  1853. class MyVideo(tv_tensors.Video):
  1854. pass
  1855. for custom_tv_tensor_subclass in [
  1856. MyImage,
  1857. MyBoundingBoxes,
  1858. MyMask,
  1859. MyVideo,
  1860. ]:
  1861. builtin_tv_tensor_class = custom_tv_tensor_subclass.__mro__[1]
  1862. builtin_tv_tensor_kernel = self.KERNELS[builtin_tv_tensor_class]
  1863. _register_kernel_internal(resize_with_pure_kernels, builtin_tv_tensor_class, tv_tensor_wrapper=False)(
  1864. builtin_tv_tensor_kernel
  1865. )
  1866. assert _get_kernel(resize_with_pure_kernels, custom_tv_tensor_subclass) is builtin_tv_tensor_kernel
  1867. def test_tv_tensor_subclass(self):
  1868. class MyTVTensor(tv_tensors.TVTensor):
  1869. pass
  1870. with pytest.raises(TypeError, match="supports inputs of type"):
  1871. _get_kernel(F.resize, MyTVTensor)
  1872. def resize_my_tv_tensor():
  1873. pass
  1874. _register_kernel_internal(F.resize, MyTVTensor, tv_tensor_wrapper=False)(resize_my_tv_tensor)
  1875. assert _get_kernel(F.resize, MyTVTensor) is resize_my_tv_tensor
  1876. def test_pil_image_subclass(self):
  1877. opened_image = PIL.Image.open(Path(__file__).parent / "assets" / "encode_jpeg" / "grace_hopper_517x606.jpg")
  1878. loaded_image = opened_image.convert("RGB")
  1879. # check the assumptions
  1880. assert isinstance(opened_image, PIL.Image.Image)
  1881. assert type(opened_image) is not PIL.Image.Image
  1882. assert type(loaded_image) is PIL.Image.Image
  1883. size = [17, 11]
  1884. for image in [opened_image, loaded_image]:
  1885. kernel = _get_kernel(F.resize, type(image))
  1886. output = kernel(image, size=size)
  1887. assert F.get_size(output) == size
  1888. class TestPermuteChannels:
  1889. _DEFAULT_PERMUTATION = [2, 0, 1]
  1890. @pytest.mark.parametrize(
  1891. ("kernel", "make_input"),
  1892. [
  1893. (F.permute_channels_image, make_image_tensor),
  1894. # FIXME
  1895. # check_kernel does not support PIL kernel, but it should
  1896. (F.permute_channels_image, make_image),
  1897. (F.permute_channels_video, make_video),
  1898. ],
  1899. )
  1900. @pytest.mark.parametrize("dtype", [torch.float32, torch.uint8])
  1901. @pytest.mark.parametrize("device", cpu_and_cuda())
  1902. def test_kernel(self, kernel, make_input, dtype, device):
  1903. check_kernel(kernel, make_input(dtype=dtype, device=device), permutation=self._DEFAULT_PERMUTATION)
  1904. @pytest.mark.parametrize("make_input", [make_image_tensor, make_image_pil, make_image, make_video])
  1905. def test_functional(self, make_input):
  1906. check_functional(F.permute_channels, make_input(), permutation=self._DEFAULT_PERMUTATION)
  1907. @pytest.mark.parametrize(
  1908. ("kernel", "input_type"),
  1909. [
  1910. (F.permute_channels_image, torch.Tensor),
  1911. (F._permute_channels_image_pil, PIL.Image.Image),
  1912. (F.permute_channels_image, tv_tensors.Image),
  1913. (F.permute_channels_video, tv_tensors.Video),
  1914. ],
  1915. )
  1916. def test_functional_signature(self, kernel, input_type):
  1917. check_functional_kernel_signature_match(F.permute_channels, kernel=kernel, input_type=input_type)
  1918. def reference_image_correctness(self, image, permutation):
  1919. channel_images = image.split(1, dim=-3)
  1920. permuted_channel_images = [channel_images[channel_idx] for channel_idx in permutation]
  1921. return tv_tensors.Image(torch.concat(permuted_channel_images, dim=-3))
  1922. @pytest.mark.parametrize("permutation", [[2, 0, 1], [1, 2, 0], [2, 0, 1], [0, 1, 2]])
  1923. @pytest.mark.parametrize("batch_dims", [(), (2,), (2, 1)])
  1924. def test_image_correctness(self, permutation, batch_dims):
  1925. image = make_image(batch_dims=batch_dims)
  1926. actual = F.permute_channels(image, permutation=permutation)
  1927. expected = self.reference_image_correctness(image, permutation=permutation)
  1928. torch.testing.assert_close(actual, expected)
  1929. class TestElastic:
  1930. def _make_displacement(self, inpt):
  1931. return torch.rand(
  1932. 1,
  1933. *F.get_size(inpt),
  1934. 2,
  1935. dtype=torch.float32,
  1936. device=inpt.device if isinstance(inpt, torch.Tensor) else "cpu",
  1937. )
  1938. @param_value_parametrization(
  1939. interpolation=[transforms.InterpolationMode.NEAREST, transforms.InterpolationMode.BILINEAR],
  1940. fill=EXHAUSTIVE_TYPE_FILLS,
  1941. )
  1942. @pytest.mark.parametrize("dtype", [torch.float32, torch.uint8])
  1943. @pytest.mark.parametrize("device", cpu_and_cuda())
  1944. def test_kernel_image(self, param, value, dtype, device):
  1945. image = make_image_tensor(dtype=dtype, device=device)
  1946. check_kernel(
  1947. F.elastic_image,
  1948. image,
  1949. displacement=self._make_displacement(image),
  1950. **{param: value},
  1951. check_scripted_vs_eager=not (param == "fill" and isinstance(value, (int, float))),
  1952. )
  1953. @pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
  1954. @pytest.mark.parametrize("dtype", [torch.float32, torch.int64])
  1955. @pytest.mark.parametrize("device", cpu_and_cuda())
  1956. def test_kernel_bounding_boxes(self, format, dtype, device):
  1957. bounding_boxes = make_bounding_boxes(format=format, dtype=dtype, device=device)
  1958. check_kernel(
  1959. F.elastic_bounding_boxes,
  1960. bounding_boxes,
  1961. format=bounding_boxes.format,
  1962. canvas_size=bounding_boxes.canvas_size,
  1963. displacement=self._make_displacement(bounding_boxes),
  1964. )
  1965. @pytest.mark.parametrize("make_mask", [make_segmentation_mask, make_detection_mask])
  1966. def test_kernel_mask(self, make_mask):
  1967. mask = make_mask()
  1968. check_kernel(F.elastic_mask, mask, displacement=self._make_displacement(mask))
  1969. def test_kernel_video(self):
  1970. video = make_video()
  1971. check_kernel(F.elastic_video, video, displacement=self._make_displacement(video))
  1972. @pytest.mark.parametrize(
  1973. "make_input",
  1974. [make_image_tensor, make_image_pil, make_image, make_bounding_boxes, make_segmentation_mask, make_video],
  1975. )
  1976. def test_functional(self, make_input):
  1977. input = make_input()
  1978. check_functional(F.elastic, input, displacement=self._make_displacement(input))
  1979. @pytest.mark.parametrize(
  1980. ("kernel", "input_type"),
  1981. [
  1982. (F.elastic_image, torch.Tensor),
  1983. (F._elastic_image_pil, PIL.Image.Image),
  1984. (F.elastic_image, tv_tensors.Image),
  1985. (F.elastic_bounding_boxes, tv_tensors.BoundingBoxes),
  1986. (F.elastic_mask, tv_tensors.Mask),
  1987. (F.elastic_video, tv_tensors.Video),
  1988. ],
  1989. )
  1990. def test_functional_signature(self, kernel, input_type):
  1991. check_functional_kernel_signature_match(F.elastic, kernel=kernel, input_type=input_type)
  1992. @pytest.mark.parametrize(
  1993. "make_input",
  1994. [make_image_tensor, make_image_pil, make_image, make_bounding_boxes, make_segmentation_mask, make_video],
  1995. )
  1996. def test_displacement_error(self, make_input):
  1997. input = make_input()
  1998. with pytest.raises(TypeError, match="displacement should be a Tensor"):
  1999. F.elastic(input, displacement=None)
  2000. with pytest.raises(ValueError, match="displacement shape should be"):
  2001. F.elastic(input, displacement=torch.rand(F.get_size(input)))
  2002. @pytest.mark.parametrize(
  2003. "make_input",
  2004. [make_image_tensor, make_image_pil, make_image, make_bounding_boxes, make_segmentation_mask, make_video],
  2005. )
  2006. # ElasticTransform needs larger images to avoid the needed internal padding being larger than the actual image
  2007. @pytest.mark.parametrize("size", [(163, 163), (72, 333), (313, 95)])
  2008. @pytest.mark.parametrize("device", cpu_and_cuda())
  2009. def test_transform(self, make_input, size, device):
  2010. check_transform(
  2011. transforms.ElasticTransform(),
  2012. make_input(size, device=device),
  2013. # We updated gaussian blur kernel generation with a faster and numerically more stable version
  2014. check_v1_compatibility=dict(rtol=0, atol=1),
  2015. )
  2016. class TestToPureTensor:
  2017. def test_correctness(self):
  2018. input = {
  2019. "img": make_image(),
  2020. "img_tensor": make_image_tensor(),
  2021. "img_pil": make_image_pil(),
  2022. "mask": make_detection_mask(),
  2023. "video": make_video(),
  2024. "bbox": make_bounding_boxes(),
  2025. "str": "str",
  2026. }
  2027. out = transforms.ToPureTensor()(input)
  2028. for input_value, out_value in zip(input.values(), out.values()):
  2029. if isinstance(input_value, tv_tensors.TVTensor):
  2030. assert isinstance(out_value, torch.Tensor) and not isinstance(out_value, tv_tensors.TVTensor)
  2031. else:
  2032. assert isinstance(out_value, type(input_value))
  2033. class TestCrop:
  2034. INPUT_SIZE = (21, 11)
  2035. CORRECTNESS_CROP_KWARGS = [
  2036. # center
  2037. dict(top=5, left=5, height=10, width=5),
  2038. # larger than input, i.e. pad
  2039. dict(top=-5, left=-5, height=30, width=20),
  2040. # sides: left, right, top, bottom
  2041. dict(top=-5, left=-5, height=30, width=10),
  2042. dict(top=-5, left=5, height=30, width=10),
  2043. dict(top=-5, left=-5, height=20, width=20),
  2044. dict(top=5, left=-5, height=20, width=20),
  2045. # corners: top-left, top-right, bottom-left, bottom-right
  2046. dict(top=-5, left=-5, height=20, width=10),
  2047. dict(top=-5, left=5, height=20, width=10),
  2048. dict(top=5, left=-5, height=20, width=10),
  2049. dict(top=5, left=5, height=20, width=10),
  2050. ]
  2051. MINIMAL_CROP_KWARGS = CORRECTNESS_CROP_KWARGS[0]
  2052. @pytest.mark.parametrize("kwargs", CORRECTNESS_CROP_KWARGS)
  2053. @pytest.mark.parametrize("dtype", [torch.uint8, torch.float32])
  2054. @pytest.mark.parametrize("device", cpu_and_cuda())
  2055. def test_kernel_image(self, kwargs, dtype, device):
  2056. check_kernel(F.crop_image, make_image(self.INPUT_SIZE, dtype=dtype, device=device), **kwargs)
  2057. @pytest.mark.parametrize("kwargs", CORRECTNESS_CROP_KWARGS)
  2058. @pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
  2059. @pytest.mark.parametrize("dtype", [torch.float32, torch.int64])
  2060. @pytest.mark.parametrize("device", cpu_and_cuda())
  2061. def test_kernel_bounding_box(self, kwargs, format, dtype, device):
  2062. bounding_boxes = make_bounding_boxes(self.INPUT_SIZE, format=format, dtype=dtype, device=device)
  2063. check_kernel(F.crop_bounding_boxes, bounding_boxes, format=format, **kwargs)
  2064. @pytest.mark.parametrize("make_mask", [make_segmentation_mask, make_detection_mask])
  2065. def test_kernel_mask(self, make_mask):
  2066. check_kernel(F.crop_mask, make_mask(self.INPUT_SIZE), **self.MINIMAL_CROP_KWARGS)
  2067. def test_kernel_video(self):
  2068. check_kernel(F.crop_video, make_video(self.INPUT_SIZE), **self.MINIMAL_CROP_KWARGS)
  2069. @pytest.mark.parametrize(
  2070. "make_input",
  2071. [make_image_tensor, make_image_pil, make_image, make_bounding_boxes, make_segmentation_mask, make_video],
  2072. )
  2073. def test_functional(self, make_input):
  2074. check_functional(F.crop, make_input(self.INPUT_SIZE), **self.MINIMAL_CROP_KWARGS)
  2075. @pytest.mark.parametrize(
  2076. ("kernel", "input_type"),
  2077. [
  2078. (F.crop_image, torch.Tensor),
  2079. (F._crop_image_pil, PIL.Image.Image),
  2080. (F.crop_image, tv_tensors.Image),
  2081. (F.crop_bounding_boxes, tv_tensors.BoundingBoxes),
  2082. (F.crop_mask, tv_tensors.Mask),
  2083. (F.crop_video, tv_tensors.Video),
  2084. ],
  2085. )
  2086. def test_functional_signature(self, kernel, input_type):
  2087. check_functional_kernel_signature_match(F.crop, kernel=kernel, input_type=input_type)
  2088. @pytest.mark.parametrize("kwargs", CORRECTNESS_CROP_KWARGS)
  2089. def test_functional_image_correctness(self, kwargs):
  2090. image = make_image(self.INPUT_SIZE, dtype=torch.uint8, device="cpu")
  2091. actual = F.crop(image, **kwargs)
  2092. expected = F.to_image(F.crop(F.to_pil_image(image), **kwargs))
  2093. assert_equal(actual, expected)
  2094. @param_value_parametrization(
  2095. size=[(10, 5), (25, 15), (25, 5), (10, 15)],
  2096. fill=EXHAUSTIVE_TYPE_FILLS,
  2097. )
  2098. @pytest.mark.parametrize(
  2099. "make_input",
  2100. [make_image_tensor, make_image_pil, make_image, make_bounding_boxes, make_segmentation_mask, make_video],
  2101. )
  2102. def test_transform(self, param, value, make_input):
  2103. input = make_input(self.INPUT_SIZE)
  2104. if param == "fill":
  2105. if isinstance(input, tv_tensors.Mask) and isinstance(value, (tuple, list)):
  2106. pytest.skip("F.pad_mask doesn't support non-scalar fill.")
  2107. kwargs = dict(
  2108. # 1. size is required
  2109. # 2. the fill parameter only has an affect if we need padding
  2110. size=[s + 4 for s in self.INPUT_SIZE],
  2111. fill=adapt_fill(value, dtype=input.dtype if isinstance(input, torch.Tensor) else torch.uint8),
  2112. )
  2113. else:
  2114. kwargs = {param: value}
  2115. check_transform(
  2116. transforms.RandomCrop(**kwargs, pad_if_needed=True),
  2117. input,
  2118. check_v1_compatibility=param != "fill" or isinstance(value, (int, float)),
  2119. )
  2120. @pytest.mark.parametrize("padding", [1, (1, 1), (1, 1, 1, 1)])
  2121. def test_transform_padding(self, padding):
  2122. inpt = make_image(self.INPUT_SIZE)
  2123. output_size = [s + 2 for s in F.get_size(inpt)]
  2124. transform = transforms.RandomCrop(output_size, padding=padding)
  2125. output = transform(inpt)
  2126. assert F.get_size(output) == output_size
  2127. @pytest.mark.parametrize("padding", [None, 1, (1, 1), (1, 1, 1, 1)])
  2128. def test_transform_insufficient_padding(self, padding):
  2129. inpt = make_image(self.INPUT_SIZE)
  2130. output_size = [s + 3 for s in F.get_size(inpt)]
  2131. transform = transforms.RandomCrop(output_size, padding=padding)
  2132. with pytest.raises(ValueError, match="larger than (padded )?input image size"):
  2133. transform(inpt)
  2134. def test_transform_pad_if_needed(self):
  2135. inpt = make_image(self.INPUT_SIZE)
  2136. output_size = [s * 2 for s in F.get_size(inpt)]
  2137. transform = transforms.RandomCrop(output_size, pad_if_needed=True)
  2138. output = transform(inpt)
  2139. assert F.get_size(output) == output_size
  2140. @param_value_parametrization(
  2141. size=[(10, 5), (25, 15), (25, 5), (10, 15)],
  2142. fill=CORRECTNESS_FILLS,
  2143. padding_mode=["constant", "edge", "reflect", "symmetric"],
  2144. )
  2145. @pytest.mark.parametrize("seed", list(range(5)))
  2146. def test_transform_image_correctness(self, param, value, seed):
  2147. kwargs = {param: value}
  2148. if param != "size":
  2149. # 1. size is required
  2150. # 2. the fill / padding_mode parameters only have an affect if we need padding
  2151. kwargs["size"] = [s + 4 for s in self.INPUT_SIZE]
  2152. if param == "fill":
  2153. kwargs["fill"] = adapt_fill(kwargs["fill"], dtype=torch.uint8)
  2154. transform = transforms.RandomCrop(pad_if_needed=True, **kwargs)
  2155. image = make_image(self.INPUT_SIZE)
  2156. with freeze_rng_state():
  2157. torch.manual_seed(seed)
  2158. actual = transform(image)
  2159. torch.manual_seed(seed)
  2160. expected = F.to_image(transform(F.to_pil_image(image)))
  2161. assert_equal(actual, expected)
  2162. def _reference_crop_bounding_boxes(self, bounding_boxes, *, top, left, height, width):
  2163. affine_matrix = np.array(
  2164. [
  2165. [1, 0, -left],
  2166. [0, 1, -top],
  2167. ],
  2168. )
  2169. return reference_affine_bounding_boxes_helper(
  2170. bounding_boxes, affine_matrix=affine_matrix, new_canvas_size=(height, width)
  2171. )
  2172. @pytest.mark.parametrize("kwargs", CORRECTNESS_CROP_KWARGS)
  2173. @pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
  2174. @pytest.mark.parametrize("dtype", [torch.float32, torch.int64])
  2175. @pytest.mark.parametrize("device", cpu_and_cuda())
  2176. def test_functional_bounding_box_correctness(self, kwargs, format, dtype, device):
  2177. bounding_boxes = make_bounding_boxes(self.INPUT_SIZE, format=format, dtype=dtype, device=device)
  2178. actual = F.crop(bounding_boxes, **kwargs)
  2179. expected = self._reference_crop_bounding_boxes(bounding_boxes, **kwargs)
  2180. assert_equal(actual, expected, atol=1, rtol=0)
  2181. assert_equal(F.get_size(actual), F.get_size(expected))
  2182. @pytest.mark.parametrize("output_size", [(17, 11), (11, 17), (11, 11)])
  2183. @pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
  2184. @pytest.mark.parametrize("dtype", [torch.float32, torch.int64])
  2185. @pytest.mark.parametrize("device", cpu_and_cuda())
  2186. @pytest.mark.parametrize("seed", list(range(5)))
  2187. def test_transform_bounding_boxes_correctness(self, output_size, format, dtype, device, seed):
  2188. input_size = [s * 2 for s in output_size]
  2189. bounding_boxes = make_bounding_boxes(input_size, format=format, dtype=dtype, device=device)
  2190. transform = transforms.RandomCrop(output_size)
  2191. with freeze_rng_state():
  2192. torch.manual_seed(seed)
  2193. params = transform._get_params([bounding_boxes])
  2194. assert not params.pop("needs_pad")
  2195. del params["padding"]
  2196. assert params.pop("needs_crop")
  2197. torch.manual_seed(seed)
  2198. actual = transform(bounding_boxes)
  2199. expected = self._reference_crop_bounding_boxes(bounding_boxes, **params)
  2200. assert_equal(actual, expected)
  2201. assert_equal(F.get_size(actual), F.get_size(expected))
  2202. def test_errors(self):
  2203. with pytest.raises(ValueError, match="Please provide only two dimensions"):
  2204. transforms.RandomCrop([10, 12, 14])
  2205. with pytest.raises(TypeError, match="Got inappropriate padding arg"):
  2206. transforms.RandomCrop([10, 12], padding="abc")
  2207. with pytest.raises(ValueError, match="Padding must be an int or a 1, 2, or 4"):
  2208. transforms.RandomCrop([10, 12], padding=[-0.7, 0, 0.7])
  2209. with pytest.raises(TypeError, match="Got inappropriate fill arg"):
  2210. transforms.RandomCrop([10, 12], padding=1, fill="abc")
  2211. with pytest.raises(ValueError, match="Padding mode should be either"):
  2212. transforms.RandomCrop([10, 12], padding=1, padding_mode="abc")
  2213. class TestErase:
  2214. INPUT_SIZE = (17, 11)
  2215. FUNCTIONAL_KWARGS = dict(
  2216. zip("ijhwv", [2, 2, 10, 8, torch.tensor(0.0, dtype=torch.float32, device="cpu").reshape(-1, 1, 1)])
  2217. )
  2218. @pytest.mark.parametrize("dtype", [torch.float32, torch.uint8])
  2219. @pytest.mark.parametrize("device", cpu_and_cuda())
  2220. def test_kernel_image(self, dtype, device):
  2221. check_kernel(F.erase_image, make_image(self.INPUT_SIZE, dtype=dtype, device=device), **self.FUNCTIONAL_KWARGS)
  2222. @pytest.mark.parametrize("dtype", [torch.float32, torch.uint8])
  2223. @pytest.mark.parametrize("device", cpu_and_cuda())
  2224. def test_kernel_image_inplace(self, dtype, device):
  2225. input = make_image(self.INPUT_SIZE, dtype=dtype, device=device)
  2226. input_version = input._version
  2227. output_out_of_place = F.erase_image(input, **self.FUNCTIONAL_KWARGS)
  2228. assert output_out_of_place.data_ptr() != input.data_ptr()
  2229. assert output_out_of_place is not input
  2230. output_inplace = F.erase_image(input, **self.FUNCTIONAL_KWARGS, inplace=True)
  2231. assert output_inplace.data_ptr() == input.data_ptr()
  2232. assert output_inplace._version > input_version
  2233. assert output_inplace is input
  2234. assert_equal(output_inplace, output_out_of_place)
  2235. def test_kernel_video(self):
  2236. check_kernel(F.erase_video, make_video(self.INPUT_SIZE), **self.FUNCTIONAL_KWARGS)
  2237. @pytest.mark.parametrize(
  2238. "make_input",
  2239. [make_image_tensor, make_image_pil, make_image, make_video],
  2240. )
  2241. def test_functional(self, make_input):
  2242. check_functional(F.erase, make_input(), **self.FUNCTIONAL_KWARGS)
  2243. @pytest.mark.parametrize(
  2244. ("kernel", "input_type"),
  2245. [
  2246. (F.erase_image, torch.Tensor),
  2247. (F._erase_image_pil, PIL.Image.Image),
  2248. (F.erase_image, tv_tensors.Image),
  2249. (F.erase_video, tv_tensors.Video),
  2250. ],
  2251. )
  2252. def test_functional_signature(self, kernel, input_type):
  2253. check_functional_kernel_signature_match(F.erase, kernel=kernel, input_type=input_type)
  2254. @pytest.mark.parametrize(
  2255. "make_input",
  2256. [make_image_tensor, make_image_pil, make_image, make_video],
  2257. )
  2258. @pytest.mark.parametrize("device", cpu_and_cuda())
  2259. def test_transform(self, make_input, device):
  2260. check_transform(transforms.RandomErasing(p=1), make_input(device=device))
  2261. def _reference_erase_image(self, image, *, i, j, h, w, v):
  2262. mask = torch.zeros_like(image, dtype=torch.bool)
  2263. mask[..., i : i + h, j : j + w] = True
  2264. # The broadcasting and type casting logic is handled automagically in the kernel through indexing
  2265. value = torch.broadcast_to(v, (*image.shape[:-2], h, w)).to(image)
  2266. erased_image = torch.empty_like(image)
  2267. erased_image[mask] = value.flatten()
  2268. erased_image[~mask] = image[~mask]
  2269. return erased_image
  2270. @pytest.mark.parametrize("dtype", [torch.float32, torch.uint8])
  2271. @pytest.mark.parametrize("device", cpu_and_cuda())
  2272. def test_functional_image_correctness(self, dtype, device):
  2273. image = make_image(dtype=dtype, device=device)
  2274. actual = F.erase(image, **self.FUNCTIONAL_KWARGS)
  2275. expected = self._reference_erase_image(image, **self.FUNCTIONAL_KWARGS)
  2276. assert_equal(actual, expected)
  2277. @param_value_parametrization(
  2278. scale=[(0.1, 0.2), [0.0, 1.0]],
  2279. ratio=[(0.3, 0.7), [0.1, 5.0]],
  2280. value=[0, 0.5, (0, 1, 0), [-0.2, 0.0, 1.3], "random"],
  2281. )
  2282. @pytest.mark.parametrize("dtype", [torch.float32, torch.uint8])
  2283. @pytest.mark.parametrize("device", cpu_and_cuda())
  2284. @pytest.mark.parametrize("seed", list(range(5)))
  2285. def test_transform_image_correctness(self, param, value, dtype, device, seed):
  2286. transform = transforms.RandomErasing(**{param: value}, p=1)
  2287. image = make_image(dtype=dtype, device=device)
  2288. with freeze_rng_state():
  2289. torch.manual_seed(seed)
  2290. # This emulates the random apply check that happens before _get_params is called
  2291. torch.rand(1)
  2292. params = transform._get_params([image])
  2293. torch.manual_seed(seed)
  2294. actual = transform(image)
  2295. expected = self._reference_erase_image(image, **params)
  2296. assert_equal(actual, expected)
  2297. def test_transform_errors(self):
  2298. with pytest.raises(TypeError, match="Argument value should be either a number or str or a sequence"):
  2299. transforms.RandomErasing(value={})
  2300. with pytest.raises(ValueError, match="If value is str, it should be 'random'"):
  2301. transforms.RandomErasing(value="abc")
  2302. with pytest.raises(TypeError, match="Scale should be a sequence"):
  2303. transforms.RandomErasing(scale=123)
  2304. with pytest.raises(TypeError, match="Ratio should be a sequence"):
  2305. transforms.RandomErasing(ratio=123)
  2306. with pytest.raises(ValueError, match="Scale should be between 0 and 1"):
  2307. transforms.RandomErasing(scale=[-1, 2])
  2308. transform = transforms.RandomErasing(value=[1, 2, 3, 4])
  2309. with pytest.raises(ValueError, match="If value is a sequence, it should have either a single value"):
  2310. transform._get_params([make_image()])
  2311. @pytest.mark.parametrize("make_input", [make_bounding_boxes, make_detection_mask])
  2312. def test_transform_passthrough(self, make_input):
  2313. transform = transforms.RandomErasing(p=1)
  2314. input = make_input(self.INPUT_SIZE)
  2315. with pytest.warns(UserWarning, match="currently passing through inputs of type"):
  2316. # RandomErasing requires an image or video to be present
  2317. _, output = transform(make_image(self.INPUT_SIZE), input)
  2318. assert output is input
  2319. class TestGaussianBlur:
  2320. @pytest.mark.parametrize(
  2321. "make_input",
  2322. [make_image_tensor, make_image_pil, make_image, make_bounding_boxes, make_segmentation_mask, make_video],
  2323. )
  2324. @pytest.mark.parametrize("device", cpu_and_cuda())
  2325. @pytest.mark.parametrize("sigma", [5, (0.5, 2)])
  2326. def test_transform(self, make_input, device, sigma):
  2327. check_transform(transforms.GaussianBlur(kernel_size=3, sigma=sigma), make_input(device=device))
  2328. def test_assertions(self):
  2329. with pytest.raises(ValueError, match="Kernel size should be a tuple/list of two integers"):
  2330. transforms.GaussianBlur([10, 12, 14])
  2331. with pytest.raises(ValueError, match="Kernel size value should be an odd and positive number"):
  2332. transforms.GaussianBlur(4)
  2333. with pytest.raises(ValueError, match="If sigma is a sequence its length should be 1 or 2. Got 3"):
  2334. transforms.GaussianBlur(3, sigma=[1, 2, 3])
  2335. with pytest.raises(ValueError, match="sigma values should be positive and of the form"):
  2336. transforms.GaussianBlur(3, sigma=-1.0)
  2337. with pytest.raises(ValueError, match="sigma values should be positive and of the form"):
  2338. transforms.GaussianBlur(3, sigma=[2.0, 1.0])
  2339. with pytest.raises(TypeError, match="sigma should be a number or a sequence of numbers"):
  2340. transforms.GaussianBlur(3, sigma={})
  2341. @pytest.mark.parametrize("sigma", [10.0, [10.0, 12.0], (10, 12.0), [10]])
  2342. def test__get_params(self, sigma):
  2343. transform = transforms.GaussianBlur(3, sigma=sigma)
  2344. params = transform._get_params([])
  2345. if isinstance(sigma, float):
  2346. assert params["sigma"][0] == params["sigma"][1] == sigma
  2347. elif isinstance(sigma, list) and len(sigma) == 1:
  2348. assert params["sigma"][0] == params["sigma"][1] == sigma[0]
  2349. else:
  2350. assert sigma[0] <= params["sigma"][0] <= sigma[1]
  2351. assert sigma[0] <= params["sigma"][1] <= sigma[1]