test_transforms_tensor.py 35 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909
  1. import os
  2. import sys
  3. import warnings
  4. import numpy as np
  5. import PIL.Image
  6. import pytest
  7. import torch
  8. from common_utils import (
  9. _assert_approx_equal_tensor_to_pil,
  10. _assert_equal_tensor_to_pil,
  11. _create_data,
  12. _create_data_batch,
  13. assert_equal,
  14. cpu_and_cuda,
  15. float_dtypes,
  16. get_tmp_dir,
  17. int_dtypes,
  18. )
  19. from torchvision import transforms as T
  20. from torchvision.transforms import functional as F, InterpolationMode
  21. from torchvision.transforms.autoaugment import _apply_op
  22. NEAREST, NEAREST_EXACT, BILINEAR, BICUBIC = (
  23. InterpolationMode.NEAREST,
  24. InterpolationMode.NEAREST_EXACT,
  25. InterpolationMode.BILINEAR,
  26. InterpolationMode.BICUBIC,
  27. )
  28. def _test_transform_vs_scripted(transform, s_transform, tensor, msg=None):
  29. torch.manual_seed(12)
  30. out1 = transform(tensor)
  31. torch.manual_seed(12)
  32. out2 = s_transform(tensor)
  33. assert_equal(out1, out2, msg=msg)
  34. def _test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors, msg=None):
  35. torch.manual_seed(12)
  36. transformed_batch = transform(batch_tensors)
  37. for i in range(len(batch_tensors)):
  38. img_tensor = batch_tensors[i, ...]
  39. torch.manual_seed(12)
  40. transformed_img = transform(img_tensor)
  41. assert_equal(transformed_img, transformed_batch[i, ...], msg=msg)
  42. torch.manual_seed(12)
  43. s_transformed_batch = s_transform(batch_tensors)
  44. assert_equal(transformed_batch, s_transformed_batch, msg=msg)
  45. def _test_functional_op(f, device, channels=3, fn_kwargs=None, test_exact_match=True, **match_kwargs):
  46. fn_kwargs = fn_kwargs or {}
  47. tensor, pil_img = _create_data(height=10, width=10, channels=channels, device=device)
  48. transformed_tensor = f(tensor, **fn_kwargs)
  49. transformed_pil_img = f(pil_img, **fn_kwargs)
  50. if test_exact_match:
  51. _assert_equal_tensor_to_pil(transformed_tensor, transformed_pil_img, **match_kwargs)
  52. else:
  53. _assert_approx_equal_tensor_to_pil(transformed_tensor, transformed_pil_img, **match_kwargs)
  54. def _test_class_op(transform_cls, device, channels=3, meth_kwargs=None, test_exact_match=True, **match_kwargs):
  55. meth_kwargs = meth_kwargs or {}
  56. # test for class interface
  57. f = transform_cls(**meth_kwargs)
  58. scripted_fn = torch.jit.script(f)
  59. tensor, pil_img = _create_data(26, 34, channels, device=device)
  60. # set seed to reproduce the same transformation for tensor and PIL image
  61. torch.manual_seed(12)
  62. transformed_tensor = f(tensor)
  63. torch.manual_seed(12)
  64. transformed_pil_img = f(pil_img)
  65. if test_exact_match:
  66. _assert_equal_tensor_to_pil(transformed_tensor, transformed_pil_img, **match_kwargs)
  67. else:
  68. _assert_approx_equal_tensor_to_pil(transformed_tensor.float(), transformed_pil_img, **match_kwargs)
  69. torch.manual_seed(12)
  70. transformed_tensor_script = scripted_fn(tensor)
  71. assert_equal(transformed_tensor, transformed_tensor_script)
  72. batch_tensors = _create_data_batch(height=23, width=34, channels=channels, num_samples=4, device=device)
  73. _test_transform_vs_scripted_on_batch(f, scripted_fn, batch_tensors)
  74. with get_tmp_dir() as tmp_dir:
  75. scripted_fn.save(os.path.join(tmp_dir, f"t_{transform_cls.__name__}.pt"))
  76. def _test_op(func, method, device, channels=3, fn_kwargs=None, meth_kwargs=None, test_exact_match=True, **match_kwargs):
  77. _test_functional_op(func, device, channels, fn_kwargs, test_exact_match=test_exact_match, **match_kwargs)
  78. _test_class_op(method, device, channels, meth_kwargs, test_exact_match=test_exact_match, **match_kwargs)
  79. def _test_fn_save_load(fn, tmpdir):
  80. scripted_fn = torch.jit.script(fn)
  81. p = os.path.join(tmpdir, f"t_op_list_{getattr(fn, '__name__', fn.__class__.__name__)}.pt")
  82. scripted_fn.save(p)
  83. _ = torch.jit.load(p)
  84. @pytest.mark.parametrize("device", cpu_and_cuda())
  85. @pytest.mark.parametrize(
  86. "func,method,fn_kwargs,match_kwargs",
  87. [
  88. (F.hflip, T.RandomHorizontalFlip, None, {}),
  89. (F.vflip, T.RandomVerticalFlip, None, {}),
  90. (F.invert, T.RandomInvert, None, {}),
  91. (F.posterize, T.RandomPosterize, {"bits": 4}, {}),
  92. (F.solarize, T.RandomSolarize, {"threshold": 192.0}, {}),
  93. (F.adjust_sharpness, T.RandomAdjustSharpness, {"sharpness_factor": 2.0}, {}),
  94. (
  95. F.autocontrast,
  96. T.RandomAutocontrast,
  97. None,
  98. {"test_exact_match": False, "agg_method": "max", "tol": (1 + 1e-5), "allowed_percentage_diff": 0.05},
  99. ),
  100. (F.equalize, T.RandomEqualize, None, {}),
  101. ],
  102. )
  103. @pytest.mark.parametrize("channels", [1, 3])
  104. def test_random(func, method, device, channels, fn_kwargs, match_kwargs):
  105. _test_op(func, method, device, channels, fn_kwargs, fn_kwargs, **match_kwargs)
  106. @pytest.mark.parametrize("seed", range(10))
  107. @pytest.mark.parametrize("device", cpu_and_cuda())
  108. @pytest.mark.parametrize("channels", [1, 3])
  109. class TestColorJitter:
  110. @pytest.fixture(autouse=True)
  111. def set_random_seed(self, seed):
  112. torch.random.manual_seed(seed)
  113. @pytest.mark.parametrize("brightness", [0.1, 0.5, 1.0, 1.34, (0.3, 0.7), [0.4, 0.5]])
  114. def test_color_jitter_brightness(self, brightness, device, channels):
  115. tol = 1.0 + 1e-10
  116. meth_kwargs = {"brightness": brightness}
  117. _test_class_op(
  118. T.ColorJitter,
  119. meth_kwargs=meth_kwargs,
  120. test_exact_match=False,
  121. device=device,
  122. tol=tol,
  123. agg_method="max",
  124. channels=channels,
  125. )
  126. @pytest.mark.parametrize("contrast", [0.2, 0.5, 1.0, 1.5, (0.3, 0.7), [0.4, 0.5]])
  127. def test_color_jitter_contrast(self, contrast, device, channels):
  128. tol = 1.0 + 1e-10
  129. meth_kwargs = {"contrast": contrast}
  130. _test_class_op(
  131. T.ColorJitter,
  132. meth_kwargs=meth_kwargs,
  133. test_exact_match=False,
  134. device=device,
  135. tol=tol,
  136. agg_method="max",
  137. channels=channels,
  138. )
  139. @pytest.mark.parametrize("saturation", [0.5, 0.75, 1.0, 1.25, (0.3, 0.7), [0.3, 0.4]])
  140. def test_color_jitter_saturation(self, saturation, device, channels):
  141. tol = 1.0 + 1e-10
  142. meth_kwargs = {"saturation": saturation}
  143. _test_class_op(
  144. T.ColorJitter,
  145. meth_kwargs=meth_kwargs,
  146. test_exact_match=False,
  147. device=device,
  148. tol=tol,
  149. agg_method="max",
  150. channels=channels,
  151. )
  152. @pytest.mark.parametrize("hue", [0.2, 0.5, (-0.2, 0.3), [-0.4, 0.5]])
  153. def test_color_jitter_hue(self, hue, device, channels):
  154. meth_kwargs = {"hue": hue}
  155. _test_class_op(
  156. T.ColorJitter,
  157. meth_kwargs=meth_kwargs,
  158. test_exact_match=False,
  159. device=device,
  160. tol=16.1,
  161. agg_method="max",
  162. channels=channels,
  163. )
  164. def test_color_jitter_all(self, device, channels):
  165. # All 4 parameters together
  166. meth_kwargs = {"brightness": 0.2, "contrast": 0.2, "saturation": 0.2, "hue": 0.2}
  167. _test_class_op(
  168. T.ColorJitter,
  169. meth_kwargs=meth_kwargs,
  170. test_exact_match=False,
  171. device=device,
  172. tol=12.1,
  173. agg_method="max",
  174. channels=channels,
  175. )
  176. @pytest.mark.parametrize("device", cpu_and_cuda())
  177. @pytest.mark.parametrize("m", ["constant", "edge", "reflect", "symmetric"])
  178. @pytest.mark.parametrize("mul", [1, -1])
  179. def test_pad(m, mul, device):
  180. fill = 127 if m == "constant" else 0
  181. # Test functional.pad (PIL and Tensor) with padding as single int
  182. _test_functional_op(F.pad, fn_kwargs={"padding": mul * 2, "fill": fill, "padding_mode": m}, device=device)
  183. # Test functional.pad and transforms.Pad with padding as [int, ]
  184. fn_kwargs = meth_kwargs = {
  185. "padding": [mul * 2],
  186. "fill": fill,
  187. "padding_mode": m,
  188. }
  189. _test_op(F.pad, T.Pad, device=device, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs)
  190. # Test functional.pad and transforms.Pad with padding as list
  191. fn_kwargs = meth_kwargs = {"padding": [mul * 4, 4], "fill": fill, "padding_mode": m}
  192. _test_op(F.pad, T.Pad, device=device, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs)
  193. # Test functional.pad and transforms.Pad with padding as tuple
  194. fn_kwargs = meth_kwargs = {"padding": (mul * 2, 2, 2, mul * 2), "fill": fill, "padding_mode": m}
  195. _test_op(F.pad, T.Pad, device=device, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs)
  196. @pytest.mark.parametrize("device", cpu_and_cuda())
  197. def test_crop(device):
  198. fn_kwargs = {"top": 2, "left": 3, "height": 4, "width": 5}
  199. # Test transforms.RandomCrop with size and padding as tuple
  200. meth_kwargs = {
  201. "size": (4, 5),
  202. "padding": (4, 4),
  203. "pad_if_needed": True,
  204. }
  205. _test_op(F.crop, T.RandomCrop, device=device, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs)
  206. # Test transforms.functional.crop including outside the image area
  207. fn_kwargs = {"top": -2, "left": 3, "height": 4, "width": 5} # top
  208. _test_functional_op(F.crop, fn_kwargs=fn_kwargs, device=device)
  209. fn_kwargs = {"top": 1, "left": -3, "height": 4, "width": 5} # left
  210. _test_functional_op(F.crop, fn_kwargs=fn_kwargs, device=device)
  211. fn_kwargs = {"top": 7, "left": 3, "height": 4, "width": 5} # bottom
  212. _test_functional_op(F.crop, fn_kwargs=fn_kwargs, device=device)
  213. fn_kwargs = {"top": 3, "left": 8, "height": 4, "width": 5} # right
  214. _test_functional_op(F.crop, fn_kwargs=fn_kwargs, device=device)
  215. fn_kwargs = {"top": -3, "left": -3, "height": 15, "width": 15} # all
  216. _test_functional_op(F.crop, fn_kwargs=fn_kwargs, device=device)
  217. @pytest.mark.parametrize("device", cpu_and_cuda())
  218. @pytest.mark.parametrize(
  219. "padding_config",
  220. [
  221. {"padding_mode": "constant", "fill": 0},
  222. {"padding_mode": "constant", "fill": 10},
  223. {"padding_mode": "edge"},
  224. {"padding_mode": "reflect"},
  225. ],
  226. )
  227. @pytest.mark.parametrize("pad_if_needed", [True, False])
  228. @pytest.mark.parametrize("padding", [[5], [5, 4], [1, 2, 3, 4]])
  229. @pytest.mark.parametrize("size", [5, [5], [6, 6]])
  230. def test_random_crop(size, padding, pad_if_needed, padding_config, device):
  231. config = dict(padding_config)
  232. config["size"] = size
  233. config["padding"] = padding
  234. config["pad_if_needed"] = pad_if_needed
  235. _test_class_op(T.RandomCrop, device, meth_kwargs=config)
  236. def test_random_crop_save_load(tmpdir):
  237. fn = T.RandomCrop(32, [4], pad_if_needed=True)
  238. _test_fn_save_load(fn, tmpdir)
  239. @pytest.mark.parametrize("device", cpu_and_cuda())
  240. def test_center_crop(device, tmpdir):
  241. fn_kwargs = {"output_size": (4, 5)}
  242. meth_kwargs = {"size": (4, 5)}
  243. _test_op(F.center_crop, T.CenterCrop, device=device, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs)
  244. fn_kwargs = {"output_size": (5,)}
  245. meth_kwargs = {"size": (5,)}
  246. _test_op(F.center_crop, T.CenterCrop, device=device, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs)
  247. tensor = torch.randint(0, 256, (3, 10, 10), dtype=torch.uint8, device=device)
  248. # Test torchscript of transforms.CenterCrop with size as int
  249. f = T.CenterCrop(size=5)
  250. scripted_fn = torch.jit.script(f)
  251. scripted_fn(tensor)
  252. # Test torchscript of transforms.CenterCrop with size as [int, ]
  253. f = T.CenterCrop(size=[5])
  254. scripted_fn = torch.jit.script(f)
  255. scripted_fn(tensor)
  256. # Test torchscript of transforms.CenterCrop with size as tuple
  257. f = T.CenterCrop(size=(6, 6))
  258. scripted_fn = torch.jit.script(f)
  259. scripted_fn(tensor)
  260. def test_center_crop_save_load(tmpdir):
  261. fn = T.CenterCrop(size=[5])
  262. _test_fn_save_load(fn, tmpdir)
  263. @pytest.mark.parametrize("device", cpu_and_cuda())
  264. @pytest.mark.parametrize(
  265. "fn, method, out_length",
  266. [
  267. # test_five_crop
  268. (F.five_crop, T.FiveCrop, 5),
  269. # test_ten_crop
  270. (F.ten_crop, T.TenCrop, 10),
  271. ],
  272. )
  273. @pytest.mark.parametrize("size", [(5,), [5], (4, 5), [4, 5]])
  274. def test_x_crop(fn, method, out_length, size, device):
  275. meth_kwargs = fn_kwargs = {"size": size}
  276. scripted_fn = torch.jit.script(fn)
  277. tensor, pil_img = _create_data(height=20, width=20, device=device)
  278. transformed_t_list = fn(tensor, **fn_kwargs)
  279. transformed_p_list = fn(pil_img, **fn_kwargs)
  280. assert len(transformed_t_list) == len(transformed_p_list)
  281. assert len(transformed_t_list) == out_length
  282. for transformed_tensor, transformed_pil_img in zip(transformed_t_list, transformed_p_list):
  283. _assert_equal_tensor_to_pil(transformed_tensor, transformed_pil_img)
  284. transformed_t_list_script = scripted_fn(tensor.detach().clone(), **fn_kwargs)
  285. assert len(transformed_t_list) == len(transformed_t_list_script)
  286. assert len(transformed_t_list_script) == out_length
  287. for transformed_tensor, transformed_tensor_script in zip(transformed_t_list, transformed_t_list_script):
  288. assert_equal(transformed_tensor, transformed_tensor_script)
  289. # test for class interface
  290. fn = method(**meth_kwargs)
  291. scripted_fn = torch.jit.script(fn)
  292. output = scripted_fn(tensor)
  293. assert len(output) == len(transformed_t_list_script)
  294. # test on batch of tensors
  295. batch_tensors = _create_data_batch(height=23, width=34, channels=3, num_samples=4, device=device)
  296. torch.manual_seed(12)
  297. transformed_batch_list = fn(batch_tensors)
  298. for i in range(len(batch_tensors)):
  299. img_tensor = batch_tensors[i, ...]
  300. torch.manual_seed(12)
  301. transformed_img_list = fn(img_tensor)
  302. for transformed_img, transformed_batch in zip(transformed_img_list, transformed_batch_list):
  303. assert_equal(transformed_img, transformed_batch[i, ...])
  304. @pytest.mark.parametrize("method", ["FiveCrop", "TenCrop"])
  305. def test_x_crop_save_load(method, tmpdir):
  306. fn = getattr(T, method)(size=[5])
  307. _test_fn_save_load(fn, tmpdir)
  308. class TestResize:
  309. @pytest.mark.parametrize("size", [32, 34, 35, 36, 38])
  310. def test_resize_int(self, size):
  311. # TODO: Minimal check for bug-fix, improve this later
  312. x = torch.rand(3, 32, 46)
  313. t = T.Resize(size=size, antialias=True)
  314. y = t(x)
  315. # If size is an int, smaller edge of the image will be matched to this number.
  316. # i.e, if height > width, then image will be rescaled to (size * height / width, size).
  317. assert isinstance(y, torch.Tensor)
  318. assert y.shape[1] == size
  319. assert y.shape[2] == int(size * 46 / 32)
  320. @pytest.mark.parametrize("device", cpu_and_cuda())
  321. @pytest.mark.parametrize("dt", [None, torch.float32, torch.float64])
  322. @pytest.mark.parametrize("size", [[32], [32, 32], (32, 32), [34, 35]])
  323. @pytest.mark.parametrize("max_size", [None, 35, 1000])
  324. @pytest.mark.parametrize("interpolation", [BILINEAR, BICUBIC, NEAREST, NEAREST_EXACT])
  325. def test_resize_scripted(self, dt, size, max_size, interpolation, device):
  326. tensor, _ = _create_data(height=34, width=36, device=device)
  327. batch_tensors = torch.randint(0, 256, size=(4, 3, 44, 56), dtype=torch.uint8, device=device)
  328. if dt is not None:
  329. # This is a trivial cast to float of uint8 data to test all cases
  330. tensor = tensor.to(dt)
  331. if max_size is not None and len(size) != 1:
  332. pytest.skip("Size should be an int or a sequence of length 1 if max_size is specified")
  333. transform = T.Resize(size=size, interpolation=interpolation, max_size=max_size, antialias=True)
  334. s_transform = torch.jit.script(transform)
  335. _test_transform_vs_scripted(transform, s_transform, tensor)
  336. _test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
  337. def test_resize_save_load(self, tmpdir):
  338. fn = T.Resize(size=[32], antialias=True)
  339. _test_fn_save_load(fn, tmpdir)
  340. @pytest.mark.parametrize("device", cpu_and_cuda())
  341. @pytest.mark.parametrize("scale", [(0.7, 1.2), [0.7, 1.2]])
  342. @pytest.mark.parametrize("ratio", [(0.75, 1.333), [0.75, 1.333]])
  343. @pytest.mark.parametrize("size", [(32,), [44], [32], [32, 32], (32, 32), [44, 55]])
  344. @pytest.mark.parametrize("interpolation", [NEAREST, BILINEAR, BICUBIC, NEAREST_EXACT])
  345. @pytest.mark.parametrize("antialias", [None, True, False])
  346. def test_resized_crop(self, scale, ratio, size, interpolation, antialias, device):
  347. if antialias and interpolation in {NEAREST, NEAREST_EXACT}:
  348. pytest.skip(f"Can not resize if interpolation mode is {interpolation} and antialias=True")
  349. tensor = torch.randint(0, 256, size=(3, 44, 56), dtype=torch.uint8, device=device)
  350. batch_tensors = torch.randint(0, 256, size=(4, 3, 44, 56), dtype=torch.uint8, device=device)
  351. transform = T.RandomResizedCrop(
  352. size=size, scale=scale, ratio=ratio, interpolation=interpolation, antialias=antialias
  353. )
  354. s_transform = torch.jit.script(transform)
  355. _test_transform_vs_scripted(transform, s_transform, tensor)
  356. _test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
  357. def test_resized_crop_save_load(self, tmpdir):
  358. fn = T.RandomResizedCrop(size=[32], antialias=True)
  359. _test_fn_save_load(fn, tmpdir)
  360. def test_antialias_default_warning(self):
  361. img = torch.randint(0, 256, size=(3, 44, 56), dtype=torch.uint8)
  362. match = "The default value of the antialias"
  363. with pytest.warns(UserWarning, match=match):
  364. T.Resize((20, 20))(img)
  365. with pytest.warns(UserWarning, match=match):
  366. T.RandomResizedCrop((20, 20))(img)
  367. # For modes that aren't bicubic or bilinear, don't throw a warning
  368. with warnings.catch_warnings():
  369. warnings.simplefilter("error")
  370. T.Resize((20, 20), interpolation=NEAREST)(img)
  371. T.RandomResizedCrop((20, 20), interpolation=NEAREST)(img)
  372. def _test_random_affine_helper(device, **kwargs):
  373. tensor = torch.randint(0, 256, size=(3, 44, 56), dtype=torch.uint8, device=device)
  374. batch_tensors = torch.randint(0, 256, size=(4, 3, 44, 56), dtype=torch.uint8, device=device)
  375. transform = T.RandomAffine(**kwargs)
  376. s_transform = torch.jit.script(transform)
  377. _test_transform_vs_scripted(transform, s_transform, tensor)
  378. _test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
  379. def test_random_affine_save_load(tmpdir):
  380. fn = T.RandomAffine(degrees=45.0)
  381. _test_fn_save_load(fn, tmpdir)
  382. @pytest.mark.parametrize("device", cpu_and_cuda())
  383. @pytest.mark.parametrize("interpolation", [NEAREST, BILINEAR])
  384. @pytest.mark.parametrize("shear", [15, 10.0, (5.0, 10.0), [-15, 15], [-10.0, 10.0, -11.0, 11.0]])
  385. def test_random_affine_shear(device, interpolation, shear):
  386. _test_random_affine_helper(device, degrees=0.0, interpolation=interpolation, shear=shear)
  387. @pytest.mark.parametrize("device", cpu_and_cuda())
  388. @pytest.mark.parametrize("interpolation", [NEAREST, BILINEAR])
  389. @pytest.mark.parametrize("scale", [(0.7, 1.2), [0.7, 1.2]])
  390. def test_random_affine_scale(device, interpolation, scale):
  391. _test_random_affine_helper(device, degrees=0.0, interpolation=interpolation, scale=scale)
  392. @pytest.mark.parametrize("device", cpu_and_cuda())
  393. @pytest.mark.parametrize("interpolation", [NEAREST, BILINEAR])
  394. @pytest.mark.parametrize("translate", [(0.1, 0.2), [0.2, 0.1]])
  395. def test_random_affine_translate(device, interpolation, translate):
  396. _test_random_affine_helper(device, degrees=0.0, interpolation=interpolation, translate=translate)
  397. @pytest.mark.parametrize("device", cpu_and_cuda())
  398. @pytest.mark.parametrize("interpolation", [NEAREST, BILINEAR])
  399. @pytest.mark.parametrize("degrees", [45, 35.0, (-45, 45), [-90.0, 90.0]])
  400. def test_random_affine_degrees(device, interpolation, degrees):
  401. _test_random_affine_helper(device, degrees=degrees, interpolation=interpolation)
  402. @pytest.mark.parametrize("device", cpu_and_cuda())
  403. @pytest.mark.parametrize("interpolation", [NEAREST, BILINEAR])
  404. @pytest.mark.parametrize("fill", [85, (10, -10, 10), 0.7, [0.0, 0.0, 0.0], [1], 1])
  405. def test_random_affine_fill(device, interpolation, fill):
  406. _test_random_affine_helper(device, degrees=0.0, interpolation=interpolation, fill=fill)
  407. @pytest.mark.parametrize("device", cpu_and_cuda())
  408. @pytest.mark.parametrize("center", [(0, 0), [10, 10], None, (56, 44)])
  409. @pytest.mark.parametrize("expand", [True, False])
  410. @pytest.mark.parametrize("degrees", [45, 35.0, (-45, 45), [-90.0, 90.0]])
  411. @pytest.mark.parametrize("interpolation", [NEAREST, BILINEAR])
  412. @pytest.mark.parametrize("fill", [85, (10, -10, 10), 0.7, [0.0, 0.0, 0.0], [1], 1])
  413. def test_random_rotate(device, center, expand, degrees, interpolation, fill):
  414. tensor = torch.randint(0, 256, size=(3, 44, 56), dtype=torch.uint8, device=device)
  415. batch_tensors = torch.randint(0, 256, size=(4, 3, 44, 56), dtype=torch.uint8, device=device)
  416. transform = T.RandomRotation(degrees=degrees, interpolation=interpolation, expand=expand, center=center, fill=fill)
  417. s_transform = torch.jit.script(transform)
  418. _test_transform_vs_scripted(transform, s_transform, tensor)
  419. _test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
  420. def test_random_rotate_save_load(tmpdir):
  421. fn = T.RandomRotation(degrees=45.0)
  422. _test_fn_save_load(fn, tmpdir)
  423. @pytest.mark.parametrize("device", cpu_and_cuda())
  424. @pytest.mark.parametrize("distortion_scale", np.linspace(0.1, 1.0, num=20))
  425. @pytest.mark.parametrize("interpolation", [NEAREST, BILINEAR])
  426. @pytest.mark.parametrize("fill", [85, (10, -10, 10), 0.7, [0.0, 0.0, 0.0], [1], 1])
  427. def test_random_perspective(device, distortion_scale, interpolation, fill):
  428. tensor = torch.randint(0, 256, size=(3, 44, 56), dtype=torch.uint8, device=device)
  429. batch_tensors = torch.randint(0, 256, size=(4, 3, 44, 56), dtype=torch.uint8, device=device)
  430. transform = T.RandomPerspective(distortion_scale=distortion_scale, interpolation=interpolation, fill=fill)
  431. s_transform = torch.jit.script(transform)
  432. _test_transform_vs_scripted(transform, s_transform, tensor)
  433. _test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
  434. def test_random_perspective_save_load(tmpdir):
  435. fn = T.RandomPerspective()
  436. _test_fn_save_load(fn, tmpdir)
  437. @pytest.mark.parametrize("device", cpu_and_cuda())
  438. @pytest.mark.parametrize(
  439. "Klass, meth_kwargs",
  440. [(T.Grayscale, {"num_output_channels": 1}), (T.Grayscale, {"num_output_channels": 3}), (T.RandomGrayscale, {})],
  441. )
  442. def test_to_grayscale(device, Klass, meth_kwargs):
  443. tol = 1.0 + 1e-10
  444. _test_class_op(Klass, meth_kwargs=meth_kwargs, test_exact_match=False, device=device, tol=tol, agg_method="max")
  445. @pytest.mark.parametrize("device", cpu_and_cuda())
  446. @pytest.mark.parametrize("in_dtype", int_dtypes() + float_dtypes())
  447. @pytest.mark.parametrize("out_dtype", int_dtypes() + float_dtypes())
  448. def test_convert_image_dtype(device, in_dtype, out_dtype):
  449. tensor, _ = _create_data(26, 34, device=device)
  450. batch_tensors = torch.rand(4, 3, 44, 56, device=device)
  451. in_tensor = tensor.to(in_dtype)
  452. in_batch_tensors = batch_tensors.to(in_dtype)
  453. fn = T.ConvertImageDtype(dtype=out_dtype)
  454. scripted_fn = torch.jit.script(fn)
  455. if (in_dtype == torch.float32 and out_dtype in (torch.int32, torch.int64)) or (
  456. in_dtype == torch.float64 and out_dtype == torch.int64
  457. ):
  458. with pytest.raises(RuntimeError, match=r"cannot be performed safely"):
  459. _test_transform_vs_scripted(fn, scripted_fn, in_tensor)
  460. with pytest.raises(RuntimeError, match=r"cannot be performed safely"):
  461. _test_transform_vs_scripted_on_batch(fn, scripted_fn, in_batch_tensors)
  462. return
  463. _test_transform_vs_scripted(fn, scripted_fn, in_tensor)
  464. _test_transform_vs_scripted_on_batch(fn, scripted_fn, in_batch_tensors)
  465. def test_convert_image_dtype_save_load(tmpdir):
  466. fn = T.ConvertImageDtype(dtype=torch.uint8)
  467. _test_fn_save_load(fn, tmpdir)
  468. @pytest.mark.parametrize("device", cpu_and_cuda())
  469. @pytest.mark.parametrize("policy", [policy for policy in T.AutoAugmentPolicy])
  470. @pytest.mark.parametrize("fill", [None, 85, (10, -10, 10), 0.7, [0.0, 0.0, 0.0], [1], 1])
  471. def test_autoaugment(device, policy, fill):
  472. tensor = torch.randint(0, 256, size=(3, 44, 56), dtype=torch.uint8, device=device)
  473. batch_tensors = torch.randint(0, 256, size=(4, 3, 44, 56), dtype=torch.uint8, device=device)
  474. transform = T.AutoAugment(policy=policy, fill=fill)
  475. s_transform = torch.jit.script(transform)
  476. for _ in range(25):
  477. _test_transform_vs_scripted(transform, s_transform, tensor)
  478. _test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
  479. @pytest.mark.parametrize("device", cpu_and_cuda())
  480. @pytest.mark.parametrize("num_ops", [1, 2, 3])
  481. @pytest.mark.parametrize("magnitude", [7, 9, 11])
  482. @pytest.mark.parametrize("fill", [None, 85, (10, -10, 10), 0.7, [0.0, 0.0, 0.0], [1], 1])
  483. def test_randaugment(device, num_ops, magnitude, fill):
  484. tensor = torch.randint(0, 256, size=(3, 44, 56), dtype=torch.uint8, device=device)
  485. batch_tensors = torch.randint(0, 256, size=(4, 3, 44, 56), dtype=torch.uint8, device=device)
  486. transform = T.RandAugment(num_ops=num_ops, magnitude=magnitude, fill=fill)
  487. s_transform = torch.jit.script(transform)
  488. for _ in range(25):
  489. _test_transform_vs_scripted(transform, s_transform, tensor)
  490. _test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
  491. @pytest.mark.parametrize("device", cpu_and_cuda())
  492. @pytest.mark.parametrize("fill", [None, 85, (10, -10, 10), 0.7, [0.0, 0.0, 0.0], [1], 1])
  493. def test_trivialaugmentwide(device, fill):
  494. tensor = torch.randint(0, 256, size=(3, 44, 56), dtype=torch.uint8, device=device)
  495. batch_tensors = torch.randint(0, 256, size=(4, 3, 44, 56), dtype=torch.uint8, device=device)
  496. transform = T.TrivialAugmentWide(fill=fill)
  497. s_transform = torch.jit.script(transform)
  498. for _ in range(25):
  499. _test_transform_vs_scripted(transform, s_transform, tensor)
  500. _test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
  501. @pytest.mark.parametrize("device", cpu_and_cuda())
  502. @pytest.mark.parametrize("fill", [None, 85, (10, -10, 10), 0.7, [0.0, 0.0, 0.0], [1], 1])
  503. def test_augmix(device, fill):
  504. tensor = torch.randint(0, 256, size=(3, 44, 56), dtype=torch.uint8, device=device)
  505. batch_tensors = torch.randint(0, 256, size=(4, 3, 44, 56), dtype=torch.uint8, device=device)
  506. class DeterministicAugMix(T.AugMix):
  507. def _sample_dirichlet(self, params: torch.Tensor) -> torch.Tensor:
  508. # patch the method to ensure that the order of rand calls doesn't affect the outcome
  509. return params.softmax(dim=-1)
  510. transform = DeterministicAugMix(fill=fill)
  511. s_transform = torch.jit.script(transform)
  512. for _ in range(25):
  513. _test_transform_vs_scripted(transform, s_transform, tensor)
  514. _test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
  515. @pytest.mark.parametrize("augmentation", [T.AutoAugment, T.RandAugment, T.TrivialAugmentWide, T.AugMix])
  516. def test_autoaugment_save_load(augmentation, tmpdir):
  517. fn = augmentation()
  518. _test_fn_save_load(fn, tmpdir)
  519. @pytest.mark.parametrize("interpolation", [F.InterpolationMode.NEAREST, F.InterpolationMode.BILINEAR])
  520. @pytest.mark.parametrize("mode", ["X", "Y"])
  521. def test_autoaugment__op_apply_shear(interpolation, mode):
  522. # We check that torchvision's implementation of shear is equivalent
  523. # to official CIFAR10 autoaugment implementation:
  524. # https://github.com/tensorflow/models/blob/885fda091c46c59d6c7bb5c7e760935eacc229da/research/autoaugment/augmentation_transforms.py#L273-L290
  525. image_size = 32
  526. def shear(pil_img, level, mode, resample):
  527. if mode == "X":
  528. matrix = (1, level, 0, 0, 1, 0)
  529. elif mode == "Y":
  530. matrix = (1, 0, 0, level, 1, 0)
  531. return pil_img.transform((image_size, image_size), PIL.Image.AFFINE, matrix, resample=resample)
  532. t_img, pil_img = _create_data(image_size, image_size)
  533. resample_pil = {
  534. F.InterpolationMode.NEAREST: PIL.Image.NEAREST,
  535. F.InterpolationMode.BILINEAR: PIL.Image.BILINEAR,
  536. }[interpolation]
  537. level = 0.3
  538. expected_out = shear(pil_img, level, mode=mode, resample=resample_pil)
  539. # Check pil output vs expected pil
  540. out = _apply_op(pil_img, op_name=f"Shear{mode}", magnitude=level, interpolation=interpolation, fill=0)
  541. assert out == expected_out
  542. if interpolation == F.InterpolationMode.BILINEAR:
  543. # We skip bilinear mode for tensors as
  544. # affine transformation results are not exactly the same
  545. # between tensors and pil images
  546. # MAE as around 1.40
  547. # Max Abs error can be 163 or 170
  548. return
  549. # Check tensor output vs expected pil
  550. out = _apply_op(t_img, op_name=f"Shear{mode}", magnitude=level, interpolation=interpolation, fill=0)
  551. _assert_approx_equal_tensor_to_pil(out, expected_out)
  552. @pytest.mark.parametrize("device", cpu_and_cuda())
  553. @pytest.mark.parametrize(
  554. "config",
  555. [
  556. {},
  557. {"value": 1},
  558. {"value": 0.2},
  559. {"value": "random"},
  560. {"value": (1, 1, 1)},
  561. {"value": (0.2, 0.2, 0.2)},
  562. {"value": [1, 1, 1]},
  563. {"value": [0.2, 0.2, 0.2]},
  564. {"value": "random", "ratio": (0.1, 0.2)},
  565. ],
  566. )
  567. def test_random_erasing(device, config):
  568. tensor, _ = _create_data(24, 32, channels=3, device=device)
  569. batch_tensors = torch.rand(4, 3, 44, 56, device=device)
  570. fn = T.RandomErasing(**config)
  571. scripted_fn = torch.jit.script(fn)
  572. _test_transform_vs_scripted(fn, scripted_fn, tensor)
  573. _test_transform_vs_scripted_on_batch(fn, scripted_fn, batch_tensors)
  574. def test_random_erasing_save_load(tmpdir):
  575. fn = T.RandomErasing(value=0.2)
  576. _test_fn_save_load(fn, tmpdir)
  577. def test_random_erasing_with_invalid_data():
  578. img = torch.rand(3, 60, 60)
  579. # Test Set 0: invalid value
  580. random_erasing = T.RandomErasing(value=(0.1, 0.2, 0.3, 0.4), p=1.0)
  581. with pytest.raises(ValueError, match="If value is a sequence, it should have either a single value or 3"):
  582. random_erasing(img)
  583. @pytest.mark.parametrize("device", cpu_and_cuda())
  584. def test_normalize(device, tmpdir):
  585. fn = T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
  586. tensor, _ = _create_data(26, 34, device=device)
  587. with pytest.raises(TypeError, match="Input tensor should be a float tensor"):
  588. fn(tensor)
  589. batch_tensors = torch.rand(4, 3, 44, 56, device=device)
  590. tensor = tensor.to(dtype=torch.float32) / 255.0
  591. # test for class interface
  592. scripted_fn = torch.jit.script(fn)
  593. _test_transform_vs_scripted(fn, scripted_fn, tensor)
  594. _test_transform_vs_scripted_on_batch(fn, scripted_fn, batch_tensors)
  595. scripted_fn.save(os.path.join(tmpdir, "t_norm.pt"))
  596. @pytest.mark.parametrize("device", cpu_and_cuda())
  597. def test_linear_transformation(device, tmpdir):
  598. c, h, w = 3, 24, 32
  599. tensor, _ = _create_data(h, w, channels=c, device=device)
  600. matrix = torch.rand(c * h * w, c * h * w, device=device)
  601. mean_vector = torch.rand(c * h * w, device=device)
  602. fn = T.LinearTransformation(matrix, mean_vector)
  603. scripted_fn = torch.jit.script(fn)
  604. _test_transform_vs_scripted(fn, scripted_fn, tensor)
  605. batch_tensors = torch.rand(4, c, h, w, device=device)
  606. # We skip some tests from _test_transform_vs_scripted_on_batch as
  607. # results for scripted and non-scripted transformations are not exactly the same
  608. torch.manual_seed(12)
  609. transformed_batch = fn(batch_tensors)
  610. torch.manual_seed(12)
  611. s_transformed_batch = scripted_fn(batch_tensors)
  612. assert_equal(transformed_batch, s_transformed_batch)
  613. scripted_fn.save(os.path.join(tmpdir, "t_norm.pt"))
  614. @pytest.mark.parametrize("device", cpu_and_cuda())
  615. def test_compose(device):
  616. tensor, _ = _create_data(26, 34, device=device)
  617. tensor = tensor.to(dtype=torch.float32) / 255.0
  618. transforms = T.Compose(
  619. [
  620. T.CenterCrop(10),
  621. T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
  622. ]
  623. )
  624. s_transforms = torch.nn.Sequential(*transforms.transforms)
  625. scripted_fn = torch.jit.script(s_transforms)
  626. torch.manual_seed(12)
  627. transformed_tensor = transforms(tensor)
  628. torch.manual_seed(12)
  629. transformed_tensor_script = scripted_fn(tensor)
  630. assert_equal(transformed_tensor, transformed_tensor_script, msg=f"{transforms}")
  631. t = T.Compose(
  632. [
  633. lambda x: x,
  634. ]
  635. )
  636. with pytest.raises(RuntimeError, match="cannot call a value of type 'Tensor'"):
  637. torch.jit.script(t)
  638. @pytest.mark.parametrize("device", cpu_and_cuda())
  639. def test_random_apply(device):
  640. tensor, _ = _create_data(26, 34, device=device)
  641. tensor = tensor.to(dtype=torch.float32) / 255.0
  642. transforms = T.RandomApply(
  643. [
  644. T.RandomHorizontalFlip(),
  645. T.ColorJitter(),
  646. ],
  647. p=0.4,
  648. )
  649. s_transforms = T.RandomApply(
  650. torch.nn.ModuleList(
  651. [
  652. T.RandomHorizontalFlip(),
  653. T.ColorJitter(),
  654. ]
  655. ),
  656. p=0.4,
  657. )
  658. scripted_fn = torch.jit.script(s_transforms)
  659. torch.manual_seed(12)
  660. transformed_tensor = transforms(tensor)
  661. torch.manual_seed(12)
  662. transformed_tensor_script = scripted_fn(tensor)
  663. assert_equal(transformed_tensor, transformed_tensor_script, msg=f"{transforms}")
  664. if device == "cpu":
  665. # Can't check this twice, otherwise
  666. # "Can't redefine method: forward on class: __torch__.torchvision.transforms.transforms.RandomApply"
  667. transforms = T.RandomApply(
  668. [
  669. T.ColorJitter(),
  670. ],
  671. p=0.3,
  672. )
  673. with pytest.raises(RuntimeError, match="Module 'RandomApply' has no attribute 'transforms'"):
  674. torch.jit.script(transforms)
  675. @pytest.mark.parametrize("device", cpu_and_cuda())
  676. @pytest.mark.parametrize(
  677. "meth_kwargs",
  678. [
  679. {"kernel_size": 3, "sigma": 0.75},
  680. {"kernel_size": 23, "sigma": [0.1, 2.0]},
  681. {"kernel_size": 23, "sigma": (0.1, 2.0)},
  682. {"kernel_size": [3, 3], "sigma": (1.0, 1.0)},
  683. {"kernel_size": (3, 3), "sigma": (0.1, 2.0)},
  684. {"kernel_size": [23], "sigma": 0.75},
  685. ],
  686. )
  687. @pytest.mark.parametrize("channels", [1, 3])
  688. def test_gaussian_blur(device, channels, meth_kwargs):
  689. if all(
  690. [
  691. device == "cuda",
  692. channels == 1,
  693. meth_kwargs["kernel_size"] in [23, [23]],
  694. torch.version.cuda == "11.3",
  695. sys.platform in ("win32", "cygwin"),
  696. ]
  697. ):
  698. pytest.skip("Fails on Windows, see https://github.com/pytorch/vision/issues/5464")
  699. tol = 1.0 + 1e-10
  700. torch.manual_seed(12)
  701. _test_class_op(
  702. T.GaussianBlur,
  703. meth_kwargs=meth_kwargs,
  704. channels=channels,
  705. test_exact_match=False,
  706. device=device,
  707. agg_method="max",
  708. tol=tol,
  709. )
  710. @pytest.mark.parametrize("device", cpu_and_cuda())
  711. @pytest.mark.parametrize(
  712. "fill",
  713. [
  714. 1,
  715. 1.0,
  716. [1],
  717. [1.0],
  718. (1,),
  719. (1.0,),
  720. [1, 2, 3],
  721. [1.0, 2.0, 3.0],
  722. (1, 2, 3),
  723. (1.0, 2.0, 3.0),
  724. ],
  725. )
  726. @pytest.mark.parametrize("channels", [1, 3])
  727. def test_elastic_transform(device, channels, fill):
  728. if isinstance(fill, (list, tuple)) and len(fill) > 1 and channels == 1:
  729. # For this the test would correctly fail, since the number of channels in the image does not match `fill`.
  730. # Thus, this is not an issue in the transform, but rather a problem of parametrization that just gives the
  731. # product of `fill` and `channels`.
  732. return
  733. _test_class_op(
  734. T.ElasticTransform,
  735. meth_kwargs=dict(fill=fill),
  736. channels=channels,
  737. device=device,
  738. )