test_transforms_v2_functional.py 36 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958
  1. import inspect
  2. import math
  3. import os
  4. import re
  5. import numpy as np
  6. import PIL.Image
  7. import pytest
  8. import torch
  9. from common_utils import assert_close, cache, cpu_and_cuda, needs_cuda, set_rng_seed
  10. from torch.utils._pytree import tree_map
  11. from torchvision import tv_tensors
  12. from torchvision.transforms.functional import _get_perspective_coeffs
  13. from torchvision.transforms.v2 import functional as F
  14. from torchvision.transforms.v2._utils import is_pure_tensor
  15. from torchvision.transforms.v2.functional._geometry import _center_crop_compute_padding
  16. from torchvision.transforms.v2.functional._meta import clamp_bounding_boxes, convert_bounding_box_format
  17. from transforms_v2_dispatcher_infos import DISPATCHER_INFOS
  18. from transforms_v2_kernel_infos import KERNEL_INFOS
  19. from transforms_v2_legacy_utils import (
  20. DEFAULT_SQUARE_SPATIAL_SIZE,
  21. make_multiple_bounding_boxes,
  22. parametrized_error_message,
  23. )
  24. KERNEL_INFOS_MAP = {info.kernel: info for info in KERNEL_INFOS}
  25. DISPATCHER_INFOS_MAP = {info.dispatcher: info for info in DISPATCHER_INFOS}
  26. @cache
  27. def script(fn):
  28. try:
  29. return torch.jit.script(fn)
  30. except Exception as error:
  31. raise AssertionError(f"Trying to `torch.jit.script` '{fn.__name__}' raised the error above.") from error
  32. # Scripting a function often triggers a warning like
  33. # `UserWarning: operator() profile_node %$INT1 : int[] = prim::profile_ivalue($INT2) does not have profile information`
  34. # with varying `INT1` and `INT2`. Since these are uninteresting for us and only clutter the test summary, we ignore
  35. # them.
  36. ignore_jit_warning_no_profile = pytest.mark.filterwarnings(
  37. f"ignore:{re.escape('operator() profile_node %')}:UserWarning"
  38. )
  39. def make_info_args_kwargs_params(info, *, args_kwargs_fn, test_id=None):
  40. args_kwargs = list(args_kwargs_fn(info))
  41. if not args_kwargs:
  42. raise pytest.UsageError(
  43. f"Couldn't collect a single `ArgsKwargs` for `{info.id}`{f' in {test_id}' if test_id else ''}"
  44. )
  45. idx_field_len = len(str(len(args_kwargs)))
  46. return [
  47. pytest.param(
  48. info,
  49. args_kwargs_,
  50. marks=info.get_marks(test_id, args_kwargs_) if test_id else [],
  51. id=f"{info.id}-{idx:0{idx_field_len}}",
  52. )
  53. for idx, args_kwargs_ in enumerate(args_kwargs)
  54. ]
  55. def make_info_args_kwargs_parametrization(infos, *, args_kwargs_fn):
  56. def decorator(test_fn):
  57. parts = test_fn.__qualname__.split(".")
  58. if len(parts) == 1:
  59. test_class_name = None
  60. test_function_name = parts[0]
  61. elif len(parts) == 2:
  62. test_class_name, test_function_name = parts
  63. else:
  64. raise pytest.UsageError("Unable to parse the test class name and test function name from test function")
  65. test_id = (test_class_name, test_function_name)
  66. argnames = ("info", "args_kwargs")
  67. argvalues = []
  68. for info in infos:
  69. argvalues.extend(make_info_args_kwargs_params(info, args_kwargs_fn=args_kwargs_fn, test_id=test_id))
  70. return pytest.mark.parametrize(argnames, argvalues)(test_fn)
  71. return decorator
  72. @pytest.fixture(autouse=True)
  73. def fix_rng_seed():
  74. set_rng_seed(0)
  75. yield
  76. @pytest.fixture()
  77. def test_id(request):
  78. test_class_name = request.cls.__name__ if request.cls is not None else None
  79. test_function_name = request.node.originalname
  80. return test_class_name, test_function_name
  81. class TestKernels:
  82. sample_inputs = make_info_args_kwargs_parametrization(
  83. KERNEL_INFOS,
  84. args_kwargs_fn=lambda kernel_info: kernel_info.sample_inputs_fn(),
  85. )
  86. reference_inputs = make_info_args_kwargs_parametrization(
  87. [info for info in KERNEL_INFOS if info.reference_fn is not None],
  88. args_kwargs_fn=lambda info: info.reference_inputs_fn(),
  89. )
  90. @make_info_args_kwargs_parametrization(
  91. [info for info in KERNEL_INFOS if info.logs_usage],
  92. args_kwargs_fn=lambda info: info.sample_inputs_fn(),
  93. )
  94. @pytest.mark.parametrize("device", cpu_and_cuda())
  95. def test_logging(self, spy_on, info, args_kwargs, device):
  96. spy = spy_on(torch._C._log_api_usage_once)
  97. (input, *other_args), kwargs = args_kwargs.load(device)
  98. info.kernel(input.as_subclass(torch.Tensor), *other_args, **kwargs)
  99. spy.assert_any_call(f"{info.kernel.__module__}.{info.id}")
  100. @ignore_jit_warning_no_profile
  101. @sample_inputs
  102. @pytest.mark.parametrize("device", cpu_and_cuda())
  103. def test_scripted_vs_eager(self, test_id, info, args_kwargs, device):
  104. kernel_eager = info.kernel
  105. kernel_scripted = script(kernel_eager)
  106. (input, *other_args), kwargs = args_kwargs.load(device)
  107. input = input.as_subclass(torch.Tensor)
  108. actual = kernel_scripted(input, *other_args, **kwargs)
  109. expected = kernel_eager(input, *other_args, **kwargs)
  110. assert_close(
  111. actual,
  112. expected,
  113. **info.get_closeness_kwargs(test_id, dtype=input.dtype, device=input.device),
  114. msg=parametrized_error_message(input, other_args, **kwargs),
  115. )
  116. def _unbatch(self, batch, *, data_dims):
  117. if isinstance(batch, torch.Tensor):
  118. batched_tensor = batch
  119. metadata = ()
  120. else:
  121. batched_tensor, *metadata = batch
  122. if batched_tensor.ndim == data_dims:
  123. return batch
  124. return [
  125. self._unbatch(unbatched, data_dims=data_dims)
  126. for unbatched in (
  127. batched_tensor.unbind(0) if not metadata else [(t, *metadata) for t in batched_tensor.unbind(0)]
  128. )
  129. ]
  130. @sample_inputs
  131. @pytest.mark.parametrize("device", cpu_and_cuda())
  132. def test_batched_vs_single(self, test_id, info, args_kwargs, device):
  133. (batched_input, *other_args), kwargs = args_kwargs.load(device)
  134. tv_tensor_type = tv_tensors.Image if is_pure_tensor(batched_input) else type(batched_input)
  135. # This dictionary contains the number of rightmost dimensions that contain the actual data.
  136. # Everything to the left is considered a batch dimension.
  137. data_dims = {
  138. tv_tensors.Image: 3,
  139. tv_tensors.BoundingBoxes: 1,
  140. # `Mask`'s are special in the sense that the data dimensions depend on the type of mask. For detection masks
  141. # it is 3 `(*, N, H, W)`, but for segmentation masks it is 2 `(*, H, W)`. Since both a grouped under one
  142. # type all kernels should also work without differentiating between the two. Thus, we go with 2 here as
  143. # common ground.
  144. tv_tensors.Mask: 2,
  145. tv_tensors.Video: 4,
  146. }.get(tv_tensor_type)
  147. if data_dims is None:
  148. raise pytest.UsageError(
  149. f"The number of data dimensions cannot be determined for input of type {tv_tensor_type.__name__}."
  150. ) from None
  151. elif batched_input.ndim <= data_dims:
  152. pytest.skip("Input is not batched.")
  153. elif not all(batched_input.shape[:-data_dims]):
  154. pytest.skip("Input has a degenerate batch shape.")
  155. batched_input = batched_input.as_subclass(torch.Tensor)
  156. batched_output = info.kernel(batched_input, *other_args, **kwargs)
  157. actual = self._unbatch(batched_output, data_dims=data_dims)
  158. single_inputs = self._unbatch(batched_input, data_dims=data_dims)
  159. expected = tree_map(lambda single_input: info.kernel(single_input, *other_args, **kwargs), single_inputs)
  160. assert_close(
  161. actual,
  162. expected,
  163. **info.get_closeness_kwargs(test_id, dtype=batched_input.dtype, device=batched_input.device),
  164. msg=parametrized_error_message(batched_input, *other_args, **kwargs),
  165. )
  166. @sample_inputs
  167. @pytest.mark.parametrize("device", cpu_and_cuda())
  168. def test_no_inplace(self, info, args_kwargs, device):
  169. (input, *other_args), kwargs = args_kwargs.load(device)
  170. input = input.as_subclass(torch.Tensor)
  171. if input.numel() == 0:
  172. pytest.skip("The input has a degenerate shape.")
  173. input_version = input._version
  174. info.kernel(input, *other_args, **kwargs)
  175. assert input._version == input_version
  176. @sample_inputs
  177. @needs_cuda
  178. def test_cuda_vs_cpu(self, test_id, info, args_kwargs):
  179. (input_cpu, *other_args), kwargs = args_kwargs.load("cpu")
  180. input_cpu = input_cpu.as_subclass(torch.Tensor)
  181. input_cuda = input_cpu.to("cuda")
  182. output_cpu = info.kernel(input_cpu, *other_args, **kwargs)
  183. output_cuda = info.kernel(input_cuda, *other_args, **kwargs)
  184. assert_close(
  185. output_cuda,
  186. output_cpu,
  187. check_device=False,
  188. **info.get_closeness_kwargs(test_id, dtype=input_cuda.dtype, device=input_cuda.device),
  189. msg=parametrized_error_message(input_cpu, *other_args, **kwargs),
  190. )
  191. @sample_inputs
  192. @pytest.mark.parametrize("device", cpu_and_cuda())
  193. def test_dtype_and_device_consistency(self, info, args_kwargs, device):
  194. (input, *other_args), kwargs = args_kwargs.load(device)
  195. input = input.as_subclass(torch.Tensor)
  196. output = info.kernel(input, *other_args, **kwargs)
  197. # Most kernels just return a tensor, but some also return some additional metadata
  198. if not isinstance(output, torch.Tensor):
  199. output, *_ = output
  200. assert output.dtype == input.dtype
  201. assert output.device == input.device
  202. @reference_inputs
  203. def test_against_reference(self, test_id, info, args_kwargs):
  204. (input, *other_args), kwargs = args_kwargs.load("cpu")
  205. actual = info.kernel(input.as_subclass(torch.Tensor), *other_args, **kwargs)
  206. # We intnetionally don't unwrap the input of the reference function in order for it to have access to all
  207. # metadata regardless of whether the kernel takes it explicitly or not
  208. expected = info.reference_fn(input, *other_args, **kwargs)
  209. assert_close(
  210. actual,
  211. expected,
  212. **info.get_closeness_kwargs(test_id, dtype=input.dtype, device=input.device),
  213. msg=parametrized_error_message(input, *other_args, **kwargs),
  214. )
  215. @make_info_args_kwargs_parametrization(
  216. [info for info in KERNEL_INFOS if info.float32_vs_uint8],
  217. args_kwargs_fn=lambda info: info.reference_inputs_fn(),
  218. )
  219. def test_float32_vs_uint8(self, test_id, info, args_kwargs):
  220. (input, *other_args), kwargs = args_kwargs.load("cpu")
  221. input = input.as_subclass(torch.Tensor)
  222. if input.dtype != torch.uint8:
  223. pytest.skip(f"Input dtype is {input.dtype}.")
  224. adapted_other_args, adapted_kwargs = info.float32_vs_uint8(other_args, kwargs)
  225. actual = info.kernel(
  226. F.to_dtype_image(input, dtype=torch.float32, scale=True),
  227. *adapted_other_args,
  228. **adapted_kwargs,
  229. )
  230. expected = F.to_dtype_image(info.kernel(input, *other_args, **kwargs), dtype=torch.float32, scale=True)
  231. assert_close(
  232. actual,
  233. expected,
  234. **info.get_closeness_kwargs(test_id, dtype=torch.float32, device=input.device),
  235. msg=parametrized_error_message(input, *other_args, **kwargs),
  236. )
  237. @pytest.fixture
  238. def spy_on(mocker):
  239. def make_spy(fn, *, module=None, name=None):
  240. # TODO: we can probably get rid of the non-default modules and names if we eliminate aliasing
  241. module = module or fn.__module__
  242. name = name or fn.__name__
  243. spy = mocker.patch(f"{module}.{name}", wraps=fn)
  244. return spy
  245. return make_spy
  246. class TestDispatchers:
  247. image_sample_inputs = make_info_args_kwargs_parametrization(
  248. [info for info in DISPATCHER_INFOS if tv_tensors.Image in info.kernels],
  249. args_kwargs_fn=lambda info: info.sample_inputs(tv_tensors.Image),
  250. )
  251. @make_info_args_kwargs_parametrization(
  252. DISPATCHER_INFOS,
  253. args_kwargs_fn=lambda info: info.sample_inputs(),
  254. )
  255. @pytest.mark.parametrize("device", cpu_and_cuda())
  256. def test_logging(self, spy_on, info, args_kwargs, device):
  257. spy = spy_on(torch._C._log_api_usage_once)
  258. args, kwargs = args_kwargs.load(device)
  259. info.dispatcher(*args, **kwargs)
  260. spy.assert_any_call(f"{info.dispatcher.__module__}.{info.id}")
  261. @ignore_jit_warning_no_profile
  262. @image_sample_inputs
  263. @pytest.mark.parametrize("device", cpu_and_cuda())
  264. def test_scripted_smoke(self, info, args_kwargs, device):
  265. dispatcher = script(info.dispatcher)
  266. (image_tv_tensor, *other_args), kwargs = args_kwargs.load(device)
  267. image_pure_tensor = torch.Tensor(image_tv_tensor)
  268. dispatcher(image_pure_tensor, *other_args, **kwargs)
  269. # TODO: We need this until the dispatchers below also have `DispatcherInfo`'s. If they do, `test_scripted_smoke`
  270. # replaces this test for them.
  271. @ignore_jit_warning_no_profile
  272. @pytest.mark.parametrize(
  273. "dispatcher",
  274. [
  275. F.get_dimensions,
  276. F.get_image_num_channels,
  277. F.get_image_size,
  278. F.get_num_channels,
  279. F.get_num_frames,
  280. F.get_size,
  281. F.rgb_to_grayscale,
  282. F.uniform_temporal_subsample,
  283. ],
  284. ids=lambda dispatcher: dispatcher.__name__,
  285. )
  286. def test_scriptable(self, dispatcher):
  287. script(dispatcher)
  288. @image_sample_inputs
  289. def test_pure_tensor_output_type(self, info, args_kwargs):
  290. (image_tv_tensor, *other_args), kwargs = args_kwargs.load()
  291. image_pure_tensor = image_tv_tensor.as_subclass(torch.Tensor)
  292. output = info.dispatcher(image_pure_tensor, *other_args, **kwargs)
  293. # We cannot use `isinstance` here since all tv_tensors are instances of `torch.Tensor` as well
  294. assert type(output) is torch.Tensor
  295. @make_info_args_kwargs_parametrization(
  296. [info for info in DISPATCHER_INFOS if info.pil_kernel_info is not None],
  297. args_kwargs_fn=lambda info: info.sample_inputs(tv_tensors.Image),
  298. )
  299. def test_pil_output_type(self, info, args_kwargs):
  300. (image_tv_tensor, *other_args), kwargs = args_kwargs.load()
  301. if image_tv_tensor.ndim > 3:
  302. pytest.skip("Input is batched")
  303. image_pil = F.to_pil_image(image_tv_tensor)
  304. output = info.dispatcher(image_pil, *other_args, **kwargs)
  305. assert isinstance(output, PIL.Image.Image)
  306. @make_info_args_kwargs_parametrization(
  307. DISPATCHER_INFOS,
  308. args_kwargs_fn=lambda info: info.sample_inputs(),
  309. )
  310. def test_tv_tensor_output_type(self, info, args_kwargs):
  311. (tv_tensor, *other_args), kwargs = args_kwargs.load()
  312. output = info.dispatcher(tv_tensor, *other_args, **kwargs)
  313. assert isinstance(output, type(tv_tensor))
  314. if isinstance(tv_tensor, tv_tensors.BoundingBoxes) and info.dispatcher is not F.convert_bounding_box_format:
  315. assert output.format == tv_tensor.format
  316. @pytest.mark.parametrize(
  317. ("dispatcher_info", "tv_tensor_type", "kernel_info"),
  318. [
  319. pytest.param(
  320. dispatcher_info, tv_tensor_type, kernel_info, id=f"{dispatcher_info.id}-{tv_tensor_type.__name__}"
  321. )
  322. for dispatcher_info in DISPATCHER_INFOS
  323. for tv_tensor_type, kernel_info in dispatcher_info.kernel_infos.items()
  324. ],
  325. )
  326. def test_dispatcher_kernel_signatures_consistency(self, dispatcher_info, tv_tensor_type, kernel_info):
  327. dispatcher_signature = inspect.signature(dispatcher_info.dispatcher)
  328. dispatcher_params = list(dispatcher_signature.parameters.values())[1:]
  329. kernel_signature = inspect.signature(kernel_info.kernel)
  330. kernel_params = list(kernel_signature.parameters.values())[1:]
  331. # We filter out metadata that is implicitly passed to the dispatcher through the input tv_tensor, but has to be
  332. # explicitly passed to the kernel.
  333. input_type = {v: k for k, v in dispatcher_info.kernels.items()}.get(kernel_info.kernel)
  334. explicit_metadata = {
  335. tv_tensors.BoundingBoxes: {"format", "canvas_size"},
  336. }
  337. kernel_params = [param for param in kernel_params if param.name not in explicit_metadata.get(input_type, set())]
  338. dispatcher_params = iter(dispatcher_params)
  339. for dispatcher_param, kernel_param in zip(dispatcher_params, kernel_params):
  340. try:
  341. # In general, the dispatcher parameters are a superset of the kernel parameters. Thus, we filter out
  342. # dispatcher parameters that have no kernel equivalent while keeping the order intact.
  343. while dispatcher_param.name != kernel_param.name:
  344. dispatcher_param = next(dispatcher_params)
  345. except StopIteration:
  346. raise AssertionError(
  347. f"Parameter `{kernel_param.name}` of kernel `{kernel_info.id}` "
  348. f"has no corresponding parameter on the dispatcher `{dispatcher_info.id}`."
  349. ) from None
  350. assert dispatcher_param == kernel_param
  351. @pytest.mark.parametrize("info", DISPATCHER_INFOS, ids=lambda info: info.id)
  352. def test_unkown_type(self, info):
  353. unkown_input = object()
  354. (_, *other_args), kwargs = next(iter(info.sample_inputs())).load("cpu")
  355. with pytest.raises(TypeError, match=re.escape(str(type(unkown_input)))):
  356. info.dispatcher(unkown_input, *other_args, **kwargs)
  357. @make_info_args_kwargs_parametrization(
  358. [
  359. info
  360. for info in DISPATCHER_INFOS
  361. if tv_tensors.BoundingBoxes in info.kernels and info.dispatcher is not F.convert_bounding_box_format
  362. ],
  363. args_kwargs_fn=lambda info: info.sample_inputs(tv_tensors.BoundingBoxes),
  364. )
  365. def test_bounding_boxes_format_consistency(self, info, args_kwargs):
  366. (bounding_boxes, *other_args), kwargs = args_kwargs.load()
  367. format = bounding_boxes.format
  368. output = info.dispatcher(bounding_boxes, *other_args, **kwargs)
  369. assert output.format == format
  370. @pytest.mark.parametrize(
  371. ("alias", "target"),
  372. [
  373. pytest.param(alias, target, id=alias.__name__)
  374. for alias, target in [
  375. (F.hflip, F.horizontal_flip),
  376. (F.vflip, F.vertical_flip),
  377. (F.get_image_num_channels, F.get_num_channels),
  378. (F.to_pil_image, F.to_pil_image),
  379. (F.elastic_transform, F.elastic),
  380. (F.to_grayscale, F.rgb_to_grayscale),
  381. ]
  382. ],
  383. )
  384. def test_alias(alias, target):
  385. assert alias is target
  386. @pytest.mark.parametrize("device", cpu_and_cuda())
  387. @pytest.mark.parametrize("num_channels", [1, 3])
  388. def test_normalize_image_tensor_stats(device, num_channels):
  389. stats = pytest.importorskip("scipy.stats", reason="SciPy is not available")
  390. def assert_samples_from_standard_normal(t):
  391. p_value = stats.kstest(t.flatten(), cdf="norm", args=(0, 1)).pvalue
  392. return p_value > 1e-4
  393. image = torch.rand(num_channels, DEFAULT_SQUARE_SPATIAL_SIZE, DEFAULT_SQUARE_SPATIAL_SIZE)
  394. mean = image.mean(dim=(1, 2)).tolist()
  395. std = image.std(dim=(1, 2)).tolist()
  396. assert_samples_from_standard_normal(F.normalize_image(image, mean, std))
  397. class TestClampBoundingBoxes:
  398. @pytest.mark.parametrize(
  399. "metadata",
  400. [
  401. dict(),
  402. dict(format=tv_tensors.BoundingBoxFormat.XYXY),
  403. dict(canvas_size=(1, 1)),
  404. ],
  405. )
  406. def test_pure_tensor_insufficient_metadata(self, metadata):
  407. pure_tensor = next(make_multiple_bounding_boxes()).as_subclass(torch.Tensor)
  408. with pytest.raises(ValueError, match=re.escape("`format` and `canvas_size` has to be passed")):
  409. F.clamp_bounding_boxes(pure_tensor, **metadata)
  410. @pytest.mark.parametrize(
  411. "metadata",
  412. [
  413. dict(format=tv_tensors.BoundingBoxFormat.XYXY),
  414. dict(canvas_size=(1, 1)),
  415. dict(format=tv_tensors.BoundingBoxFormat.XYXY, canvas_size=(1, 1)),
  416. ],
  417. )
  418. def test_tv_tensor_explicit_metadata(self, metadata):
  419. tv_tensor = next(make_multiple_bounding_boxes())
  420. with pytest.raises(ValueError, match=re.escape("`format` and `canvas_size` must not be passed")):
  421. F.clamp_bounding_boxes(tv_tensor, **metadata)
  422. class TestConvertFormatBoundingBoxes:
  423. @pytest.mark.parametrize(
  424. ("inpt", "old_format"),
  425. [
  426. (next(make_multiple_bounding_boxes()), None),
  427. (next(make_multiple_bounding_boxes()).as_subclass(torch.Tensor), tv_tensors.BoundingBoxFormat.XYXY),
  428. ],
  429. )
  430. def test_missing_new_format(self, inpt, old_format):
  431. with pytest.raises(TypeError, match=re.escape("missing 1 required argument: 'new_format'")):
  432. F.convert_bounding_box_format(inpt, old_format)
  433. def test_pure_tensor_insufficient_metadata(self):
  434. pure_tensor = next(make_multiple_bounding_boxes()).as_subclass(torch.Tensor)
  435. with pytest.raises(ValueError, match=re.escape("`old_format` has to be passed")):
  436. F.convert_bounding_box_format(pure_tensor, new_format=tv_tensors.BoundingBoxFormat.CXCYWH)
  437. def test_tv_tensor_explicit_metadata(self):
  438. tv_tensor = next(make_multiple_bounding_boxes())
  439. with pytest.raises(ValueError, match=re.escape("`old_format` must not be passed")):
  440. F.convert_bounding_box_format(
  441. tv_tensor, old_format=tv_tensor.format, new_format=tv_tensors.BoundingBoxFormat.CXCYWH
  442. )
  443. # TODO: All correctness checks below this line should be ported to be references on a `KernelInfo` in
  444. # `transforms_v2_kernel_infos.py`
  445. def _compute_affine_matrix(angle_, translate_, scale_, shear_, center_):
  446. rot = math.radians(angle_)
  447. cx, cy = center_
  448. tx, ty = translate_
  449. sx, sy = [math.radians(sh_) for sh_ in shear_]
  450. c_matrix = np.array([[1, 0, cx], [0, 1, cy], [0, 0, 1]])
  451. t_matrix = np.array([[1, 0, tx], [0, 1, ty], [0, 0, 1]])
  452. c_matrix_inv = np.linalg.inv(c_matrix)
  453. rs_matrix = np.array(
  454. [
  455. [scale_ * math.cos(rot), -scale_ * math.sin(rot), 0],
  456. [scale_ * math.sin(rot), scale_ * math.cos(rot), 0],
  457. [0, 0, 1],
  458. ]
  459. )
  460. shear_x_matrix = np.array([[1, -math.tan(sx), 0], [0, 1, 0], [0, 0, 1]])
  461. shear_y_matrix = np.array([[1, 0, 0], [-math.tan(sy), 1, 0], [0, 0, 1]])
  462. rss_matrix = np.matmul(rs_matrix, np.matmul(shear_y_matrix, shear_x_matrix))
  463. true_matrix = np.matmul(t_matrix, np.matmul(c_matrix, np.matmul(rss_matrix, c_matrix_inv)))
  464. return true_matrix
  465. @pytest.mark.parametrize("device", cpu_and_cuda())
  466. def test_correctness_vertical_flip_segmentation_mask_on_fixed_input(device):
  467. mask = torch.zeros((3, 3, 3), dtype=torch.long, device=device)
  468. mask[:, 0, :] = 1
  469. out_mask = F.vertical_flip_mask(mask)
  470. expected_mask = torch.zeros((3, 3, 3), dtype=torch.long, device=device)
  471. expected_mask[:, -1, :] = 1
  472. torch.testing.assert_close(out_mask, expected_mask)
  473. @pytest.mark.parametrize("device", cpu_and_cuda())
  474. @pytest.mark.parametrize(
  475. "format",
  476. [tv_tensors.BoundingBoxFormat.XYXY, tv_tensors.BoundingBoxFormat.XYWH, tv_tensors.BoundingBoxFormat.CXCYWH],
  477. )
  478. @pytest.mark.parametrize(
  479. "top, left, height, width, size",
  480. [
  481. [0, 0, 30, 30, (60, 60)],
  482. [-5, 5, 35, 45, (32, 34)],
  483. ],
  484. )
  485. def test_correctness_resized_crop_bounding_boxes(device, format, top, left, height, width, size):
  486. def _compute_expected_bbox(bbox, top_, left_, height_, width_, size_):
  487. # bbox should be xyxy
  488. bbox[0] = (bbox[0] - left_) * size_[1] / width_
  489. bbox[1] = (bbox[1] - top_) * size_[0] / height_
  490. bbox[2] = (bbox[2] - left_) * size_[1] / width_
  491. bbox[3] = (bbox[3] - top_) * size_[0] / height_
  492. return bbox
  493. format = tv_tensors.BoundingBoxFormat.XYXY
  494. canvas_size = (100, 100)
  495. in_boxes = [
  496. [10.0, 10.0, 20.0, 20.0],
  497. [5.0, 10.0, 15.0, 20.0],
  498. ]
  499. expected_bboxes = []
  500. for in_box in in_boxes:
  501. expected_bboxes.append(_compute_expected_bbox(list(in_box), top, left, height, width, size))
  502. expected_bboxes = torch.tensor(expected_bboxes, device=device)
  503. in_boxes = tv_tensors.BoundingBoxes(
  504. in_boxes, format=tv_tensors.BoundingBoxFormat.XYXY, canvas_size=canvas_size, device=device
  505. )
  506. if format != tv_tensors.BoundingBoxFormat.XYXY:
  507. in_boxes = convert_bounding_box_format(in_boxes, tv_tensors.BoundingBoxFormat.XYXY, format)
  508. output_boxes, output_canvas_size = F.resized_crop_bounding_boxes(in_boxes, format, top, left, height, width, size)
  509. if format != tv_tensors.BoundingBoxFormat.XYXY:
  510. output_boxes = convert_bounding_box_format(output_boxes, format, tv_tensors.BoundingBoxFormat.XYXY)
  511. torch.testing.assert_close(output_boxes, expected_bboxes)
  512. torch.testing.assert_close(output_canvas_size, size)
  513. def _parse_padding(padding):
  514. if isinstance(padding, int):
  515. return [padding] * 4
  516. if isinstance(padding, list):
  517. if len(padding) == 1:
  518. return padding * 4
  519. if len(padding) == 2:
  520. return padding * 2 # [left, up, right, down]
  521. return padding
  522. @pytest.mark.parametrize("device", cpu_and_cuda())
  523. @pytest.mark.parametrize("padding", [[1], [1, 1], [1, 1, 2, 2]])
  524. def test_correctness_pad_bounding_boxes(device, padding):
  525. def _compute_expected_bbox(bbox, format, padding_):
  526. pad_left, pad_up, _, _ = _parse_padding(padding_)
  527. dtype = bbox.dtype
  528. bbox = (
  529. bbox.clone()
  530. if format == tv_tensors.BoundingBoxFormat.XYXY
  531. else convert_bounding_box_format(bbox, old_format=format, new_format=tv_tensors.BoundingBoxFormat.XYXY)
  532. )
  533. bbox[0::2] += pad_left
  534. bbox[1::2] += pad_up
  535. bbox = convert_bounding_box_format(bbox, old_format=tv_tensors.BoundingBoxFormat.XYXY, new_format=format)
  536. if bbox.dtype != dtype:
  537. # Temporary cast to original dtype
  538. # e.g. float32 -> int
  539. bbox = bbox.to(dtype)
  540. return bbox
  541. def _compute_expected_canvas_size(bbox, padding_):
  542. pad_left, pad_up, pad_right, pad_down = _parse_padding(padding_)
  543. height, width = bbox.canvas_size
  544. return height + pad_up + pad_down, width + pad_left + pad_right
  545. for bboxes in make_multiple_bounding_boxes(extra_dims=((4,),)):
  546. bboxes = bboxes.to(device)
  547. bboxes_format = bboxes.format
  548. bboxes_canvas_size = bboxes.canvas_size
  549. output_boxes, output_canvas_size = F.pad_bounding_boxes(
  550. bboxes, format=bboxes_format, canvas_size=bboxes_canvas_size, padding=padding
  551. )
  552. torch.testing.assert_close(output_canvas_size, _compute_expected_canvas_size(bboxes, padding))
  553. expected_bboxes = torch.stack(
  554. [_compute_expected_bbox(b, bboxes_format, padding) for b in bboxes.reshape(-1, 4).unbind()]
  555. ).reshape(bboxes.shape)
  556. torch.testing.assert_close(output_boxes, expected_bboxes, atol=1, rtol=0)
  557. @pytest.mark.parametrize("device", cpu_and_cuda())
  558. def test_correctness_pad_segmentation_mask_on_fixed_input(device):
  559. mask = torch.ones((1, 3, 3), dtype=torch.long, device=device)
  560. out_mask = F.pad_mask(mask, padding=[1, 1, 1, 1])
  561. expected_mask = torch.zeros((1, 5, 5), dtype=torch.long, device=device)
  562. expected_mask[:, 1:-1, 1:-1] = 1
  563. torch.testing.assert_close(out_mask, expected_mask)
  564. @pytest.mark.parametrize("device", cpu_and_cuda())
  565. @pytest.mark.parametrize(
  566. "startpoints, endpoints",
  567. [
  568. [[[0, 0], [33, 0], [33, 25], [0, 25]], [[3, 2], [32, 3], [30, 24], [2, 25]]],
  569. [[[3, 2], [32, 3], [30, 24], [2, 25]], [[0, 0], [33, 0], [33, 25], [0, 25]]],
  570. [[[3, 2], [32, 3], [30, 24], [2, 25]], [[5, 5], [30, 3], [33, 19], [4, 25]]],
  571. ],
  572. )
  573. def test_correctness_perspective_bounding_boxes(device, startpoints, endpoints):
  574. def _compute_expected_bbox(bbox, format_, canvas_size_, pcoeffs_):
  575. m1 = np.array(
  576. [
  577. [pcoeffs_[0], pcoeffs_[1], pcoeffs_[2]],
  578. [pcoeffs_[3], pcoeffs_[4], pcoeffs_[5]],
  579. ]
  580. )
  581. m2 = np.array(
  582. [
  583. [pcoeffs_[6], pcoeffs_[7], 1.0],
  584. [pcoeffs_[6], pcoeffs_[7], 1.0],
  585. ]
  586. )
  587. bbox_xyxy = convert_bounding_box_format(bbox, old_format=format_, new_format=tv_tensors.BoundingBoxFormat.XYXY)
  588. points = np.array(
  589. [
  590. [bbox_xyxy[0].item(), bbox_xyxy[1].item(), 1.0],
  591. [bbox_xyxy[2].item(), bbox_xyxy[1].item(), 1.0],
  592. [bbox_xyxy[0].item(), bbox_xyxy[3].item(), 1.0],
  593. [bbox_xyxy[2].item(), bbox_xyxy[3].item(), 1.0],
  594. ]
  595. )
  596. numer = np.matmul(points, m1.T)
  597. denom = np.matmul(points, m2.T)
  598. transformed_points = numer / denom
  599. out_bbox = np.array(
  600. [
  601. np.min(transformed_points[:, 0]),
  602. np.min(transformed_points[:, 1]),
  603. np.max(transformed_points[:, 0]),
  604. np.max(transformed_points[:, 1]),
  605. ]
  606. )
  607. out_bbox = torch.from_numpy(out_bbox)
  608. out_bbox = convert_bounding_box_format(
  609. out_bbox, old_format=tv_tensors.BoundingBoxFormat.XYXY, new_format=format_
  610. )
  611. return clamp_bounding_boxes(out_bbox, format=format_, canvas_size=canvas_size_).to(bbox)
  612. canvas_size = (32, 38)
  613. pcoeffs = _get_perspective_coeffs(startpoints, endpoints)
  614. inv_pcoeffs = _get_perspective_coeffs(endpoints, startpoints)
  615. for bboxes in make_multiple_bounding_boxes(spatial_size=canvas_size, extra_dims=((4,),)):
  616. bboxes = bboxes.to(device)
  617. output_bboxes = F.perspective_bounding_boxes(
  618. bboxes.as_subclass(torch.Tensor),
  619. format=bboxes.format,
  620. canvas_size=bboxes.canvas_size,
  621. startpoints=None,
  622. endpoints=None,
  623. coefficients=pcoeffs,
  624. )
  625. expected_bboxes = torch.stack(
  626. [
  627. _compute_expected_bbox(b, bboxes.format, bboxes.canvas_size, inv_pcoeffs)
  628. for b in bboxes.reshape(-1, 4).unbind()
  629. ]
  630. ).reshape(bboxes.shape)
  631. torch.testing.assert_close(output_bboxes, expected_bboxes, rtol=0, atol=1)
  632. @pytest.mark.parametrize("device", cpu_and_cuda())
  633. @pytest.mark.parametrize(
  634. "output_size",
  635. [(18, 18), [18, 15], (16, 19), [12], [46, 48]],
  636. )
  637. def test_correctness_center_crop_bounding_boxes(device, output_size):
  638. def _compute_expected_bbox(bbox, format_, canvas_size_, output_size_):
  639. dtype = bbox.dtype
  640. bbox = convert_bounding_box_format(bbox.float(), format_, tv_tensors.BoundingBoxFormat.XYWH)
  641. if len(output_size_) == 1:
  642. output_size_.append(output_size_[-1])
  643. cy = int(round((canvas_size_[0] - output_size_[0]) * 0.5))
  644. cx = int(round((canvas_size_[1] - output_size_[1]) * 0.5))
  645. out_bbox = [
  646. bbox[0].item() - cx,
  647. bbox[1].item() - cy,
  648. bbox[2].item(),
  649. bbox[3].item(),
  650. ]
  651. out_bbox = torch.tensor(out_bbox)
  652. out_bbox = convert_bounding_box_format(out_bbox, tv_tensors.BoundingBoxFormat.XYWH, format_)
  653. out_bbox = clamp_bounding_boxes(out_bbox, format=format_, canvas_size=output_size)
  654. return out_bbox.to(dtype=dtype, device=bbox.device)
  655. for bboxes in make_multiple_bounding_boxes(extra_dims=((4,),)):
  656. bboxes = bboxes.to(device)
  657. bboxes_format = bboxes.format
  658. bboxes_canvas_size = bboxes.canvas_size
  659. output_boxes, output_canvas_size = F.center_crop_bounding_boxes(
  660. bboxes, bboxes_format, bboxes_canvas_size, output_size
  661. )
  662. expected_bboxes = torch.stack(
  663. [
  664. _compute_expected_bbox(b, bboxes_format, bboxes_canvas_size, output_size)
  665. for b in bboxes.reshape(-1, 4).unbind()
  666. ]
  667. ).reshape(bboxes.shape)
  668. torch.testing.assert_close(output_boxes, expected_bboxes, atol=1, rtol=0)
  669. torch.testing.assert_close(output_canvas_size, output_size)
  670. @pytest.mark.parametrize("device", cpu_and_cuda())
  671. @pytest.mark.parametrize("output_size", [[4, 2], [4], [7, 6]])
  672. def test_correctness_center_crop_mask(device, output_size):
  673. def _compute_expected_mask(mask, output_size):
  674. crop_height, crop_width = output_size if len(output_size) > 1 else [output_size[0], output_size[0]]
  675. _, image_height, image_width = mask.shape
  676. if crop_width > image_height or crop_height > image_width:
  677. padding = _center_crop_compute_padding(crop_height, crop_width, image_height, image_width)
  678. mask = F.pad_image(mask, padding, fill=0)
  679. left = round((image_width - crop_width) * 0.5)
  680. top = round((image_height - crop_height) * 0.5)
  681. return mask[:, top : top + crop_height, left : left + crop_width]
  682. mask = torch.randint(0, 2, size=(1, 6, 6), dtype=torch.long, device=device)
  683. actual = F.center_crop_mask(mask, output_size)
  684. expected = _compute_expected_mask(mask, output_size)
  685. torch.testing.assert_close(expected, actual)
  686. # Copied from test/test_functional_tensor.py
  687. @pytest.mark.parametrize("device", cpu_and_cuda())
  688. @pytest.mark.parametrize("canvas_size", ("small", "large"))
  689. @pytest.mark.parametrize("dt", [None, torch.float32, torch.float64, torch.float16])
  690. @pytest.mark.parametrize("ksize", [(3, 3), [3, 5], (23, 23)])
  691. @pytest.mark.parametrize("sigma", [[0.5, 0.5], (0.5, 0.5), (0.8, 0.8), (1.7, 1.7)])
  692. def test_correctness_gaussian_blur_image_tensor(device, canvas_size, dt, ksize, sigma):
  693. fn = F.gaussian_blur_image
  694. # true_cv2_results = {
  695. # # np_img = np.arange(3 * 10 * 12, dtype="uint8").reshape((10, 12, 3))
  696. # # cv2.GaussianBlur(np_img, ksize=(3, 3), sigmaX=0.8)
  697. # "3_3_0.8": ...
  698. # # cv2.GaussianBlur(np_img, ksize=(3, 3), sigmaX=0.5)
  699. # "3_3_0.5": ...
  700. # # cv2.GaussianBlur(np_img, ksize=(3, 5), sigmaX=0.8)
  701. # "3_5_0.8": ...
  702. # # cv2.GaussianBlur(np_img, ksize=(3, 5), sigmaX=0.5)
  703. # "3_5_0.5": ...
  704. # # np_img2 = np.arange(26 * 28, dtype="uint8").reshape((26, 28))
  705. # # cv2.GaussianBlur(np_img2, ksize=(23, 23), sigmaX=1.7)
  706. # "23_23_1.7": ...
  707. # }
  708. p = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "gaussian_blur_opencv_results.pt")
  709. true_cv2_results = torch.load(p)
  710. if canvas_size == "small":
  711. tensor = (
  712. torch.from_numpy(np.arange(3 * 10 * 12, dtype="uint8").reshape((10, 12, 3))).permute(2, 0, 1).to(device)
  713. )
  714. else:
  715. tensor = torch.from_numpy(np.arange(26 * 28, dtype="uint8").reshape((1, 26, 28))).to(device)
  716. if dt == torch.float16 and device == "cpu":
  717. # skip float16 on CPU case
  718. return
  719. if dt is not None:
  720. tensor = tensor.to(dtype=dt)
  721. _ksize = (ksize, ksize) if isinstance(ksize, int) else ksize
  722. _sigma = sigma[0] if sigma is not None else None
  723. shape = tensor.shape
  724. gt_key = f"{shape[-2]}_{shape[-1]}_{shape[-3]}__{_ksize[0]}_{_ksize[1]}_{_sigma}"
  725. if gt_key not in true_cv2_results:
  726. return
  727. true_out = (
  728. torch.tensor(true_cv2_results[gt_key]).reshape(shape[-2], shape[-1], shape[-3]).permute(2, 0, 1).to(tensor)
  729. )
  730. image = tv_tensors.Image(tensor)
  731. out = fn(image, kernel_size=ksize, sigma=sigma)
  732. torch.testing.assert_close(out, true_out, rtol=0.0, atol=1.0, msg=f"{ksize}, {sigma}")
  733. @pytest.mark.parametrize(
  734. "inpt",
  735. [
  736. 127 * np.ones((32, 32, 3), dtype="uint8"),
  737. PIL.Image.new("RGB", (32, 32), 122),
  738. ],
  739. )
  740. def test_to_image(inpt):
  741. output = F.to_image(inpt)
  742. assert isinstance(output, torch.Tensor)
  743. assert output.shape == (3, 32, 32)
  744. assert np.asarray(inpt).sum() == output.sum().item()
  745. @pytest.mark.parametrize(
  746. "inpt",
  747. [
  748. torch.randint(0, 256, size=(3, 32, 32), dtype=torch.uint8),
  749. 127 * np.ones((32, 32, 3), dtype="uint8"),
  750. ],
  751. )
  752. @pytest.mark.parametrize("mode", [None, "RGB"])
  753. def test_to_pil_image(inpt, mode):
  754. output = F.to_pil_image(inpt, mode=mode)
  755. assert isinstance(output, PIL.Image.Image)
  756. assert np.asarray(inpt).sum() == np.asarray(output).sum()
  757. def test_equalize_image_tensor_edge_cases():
  758. inpt = torch.zeros(3, 200, 200, dtype=torch.uint8)
  759. output = F.equalize_image(inpt)
  760. torch.testing.assert_close(inpt, output)
  761. inpt = torch.zeros(5, 3, 200, 200, dtype=torch.uint8)
  762. inpt[..., 100:, 100:] = 1
  763. output = F.equalize_image(inpt)
  764. assert output.unique().tolist() == [0, 255]
  765. @pytest.mark.parametrize("device", cpu_and_cuda())
  766. def test_correctness_uniform_temporal_subsample(device):
  767. video = torch.arange(10, device=device)[:, None, None, None].expand(-1, 3, 8, 8)
  768. out_video = F.uniform_temporal_subsample(video, 5)
  769. assert out_video.unique().tolist() == [0, 2, 4, 6, 9]
  770. out_video = F.uniform_temporal_subsample(video, 8)
  771. assert out_video.unique().tolist() == [0, 1, 2, 3, 5, 6, 7, 9]