test_transforms_v2.py 44 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185
  1. import itertools
  2. import pathlib
  3. import pickle
  4. import random
  5. import warnings
  6. import numpy as np
  7. import PIL.Image
  8. import pytest
  9. import torch
  10. import torchvision.transforms.v2 as transforms
  11. from common_utils import assert_equal, cpu_and_cuda
  12. from torch.utils._pytree import tree_flatten, tree_unflatten
  13. from torchvision import tv_tensors
  14. from torchvision.ops.boxes import box_iou
  15. from torchvision.transforms.functional import to_pil_image
  16. from torchvision.transforms.v2 import functional as F
  17. from torchvision.transforms.v2._utils import check_type, is_pure_tensor, query_chw
  18. from transforms_v2_legacy_utils import (
  19. make_bounding_boxes,
  20. make_detection_mask,
  21. make_image,
  22. make_images,
  23. make_multiple_bounding_boxes,
  24. make_segmentation_mask,
  25. make_video,
  26. make_videos,
  27. )
  28. def make_vanilla_tensor_images(*args, **kwargs):
  29. for image in make_images(*args, **kwargs):
  30. if image.ndim > 3:
  31. continue
  32. yield image.data
  33. def make_pil_images(*args, **kwargs):
  34. for image in make_vanilla_tensor_images(*args, **kwargs):
  35. yield to_pil_image(image)
  36. def make_vanilla_tensor_bounding_boxes(*args, **kwargs):
  37. for bounding_boxes in make_multiple_bounding_boxes(*args, **kwargs):
  38. yield bounding_boxes.data
  39. def parametrize(transforms_with_inputs):
  40. return pytest.mark.parametrize(
  41. ("transform", "input"),
  42. [
  43. pytest.param(
  44. transform,
  45. input,
  46. id=f"{type(transform).__name__}-{type(input).__module__}.{type(input).__name__}-{idx}",
  47. )
  48. for transform, inputs in transforms_with_inputs
  49. for idx, input in enumerate(inputs)
  50. ],
  51. )
  52. def auto_augment_adapter(transform, input, device):
  53. adapted_input = {}
  54. image_or_video_found = False
  55. for key, value in input.items():
  56. if isinstance(value, (tv_tensors.BoundingBoxes, tv_tensors.Mask)):
  57. # AA transforms don't support bounding boxes or masks
  58. continue
  59. elif check_type(value, (tv_tensors.Image, tv_tensors.Video, is_pure_tensor, PIL.Image.Image)):
  60. if image_or_video_found:
  61. # AA transforms only support a single image or video
  62. continue
  63. image_or_video_found = True
  64. adapted_input[key] = value
  65. return adapted_input
  66. def linear_transformation_adapter(transform, input, device):
  67. flat_inputs = list(input.values())
  68. c, h, w = query_chw(
  69. [
  70. item
  71. for item, needs_transform in zip(flat_inputs, transforms.Transform()._needs_transform_list(flat_inputs))
  72. if needs_transform
  73. ]
  74. )
  75. num_elements = c * h * w
  76. transform.transformation_matrix = torch.randn((num_elements, num_elements), device=device)
  77. transform.mean_vector = torch.randn((num_elements,), device=device)
  78. return {key: value for key, value in input.items() if not isinstance(value, PIL.Image.Image)}
  79. def normalize_adapter(transform, input, device):
  80. adapted_input = {}
  81. for key, value in input.items():
  82. if isinstance(value, PIL.Image.Image):
  83. # normalize doesn't support PIL images
  84. continue
  85. elif check_type(value, (tv_tensors.Image, tv_tensors.Video, is_pure_tensor)):
  86. # normalize doesn't support integer images
  87. value = F.to_dtype(value, torch.float32, scale=True)
  88. adapted_input[key] = value
  89. return adapted_input
  90. class TestSmoke:
  91. @pytest.mark.parametrize(
  92. ("transform", "adapter"),
  93. [
  94. (transforms.RandomErasing(p=1.0), None),
  95. (transforms.AugMix(), auto_augment_adapter),
  96. (transforms.AutoAugment(), auto_augment_adapter),
  97. (transforms.RandAugment(), auto_augment_adapter),
  98. (transforms.TrivialAugmentWide(), auto_augment_adapter),
  99. (transforms.ColorJitter(brightness=0.1, contrast=0.2, saturation=0.3, hue=0.15), None),
  100. (transforms.Grayscale(), None),
  101. (transforms.RandomAdjustSharpness(sharpness_factor=0.5, p=1.0), None),
  102. (transforms.RandomAutocontrast(p=1.0), None),
  103. (transforms.RandomEqualize(p=1.0), None),
  104. (transforms.RandomGrayscale(p=1.0), None),
  105. (transforms.RandomInvert(p=1.0), None),
  106. (transforms.RandomChannelPermutation(), None),
  107. (transforms.RandomPhotometricDistort(p=1.0), None),
  108. (transforms.RandomPosterize(bits=4, p=1.0), None),
  109. (transforms.RandomSolarize(threshold=0.5, p=1.0), None),
  110. (transforms.CenterCrop([16, 16]), None),
  111. (transforms.ElasticTransform(sigma=1.0), None),
  112. (transforms.Pad(4), None),
  113. (transforms.RandomAffine(degrees=30.0), None),
  114. (transforms.RandomCrop([16, 16], pad_if_needed=True), None),
  115. (transforms.RandomHorizontalFlip(p=1.0), None),
  116. (transforms.RandomPerspective(p=1.0), None),
  117. (transforms.RandomResize(min_size=10, max_size=20, antialias=True), None),
  118. (transforms.RandomResizedCrop([16, 16], antialias=True), None),
  119. (transforms.RandomRotation(degrees=30), None),
  120. (transforms.RandomShortestSize(min_size=10, antialias=True), None),
  121. (transforms.RandomVerticalFlip(p=1.0), None),
  122. (transforms.RandomZoomOut(p=1.0), None),
  123. (transforms.Resize([16, 16], antialias=True), None),
  124. (transforms.ScaleJitter((16, 16), scale_range=(0.8, 1.2), antialias=True), None),
  125. (transforms.ClampBoundingBoxes(), None),
  126. (transforms.ConvertBoundingBoxFormat(tv_tensors.BoundingBoxFormat.CXCYWH), None),
  127. (transforms.ConvertImageDtype(), None),
  128. (transforms.GaussianBlur(kernel_size=3), None),
  129. (
  130. transforms.LinearTransformation(
  131. # These are just dummy values that will be filled by the adapter. We can't define them upfront,
  132. # because for we neither know the spatial size nor the device at this point
  133. transformation_matrix=torch.empty((1, 1)),
  134. mean_vector=torch.empty((1,)),
  135. ),
  136. linear_transformation_adapter,
  137. ),
  138. (transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), normalize_adapter),
  139. (transforms.ToDtype(torch.float64), None),
  140. (transforms.UniformTemporalSubsample(num_samples=2), None),
  141. ],
  142. ids=lambda transform: type(transform).__name__,
  143. )
  144. @pytest.mark.parametrize("container_type", [dict, list, tuple])
  145. @pytest.mark.parametrize(
  146. "image_or_video",
  147. [
  148. make_image(),
  149. make_video(),
  150. next(make_pil_images(color_spaces=["RGB"])),
  151. next(make_vanilla_tensor_images()),
  152. ],
  153. )
  154. @pytest.mark.parametrize("de_serialize", [lambda t: t, lambda t: pickle.loads(pickle.dumps(t))])
  155. @pytest.mark.parametrize("device", cpu_and_cuda())
  156. def test_common(self, transform, adapter, container_type, image_or_video, de_serialize, device):
  157. transform = de_serialize(transform)
  158. canvas_size = F.get_size(image_or_video)
  159. input = dict(
  160. image_or_video=image_or_video,
  161. image_tv_tensor=make_image(size=canvas_size),
  162. video_tv_tensor=make_video(size=canvas_size),
  163. image_pil=next(make_pil_images(sizes=[canvas_size], color_spaces=["RGB"])),
  164. bounding_boxes_xyxy=make_bounding_boxes(
  165. format=tv_tensors.BoundingBoxFormat.XYXY, canvas_size=canvas_size, batch_dims=(3,)
  166. ),
  167. bounding_boxes_xywh=make_bounding_boxes(
  168. format=tv_tensors.BoundingBoxFormat.XYWH, canvas_size=canvas_size, batch_dims=(4,)
  169. ),
  170. bounding_boxes_cxcywh=make_bounding_boxes(
  171. format=tv_tensors.BoundingBoxFormat.CXCYWH, canvas_size=canvas_size, batch_dims=(5,)
  172. ),
  173. bounding_boxes_degenerate_xyxy=tv_tensors.BoundingBoxes(
  174. [
  175. [0, 0, 0, 0], # no height or width
  176. [0, 0, 0, 1], # no height
  177. [0, 0, 1, 0], # no width
  178. [2, 0, 1, 1], # x1 > x2, y1 < y2
  179. [0, 2, 1, 1], # x1 < x2, y1 > y2
  180. [2, 2, 1, 1], # x1 > x2, y1 > y2
  181. ],
  182. format=tv_tensors.BoundingBoxFormat.XYXY,
  183. canvas_size=canvas_size,
  184. ),
  185. bounding_boxes_degenerate_xywh=tv_tensors.BoundingBoxes(
  186. [
  187. [0, 0, 0, 0], # no height or width
  188. [0, 0, 0, 1], # no height
  189. [0, 0, 1, 0], # no width
  190. [0, 0, 1, -1], # negative height
  191. [0, 0, -1, 1], # negative width
  192. [0, 0, -1, -1], # negative height and width
  193. ],
  194. format=tv_tensors.BoundingBoxFormat.XYWH,
  195. canvas_size=canvas_size,
  196. ),
  197. bounding_boxes_degenerate_cxcywh=tv_tensors.BoundingBoxes(
  198. [
  199. [0, 0, 0, 0], # no height or width
  200. [0, 0, 0, 1], # no height
  201. [0, 0, 1, 0], # no width
  202. [0, 0, 1, -1], # negative height
  203. [0, 0, -1, 1], # negative width
  204. [0, 0, -1, -1], # negative height and width
  205. ],
  206. format=tv_tensors.BoundingBoxFormat.CXCYWH,
  207. canvas_size=canvas_size,
  208. ),
  209. detection_mask=make_detection_mask(size=canvas_size),
  210. segmentation_mask=make_segmentation_mask(size=canvas_size),
  211. int=0,
  212. float=0.0,
  213. bool=True,
  214. none=None,
  215. str="str",
  216. path=pathlib.Path.cwd(),
  217. object=object(),
  218. tensor=torch.empty(5),
  219. array=np.empty(5),
  220. )
  221. if adapter is not None:
  222. input = adapter(transform, input, device)
  223. if container_type in {tuple, list}:
  224. input = container_type(input.values())
  225. input_flat, input_spec = tree_flatten(input)
  226. input_flat = [item.to(device) if isinstance(item, torch.Tensor) else item for item in input_flat]
  227. input = tree_unflatten(input_flat, input_spec)
  228. torch.manual_seed(0)
  229. output = transform(input)
  230. output_flat, output_spec = tree_flatten(output)
  231. assert output_spec == input_spec
  232. for output_item, input_item, should_be_transformed in zip(
  233. output_flat, input_flat, transforms.Transform()._needs_transform_list(input_flat)
  234. ):
  235. if should_be_transformed:
  236. assert type(output_item) is type(input_item)
  237. else:
  238. assert output_item is input_item
  239. if isinstance(input_item, tv_tensors.BoundingBoxes) and not isinstance(
  240. transform, transforms.ConvertBoundingBoxFormat
  241. ):
  242. assert output_item.format == input_item.format
  243. # Enforce that the transform does not turn a degenerate box marked by RandomIoUCrop (or any other future
  244. # transform that does this), back into a valid one.
  245. # TODO: we should test that against all degenerate boxes above
  246. for format in list(tv_tensors.BoundingBoxFormat):
  247. sample = dict(
  248. boxes=tv_tensors.BoundingBoxes([[0, 0, 0, 0]], format=format, canvas_size=(224, 244)),
  249. labels=torch.tensor([3]),
  250. )
  251. assert transforms.SanitizeBoundingBoxes()(sample)["boxes"].shape == (0, 4)
  252. @parametrize(
  253. [
  254. (
  255. transform,
  256. itertools.chain.from_iterable(
  257. fn(
  258. color_spaces=[
  259. "GRAY",
  260. "RGB",
  261. ],
  262. dtypes=[torch.uint8],
  263. extra_dims=[(), (4,)],
  264. **(dict(num_frames=[3]) if fn is make_videos else dict()),
  265. )
  266. for fn in [
  267. make_images,
  268. make_vanilla_tensor_images,
  269. make_pil_images,
  270. make_videos,
  271. ]
  272. ),
  273. )
  274. for transform in (
  275. transforms.RandAugment(),
  276. transforms.TrivialAugmentWide(),
  277. transforms.AutoAugment(),
  278. transforms.AugMix(),
  279. )
  280. ]
  281. )
  282. def test_auto_augment(self, transform, input):
  283. transform(input)
  284. @parametrize(
  285. [
  286. (
  287. transforms.Normalize(mean=[0.0, 0.0, 0.0], std=[1.0, 1.0, 1.0]),
  288. itertools.chain.from_iterable(
  289. fn(color_spaces=["RGB"], dtypes=[torch.float32])
  290. for fn in [
  291. make_images,
  292. make_vanilla_tensor_images,
  293. make_videos,
  294. ]
  295. ),
  296. ),
  297. ]
  298. )
  299. def test_normalize(self, transform, input):
  300. transform(input)
  301. @parametrize(
  302. [
  303. (
  304. transforms.RandomResizedCrop([16, 16], antialias=True),
  305. itertools.chain(
  306. make_images(extra_dims=[(4,)]),
  307. make_vanilla_tensor_images(),
  308. make_pil_images(),
  309. make_videos(extra_dims=[()]),
  310. ),
  311. )
  312. ]
  313. )
  314. def test_random_resized_crop(self, transform, input):
  315. transform(input)
  316. @pytest.mark.parametrize(
  317. "flat_inputs",
  318. itertools.permutations(
  319. [
  320. next(make_vanilla_tensor_images()),
  321. next(make_vanilla_tensor_images()),
  322. next(make_pil_images()),
  323. make_image(),
  324. next(make_videos()),
  325. ],
  326. 3,
  327. ),
  328. )
  329. def test_pure_tensor_heuristic(flat_inputs):
  330. def split_on_pure_tensor(to_split):
  331. # This takes a sequence that is structurally aligned with `flat_inputs` and splits its items into three parts:
  332. # 1. The first pure tensor. If none is present, this will be `None`
  333. # 2. A list of the remaining pure tensors
  334. # 3. A list of all other items
  335. pure_tensors = []
  336. others = []
  337. # Splitting always happens on the original `flat_inputs` to avoid any erroneous type changes by the transform to
  338. # affect the splitting.
  339. for item, inpt in zip(to_split, flat_inputs):
  340. (pure_tensors if is_pure_tensor(inpt) else others).append(item)
  341. return pure_tensors[0] if pure_tensors else None, pure_tensors[1:], others
  342. class CopyCloneTransform(transforms.Transform):
  343. def _transform(self, inpt, params):
  344. return inpt.clone() if isinstance(inpt, torch.Tensor) else inpt.copy()
  345. @staticmethod
  346. def was_applied(output, inpt):
  347. identity = output is inpt
  348. if identity:
  349. return False
  350. # Make sure nothing fishy is going on
  351. assert_equal(output, inpt)
  352. return True
  353. first_pure_tensor_input, other_pure_tensor_inputs, other_inputs = split_on_pure_tensor(flat_inputs)
  354. transform = CopyCloneTransform()
  355. transformed_sample = transform(flat_inputs)
  356. first_pure_tensor_output, other_pure_tensor_outputs, other_outputs = split_on_pure_tensor(transformed_sample)
  357. if first_pure_tensor_input is not None:
  358. if other_inputs:
  359. assert not transform.was_applied(first_pure_tensor_output, first_pure_tensor_input)
  360. else:
  361. assert transform.was_applied(first_pure_tensor_output, first_pure_tensor_input)
  362. for output, inpt in zip(other_pure_tensor_outputs, other_pure_tensor_inputs):
  363. assert not transform.was_applied(output, inpt)
  364. for input, output in zip(other_inputs, other_outputs):
  365. assert transform.was_applied(output, input)
  366. class TestPad:
  367. def test_assertions(self):
  368. with pytest.raises(TypeError, match="Got inappropriate padding arg"):
  369. transforms.Pad("abc")
  370. with pytest.raises(ValueError, match="Padding must be an int or a 1, 2, or 4"):
  371. transforms.Pad([-0.7, 0, 0.7])
  372. with pytest.raises(TypeError, match="Got inappropriate fill arg"):
  373. transforms.Pad(12, fill="abc")
  374. with pytest.raises(ValueError, match="Padding mode should be either"):
  375. transforms.Pad(12, padding_mode="abc")
  376. class TestRandomZoomOut:
  377. def test_assertions(self):
  378. with pytest.raises(TypeError, match="Got inappropriate fill arg"):
  379. transforms.RandomZoomOut(fill="abc")
  380. with pytest.raises(TypeError, match="should be a sequence of length"):
  381. transforms.RandomZoomOut(0, side_range=0)
  382. with pytest.raises(ValueError, match="Invalid canvas side range"):
  383. transforms.RandomZoomOut(0, side_range=[4.0, 1.0])
  384. @pytest.mark.parametrize("fill", [0, [1, 2, 3], (2, 3, 4)])
  385. @pytest.mark.parametrize("side_range", [(1.0, 4.0), [2.0, 5.0]])
  386. def test__get_params(self, fill, side_range):
  387. transform = transforms.RandomZoomOut(fill=fill, side_range=side_range)
  388. h, w = size = (24, 32)
  389. image = make_image(size)
  390. params = transform._get_params([image])
  391. assert len(params["padding"]) == 4
  392. assert 0 <= params["padding"][0] <= (side_range[1] - 1) * w
  393. assert 0 <= params["padding"][1] <= (side_range[1] - 1) * h
  394. assert 0 <= params["padding"][2] <= (side_range[1] - 1) * w
  395. assert 0 <= params["padding"][3] <= (side_range[1] - 1) * h
  396. class TestRandomPerspective:
  397. def test_assertions(self):
  398. with pytest.raises(ValueError, match="Argument distortion_scale value should be between 0 and 1"):
  399. transforms.RandomPerspective(distortion_scale=-1.0)
  400. with pytest.raises(TypeError, match="Got inappropriate fill arg"):
  401. transforms.RandomPerspective(0.5, fill="abc")
  402. def test__get_params(self):
  403. dscale = 0.5
  404. transform = transforms.RandomPerspective(dscale)
  405. image = make_image((24, 32))
  406. params = transform._get_params([image])
  407. assert "coefficients" in params
  408. assert len(params["coefficients"]) == 8
  409. class TestElasticTransform:
  410. def test_assertions(self):
  411. with pytest.raises(TypeError, match="alpha should be a number or a sequence of numbers"):
  412. transforms.ElasticTransform({})
  413. with pytest.raises(ValueError, match="alpha is a sequence its length should be 1 or 2"):
  414. transforms.ElasticTransform([1.0, 2.0, 3.0])
  415. with pytest.raises(TypeError, match="sigma should be a number or a sequence of numbers"):
  416. transforms.ElasticTransform(1.0, {})
  417. with pytest.raises(ValueError, match="sigma is a sequence its length should be 1 or 2"):
  418. transforms.ElasticTransform(1.0, [1.0, 2.0, 3.0])
  419. with pytest.raises(TypeError, match="Got inappropriate fill arg"):
  420. transforms.ElasticTransform(1.0, 2.0, fill="abc")
  421. def test__get_params(self):
  422. alpha = 2.0
  423. sigma = 3.0
  424. transform = transforms.ElasticTransform(alpha, sigma)
  425. h, w = size = (24, 32)
  426. image = make_image(size)
  427. params = transform._get_params([image])
  428. displacement = params["displacement"]
  429. assert displacement.shape == (1, h, w, 2)
  430. assert (-alpha / w <= displacement[0, ..., 0]).all() and (displacement[0, ..., 0] <= alpha / w).all()
  431. assert (-alpha / h <= displacement[0, ..., 1]).all() and (displacement[0, ..., 1] <= alpha / h).all()
  432. class TestTransform:
  433. @pytest.mark.parametrize(
  434. "inpt_type",
  435. [torch.Tensor, PIL.Image.Image, tv_tensors.Image, np.ndarray, tv_tensors.BoundingBoxes, str, int],
  436. )
  437. def test_check_transformed_types(self, inpt_type, mocker):
  438. # This test ensures that we correctly handle which types to transform and which to bypass
  439. t = transforms.Transform()
  440. inpt = mocker.MagicMock(spec=inpt_type)
  441. if inpt_type in (np.ndarray, str, int):
  442. output = t(inpt)
  443. assert output is inpt
  444. else:
  445. with pytest.raises(NotImplementedError):
  446. t(inpt)
  447. class TestToImage:
  448. @pytest.mark.parametrize(
  449. "inpt_type",
  450. [torch.Tensor, PIL.Image.Image, tv_tensors.Image, np.ndarray, tv_tensors.BoundingBoxes, str, int],
  451. )
  452. def test__transform(self, inpt_type, mocker):
  453. fn = mocker.patch(
  454. "torchvision.transforms.v2.functional.to_image",
  455. return_value=torch.rand(1, 3, 8, 8),
  456. )
  457. inpt = mocker.MagicMock(spec=inpt_type)
  458. transform = transforms.ToImage()
  459. transform(inpt)
  460. if inpt_type in (tv_tensors.BoundingBoxes, tv_tensors.Image, str, int):
  461. assert fn.call_count == 0
  462. else:
  463. fn.assert_called_once_with(inpt)
  464. class TestToPILImage:
  465. @pytest.mark.parametrize(
  466. "inpt_type",
  467. [torch.Tensor, PIL.Image.Image, tv_tensors.Image, np.ndarray, tv_tensors.BoundingBoxes, str, int],
  468. )
  469. def test__transform(self, inpt_type, mocker):
  470. fn = mocker.patch("torchvision.transforms.v2.functional.to_pil_image")
  471. inpt = mocker.MagicMock(spec=inpt_type)
  472. transform = transforms.ToPILImage()
  473. transform(inpt)
  474. if inpt_type in (PIL.Image.Image, tv_tensors.BoundingBoxes, str, int):
  475. assert fn.call_count == 0
  476. else:
  477. fn.assert_called_once_with(inpt, mode=transform.mode)
  478. class TestToTensor:
  479. @pytest.mark.parametrize(
  480. "inpt_type",
  481. [torch.Tensor, PIL.Image.Image, tv_tensors.Image, np.ndarray, tv_tensors.BoundingBoxes, str, int],
  482. )
  483. def test__transform(self, inpt_type, mocker):
  484. fn = mocker.patch("torchvision.transforms.functional.to_tensor")
  485. inpt = mocker.MagicMock(spec=inpt_type)
  486. with pytest.warns(UserWarning, match="deprecated and will be removed"):
  487. transform = transforms.ToTensor()
  488. transform(inpt)
  489. if inpt_type in (tv_tensors.Image, torch.Tensor, tv_tensors.BoundingBoxes, str, int):
  490. assert fn.call_count == 0
  491. else:
  492. fn.assert_called_once_with(inpt)
  493. class TestContainers:
  494. @pytest.mark.parametrize("transform_cls", [transforms.Compose, transforms.RandomChoice, transforms.RandomOrder])
  495. def test_assertions(self, transform_cls):
  496. with pytest.raises(TypeError, match="Argument transforms should be a sequence of callables"):
  497. transform_cls(transforms.RandomCrop(28))
  498. @pytest.mark.parametrize("transform_cls", [transforms.Compose, transforms.RandomChoice, transforms.RandomOrder])
  499. @pytest.mark.parametrize(
  500. "trfms",
  501. [
  502. [transforms.Pad(2), transforms.RandomCrop(28)],
  503. [lambda x: 2.0 * x, transforms.Pad(2), transforms.RandomCrop(28)],
  504. [transforms.Pad(2), lambda x: 2.0 * x, transforms.RandomCrop(28)],
  505. ],
  506. )
  507. def test_ctor(self, transform_cls, trfms):
  508. c = transform_cls(trfms)
  509. inpt = torch.rand(1, 3, 32, 32)
  510. output = c(inpt)
  511. assert isinstance(output, torch.Tensor)
  512. assert output.ndim == 4
  513. class TestRandomChoice:
  514. def test_assertions(self):
  515. with pytest.raises(ValueError, match="Length of p doesn't match the number of transforms"):
  516. transforms.RandomChoice([transforms.Pad(2), transforms.RandomCrop(28)], p=[1])
  517. class TestRandomIoUCrop:
  518. @pytest.mark.parametrize("device", cpu_and_cuda())
  519. @pytest.mark.parametrize("options", [[0.5, 0.9], [2.0]])
  520. def test__get_params(self, device, options):
  521. orig_h, orig_w = size = (24, 32)
  522. image = make_image(size)
  523. bboxes = tv_tensors.BoundingBoxes(
  524. torch.tensor([[1, 1, 10, 10], [20, 20, 23, 23], [1, 20, 10, 23], [20, 1, 23, 10]]),
  525. format="XYXY",
  526. canvas_size=size,
  527. device=device,
  528. )
  529. sample = [image, bboxes]
  530. transform = transforms.RandomIoUCrop(sampler_options=options)
  531. n_samples = 5
  532. for _ in range(n_samples):
  533. params = transform._get_params(sample)
  534. if options == [2.0]:
  535. assert len(params) == 0
  536. return
  537. assert len(params["is_within_crop_area"]) > 0
  538. assert params["is_within_crop_area"].dtype == torch.bool
  539. assert int(transform.min_scale * orig_h) <= params["height"] <= int(transform.max_scale * orig_h)
  540. assert int(transform.min_scale * orig_w) <= params["width"] <= int(transform.max_scale * orig_w)
  541. left, top = params["left"], params["top"]
  542. new_h, new_w = params["height"], params["width"]
  543. ious = box_iou(
  544. bboxes,
  545. torch.tensor([[left, top, left + new_w, top + new_h]], dtype=bboxes.dtype, device=bboxes.device),
  546. )
  547. assert ious.max() >= options[0] or ious.max() >= options[1], f"{ious} vs {options}"
  548. def test__transform_empty_params(self, mocker):
  549. transform = transforms.RandomIoUCrop(sampler_options=[2.0])
  550. image = tv_tensors.Image(torch.rand(1, 3, 4, 4))
  551. bboxes = tv_tensors.BoundingBoxes(torch.tensor([[1, 1, 2, 2]]), format="XYXY", canvas_size=(4, 4))
  552. label = torch.tensor([1])
  553. sample = [image, bboxes, label]
  554. # Let's mock transform._get_params to control the output:
  555. transform._get_params = mocker.MagicMock(return_value={})
  556. output = transform(sample)
  557. torch.testing.assert_close(output, sample)
  558. def test_forward_assertion(self):
  559. transform = transforms.RandomIoUCrop()
  560. with pytest.raises(
  561. TypeError,
  562. match="requires input sample to contain tensor or PIL images and bounding boxes",
  563. ):
  564. transform(torch.tensor(0))
  565. def test__transform(self, mocker):
  566. transform = transforms.RandomIoUCrop()
  567. size = (32, 24)
  568. image = make_image(size)
  569. bboxes = make_bounding_boxes(format="XYXY", canvas_size=size, batch_dims=(6,))
  570. masks = make_detection_mask(size, num_objects=6)
  571. sample = [image, bboxes, masks]
  572. is_within_crop_area = torch.tensor([0, 1, 0, 1, 0, 1], dtype=torch.bool)
  573. params = dict(top=1, left=2, height=12, width=12, is_within_crop_area=is_within_crop_area)
  574. transform._get_params = mocker.MagicMock(return_value=params)
  575. output = transform(sample)
  576. # check number of bboxes vs number of labels:
  577. output_bboxes = output[1]
  578. assert isinstance(output_bboxes, tv_tensors.BoundingBoxes)
  579. assert (output_bboxes[~is_within_crop_area] == 0).all()
  580. output_masks = output[2]
  581. assert isinstance(output_masks, tv_tensors.Mask)
  582. class TestScaleJitter:
  583. def test__get_params(self):
  584. canvas_size = (24, 32)
  585. target_size = (16, 12)
  586. scale_range = (0.5, 1.5)
  587. transform = transforms.ScaleJitter(target_size=target_size, scale_range=scale_range)
  588. sample = make_image(canvas_size)
  589. n_samples = 5
  590. for _ in range(n_samples):
  591. params = transform._get_params([sample])
  592. assert "size" in params
  593. size = params["size"]
  594. assert isinstance(size, tuple) and len(size) == 2
  595. height, width = size
  596. r_min = min(target_size[1] / canvas_size[0], target_size[0] / canvas_size[1]) * scale_range[0]
  597. r_max = min(target_size[1] / canvas_size[0], target_size[0] / canvas_size[1]) * scale_range[1]
  598. assert int(canvas_size[0] * r_min) <= height <= int(canvas_size[0] * r_max)
  599. assert int(canvas_size[1] * r_min) <= width <= int(canvas_size[1] * r_max)
  600. class TestRandomShortestSize:
  601. @pytest.mark.parametrize("min_size,max_size", [([5, 9], 20), ([5, 9], None)])
  602. def test__get_params(self, min_size, max_size):
  603. canvas_size = (3, 10)
  604. transform = transforms.RandomShortestSize(min_size=min_size, max_size=max_size, antialias=True)
  605. sample = make_image(canvas_size)
  606. params = transform._get_params([sample])
  607. assert "size" in params
  608. size = params["size"]
  609. assert isinstance(size, tuple) and len(size) == 2
  610. longer = max(size)
  611. shorter = min(size)
  612. if max_size is not None:
  613. assert longer <= max_size
  614. assert shorter <= max_size
  615. else:
  616. assert shorter in min_size
  617. class TestLinearTransformation:
  618. def test_assertions(self):
  619. with pytest.raises(ValueError, match="transformation_matrix should be square"):
  620. transforms.LinearTransformation(torch.rand(2, 3), torch.rand(5))
  621. with pytest.raises(ValueError, match="mean_vector should have the same length"):
  622. transforms.LinearTransformation(torch.rand(3, 3), torch.rand(5))
  623. @pytest.mark.parametrize(
  624. "inpt",
  625. [
  626. 122 * torch.ones(1, 3, 8, 8),
  627. 122.0 * torch.ones(1, 3, 8, 8),
  628. tv_tensors.Image(122 * torch.ones(1, 3, 8, 8)),
  629. PIL.Image.new("RGB", (8, 8), (122, 122, 122)),
  630. ],
  631. )
  632. def test__transform(self, inpt):
  633. v = 121 * torch.ones(3 * 8 * 8)
  634. m = torch.ones(3 * 8 * 8, 3 * 8 * 8)
  635. transform = transforms.LinearTransformation(m, v)
  636. if isinstance(inpt, PIL.Image.Image):
  637. with pytest.raises(TypeError, match="does not support PIL images"):
  638. transform(inpt)
  639. else:
  640. output = transform(inpt)
  641. assert isinstance(output, torch.Tensor)
  642. assert output.unique() == 3 * 8 * 8
  643. assert output.dtype == inpt.dtype
  644. class TestRandomResize:
  645. def test__get_params(self):
  646. min_size = 3
  647. max_size = 6
  648. transform = transforms.RandomResize(min_size=min_size, max_size=max_size, antialias=True)
  649. for _ in range(10):
  650. params = transform._get_params([])
  651. assert isinstance(params["size"], list) and len(params["size"]) == 1
  652. size = params["size"][0]
  653. assert min_size <= size < max_size
  654. class TestUniformTemporalSubsample:
  655. @pytest.mark.parametrize(
  656. "inpt",
  657. [
  658. torch.zeros(10, 3, 8, 8),
  659. torch.zeros(1, 10, 3, 8, 8),
  660. tv_tensors.Video(torch.zeros(1, 10, 3, 8, 8)),
  661. ],
  662. )
  663. def test__transform(self, inpt):
  664. num_samples = 5
  665. transform = transforms.UniformTemporalSubsample(num_samples)
  666. output = transform(inpt)
  667. assert type(output) is type(inpt)
  668. assert output.shape[-4] == num_samples
  669. assert output.dtype == inpt.dtype
  670. # TODO: remove this test in 0.17 when the default of antialias changes to True
  671. def test_antialias_warning():
  672. pil_img = PIL.Image.new("RGB", size=(10, 10), color=127)
  673. tensor_img = torch.randint(0, 256, size=(3, 10, 10), dtype=torch.uint8)
  674. tensor_video = torch.randint(0, 256, size=(2, 3, 10, 10), dtype=torch.uint8)
  675. match = "The default value of the antialias parameter"
  676. with pytest.warns(UserWarning, match=match):
  677. transforms.RandomResizedCrop((20, 20))(tensor_img)
  678. with pytest.warns(UserWarning, match=match):
  679. transforms.ScaleJitter((20, 20))(tensor_img)
  680. with pytest.warns(UserWarning, match=match):
  681. transforms.RandomShortestSize((20, 20))(tensor_img)
  682. with pytest.warns(UserWarning, match=match):
  683. transforms.RandomResize(10, 20)(tensor_img)
  684. with pytest.warns(UserWarning, match=match):
  685. F.resized_crop(tv_tensors.Image(tensor_img), 0, 0, 10, 10, (20, 20))
  686. with pytest.warns(UserWarning, match=match):
  687. F.resize(tv_tensors.Video(tensor_video), (20, 20))
  688. with pytest.warns(UserWarning, match=match):
  689. F.resized_crop(tv_tensors.Video(tensor_video), 0, 0, 10, 10, (20, 20))
  690. with warnings.catch_warnings():
  691. warnings.simplefilter("error")
  692. transforms.RandomResizedCrop((20, 20))(pil_img)
  693. transforms.ScaleJitter((20, 20))(pil_img)
  694. transforms.RandomShortestSize((20, 20))(pil_img)
  695. transforms.RandomResize(10, 20)(pil_img)
  696. transforms.RandomResizedCrop((20, 20), antialias=True)(tensor_img)
  697. transforms.ScaleJitter((20, 20), antialias=True)(tensor_img)
  698. transforms.RandomShortestSize((20, 20), antialias=True)(tensor_img)
  699. transforms.RandomResize(10, 20, antialias=True)(tensor_img)
  700. F.resized_crop(tv_tensors.Image(tensor_img), 0, 0, 10, 10, (20, 20), antialias=True)
  701. F.resized_crop(tv_tensors.Video(tensor_video), 0, 0, 10, 10, (20, 20), antialias=True)
  702. @pytest.mark.parametrize("image_type", (PIL.Image, torch.Tensor, tv_tensors.Image))
  703. @pytest.mark.parametrize("label_type", (torch.Tensor, int))
  704. @pytest.mark.parametrize("dataset_return_type", (dict, tuple))
  705. @pytest.mark.parametrize("to_tensor", (transforms.ToTensor, transforms.ToImage))
  706. def test_classif_preset(image_type, label_type, dataset_return_type, to_tensor):
  707. image = tv_tensors.Image(torch.randint(0, 256, size=(1, 3, 250, 250), dtype=torch.uint8))
  708. if image_type is PIL.Image:
  709. image = to_pil_image(image[0])
  710. elif image_type is torch.Tensor:
  711. image = image.as_subclass(torch.Tensor)
  712. assert is_pure_tensor(image)
  713. label = 1 if label_type is int else torch.tensor([1])
  714. if dataset_return_type is dict:
  715. sample = {
  716. "image": image,
  717. "label": label,
  718. }
  719. else:
  720. sample = image, label
  721. if to_tensor is transforms.ToTensor:
  722. with pytest.warns(UserWarning, match="deprecated and will be removed"):
  723. to_tensor = to_tensor()
  724. else:
  725. to_tensor = to_tensor()
  726. t = transforms.Compose(
  727. [
  728. transforms.RandomResizedCrop((224, 224), antialias=True),
  729. transforms.RandomHorizontalFlip(p=1),
  730. transforms.RandAugment(),
  731. transforms.TrivialAugmentWide(),
  732. transforms.AugMix(),
  733. transforms.AutoAugment(),
  734. to_tensor,
  735. # TODO: ConvertImageDtype is a pass-through on PIL images, is that
  736. # intended? This results in a failure if we convert to tensor after
  737. # it, because the image would still be uint8 which make Normalize
  738. # fail.
  739. transforms.ConvertImageDtype(torch.float),
  740. transforms.Normalize(mean=[0, 0, 0], std=[1, 1, 1]),
  741. transforms.RandomErasing(p=1),
  742. ]
  743. )
  744. out = t(sample)
  745. assert type(out) == type(sample)
  746. if dataset_return_type is tuple:
  747. out_image, out_label = out
  748. else:
  749. assert out.keys() == sample.keys()
  750. out_image, out_label = out.values()
  751. assert out_image.shape[-2:] == (224, 224)
  752. assert out_label == label
  753. @pytest.mark.parametrize("image_type", (PIL.Image, torch.Tensor, tv_tensors.Image))
  754. @pytest.mark.parametrize("data_augmentation", ("hflip", "lsj", "multiscale", "ssd", "ssdlite"))
  755. @pytest.mark.parametrize("to_tensor", (transforms.ToTensor, transforms.ToImage))
  756. @pytest.mark.parametrize("sanitize", (True, False))
  757. def test_detection_preset(image_type, data_augmentation, to_tensor, sanitize):
  758. torch.manual_seed(0)
  759. if to_tensor is transforms.ToTensor:
  760. with pytest.warns(UserWarning, match="deprecated and will be removed"):
  761. to_tensor = to_tensor()
  762. else:
  763. to_tensor = to_tensor()
  764. if data_augmentation == "hflip":
  765. t = [
  766. transforms.RandomHorizontalFlip(p=1),
  767. to_tensor,
  768. transforms.ConvertImageDtype(torch.float),
  769. ]
  770. elif data_augmentation == "lsj":
  771. t = [
  772. transforms.ScaleJitter(target_size=(1024, 1024), antialias=True),
  773. # Note: replaced FixedSizeCrop with RandomCrop, becuase we're
  774. # leaving FixedSizeCrop in prototype for now, and it expects Label
  775. # classes which we won't release yet.
  776. # transforms.FixedSizeCrop(
  777. # size=(1024, 1024), fill=defaultdict(lambda: (123.0, 117.0, 104.0), {tv_tensors.Mask: 0})
  778. # ),
  779. transforms.RandomCrop((1024, 1024), pad_if_needed=True),
  780. transforms.RandomHorizontalFlip(p=1),
  781. to_tensor,
  782. transforms.ConvertImageDtype(torch.float),
  783. ]
  784. elif data_augmentation == "multiscale":
  785. t = [
  786. transforms.RandomShortestSize(
  787. min_size=(480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800), max_size=1333, antialias=True
  788. ),
  789. transforms.RandomHorizontalFlip(p=1),
  790. to_tensor,
  791. transforms.ConvertImageDtype(torch.float),
  792. ]
  793. elif data_augmentation == "ssd":
  794. t = [
  795. transforms.RandomPhotometricDistort(p=1),
  796. transforms.RandomZoomOut(fill={"others": (123.0, 117.0, 104.0), tv_tensors.Mask: 0}, p=1),
  797. transforms.RandomIoUCrop(),
  798. transforms.RandomHorizontalFlip(p=1),
  799. to_tensor,
  800. transforms.ConvertImageDtype(torch.float),
  801. ]
  802. elif data_augmentation == "ssdlite":
  803. t = [
  804. transforms.RandomIoUCrop(),
  805. transforms.RandomHorizontalFlip(p=1),
  806. to_tensor,
  807. transforms.ConvertImageDtype(torch.float),
  808. ]
  809. if sanitize:
  810. t += [transforms.SanitizeBoundingBoxes()]
  811. t = transforms.Compose(t)
  812. num_boxes = 5
  813. H = W = 250
  814. image = tv_tensors.Image(torch.randint(0, 256, size=(1, 3, H, W), dtype=torch.uint8))
  815. if image_type is PIL.Image:
  816. image = to_pil_image(image[0])
  817. elif image_type is torch.Tensor:
  818. image = image.as_subclass(torch.Tensor)
  819. assert is_pure_tensor(image)
  820. label = torch.randint(0, 10, size=(num_boxes,))
  821. boxes = torch.randint(0, min(H, W) // 2, size=(num_boxes, 4))
  822. boxes[:, 2:] += boxes[:, :2]
  823. boxes = boxes.clamp(min=0, max=min(H, W))
  824. boxes = tv_tensors.BoundingBoxes(boxes, format="XYXY", canvas_size=(H, W))
  825. masks = tv_tensors.Mask(torch.randint(0, 2, size=(num_boxes, H, W), dtype=torch.uint8))
  826. sample = {
  827. "image": image,
  828. "label": label,
  829. "boxes": boxes,
  830. "masks": masks,
  831. }
  832. out = t(sample)
  833. if isinstance(to_tensor, transforms.ToTensor) and image_type is not tv_tensors.Image:
  834. assert is_pure_tensor(out["image"])
  835. else:
  836. assert isinstance(out["image"], tv_tensors.Image)
  837. assert isinstance(out["label"], type(sample["label"]))
  838. num_boxes_expected = {
  839. # ssd and ssdlite contain RandomIoUCrop which may "remove" some bbox. It
  840. # doesn't remove them strictly speaking, it just marks some boxes as
  841. # degenerate and those boxes will be later removed by
  842. # SanitizeBoundingBoxes(), which we add to the pipelines if the sanitize
  843. # param is True.
  844. # Note that the values below are probably specific to the random seed
  845. # set above (which is fine).
  846. (True, "ssd"): 5,
  847. (True, "ssdlite"): 4,
  848. }.get((sanitize, data_augmentation), num_boxes)
  849. assert out["boxes"].shape[0] == out["masks"].shape[0] == out["label"].shape[0] == num_boxes_expected
  850. @pytest.mark.parametrize("min_size", (1, 10))
  851. @pytest.mark.parametrize("labels_getter", ("default", lambda inputs: inputs["labels"], None, lambda inputs: None))
  852. @pytest.mark.parametrize("sample_type", (tuple, dict))
  853. def test_sanitize_bounding_boxes(min_size, labels_getter, sample_type):
  854. if sample_type is tuple and not isinstance(labels_getter, str):
  855. # The "lambda inputs: inputs["labels"]" labels_getter used in this test
  856. # doesn't work if the input is a tuple.
  857. return
  858. H, W = 256, 128
  859. boxes_and_validity = [
  860. ([0, 1, 10, 1], False), # Y1 == Y2
  861. ([0, 1, 0, 20], False), # X1 == X2
  862. ([0, 0, min_size - 1, 10], False), # H < min_size
  863. ([0, 0, 10, min_size - 1], False), # W < min_size
  864. ([0, 0, 10, H + 1], False), # Y2 > H
  865. ([0, 0, W + 1, 10], False), # X2 > W
  866. ([-1, 1, 10, 20], False), # any < 0
  867. ([0, 0, -1, 20], False), # any < 0
  868. ([0, 0, -10, -1], False), # any < 0
  869. ([0, 0, min_size, 10], True), # H < min_size
  870. ([0, 0, 10, min_size], True), # W < min_size
  871. ([0, 0, W, H], True), # TODO: Is that actually OK?? Should it be -1?
  872. ([1, 1, 30, 20], True),
  873. ([0, 0, 10, 10], True),
  874. ([1, 1, 30, 20], True),
  875. ]
  876. random.shuffle(boxes_and_validity) # For test robustness: mix order of wrong and correct cases
  877. boxes, is_valid_mask = zip(*boxes_and_validity)
  878. valid_indices = [i for (i, is_valid) in enumerate(is_valid_mask) if is_valid]
  879. boxes = torch.tensor(boxes)
  880. labels = torch.arange(boxes.shape[0])
  881. boxes = tv_tensors.BoundingBoxes(
  882. boxes,
  883. format=tv_tensors.BoundingBoxFormat.XYXY,
  884. canvas_size=(H, W),
  885. )
  886. masks = tv_tensors.Mask(torch.randint(0, 2, size=(boxes.shape[0], H, W)))
  887. whatever = torch.rand(10)
  888. input_img = torch.randint(0, 256, size=(1, 3, H, W), dtype=torch.uint8)
  889. sample = {
  890. "image": input_img,
  891. "labels": labels,
  892. "boxes": boxes,
  893. "whatever": whatever,
  894. "None": None,
  895. "masks": masks,
  896. }
  897. if sample_type is tuple:
  898. img = sample.pop("image")
  899. sample = (img, sample)
  900. out = transforms.SanitizeBoundingBoxes(min_size=min_size, labels_getter=labels_getter)(sample)
  901. if sample_type is tuple:
  902. out_image = out[0]
  903. out_labels = out[1]["labels"]
  904. out_boxes = out[1]["boxes"]
  905. out_masks = out[1]["masks"]
  906. out_whatever = out[1]["whatever"]
  907. else:
  908. out_image = out["image"]
  909. out_labels = out["labels"]
  910. out_boxes = out["boxes"]
  911. out_masks = out["masks"]
  912. out_whatever = out["whatever"]
  913. assert out_image is input_img
  914. assert out_whatever is whatever
  915. assert isinstance(out_boxes, tv_tensors.BoundingBoxes)
  916. assert isinstance(out_masks, tv_tensors.Mask)
  917. if labels_getter is None or (callable(labels_getter) and labels_getter({"labels": "blah"}) is None):
  918. assert out_labels is labels
  919. else:
  920. assert isinstance(out_labels, torch.Tensor)
  921. assert out_boxes.shape[0] == out_labels.shape[0] == out_masks.shape[0]
  922. # This works because we conveniently set labels to arange(num_boxes)
  923. assert out_labels.tolist() == valid_indices
  924. def test_sanitize_bounding_boxes_no_label():
  925. # Non-regression test for https://github.com/pytorch/vision/issues/7878
  926. img = make_image()
  927. boxes = make_bounding_boxes()
  928. with pytest.raises(ValueError, match="or a two-tuple whose second item is a dict"):
  929. transforms.SanitizeBoundingBoxes()(img, boxes)
  930. out_img, out_boxes = transforms.SanitizeBoundingBoxes(labels_getter=None)(img, boxes)
  931. assert isinstance(out_img, tv_tensors.Image)
  932. assert isinstance(out_boxes, tv_tensors.BoundingBoxes)
  933. def test_sanitize_bounding_boxes_errors():
  934. good_bbox = tv_tensors.BoundingBoxes(
  935. [[0, 0, 10, 10]],
  936. format=tv_tensors.BoundingBoxFormat.XYXY,
  937. canvas_size=(20, 20),
  938. )
  939. with pytest.raises(ValueError, match="min_size must be >= 1"):
  940. transforms.SanitizeBoundingBoxes(min_size=0)
  941. with pytest.raises(ValueError, match="labels_getter should either be 'default'"):
  942. transforms.SanitizeBoundingBoxes(labels_getter=12)
  943. with pytest.raises(ValueError, match="Could not infer where the labels are"):
  944. bad_labels_key = {"bbox": good_bbox, "BAD_KEY": torch.arange(good_bbox.shape[0])}
  945. transforms.SanitizeBoundingBoxes()(bad_labels_key)
  946. with pytest.raises(ValueError, match="must be a tensor"):
  947. not_a_tensor = {"bbox": good_bbox, "labels": torch.arange(good_bbox.shape[0]).tolist()}
  948. transforms.SanitizeBoundingBoxes()(not_a_tensor)
  949. with pytest.raises(ValueError, match="Number of boxes"):
  950. different_sizes = {"bbox": good_bbox, "labels": torch.arange(good_bbox.shape[0] + 3)}
  951. transforms.SanitizeBoundingBoxes()(different_sizes)
  952. class TestLambda:
  953. inputs = pytest.mark.parametrize("input", [object(), torch.empty(()), np.empty(()), "string", 1, 0.0])
  954. @inputs
  955. def test_default(self, input):
  956. was_applied = False
  957. def was_applied_fn(input):
  958. nonlocal was_applied
  959. was_applied = True
  960. return input
  961. transform = transforms.Lambda(was_applied_fn)
  962. transform(input)
  963. assert was_applied
  964. @inputs
  965. def test_with_types(self, input):
  966. was_applied = False
  967. def was_applied_fn(input):
  968. nonlocal was_applied
  969. was_applied = True
  970. return input
  971. types = (torch.Tensor, np.ndarray)
  972. transform = transforms.Lambda(was_applied_fn, *types)
  973. transform(input)
  974. assert was_applied is isinstance(input, types)