test_transforms_v2_consistency.py 46 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254
  1. import importlib.machinery
  2. import importlib.util
  3. import inspect
  4. import random
  5. import re
  6. from pathlib import Path
  7. import numpy as np
  8. import PIL.Image
  9. import pytest
  10. import torch
  11. import torchvision.transforms.v2 as v2_transforms
  12. from common_utils import assert_close, assert_equal, set_rng_seed
  13. from torch import nn
  14. from torchvision import transforms as legacy_transforms, tv_tensors
  15. from torchvision._utils import sequence_to_str
  16. from torchvision.transforms import functional as legacy_F
  17. from torchvision.transforms.v2 import functional as prototype_F
  18. from torchvision.transforms.v2._utils import _get_fill, query_size
  19. from torchvision.transforms.v2.functional import to_pil_image
  20. from transforms_v2_legacy_utils import (
  21. ArgsKwargs,
  22. make_bounding_boxes,
  23. make_detection_mask,
  24. make_image,
  25. make_images,
  26. make_segmentation_mask,
  27. )
  28. DEFAULT_MAKE_IMAGES_KWARGS = dict(color_spaces=["RGB"], extra_dims=[(4,)])
  29. @pytest.fixture(autouse=True)
  30. def fix_rng_seed():
  31. set_rng_seed(0)
  32. yield
  33. class NotScriptableArgsKwargs(ArgsKwargs):
  34. """
  35. This class is used to mark parameters that render the transform non-scriptable. They still work in eager mode and
  36. thus will be tested there, but will be skipped by the JIT tests.
  37. """
  38. pass
  39. class ConsistencyConfig:
  40. def __init__(
  41. self,
  42. prototype_cls,
  43. legacy_cls,
  44. # If no args_kwargs is passed, only the signature will be checked
  45. args_kwargs=(),
  46. make_images_kwargs=None,
  47. supports_pil=True,
  48. removed_params=(),
  49. closeness_kwargs=None,
  50. ):
  51. self.prototype_cls = prototype_cls
  52. self.legacy_cls = legacy_cls
  53. self.args_kwargs = args_kwargs
  54. self.make_images_kwargs = make_images_kwargs or DEFAULT_MAKE_IMAGES_KWARGS
  55. self.supports_pil = supports_pil
  56. self.removed_params = removed_params
  57. self.closeness_kwargs = closeness_kwargs or dict(rtol=0, atol=0)
  58. # These are here since both the prototype and legacy transform need to be constructed with the same random parameters
  59. LINEAR_TRANSFORMATION_MEAN = torch.rand(36)
  60. LINEAR_TRANSFORMATION_MATRIX = torch.rand([LINEAR_TRANSFORMATION_MEAN.numel()] * 2)
  61. CONSISTENCY_CONFIGS = [
  62. ConsistencyConfig(
  63. v2_transforms.Normalize,
  64. legacy_transforms.Normalize,
  65. [
  66. ArgsKwargs(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
  67. ],
  68. supports_pil=False,
  69. make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, dtypes=[torch.float]),
  70. ),
  71. ConsistencyConfig(
  72. v2_transforms.CenterCrop,
  73. legacy_transforms.CenterCrop,
  74. [
  75. ArgsKwargs(18),
  76. ArgsKwargs((18, 13)),
  77. ],
  78. ),
  79. ConsistencyConfig(
  80. v2_transforms.FiveCrop,
  81. legacy_transforms.FiveCrop,
  82. [
  83. ArgsKwargs(18),
  84. ArgsKwargs((18, 13)),
  85. ],
  86. make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, sizes=[(20, 19)]),
  87. ),
  88. ConsistencyConfig(
  89. v2_transforms.TenCrop,
  90. legacy_transforms.TenCrop,
  91. [
  92. ArgsKwargs(18),
  93. ArgsKwargs((18, 13)),
  94. ArgsKwargs(18, vertical_flip=True),
  95. ],
  96. make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, sizes=[(20, 19)]),
  97. ),
  98. ConsistencyConfig(
  99. v2_transforms.Pad,
  100. legacy_transforms.Pad,
  101. [
  102. NotScriptableArgsKwargs(3),
  103. ArgsKwargs([3]),
  104. ArgsKwargs([2, 3]),
  105. ArgsKwargs([3, 2, 1, 4]),
  106. NotScriptableArgsKwargs(5, fill=1, padding_mode="constant"),
  107. ArgsKwargs([5], fill=1, padding_mode="constant"),
  108. NotScriptableArgsKwargs(5, padding_mode="edge"),
  109. NotScriptableArgsKwargs(5, padding_mode="reflect"),
  110. NotScriptableArgsKwargs(5, padding_mode="symmetric"),
  111. ],
  112. ),
  113. *[
  114. ConsistencyConfig(
  115. v2_transforms.LinearTransformation,
  116. legacy_transforms.LinearTransformation,
  117. [
  118. ArgsKwargs(LINEAR_TRANSFORMATION_MATRIX.to(matrix_dtype), LINEAR_TRANSFORMATION_MEAN.to(matrix_dtype)),
  119. ],
  120. # Make sure that the product of the height, width and number of channels matches the number of elements in
  121. # `LINEAR_TRANSFORMATION_MEAN`. For example 2 * 6 * 3 == 4 * 3 * 3 == 36.
  122. make_images_kwargs=dict(
  123. DEFAULT_MAKE_IMAGES_KWARGS, sizes=[(2, 6), (4, 3)], color_spaces=["RGB"], dtypes=[image_dtype]
  124. ),
  125. supports_pil=False,
  126. )
  127. for matrix_dtype, image_dtype in [
  128. (torch.float32, torch.float32),
  129. (torch.float64, torch.float64),
  130. (torch.float32, torch.uint8),
  131. (torch.float64, torch.float32),
  132. (torch.float32, torch.float64),
  133. ]
  134. ],
  135. ConsistencyConfig(
  136. v2_transforms.Grayscale,
  137. legacy_transforms.Grayscale,
  138. [
  139. ArgsKwargs(num_output_channels=1),
  140. ArgsKwargs(num_output_channels=3),
  141. ],
  142. make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, color_spaces=["RGB", "GRAY"]),
  143. # Use default tolerances of `torch.testing.assert_close`
  144. closeness_kwargs=dict(rtol=None, atol=None),
  145. ),
  146. ConsistencyConfig(
  147. v2_transforms.ToPILImage,
  148. legacy_transforms.ToPILImage,
  149. [NotScriptableArgsKwargs()],
  150. make_images_kwargs=dict(
  151. color_spaces=[
  152. "GRAY",
  153. "GRAY_ALPHA",
  154. "RGB",
  155. "RGBA",
  156. ],
  157. extra_dims=[()],
  158. ),
  159. supports_pil=False,
  160. ),
  161. ConsistencyConfig(
  162. v2_transforms.Lambda,
  163. legacy_transforms.Lambda,
  164. [
  165. NotScriptableArgsKwargs(lambda image: image / 2),
  166. ],
  167. # Technically, this also supports PIL, but it is overkill to write a function here that supports tensor and PIL
  168. # images given that the transform does nothing but call it anyway.
  169. supports_pil=False,
  170. ),
  171. ConsistencyConfig(
  172. v2_transforms.RandomEqualize,
  173. legacy_transforms.RandomEqualize,
  174. [
  175. ArgsKwargs(p=0),
  176. ArgsKwargs(p=1),
  177. ],
  178. make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, dtypes=[torch.uint8]),
  179. ),
  180. ConsistencyConfig(
  181. v2_transforms.RandomInvert,
  182. legacy_transforms.RandomInvert,
  183. [
  184. ArgsKwargs(p=0),
  185. ArgsKwargs(p=1),
  186. ],
  187. ),
  188. ConsistencyConfig(
  189. v2_transforms.RandomPosterize,
  190. legacy_transforms.RandomPosterize,
  191. [
  192. ArgsKwargs(p=0, bits=5),
  193. ArgsKwargs(p=1, bits=1),
  194. ArgsKwargs(p=1, bits=3),
  195. ],
  196. make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, dtypes=[torch.uint8]),
  197. ),
  198. ConsistencyConfig(
  199. v2_transforms.RandomSolarize,
  200. legacy_transforms.RandomSolarize,
  201. [
  202. ArgsKwargs(p=0, threshold=0.5),
  203. ArgsKwargs(p=1, threshold=0.3),
  204. ArgsKwargs(p=1, threshold=0.99),
  205. ],
  206. ),
  207. *[
  208. ConsistencyConfig(
  209. v2_transforms.RandomAutocontrast,
  210. legacy_transforms.RandomAutocontrast,
  211. [
  212. ArgsKwargs(p=0),
  213. ArgsKwargs(p=1),
  214. ],
  215. make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, dtypes=[dt]),
  216. closeness_kwargs=ckw,
  217. )
  218. for dt, ckw in [(torch.uint8, dict(atol=1, rtol=0)), (torch.float32, dict(rtol=None, atol=None))]
  219. ],
  220. ConsistencyConfig(
  221. v2_transforms.RandomAdjustSharpness,
  222. legacy_transforms.RandomAdjustSharpness,
  223. [
  224. ArgsKwargs(p=0, sharpness_factor=0.5),
  225. ArgsKwargs(p=1, sharpness_factor=0.2),
  226. ArgsKwargs(p=1, sharpness_factor=0.99),
  227. ],
  228. closeness_kwargs={"atol": 1e-6, "rtol": 1e-6},
  229. ),
  230. ConsistencyConfig(
  231. v2_transforms.RandomGrayscale,
  232. legacy_transforms.RandomGrayscale,
  233. [
  234. ArgsKwargs(p=0),
  235. ArgsKwargs(p=1),
  236. ],
  237. make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, color_spaces=["RGB", "GRAY"]),
  238. # Use default tolerances of `torch.testing.assert_close`
  239. closeness_kwargs=dict(rtol=None, atol=None),
  240. ),
  241. ConsistencyConfig(
  242. v2_transforms.RandomResizedCrop,
  243. legacy_transforms.RandomResizedCrop,
  244. [
  245. ArgsKwargs(16),
  246. ArgsKwargs(17, scale=(0.3, 0.7)),
  247. ArgsKwargs(25, ratio=(0.5, 1.5)),
  248. ArgsKwargs((31, 28), interpolation=v2_transforms.InterpolationMode.NEAREST),
  249. ArgsKwargs((31, 28), interpolation=PIL.Image.NEAREST),
  250. ArgsKwargs((29, 32), antialias=False),
  251. ArgsKwargs((28, 31), antialias=True),
  252. ],
  253. # atol=1 due to Resize v2 is using native uint8 interpolate path for bilinear and nearest modes
  254. closeness_kwargs=dict(rtol=0, atol=1),
  255. ),
  256. ConsistencyConfig(
  257. v2_transforms.RandomResizedCrop,
  258. legacy_transforms.RandomResizedCrop,
  259. [
  260. ArgsKwargs((33, 26), interpolation=v2_transforms.InterpolationMode.BICUBIC, antialias=True),
  261. ArgsKwargs((33, 26), interpolation=PIL.Image.BICUBIC, antialias=True),
  262. ],
  263. closeness_kwargs=dict(rtol=0, atol=21),
  264. ),
  265. ConsistencyConfig(
  266. v2_transforms.ColorJitter,
  267. legacy_transforms.ColorJitter,
  268. [
  269. ArgsKwargs(),
  270. ArgsKwargs(brightness=0.1),
  271. ArgsKwargs(brightness=(0.2, 0.3)),
  272. ArgsKwargs(contrast=0.4),
  273. ArgsKwargs(contrast=(0.5, 0.6)),
  274. ArgsKwargs(saturation=0.7),
  275. ArgsKwargs(saturation=(0.8, 0.9)),
  276. ArgsKwargs(hue=0.3),
  277. ArgsKwargs(hue=(-0.1, 0.2)),
  278. ArgsKwargs(brightness=0.1, contrast=0.4, saturation=0.5, hue=0.3),
  279. ],
  280. closeness_kwargs={"atol": 1e-5, "rtol": 1e-5},
  281. ),
  282. ConsistencyConfig(
  283. v2_transforms.GaussianBlur,
  284. legacy_transforms.GaussianBlur,
  285. [
  286. ArgsKwargs(kernel_size=3),
  287. ArgsKwargs(kernel_size=(1, 5)),
  288. ArgsKwargs(kernel_size=3, sigma=0.7),
  289. ArgsKwargs(kernel_size=5, sigma=(0.3, 1.4)),
  290. ],
  291. closeness_kwargs={"rtol": 1e-5, "atol": 1e-5},
  292. ),
  293. ConsistencyConfig(
  294. v2_transforms.RandomPerspective,
  295. legacy_transforms.RandomPerspective,
  296. [
  297. ArgsKwargs(p=0),
  298. ArgsKwargs(p=1),
  299. ArgsKwargs(p=1, distortion_scale=0.3),
  300. ArgsKwargs(p=1, distortion_scale=0.2, interpolation=v2_transforms.InterpolationMode.NEAREST),
  301. ArgsKwargs(p=1, distortion_scale=0.2, interpolation=PIL.Image.NEAREST),
  302. ArgsKwargs(p=1, distortion_scale=0.1, fill=1),
  303. ArgsKwargs(p=1, distortion_scale=0.4, fill=(1, 2, 3)),
  304. ],
  305. closeness_kwargs={"atol": None, "rtol": None},
  306. ),
  307. ConsistencyConfig(
  308. v2_transforms.PILToTensor,
  309. legacy_transforms.PILToTensor,
  310. ),
  311. ConsistencyConfig(
  312. v2_transforms.ToTensor,
  313. legacy_transforms.ToTensor,
  314. ),
  315. ConsistencyConfig(
  316. v2_transforms.Compose,
  317. legacy_transforms.Compose,
  318. ),
  319. ConsistencyConfig(
  320. v2_transforms.RandomApply,
  321. legacy_transforms.RandomApply,
  322. ),
  323. ConsistencyConfig(
  324. v2_transforms.RandomChoice,
  325. legacy_transforms.RandomChoice,
  326. ),
  327. ConsistencyConfig(
  328. v2_transforms.RandomOrder,
  329. legacy_transforms.RandomOrder,
  330. ),
  331. ConsistencyConfig(
  332. v2_transforms.AugMix,
  333. legacy_transforms.AugMix,
  334. ),
  335. ConsistencyConfig(
  336. v2_transforms.AutoAugment,
  337. legacy_transforms.AutoAugment,
  338. ),
  339. ConsistencyConfig(
  340. v2_transforms.RandAugment,
  341. legacy_transforms.RandAugment,
  342. ),
  343. ConsistencyConfig(
  344. v2_transforms.TrivialAugmentWide,
  345. legacy_transforms.TrivialAugmentWide,
  346. ),
  347. ]
  348. @pytest.mark.parametrize("config", CONSISTENCY_CONFIGS, ids=lambda config: config.legacy_cls.__name__)
  349. def test_signature_consistency(config):
  350. legacy_params = dict(inspect.signature(config.legacy_cls).parameters)
  351. prototype_params = dict(inspect.signature(config.prototype_cls).parameters)
  352. for param in config.removed_params:
  353. legacy_params.pop(param, None)
  354. missing = legacy_params.keys() - prototype_params.keys()
  355. if missing:
  356. raise AssertionError(
  357. f"The prototype transform does not support the parameters "
  358. f"{sequence_to_str(sorted(missing), separate_last='and ')}, but the legacy transform does. "
  359. f"If that is intentional, e.g. pending deprecation, please add the parameters to the `removed_params` on "
  360. f"the `ConsistencyConfig`."
  361. )
  362. extra = prototype_params.keys() - legacy_params.keys()
  363. extra_without_default = {
  364. param
  365. for param in extra
  366. if prototype_params[param].default is inspect.Parameter.empty
  367. and prototype_params[param].kind not in {inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD}
  368. }
  369. if extra_without_default:
  370. raise AssertionError(
  371. f"The prototype transform requires the parameters "
  372. f"{sequence_to_str(sorted(extra_without_default), separate_last='and ')}, but the legacy transform does "
  373. f"not. Please add a default value."
  374. )
  375. legacy_signature = list(legacy_params.keys())
  376. # Since we made sure that we don't have any extra parameters without default above, we clamp the prototype signature
  377. # to the same number of parameters as the legacy one
  378. prototype_signature = list(prototype_params.keys())[: len(legacy_signature)]
  379. assert prototype_signature == legacy_signature
  380. def check_call_consistency(
  381. prototype_transform, legacy_transform, images=None, supports_pil=True, closeness_kwargs=None
  382. ):
  383. if images is None:
  384. images = make_images(**DEFAULT_MAKE_IMAGES_KWARGS)
  385. closeness_kwargs = closeness_kwargs or dict()
  386. for image in images:
  387. image_repr = f"[{tuple(image.shape)}, {str(image.dtype).rsplit('.')[-1]}]"
  388. image_tensor = torch.Tensor(image)
  389. try:
  390. torch.manual_seed(0)
  391. output_legacy_tensor = legacy_transform(image_tensor)
  392. except Exception as exc:
  393. raise pytest.UsageError(
  394. f"Transforming a tensor image {image_repr} failed in the legacy transform with the "
  395. f"error above. This means that you need to specify the parameters passed to `make_images` through the "
  396. "`make_images_kwargs` of the `ConsistencyConfig`."
  397. ) from exc
  398. try:
  399. torch.manual_seed(0)
  400. output_prototype_tensor = prototype_transform(image_tensor)
  401. except Exception as exc:
  402. raise AssertionError(
  403. f"Transforming a tensor image with shape {image_repr} failed in the prototype transform with "
  404. f"the error above. This means there is a consistency bug either in `_get_params` or in the "
  405. f"`is_pure_tensor` path in `_transform`."
  406. ) from exc
  407. assert_close(
  408. output_prototype_tensor,
  409. output_legacy_tensor,
  410. msg=lambda msg: f"Tensor image consistency check failed with: \n\n{msg}",
  411. **closeness_kwargs,
  412. )
  413. try:
  414. torch.manual_seed(0)
  415. output_prototype_image = prototype_transform(image)
  416. except Exception as exc:
  417. raise AssertionError(
  418. f"Transforming a image tv_tensor with shape {image_repr} failed in the prototype transform with "
  419. f"the error above. This means there is a consistency bug either in `_get_params` or in the "
  420. f"`tv_tensors.Image` path in `_transform`."
  421. ) from exc
  422. assert_close(
  423. output_prototype_image,
  424. output_prototype_tensor,
  425. msg=lambda msg: f"Output for tv_tensor and tensor images is not equal: \n\n{msg}",
  426. **closeness_kwargs,
  427. )
  428. if image.ndim == 3 and supports_pil:
  429. image_pil = to_pil_image(image)
  430. try:
  431. torch.manual_seed(0)
  432. output_legacy_pil = legacy_transform(image_pil)
  433. except Exception as exc:
  434. raise pytest.UsageError(
  435. f"Transforming a PIL image with shape {image_repr} failed in the legacy transform with the "
  436. f"error above. If this transform does not support PIL images, set `supports_pil=False` on the "
  437. "`ConsistencyConfig`. "
  438. ) from exc
  439. try:
  440. torch.manual_seed(0)
  441. output_prototype_pil = prototype_transform(image_pil)
  442. except Exception as exc:
  443. raise AssertionError(
  444. f"Transforming a PIL image with shape {image_repr} failed in the prototype transform with "
  445. f"the error above. This means there is a consistency bug either in `_get_params` or in the "
  446. f"`PIL.Image.Image` path in `_transform`."
  447. ) from exc
  448. assert_close(
  449. output_prototype_pil,
  450. output_legacy_pil,
  451. msg=lambda msg: f"PIL image consistency check failed with: \n\n{msg}",
  452. **closeness_kwargs,
  453. )
  454. @pytest.mark.parametrize(
  455. ("config", "args_kwargs"),
  456. [
  457. pytest.param(
  458. config, args_kwargs, id=f"{config.legacy_cls.__name__}-{idx:0{len(str(len(config.args_kwargs)))}d}"
  459. )
  460. for config in CONSISTENCY_CONFIGS
  461. for idx, args_kwargs in enumerate(config.args_kwargs)
  462. ],
  463. )
  464. @pytest.mark.filterwarnings("ignore")
  465. def test_call_consistency(config, args_kwargs):
  466. args, kwargs = args_kwargs
  467. try:
  468. legacy_transform = config.legacy_cls(*args, **kwargs)
  469. except Exception as exc:
  470. raise pytest.UsageError(
  471. f"Initializing the legacy transform failed with the error above. "
  472. f"Please correct the `ArgsKwargs({args_kwargs})` in the `ConsistencyConfig`."
  473. ) from exc
  474. try:
  475. prototype_transform = config.prototype_cls(*args, **kwargs)
  476. except Exception as exc:
  477. raise AssertionError(
  478. "Initializing the prototype transform failed with the error above. "
  479. "This means there is a consistency bug in the constructor."
  480. ) from exc
  481. check_call_consistency(
  482. prototype_transform,
  483. legacy_transform,
  484. images=make_images(**config.make_images_kwargs),
  485. supports_pil=config.supports_pil,
  486. closeness_kwargs=config.closeness_kwargs,
  487. )
  488. get_params_parametrization = pytest.mark.parametrize(
  489. ("config", "get_params_args_kwargs"),
  490. [
  491. pytest.param(
  492. next(config for config in CONSISTENCY_CONFIGS if config.prototype_cls is transform_cls),
  493. get_params_args_kwargs,
  494. id=transform_cls.__name__,
  495. )
  496. for transform_cls, get_params_args_kwargs in [
  497. (v2_transforms.RandomResizedCrop, ArgsKwargs(make_image(), scale=[0.3, 0.7], ratio=[0.5, 1.5])),
  498. (v2_transforms.ColorJitter, ArgsKwargs(brightness=None, contrast=None, saturation=None, hue=None)),
  499. (v2_transforms.GaussianBlur, ArgsKwargs(0.3, 1.4)),
  500. (v2_transforms.RandomPerspective, ArgsKwargs(23, 17, 0.5)),
  501. (v2_transforms.AutoAugment, ArgsKwargs(5)),
  502. ]
  503. ],
  504. )
  505. @get_params_parametrization
  506. def test_get_params_alias(config, get_params_args_kwargs):
  507. assert config.prototype_cls.get_params is config.legacy_cls.get_params
  508. if not config.args_kwargs:
  509. return
  510. args, kwargs = config.args_kwargs[0]
  511. legacy_transform = config.legacy_cls(*args, **kwargs)
  512. prototype_transform = config.prototype_cls(*args, **kwargs)
  513. assert prototype_transform.get_params is legacy_transform.get_params
  514. @get_params_parametrization
  515. def test_get_params_jit(config, get_params_args_kwargs):
  516. get_params_args, get_params_kwargs = get_params_args_kwargs
  517. torch.jit.script(config.prototype_cls.get_params)(*get_params_args, **get_params_kwargs)
  518. if not config.args_kwargs:
  519. return
  520. args, kwargs = config.args_kwargs[0]
  521. transform = config.prototype_cls(*args, **kwargs)
  522. torch.jit.script(transform.get_params)(*get_params_args, **get_params_kwargs)
  523. @pytest.mark.parametrize(
  524. ("config", "args_kwargs"),
  525. [
  526. pytest.param(
  527. config, args_kwargs, id=f"{config.legacy_cls.__name__}-{idx:0{len(str(len(config.args_kwargs)))}d}"
  528. )
  529. for config in CONSISTENCY_CONFIGS
  530. for idx, args_kwargs in enumerate(config.args_kwargs)
  531. if not isinstance(args_kwargs, NotScriptableArgsKwargs)
  532. ],
  533. )
  534. def test_jit_consistency(config, args_kwargs):
  535. args, kwargs = args_kwargs
  536. prototype_transform_eager = config.prototype_cls(*args, **kwargs)
  537. legacy_transform_eager = config.legacy_cls(*args, **kwargs)
  538. legacy_transform_scripted = torch.jit.script(legacy_transform_eager)
  539. prototype_transform_scripted = torch.jit.script(prototype_transform_eager)
  540. for image in make_images(**config.make_images_kwargs):
  541. image = image.as_subclass(torch.Tensor)
  542. torch.manual_seed(0)
  543. output_legacy_scripted = legacy_transform_scripted(image)
  544. torch.manual_seed(0)
  545. output_prototype_scripted = prototype_transform_scripted(image)
  546. assert_close(output_prototype_scripted, output_legacy_scripted, **config.closeness_kwargs)
  547. class TestContainerTransforms:
  548. """
  549. Since we are testing containers here, we also need some transforms to wrap. Thus, testing a container transform for
  550. consistency automatically tests the wrapped transforms consistency.
  551. Instead of complicated mocking or creating custom transforms just for these tests, here we use deterministic ones
  552. that were already tested for consistency above.
  553. """
  554. def test_compose(self):
  555. prototype_transform = v2_transforms.Compose(
  556. [
  557. v2_transforms.Resize(256),
  558. v2_transforms.CenterCrop(224),
  559. ]
  560. )
  561. legacy_transform = legacy_transforms.Compose(
  562. [
  563. legacy_transforms.Resize(256),
  564. legacy_transforms.CenterCrop(224),
  565. ]
  566. )
  567. # atol=1 due to Resize v2 is using native uint8 interpolate path for bilinear and nearest modes
  568. check_call_consistency(prototype_transform, legacy_transform, closeness_kwargs=dict(rtol=0, atol=1))
  569. @pytest.mark.parametrize("p", [0, 0.1, 0.5, 0.9, 1])
  570. @pytest.mark.parametrize("sequence_type", [list, nn.ModuleList])
  571. def test_random_apply(self, p, sequence_type):
  572. prototype_transform = v2_transforms.RandomApply(
  573. sequence_type(
  574. [
  575. v2_transforms.Resize(256),
  576. v2_transforms.CenterCrop(224),
  577. ]
  578. ),
  579. p=p,
  580. )
  581. legacy_transform = legacy_transforms.RandomApply(
  582. sequence_type(
  583. [
  584. legacy_transforms.Resize(256),
  585. legacy_transforms.CenterCrop(224),
  586. ]
  587. ),
  588. p=p,
  589. )
  590. # atol=1 due to Resize v2 is using native uint8 interpolate path for bilinear and nearest modes
  591. check_call_consistency(prototype_transform, legacy_transform, closeness_kwargs=dict(rtol=0, atol=1))
  592. if sequence_type is nn.ModuleList:
  593. # quick and dirty test that it is jit-scriptable
  594. scripted = torch.jit.script(prototype_transform)
  595. scripted(torch.rand(1, 3, 300, 300))
  596. # We can't test other values for `p` since the random parameter generation is different
  597. @pytest.mark.parametrize("probabilities", [(0, 1), (1, 0)])
  598. def test_random_choice(self, probabilities):
  599. prototype_transform = v2_transforms.RandomChoice(
  600. [
  601. v2_transforms.Resize(256),
  602. legacy_transforms.CenterCrop(224),
  603. ],
  604. p=probabilities,
  605. )
  606. legacy_transform = legacy_transforms.RandomChoice(
  607. [
  608. legacy_transforms.Resize(256),
  609. legacy_transforms.CenterCrop(224),
  610. ],
  611. p=probabilities,
  612. )
  613. # atol=1 due to Resize v2 is using native uint8 interpolate path for bilinear and nearest modes
  614. check_call_consistency(prototype_transform, legacy_transform, closeness_kwargs=dict(rtol=0, atol=1))
  615. class TestToTensorTransforms:
  616. def test_pil_to_tensor(self):
  617. prototype_transform = v2_transforms.PILToTensor()
  618. legacy_transform = legacy_transforms.PILToTensor()
  619. for image in make_images(extra_dims=[()]):
  620. image_pil = to_pil_image(image)
  621. assert_equal(prototype_transform(image_pil), legacy_transform(image_pil))
  622. def test_to_tensor(self):
  623. with pytest.warns(UserWarning, match=re.escape("The transform `ToTensor()` is deprecated")):
  624. prototype_transform = v2_transforms.ToTensor()
  625. legacy_transform = legacy_transforms.ToTensor()
  626. for image in make_images(extra_dims=[()]):
  627. image_pil = to_pil_image(image)
  628. image_numpy = np.array(image_pil)
  629. assert_equal(prototype_transform(image_pil), legacy_transform(image_pil))
  630. assert_equal(prototype_transform(image_numpy), legacy_transform(image_numpy))
  631. class TestAATransforms:
  632. @pytest.mark.parametrize(
  633. "inpt",
  634. [
  635. torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8),
  636. PIL.Image.new("RGB", (256, 256), 123),
  637. tv_tensors.Image(torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)),
  638. ],
  639. )
  640. @pytest.mark.parametrize(
  641. "interpolation",
  642. [
  643. v2_transforms.InterpolationMode.NEAREST,
  644. v2_transforms.InterpolationMode.BILINEAR,
  645. PIL.Image.NEAREST,
  646. ],
  647. )
  648. def test_randaug(self, inpt, interpolation, mocker):
  649. t_ref = legacy_transforms.RandAugment(interpolation=interpolation, num_ops=1)
  650. t = v2_transforms.RandAugment(interpolation=interpolation, num_ops=1)
  651. le = len(t._AUGMENTATION_SPACE)
  652. keys = list(t._AUGMENTATION_SPACE.keys())
  653. randint_values = []
  654. for i in range(le):
  655. # Stable API, op_index random call
  656. randint_values.append(i)
  657. # Stable API, if signed there is another random call
  658. if t._AUGMENTATION_SPACE[keys[i]][1]:
  659. randint_values.append(0)
  660. # New API, _get_random_item
  661. randint_values.append(i)
  662. randint_values = iter(randint_values)
  663. mocker.patch("torch.randint", side_effect=lambda *arg, **kwargs: torch.tensor(next(randint_values)))
  664. mocker.patch("torch.rand", return_value=1.0)
  665. for i in range(le):
  666. expected_output = t_ref(inpt)
  667. output = t(inpt)
  668. assert_close(expected_output, output, atol=1, rtol=0.1)
  669. @pytest.mark.parametrize(
  670. "interpolation",
  671. [
  672. v2_transforms.InterpolationMode.NEAREST,
  673. v2_transforms.InterpolationMode.BILINEAR,
  674. ],
  675. )
  676. @pytest.mark.parametrize("fill", [None, 85, (10, -10, 10), 0.7, [0.0, 0.0, 0.0], [1], 1])
  677. def test_randaug_jit(self, interpolation, fill):
  678. inpt = torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)
  679. t_ref = legacy_transforms.RandAugment(interpolation=interpolation, num_ops=1, fill=fill)
  680. t = v2_transforms.RandAugment(interpolation=interpolation, num_ops=1, fill=fill)
  681. tt_ref = torch.jit.script(t_ref)
  682. tt = torch.jit.script(t)
  683. torch.manual_seed(12)
  684. expected_output = tt_ref(inpt)
  685. torch.manual_seed(12)
  686. scripted_output = tt(inpt)
  687. assert_equal(scripted_output, expected_output)
  688. @pytest.mark.parametrize(
  689. "inpt",
  690. [
  691. torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8),
  692. PIL.Image.new("RGB", (256, 256), 123),
  693. tv_tensors.Image(torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)),
  694. ],
  695. )
  696. @pytest.mark.parametrize(
  697. "interpolation",
  698. [
  699. v2_transforms.InterpolationMode.NEAREST,
  700. v2_transforms.InterpolationMode.BILINEAR,
  701. PIL.Image.NEAREST,
  702. ],
  703. )
  704. def test_trivial_aug(self, inpt, interpolation, mocker):
  705. t_ref = legacy_transforms.TrivialAugmentWide(interpolation=interpolation)
  706. t = v2_transforms.TrivialAugmentWide(interpolation=interpolation)
  707. le = len(t._AUGMENTATION_SPACE)
  708. keys = list(t._AUGMENTATION_SPACE.keys())
  709. randint_values = []
  710. for i in range(le):
  711. # Stable API, op_index random call
  712. randint_values.append(i)
  713. key = keys[i]
  714. # Stable API, random magnitude
  715. aug_op = t._AUGMENTATION_SPACE[key]
  716. magnitudes = aug_op[0](2, 0, 0)
  717. if magnitudes is not None:
  718. randint_values.append(5)
  719. # Stable API, if signed there is another random call
  720. if aug_op[1]:
  721. randint_values.append(0)
  722. # New API, _get_random_item
  723. randint_values.append(i)
  724. # New API, random magnitude
  725. if magnitudes is not None:
  726. randint_values.append(5)
  727. randint_values = iter(randint_values)
  728. mocker.patch("torch.randint", side_effect=lambda *arg, **kwargs: torch.tensor(next(randint_values)))
  729. mocker.patch("torch.rand", return_value=1.0)
  730. for _ in range(le):
  731. expected_output = t_ref(inpt)
  732. output = t(inpt)
  733. assert_close(expected_output, output, atol=1, rtol=0.1)
  734. @pytest.mark.parametrize(
  735. "interpolation",
  736. [
  737. v2_transforms.InterpolationMode.NEAREST,
  738. v2_transforms.InterpolationMode.BILINEAR,
  739. ],
  740. )
  741. @pytest.mark.parametrize("fill", [None, 85, (10, -10, 10), 0.7, [0.0, 0.0, 0.0], [1], 1])
  742. def test_trivial_aug_jit(self, interpolation, fill):
  743. inpt = torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)
  744. t_ref = legacy_transforms.TrivialAugmentWide(interpolation=interpolation, fill=fill)
  745. t = v2_transforms.TrivialAugmentWide(interpolation=interpolation, fill=fill)
  746. tt_ref = torch.jit.script(t_ref)
  747. tt = torch.jit.script(t)
  748. torch.manual_seed(12)
  749. expected_output = tt_ref(inpt)
  750. torch.manual_seed(12)
  751. scripted_output = tt(inpt)
  752. assert_equal(scripted_output, expected_output)
  753. @pytest.mark.parametrize(
  754. "inpt",
  755. [
  756. torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8),
  757. PIL.Image.new("RGB", (256, 256), 123),
  758. tv_tensors.Image(torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)),
  759. ],
  760. )
  761. @pytest.mark.parametrize(
  762. "interpolation",
  763. [
  764. v2_transforms.InterpolationMode.NEAREST,
  765. v2_transforms.InterpolationMode.BILINEAR,
  766. PIL.Image.NEAREST,
  767. ],
  768. )
  769. def test_augmix(self, inpt, interpolation, mocker):
  770. t_ref = legacy_transforms.AugMix(interpolation=interpolation, mixture_width=1, chain_depth=1)
  771. t_ref._sample_dirichlet = lambda t: t.softmax(dim=-1)
  772. t = v2_transforms.AugMix(interpolation=interpolation, mixture_width=1, chain_depth=1)
  773. t._sample_dirichlet = lambda t: t.softmax(dim=-1)
  774. le = len(t._AUGMENTATION_SPACE)
  775. keys = list(t._AUGMENTATION_SPACE.keys())
  776. randint_values = []
  777. for i in range(le):
  778. # Stable API, op_index random call
  779. randint_values.append(i)
  780. key = keys[i]
  781. # Stable API, random magnitude
  782. aug_op = t._AUGMENTATION_SPACE[key]
  783. magnitudes = aug_op[0](2, 0, 0)
  784. if magnitudes is not None:
  785. randint_values.append(5)
  786. # Stable API, if signed there is another random call
  787. if aug_op[1]:
  788. randint_values.append(0)
  789. # New API, _get_random_item
  790. randint_values.append(i)
  791. # New API, random magnitude
  792. if magnitudes is not None:
  793. randint_values.append(5)
  794. randint_values = iter(randint_values)
  795. mocker.patch("torch.randint", side_effect=lambda *arg, **kwargs: torch.tensor(next(randint_values)))
  796. mocker.patch("torch.rand", return_value=1.0)
  797. expected_output = t_ref(inpt)
  798. output = t(inpt)
  799. assert_equal(expected_output, output)
  800. @pytest.mark.parametrize(
  801. "interpolation",
  802. [
  803. v2_transforms.InterpolationMode.NEAREST,
  804. v2_transforms.InterpolationMode.BILINEAR,
  805. ],
  806. )
  807. @pytest.mark.parametrize("fill", [None, 85, (10, -10, 10), 0.7, [0.0, 0.0, 0.0], [1], 1])
  808. def test_augmix_jit(self, interpolation, fill):
  809. inpt = torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)
  810. t_ref = legacy_transforms.AugMix(interpolation=interpolation, mixture_width=1, chain_depth=1, fill=fill)
  811. t = v2_transforms.AugMix(interpolation=interpolation, mixture_width=1, chain_depth=1, fill=fill)
  812. tt_ref = torch.jit.script(t_ref)
  813. tt = torch.jit.script(t)
  814. torch.manual_seed(12)
  815. expected_output = tt_ref(inpt)
  816. torch.manual_seed(12)
  817. scripted_output = tt(inpt)
  818. assert_equal(scripted_output, expected_output)
  819. @pytest.mark.parametrize(
  820. "inpt",
  821. [
  822. torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8),
  823. PIL.Image.new("RGB", (256, 256), 123),
  824. tv_tensors.Image(torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)),
  825. ],
  826. )
  827. @pytest.mark.parametrize(
  828. "interpolation",
  829. [
  830. v2_transforms.InterpolationMode.NEAREST,
  831. v2_transforms.InterpolationMode.BILINEAR,
  832. PIL.Image.NEAREST,
  833. ],
  834. )
  835. def test_aa(self, inpt, interpolation):
  836. aa_policy = legacy_transforms.AutoAugmentPolicy("imagenet")
  837. t_ref = legacy_transforms.AutoAugment(aa_policy, interpolation=interpolation)
  838. t = v2_transforms.AutoAugment(aa_policy, interpolation=interpolation)
  839. torch.manual_seed(12)
  840. expected_output = t_ref(inpt)
  841. torch.manual_seed(12)
  842. output = t(inpt)
  843. assert_equal(expected_output, output)
  844. @pytest.mark.parametrize(
  845. "interpolation",
  846. [
  847. v2_transforms.InterpolationMode.NEAREST,
  848. v2_transforms.InterpolationMode.BILINEAR,
  849. ],
  850. )
  851. def test_aa_jit(self, interpolation):
  852. inpt = torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)
  853. aa_policy = legacy_transforms.AutoAugmentPolicy("imagenet")
  854. t_ref = legacy_transforms.AutoAugment(aa_policy, interpolation=interpolation)
  855. t = v2_transforms.AutoAugment(aa_policy, interpolation=interpolation)
  856. tt_ref = torch.jit.script(t_ref)
  857. tt = torch.jit.script(t)
  858. torch.manual_seed(12)
  859. expected_output = tt_ref(inpt)
  860. torch.manual_seed(12)
  861. scripted_output = tt(inpt)
  862. assert_equal(scripted_output, expected_output)
  863. def import_transforms_from_references(reference):
  864. HERE = Path(__file__).parent
  865. PROJECT_ROOT = HERE.parent
  866. loader = importlib.machinery.SourceFileLoader(
  867. "transforms", str(PROJECT_ROOT / "references" / reference / "transforms.py")
  868. )
  869. spec = importlib.util.spec_from_loader("transforms", loader)
  870. module = importlib.util.module_from_spec(spec)
  871. loader.exec_module(module)
  872. return module
  873. det_transforms = import_transforms_from_references("detection")
  874. class TestRefDetTransforms:
  875. def make_tv_tensors(self, with_mask=True):
  876. size = (600, 800)
  877. num_objects = 22
  878. def make_label(extra_dims, categories):
  879. return torch.randint(categories, extra_dims, dtype=torch.int64)
  880. pil_image = to_pil_image(make_image(size=size, color_space="RGB"))
  881. target = {
  882. "boxes": make_bounding_boxes(canvas_size=size, format="XYXY", batch_dims=(num_objects,), dtype=torch.float),
  883. "labels": make_label(extra_dims=(num_objects,), categories=80),
  884. }
  885. if with_mask:
  886. target["masks"] = make_detection_mask(size=size, num_objects=num_objects, dtype=torch.long)
  887. yield (pil_image, target)
  888. tensor_image = torch.Tensor(make_image(size=size, color_space="RGB", dtype=torch.float32))
  889. target = {
  890. "boxes": make_bounding_boxes(canvas_size=size, format="XYXY", batch_dims=(num_objects,), dtype=torch.float),
  891. "labels": make_label(extra_dims=(num_objects,), categories=80),
  892. }
  893. if with_mask:
  894. target["masks"] = make_detection_mask(size=size, num_objects=num_objects, dtype=torch.long)
  895. yield (tensor_image, target)
  896. tv_tensor_image = make_image(size=size, color_space="RGB", dtype=torch.float32)
  897. target = {
  898. "boxes": make_bounding_boxes(canvas_size=size, format="XYXY", batch_dims=(num_objects,), dtype=torch.float),
  899. "labels": make_label(extra_dims=(num_objects,), categories=80),
  900. }
  901. if with_mask:
  902. target["masks"] = make_detection_mask(size=size, num_objects=num_objects, dtype=torch.long)
  903. yield (tv_tensor_image, target)
  904. @pytest.mark.parametrize(
  905. "t_ref, t, data_kwargs",
  906. [
  907. (det_transforms.RandomHorizontalFlip(p=1.0), v2_transforms.RandomHorizontalFlip(p=1.0), {}),
  908. (
  909. det_transforms.RandomIoUCrop(),
  910. v2_transforms.Compose(
  911. [
  912. v2_transforms.RandomIoUCrop(),
  913. v2_transforms.SanitizeBoundingBoxes(labels_getter=lambda sample: sample[1]["labels"]),
  914. ]
  915. ),
  916. {"with_mask": False},
  917. ),
  918. (det_transforms.RandomZoomOut(), v2_transforms.RandomZoomOut(), {"with_mask": False}),
  919. (det_transforms.ScaleJitter((1024, 1024)), v2_transforms.ScaleJitter((1024, 1024), antialias=True), {}),
  920. (
  921. det_transforms.RandomShortestSize(
  922. min_size=(480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800), max_size=1333
  923. ),
  924. v2_transforms.RandomShortestSize(
  925. min_size=(480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800), max_size=1333
  926. ),
  927. {},
  928. ),
  929. ],
  930. )
  931. def test_transform(self, t_ref, t, data_kwargs):
  932. for dp in self.make_tv_tensors(**data_kwargs):
  933. # We should use prototype transform first as reference transform performs inplace target update
  934. torch.manual_seed(12)
  935. output = t(dp)
  936. torch.manual_seed(12)
  937. expected_output = t_ref(*dp)
  938. assert_equal(expected_output, output)
  939. seg_transforms = import_transforms_from_references("segmentation")
  940. # We need this transform for two reasons:
  941. # 1. transforms.RandomCrop uses a different scheme to pad images and masks of insufficient size than its name
  942. # counterpart in the detection references. Thus, we cannot use it with `pad_if_needed=True`
  943. # 2. transforms.Pad only supports a fixed padding, but the segmentation datasets don't have a fixed image size.
  944. class PadIfSmaller(v2_transforms.Transform):
  945. def __init__(self, size, fill=0):
  946. super().__init__()
  947. self.size = size
  948. self.fill = v2_transforms._geometry._setup_fill_arg(fill)
  949. def _get_params(self, sample):
  950. height, width = query_size(sample)
  951. padding = [0, 0, max(self.size - width, 0), max(self.size - height, 0)]
  952. needs_padding = any(padding)
  953. return dict(padding=padding, needs_padding=needs_padding)
  954. def _transform(self, inpt, params):
  955. if not params["needs_padding"]:
  956. return inpt
  957. fill = _get_fill(self.fill, type(inpt))
  958. return prototype_F.pad(inpt, padding=params["padding"], fill=fill)
  959. class TestRefSegTransforms:
  960. def make_tv_tensors(self, supports_pil=True, image_dtype=torch.uint8):
  961. size = (256, 460)
  962. num_categories = 21
  963. conv_fns = []
  964. if supports_pil:
  965. conv_fns.append(to_pil_image)
  966. conv_fns.extend([torch.Tensor, lambda x: x])
  967. for conv_fn in conv_fns:
  968. tv_tensor_image = make_image(size=size, color_space="RGB", dtype=image_dtype)
  969. tv_tensor_mask = make_segmentation_mask(size=size, num_categories=num_categories, dtype=torch.uint8)
  970. dp = (conv_fn(tv_tensor_image), tv_tensor_mask)
  971. dp_ref = (
  972. to_pil_image(tv_tensor_image) if supports_pil else tv_tensor_image.as_subclass(torch.Tensor),
  973. to_pil_image(tv_tensor_mask),
  974. )
  975. yield dp, dp_ref
  976. def set_seed(self, seed=12):
  977. torch.manual_seed(seed)
  978. random.seed(seed)
  979. def check(self, t, t_ref, data_kwargs=None):
  980. for dp, dp_ref in self.make_tv_tensors(**data_kwargs or dict()):
  981. self.set_seed()
  982. actual = actual_image, actual_mask = t(dp)
  983. self.set_seed()
  984. expected_image, expected_mask = t_ref(*dp_ref)
  985. if isinstance(actual_image, torch.Tensor) and not isinstance(expected_image, torch.Tensor):
  986. expected_image = legacy_F.pil_to_tensor(expected_image)
  987. expected_mask = legacy_F.pil_to_tensor(expected_mask).squeeze(0)
  988. expected = (expected_image, expected_mask)
  989. assert_equal(actual, expected)
  990. @pytest.mark.parametrize(
  991. ("t_ref", "t", "data_kwargs"),
  992. [
  993. (
  994. seg_transforms.RandomHorizontalFlip(flip_prob=1.0),
  995. v2_transforms.RandomHorizontalFlip(p=1.0),
  996. dict(),
  997. ),
  998. (
  999. seg_transforms.RandomHorizontalFlip(flip_prob=0.0),
  1000. v2_transforms.RandomHorizontalFlip(p=0.0),
  1001. dict(),
  1002. ),
  1003. (
  1004. seg_transforms.RandomCrop(size=480),
  1005. v2_transforms.Compose(
  1006. [
  1007. PadIfSmaller(size=480, fill={tv_tensors.Mask: 255, "others": 0}),
  1008. v2_transforms.RandomCrop(size=480),
  1009. ]
  1010. ),
  1011. dict(),
  1012. ),
  1013. (
  1014. seg_transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
  1015. v2_transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
  1016. dict(supports_pil=False, image_dtype=torch.float),
  1017. ),
  1018. ],
  1019. )
  1020. def test_common(self, t_ref, t, data_kwargs):
  1021. self.check(t, t_ref, data_kwargs)
  1022. @pytest.mark.parametrize(
  1023. ("legacy_dispatcher", "name_only_params"),
  1024. [
  1025. (legacy_F.get_dimensions, {}),
  1026. (legacy_F.get_image_size, {}),
  1027. (legacy_F.get_image_num_channels, {}),
  1028. (legacy_F.to_tensor, {}),
  1029. (legacy_F.pil_to_tensor, {}),
  1030. (legacy_F.convert_image_dtype, {}),
  1031. (legacy_F.to_pil_image, {}),
  1032. (legacy_F.normalize, {}),
  1033. (legacy_F.resize, {"interpolation"}),
  1034. (legacy_F.pad, {"padding", "fill"}),
  1035. (legacy_F.crop, {}),
  1036. (legacy_F.center_crop, {}),
  1037. (legacy_F.resized_crop, {"interpolation"}),
  1038. (legacy_F.hflip, {}),
  1039. (legacy_F.perspective, {"startpoints", "endpoints", "fill", "interpolation"}),
  1040. (legacy_F.vflip, {}),
  1041. (legacy_F.five_crop, {}),
  1042. (legacy_F.ten_crop, {}),
  1043. (legacy_F.adjust_brightness, {}),
  1044. (legacy_F.adjust_contrast, {}),
  1045. (legacy_F.adjust_saturation, {}),
  1046. (legacy_F.adjust_hue, {}),
  1047. (legacy_F.adjust_gamma, {}),
  1048. (legacy_F.rotate, {"center", "fill", "interpolation"}),
  1049. (legacy_F.affine, {"angle", "translate", "center", "fill", "interpolation"}),
  1050. (legacy_F.to_grayscale, {}),
  1051. (legacy_F.rgb_to_grayscale, {}),
  1052. (legacy_F.to_tensor, {}),
  1053. (legacy_F.erase, {}),
  1054. (legacy_F.gaussian_blur, {}),
  1055. (legacy_F.invert, {}),
  1056. (legacy_F.posterize, {}),
  1057. (legacy_F.solarize, {}),
  1058. (legacy_F.adjust_sharpness, {}),
  1059. (legacy_F.autocontrast, {}),
  1060. (legacy_F.equalize, {}),
  1061. (legacy_F.elastic_transform, {"fill", "interpolation"}),
  1062. ],
  1063. )
  1064. def test_dispatcher_signature_consistency(legacy_dispatcher, name_only_params):
  1065. legacy_signature = inspect.signature(legacy_dispatcher)
  1066. legacy_params = list(legacy_signature.parameters.values())[1:]
  1067. try:
  1068. prototype_dispatcher = getattr(prototype_F, legacy_dispatcher.__name__)
  1069. except AttributeError:
  1070. raise AssertionError(
  1071. f"Legacy dispatcher `F.{legacy_dispatcher.__name__}` has no prototype equivalent"
  1072. ) from None
  1073. prototype_signature = inspect.signature(prototype_dispatcher)
  1074. prototype_params = list(prototype_signature.parameters.values())[1:]
  1075. # Some dispatchers got extra parameters. This makes sure they have a default argument and thus are BC. We don't
  1076. # need to check if parameters were added in the middle rather than at the end, since that will be caught by the
  1077. # regular check below.
  1078. prototype_params, new_prototype_params = (
  1079. prototype_params[: len(legacy_params)],
  1080. prototype_params[len(legacy_params) :],
  1081. )
  1082. for param in new_prototype_params:
  1083. assert param.default is not param.empty
  1084. # Some annotations were changed mostly to supersets of what was there before. Plus, some legacy dispatchers had no
  1085. # annotations. In these cases we simply drop the annotation and default argument from the comparison
  1086. for prototype_param, legacy_param in zip(prototype_params, legacy_params):
  1087. if legacy_param.name in name_only_params:
  1088. prototype_param._annotation = prototype_param._default = inspect.Parameter.empty
  1089. legacy_param._annotation = legacy_param._default = inspect.Parameter.empty
  1090. elif legacy_param.annotation is inspect.Parameter.empty:
  1091. prototype_param._annotation = inspect.Parameter.empty
  1092. assert prototype_params == legacy_params