test_transforms.py 82 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131213221332134213521362137213821392140214121422143214421452146214721482149215021512152215321542155215621572158215921602161216221632164216521662167216821692170217121722173217421752176217721782179218021812182218321842185218621872188218921902191219221932194219521962197219821992200220122022203220422052206220722082209221022112212221322142215221622172218221922202221222222232224222522262227222822292230223122322233223422352236223722382239224022412242224322442245224622472248224922502251225222532254225522562257225822592260226122622263226422652266226722682269227022712272227322742275227622772278227922802281228222832284
  1. import math
  2. import os
  3. import random
  4. import re
  5. import textwrap
  6. import warnings
  7. from functools import partial
  8. import numpy as np
  9. import pytest
  10. import torch
  11. import torchvision.transforms as transforms
  12. import torchvision.transforms._functional_tensor as F_t
  13. import torchvision.transforms.functional as F
  14. from PIL import Image
  15. from torch._utils_internal import get_file_path_2
  16. try:
  17. import accimage
  18. except ImportError:
  19. accimage = None
  20. try:
  21. from scipy import stats
  22. except ImportError:
  23. stats = None
  24. from common_utils import assert_equal, assert_run_python_script, cycle_over, float_dtypes, int_dtypes
  25. GRACE_HOPPER = get_file_path_2(
  26. os.path.dirname(os.path.abspath(__file__)), "assets", "encode_jpeg", "grace_hopper_517x606.jpg"
  27. )
  28. def _get_grayscale_test_image(img, fill=None):
  29. img = img.convert("L")
  30. fill = (fill[0],) if isinstance(fill, tuple) else fill
  31. return img, fill
  32. class TestConvertImageDtype:
  33. @pytest.mark.parametrize("input_dtype, output_dtype", cycle_over(float_dtypes()))
  34. def test_float_to_float(self, input_dtype, output_dtype):
  35. input_image = torch.tensor((0.0, 1.0), dtype=input_dtype)
  36. transform = transforms.ConvertImageDtype(output_dtype)
  37. transform_script = torch.jit.script(F.convert_image_dtype)
  38. output_image = transform(input_image)
  39. output_image_script = transform_script(input_image, output_dtype)
  40. torch.testing.assert_close(output_image_script, output_image, rtol=0.0, atol=1e-6)
  41. actual_min, actual_max = output_image.tolist()
  42. desired_min, desired_max = 0.0, 1.0
  43. assert abs(actual_min - desired_min) < 1e-7
  44. assert abs(actual_max - desired_max) < 1e-7
  45. @pytest.mark.parametrize("input_dtype", float_dtypes())
  46. @pytest.mark.parametrize("output_dtype", int_dtypes())
  47. def test_float_to_int(self, input_dtype, output_dtype):
  48. input_image = torch.tensor((0.0, 1.0), dtype=input_dtype)
  49. transform = transforms.ConvertImageDtype(output_dtype)
  50. transform_script = torch.jit.script(F.convert_image_dtype)
  51. if (input_dtype == torch.float32 and output_dtype in (torch.int32, torch.int64)) or (
  52. input_dtype == torch.float64 and output_dtype == torch.int64
  53. ):
  54. with pytest.raises(RuntimeError):
  55. transform(input_image)
  56. else:
  57. output_image = transform(input_image)
  58. output_image_script = transform_script(input_image, output_dtype)
  59. torch.testing.assert_close(output_image_script, output_image, rtol=0.0, atol=1e-6)
  60. actual_min, actual_max = output_image.tolist()
  61. desired_min, desired_max = 0, torch.iinfo(output_dtype).max
  62. assert actual_min == desired_min
  63. assert actual_max == desired_max
  64. @pytest.mark.parametrize("input_dtype", int_dtypes())
  65. @pytest.mark.parametrize("output_dtype", float_dtypes())
  66. def test_int_to_float(self, input_dtype, output_dtype):
  67. input_image = torch.tensor((0, torch.iinfo(input_dtype).max), dtype=input_dtype)
  68. transform = transforms.ConvertImageDtype(output_dtype)
  69. transform_script = torch.jit.script(F.convert_image_dtype)
  70. output_image = transform(input_image)
  71. output_image_script = transform_script(input_image, output_dtype)
  72. torch.testing.assert_close(output_image_script, output_image, rtol=0.0, atol=1e-6)
  73. actual_min, actual_max = output_image.tolist()
  74. desired_min, desired_max = 0.0, 1.0
  75. assert abs(actual_min - desired_min) < 1e-7
  76. assert actual_min >= desired_min
  77. assert abs(actual_max - desired_max) < 1e-7
  78. assert actual_max <= desired_max
  79. @pytest.mark.parametrize("input_dtype, output_dtype", cycle_over(int_dtypes()))
  80. def test_dtype_int_to_int(self, input_dtype, output_dtype):
  81. input_max = torch.iinfo(input_dtype).max
  82. input_image = torch.tensor((0, input_max), dtype=input_dtype)
  83. output_max = torch.iinfo(output_dtype).max
  84. transform = transforms.ConvertImageDtype(output_dtype)
  85. transform_script = torch.jit.script(F.convert_image_dtype)
  86. output_image = transform(input_image)
  87. output_image_script = transform_script(input_image, output_dtype)
  88. torch.testing.assert_close(
  89. output_image_script,
  90. output_image,
  91. rtol=0.0,
  92. atol=1e-6,
  93. msg=f"{output_image_script} vs {output_image}",
  94. )
  95. actual_min, actual_max = output_image.tolist()
  96. desired_min, desired_max = 0, output_max
  97. # see https://github.com/pytorch/vision/pull/2078#issuecomment-641036236 for details
  98. if input_max >= output_max:
  99. error_term = 0
  100. else:
  101. error_term = 1 - (torch.iinfo(output_dtype).max + 1) // (torch.iinfo(input_dtype).max + 1)
  102. assert actual_min == desired_min
  103. assert actual_max == (desired_max + error_term)
  104. @pytest.mark.parametrize("input_dtype, output_dtype", cycle_over(int_dtypes()))
  105. def test_int_to_int_consistency(self, input_dtype, output_dtype):
  106. input_max = torch.iinfo(input_dtype).max
  107. input_image = torch.tensor((0, input_max), dtype=input_dtype)
  108. output_max = torch.iinfo(output_dtype).max
  109. if output_max <= input_max:
  110. return
  111. transform = transforms.ConvertImageDtype(output_dtype)
  112. inverse_transfrom = transforms.ConvertImageDtype(input_dtype)
  113. output_image = inverse_transfrom(transform(input_image))
  114. actual_min, actual_max = output_image.tolist()
  115. desired_min, desired_max = 0, input_max
  116. assert actual_min == desired_min
  117. assert actual_max == desired_max
  118. @pytest.mark.skipif(accimage is None, reason="accimage not available")
  119. class TestAccImage:
  120. def test_accimage_to_tensor(self):
  121. trans = transforms.PILToTensor()
  122. expected_output = trans(Image.open(GRACE_HOPPER).convert("RGB"))
  123. output = trans(accimage.Image(GRACE_HOPPER))
  124. torch.testing.assert_close(output, expected_output)
  125. def test_accimage_pil_to_tensor(self):
  126. trans = transforms.PILToTensor()
  127. expected_output = trans(Image.open(GRACE_HOPPER).convert("RGB"))
  128. output = trans(accimage.Image(GRACE_HOPPER))
  129. assert expected_output.size() == output.size()
  130. torch.testing.assert_close(output, expected_output)
  131. def test_accimage_resize(self):
  132. trans = transforms.Compose(
  133. [
  134. transforms.Resize(256, interpolation=Image.LINEAR),
  135. transforms.PILToTensor(),
  136. transforms.ConvertImageDtype(dtype=torch.float),
  137. ]
  138. )
  139. # Checking if Compose, Resize and ToTensor can be printed as string
  140. trans.__repr__()
  141. expected_output = trans(Image.open(GRACE_HOPPER).convert("RGB"))
  142. output = trans(accimage.Image(GRACE_HOPPER))
  143. assert expected_output.size() == output.size()
  144. assert np.abs((expected_output - output).mean()) < 1e-3
  145. assert (expected_output - output).var() < 1e-5
  146. # note the high absolute tolerance
  147. torch.testing.assert_close(output.numpy(), expected_output.numpy(), rtol=1e-5, atol=5e-2)
  148. def test_accimage_crop(self):
  149. trans = transforms.Compose(
  150. [transforms.CenterCrop(256), transforms.PILToTensor(), transforms.ConvertImageDtype(dtype=torch.float)]
  151. )
  152. # Checking if Compose, CenterCrop and ToTensor can be printed as string
  153. trans.__repr__()
  154. expected_output = trans(Image.open(GRACE_HOPPER).convert("RGB"))
  155. output = trans(accimage.Image(GRACE_HOPPER))
  156. assert expected_output.size() == output.size()
  157. torch.testing.assert_close(output, expected_output)
  158. class TestToTensor:
  159. @pytest.mark.parametrize("channels", [1, 3, 4])
  160. def test_to_tensor(self, channels):
  161. height, width = 4, 4
  162. trans = transforms.ToTensor()
  163. np_rng = np.random.RandomState(0)
  164. input_data = torch.ByteTensor(channels, height, width).random_(0, 255).float().div_(255)
  165. img = transforms.ToPILImage()(input_data)
  166. output = trans(img)
  167. torch.testing.assert_close(output, input_data)
  168. ndarray = np_rng.randint(low=0, high=255, size=(height, width, channels)).astype(np.uint8)
  169. output = trans(ndarray)
  170. expected_output = ndarray.transpose((2, 0, 1)) / 255.0
  171. torch.testing.assert_close(output.numpy(), expected_output, check_dtype=False)
  172. ndarray = np_rng.rand(height, width, channels).astype(np.float32)
  173. output = trans(ndarray)
  174. expected_output = ndarray.transpose((2, 0, 1))
  175. torch.testing.assert_close(output.numpy(), expected_output, check_dtype=False)
  176. # separate test for mode '1' PIL images
  177. input_data = torch.ByteTensor(1, height, width).bernoulli_()
  178. img = transforms.ToPILImage()(input_data.mul(255)).convert("1")
  179. output = trans(img)
  180. torch.testing.assert_close(input_data, output, check_dtype=False)
  181. def test_to_tensor_errors(self):
  182. height, width = 4, 4
  183. trans = transforms.ToTensor()
  184. np_rng = np.random.RandomState(0)
  185. with pytest.raises(TypeError):
  186. trans(np_rng.rand(1, height, width).tolist())
  187. with pytest.raises(ValueError):
  188. trans(np_rng.rand(height))
  189. with pytest.raises(ValueError):
  190. trans(np_rng.rand(1, 1, height, width))
  191. @pytest.mark.parametrize("dtype", [torch.float16, torch.float, torch.double])
  192. def test_to_tensor_with_other_default_dtypes(self, dtype):
  193. np_rng = np.random.RandomState(0)
  194. current_def_dtype = torch.get_default_dtype()
  195. t = transforms.ToTensor()
  196. np_arr = np_rng.randint(0, 255, (32, 32, 3), dtype=np.uint8)
  197. img = Image.fromarray(np_arr)
  198. torch.set_default_dtype(dtype)
  199. res = t(img)
  200. assert res.dtype == dtype, f"{res.dtype} vs {dtype}"
  201. torch.set_default_dtype(current_def_dtype)
  202. @pytest.mark.parametrize("channels", [1, 3, 4])
  203. def test_pil_to_tensor(self, channels):
  204. height, width = 4, 4
  205. trans = transforms.PILToTensor()
  206. np_rng = np.random.RandomState(0)
  207. input_data = torch.ByteTensor(channels, height, width).random_(0, 255)
  208. img = transforms.ToPILImage()(input_data)
  209. output = trans(img)
  210. torch.testing.assert_close(input_data, output)
  211. input_data = np_rng.randint(low=0, high=255, size=(height, width, channels)).astype(np.uint8)
  212. img = transforms.ToPILImage()(input_data)
  213. output = trans(img)
  214. expected_output = input_data.transpose((2, 0, 1))
  215. torch.testing.assert_close(output.numpy(), expected_output)
  216. input_data = torch.as_tensor(np_rng.rand(channels, height, width).astype(np.float32))
  217. img = transforms.ToPILImage()(input_data) # CHW -> HWC and (* 255).byte()
  218. output = trans(img) # HWC -> CHW
  219. expected_output = (input_data * 255).byte()
  220. torch.testing.assert_close(output, expected_output)
  221. # separate test for mode '1' PIL images
  222. input_data = torch.ByteTensor(1, height, width).bernoulli_()
  223. img = transforms.ToPILImage()(input_data.mul(255)).convert("1")
  224. output = trans(img).view(torch.uint8).bool().to(torch.uint8)
  225. torch.testing.assert_close(input_data, output)
  226. def test_pil_to_tensor_errors(self):
  227. height, width = 4, 4
  228. trans = transforms.PILToTensor()
  229. np_rng = np.random.RandomState(0)
  230. with pytest.raises(TypeError):
  231. trans(np_rng.rand(1, height, width).tolist())
  232. with pytest.raises(TypeError):
  233. trans(np_rng.rand(1, height, width))
  234. def test_randomresized_params():
  235. height = random.randint(24, 32) * 2
  236. width = random.randint(24, 32) * 2
  237. img = torch.ones(3, height, width)
  238. to_pil_image = transforms.ToPILImage()
  239. img = to_pil_image(img)
  240. size = 100
  241. epsilon = 0.05
  242. min_scale = 0.25
  243. for _ in range(10):
  244. scale_min = max(round(random.random(), 2), min_scale)
  245. scale_range = (scale_min, scale_min + round(random.random(), 2))
  246. aspect_min = max(round(random.random(), 2), epsilon)
  247. aspect_ratio_range = (aspect_min, aspect_min + round(random.random(), 2))
  248. randresizecrop = transforms.RandomResizedCrop(size, scale_range, aspect_ratio_range, antialias=True)
  249. i, j, h, w = randresizecrop.get_params(img, scale_range, aspect_ratio_range)
  250. aspect_ratio_obtained = w / h
  251. assert (
  252. min(aspect_ratio_range) - epsilon <= aspect_ratio_obtained
  253. and aspect_ratio_obtained <= max(aspect_ratio_range) + epsilon
  254. ) or aspect_ratio_obtained == 1.0
  255. assert isinstance(i, int)
  256. assert isinstance(j, int)
  257. assert isinstance(h, int)
  258. assert isinstance(w, int)
  259. @pytest.mark.parametrize(
  260. "height, width",
  261. [
  262. # height, width
  263. # square image
  264. (28, 28),
  265. (27, 27),
  266. # rectangular image: h < w
  267. (28, 34),
  268. (29, 35),
  269. # rectangular image: h > w
  270. (34, 28),
  271. (35, 29),
  272. ],
  273. )
  274. @pytest.mark.parametrize(
  275. "osize",
  276. [
  277. # single integer
  278. 22,
  279. 27,
  280. 28,
  281. 36,
  282. # single integer in tuple/list
  283. [
  284. 22,
  285. ],
  286. (27,),
  287. ],
  288. )
  289. @pytest.mark.parametrize("max_size", (None, 37, 1000))
  290. def test_resize(height, width, osize, max_size):
  291. img = Image.new("RGB", size=(width, height), color=127)
  292. t = transforms.Resize(osize, max_size=max_size, antialias=True)
  293. result = t(img)
  294. msg = f"{height}, {width} - {osize} - {max_size}"
  295. osize = osize[0] if isinstance(osize, (list, tuple)) else osize
  296. # If size is an int, smaller edge of the image will be matched to this number.
  297. # i.e, if height > width, then image will be rescaled to (size * height / width, size).
  298. if height < width:
  299. exp_w, exp_h = (int(osize * width / height), osize) # (w, h)
  300. if max_size is not None and max_size < exp_w:
  301. exp_w, exp_h = max_size, int(max_size * exp_h / exp_w)
  302. assert result.size == (exp_w, exp_h), msg
  303. elif width < height:
  304. exp_w, exp_h = (osize, int(osize * height / width)) # (w, h)
  305. if max_size is not None and max_size < exp_h:
  306. exp_w, exp_h = int(max_size * exp_w / exp_h), max_size
  307. assert result.size == (exp_w, exp_h), msg
  308. else:
  309. exp_w, exp_h = (osize, osize) # (w, h)
  310. if max_size is not None and max_size < osize:
  311. exp_w, exp_h = max_size, max_size
  312. assert result.size == (exp_w, exp_h), msg
  313. @pytest.mark.parametrize(
  314. "height, width",
  315. [
  316. # height, width
  317. # square image
  318. (28, 28),
  319. (27, 27),
  320. # rectangular image: h < w
  321. (28, 34),
  322. (29, 35),
  323. # rectangular image: h > w
  324. (34, 28),
  325. (35, 29),
  326. ],
  327. )
  328. @pytest.mark.parametrize(
  329. "osize",
  330. [
  331. # two integers sequence output
  332. [22, 22],
  333. [22, 28],
  334. [22, 36],
  335. [27, 22],
  336. [36, 22],
  337. [28, 28],
  338. [28, 37],
  339. [37, 27],
  340. [37, 37],
  341. ],
  342. )
  343. def test_resize_sequence_output(height, width, osize):
  344. img = Image.new("RGB", size=(width, height), color=127)
  345. oheight, owidth = osize
  346. t = transforms.Resize(osize, antialias=True)
  347. result = t(img)
  348. assert (owidth, oheight) == result.size
  349. def test_resize_antialias_error():
  350. osize = [37, 37]
  351. img = Image.new("RGB", size=(35, 29), color=127)
  352. with pytest.warns(UserWarning, match=r"Anti-alias option is always applied for PIL Image input"):
  353. t = transforms.Resize(osize, antialias=False)
  354. t(img)
  355. def test_resize_antialias_default_warning():
  356. img = Image.new("RGB", size=(10, 10), color=127)
  357. # We make sure we don't warn for PIL images since the default behaviour doesn't change
  358. with warnings.catch_warnings():
  359. warnings.simplefilter("error")
  360. transforms.Resize((20, 20))(img)
  361. transforms.RandomResizedCrop((20, 20))(img)
  362. @pytest.mark.parametrize("height, width", ((32, 64), (64, 32)))
  363. def test_resize_size_equals_small_edge_size(height, width):
  364. # Non-regression test for https://github.com/pytorch/vision/issues/5405
  365. # max_size used to be ignored if size == small_edge_size
  366. max_size = 40
  367. img = Image.new("RGB", size=(width, height), color=127)
  368. small_edge = min(height, width)
  369. t = transforms.Resize(small_edge, max_size=max_size, antialias=True)
  370. result = t(img)
  371. assert max(result.size) == max_size
  372. def test_resize_equal_input_output_sizes():
  373. # Regression test for https://github.com/pytorch/vision/issues/7518
  374. height, width = 28, 27
  375. img = Image.new("RGB", size=(width, height))
  376. t = transforms.Resize((height, width), antialias=True)
  377. result = t(img)
  378. assert result is img
  379. class TestPad:
  380. @pytest.mark.parametrize("fill", [85, 85.0])
  381. def test_pad(self, fill):
  382. height = random.randint(10, 32) * 2
  383. width = random.randint(10, 32) * 2
  384. img = torch.ones(3, height, width, dtype=torch.uint8)
  385. padding = random.randint(1, 20)
  386. result = transforms.Compose(
  387. [
  388. transforms.ToPILImage(),
  389. transforms.Pad(padding, fill=fill),
  390. transforms.PILToTensor(),
  391. ]
  392. )(img)
  393. assert result.size(1) == height + 2 * padding
  394. assert result.size(2) == width + 2 * padding
  395. # check that all elements in the padded region correspond
  396. # to the pad value
  397. h_padded = result[:, :padding, :]
  398. w_padded = result[:, :, :padding]
  399. torch.testing.assert_close(h_padded, torch.full_like(h_padded, fill_value=fill), rtol=0.0, atol=0.0)
  400. torch.testing.assert_close(w_padded, torch.full_like(w_padded, fill_value=fill), rtol=0.0, atol=0.0)
  401. pytest.raises(ValueError, transforms.Pad(padding, fill=(1, 2)), transforms.ToPILImage()(img))
  402. def test_pad_with_tuple_of_pad_values(self):
  403. height = random.randint(10, 32) * 2
  404. width = random.randint(10, 32) * 2
  405. img = transforms.ToPILImage()(torch.ones(3, height, width))
  406. padding = tuple(random.randint(1, 20) for _ in range(2))
  407. output = transforms.Pad(padding)(img)
  408. assert output.size == (width + padding[0] * 2, height + padding[1] * 2)
  409. padding = [random.randint(1, 20) for _ in range(4)]
  410. output = transforms.Pad(padding)(img)
  411. assert output.size[0] == width + padding[0] + padding[2]
  412. assert output.size[1] == height + padding[1] + padding[3]
  413. # Checking if Padding can be printed as string
  414. transforms.Pad(padding).__repr__()
  415. def test_pad_with_non_constant_padding_modes(self):
  416. """Unit tests for edge, reflect, symmetric padding"""
  417. img = torch.zeros(3, 27, 27).byte()
  418. img[:, :, 0] = 1 # Constant value added to leftmost edge
  419. img = transforms.ToPILImage()(img)
  420. img = F.pad(img, 1, (200, 200, 200))
  421. # pad 3 to all sidess
  422. edge_padded_img = F.pad(img, 3, padding_mode="edge")
  423. # First 6 elements of leftmost edge in the middle of the image, values are in order:
  424. # edge_pad, edge_pad, edge_pad, constant_pad, constant value added to leftmost edge, 0
  425. edge_middle_slice = np.asarray(edge_padded_img).transpose(2, 0, 1)[0][17][:6]
  426. assert_equal(edge_middle_slice, np.asarray([200, 200, 200, 200, 1, 0], dtype=np.uint8))
  427. assert transforms.PILToTensor()(edge_padded_img).size() == (3, 35, 35)
  428. # Pad 3 to left/right, 2 to top/bottom
  429. reflect_padded_img = F.pad(img, (3, 2), padding_mode="reflect")
  430. # First 6 elements of leftmost edge in the middle of the image, values are in order:
  431. # reflect_pad, reflect_pad, reflect_pad, constant_pad, constant value added to leftmost edge, 0
  432. reflect_middle_slice = np.asarray(reflect_padded_img).transpose(2, 0, 1)[0][17][:6]
  433. assert_equal(reflect_middle_slice, np.asarray([0, 0, 1, 200, 1, 0], dtype=np.uint8))
  434. assert transforms.PILToTensor()(reflect_padded_img).size() == (3, 33, 35)
  435. # Pad 3 to left, 2 to top, 2 to right, 1 to bottom
  436. symmetric_padded_img = F.pad(img, (3, 2, 2, 1), padding_mode="symmetric")
  437. # First 6 elements of leftmost edge in the middle of the image, values are in order:
  438. # sym_pad, sym_pad, sym_pad, constant_pad, constant value added to leftmost edge, 0
  439. symmetric_middle_slice = np.asarray(symmetric_padded_img).transpose(2, 0, 1)[0][17][:6]
  440. assert_equal(symmetric_middle_slice, np.asarray([0, 1, 200, 200, 1, 0], dtype=np.uint8))
  441. assert transforms.PILToTensor()(symmetric_padded_img).size() == (3, 32, 34)
  442. # Check negative padding explicitly for symmetric case, since it is not
  443. # implemented for tensor case to compare to
  444. # Crop 1 to left, pad 2 to top, pad 3 to right, crop 3 to bottom
  445. symmetric_padded_img_neg = F.pad(img, (-1, 2, 3, -3), padding_mode="symmetric")
  446. symmetric_neg_middle_left = np.asarray(symmetric_padded_img_neg).transpose(2, 0, 1)[0][17][:3]
  447. symmetric_neg_middle_right = np.asarray(symmetric_padded_img_neg).transpose(2, 0, 1)[0][17][-4:]
  448. assert_equal(symmetric_neg_middle_left, np.asarray([1, 0, 0], dtype=np.uint8))
  449. assert_equal(symmetric_neg_middle_right, np.asarray([200, 200, 0, 0], dtype=np.uint8))
  450. assert transforms.PILToTensor()(symmetric_padded_img_neg).size() == (3, 28, 31)
  451. def test_pad_raises_with_invalid_pad_sequence_len(self):
  452. with pytest.raises(ValueError):
  453. transforms.Pad(())
  454. with pytest.raises(ValueError):
  455. transforms.Pad((1, 2, 3))
  456. with pytest.raises(ValueError):
  457. transforms.Pad((1, 2, 3, 4, 5))
  458. def test_pad_with_mode_F_images(self):
  459. pad = 2
  460. transform = transforms.Pad(pad)
  461. img = Image.new("F", (10, 10))
  462. padded_img = transform(img)
  463. assert_equal(padded_img.size, [edge_size + 2 * pad for edge_size in img.size])
  464. @pytest.mark.parametrize(
  465. "fn, trans, kwargs",
  466. [
  467. (F.invert, transforms.RandomInvert, {}),
  468. (F.posterize, transforms.RandomPosterize, {"bits": 4}),
  469. (F.solarize, transforms.RandomSolarize, {"threshold": 192}),
  470. (F.adjust_sharpness, transforms.RandomAdjustSharpness, {"sharpness_factor": 2.0}),
  471. (F.autocontrast, transforms.RandomAutocontrast, {}),
  472. (F.equalize, transforms.RandomEqualize, {}),
  473. (F.vflip, transforms.RandomVerticalFlip, {}),
  474. (F.hflip, transforms.RandomHorizontalFlip, {}),
  475. (partial(F.to_grayscale, num_output_channels=3), transforms.RandomGrayscale, {}),
  476. ],
  477. )
  478. @pytest.mark.parametrize("seed", range(10))
  479. @pytest.mark.parametrize("p", (0, 1))
  480. def test_randomness(fn, trans, kwargs, seed, p):
  481. torch.manual_seed(seed)
  482. img = transforms.ToPILImage()(torch.rand(3, 16, 18))
  483. expected_transformed_img = fn(img, **kwargs)
  484. randomly_transformed_img = trans(p=p, **kwargs)(img)
  485. if p == 0:
  486. assert randomly_transformed_img == img
  487. elif p == 1:
  488. assert randomly_transformed_img == expected_transformed_img
  489. trans(**kwargs).__repr__()
  490. def test_autocontrast_equal_minmax():
  491. img_tensor = torch.tensor([[[10]], [[128]], [[245]]], dtype=torch.uint8).expand(3, 32, 32)
  492. img_pil = F.to_pil_image(img_tensor)
  493. img_tensor = F.autocontrast(img_tensor)
  494. img_pil = F.autocontrast(img_pil)
  495. torch.testing.assert_close(img_tensor, F.pil_to_tensor(img_pil))
  496. class TestToPil:
  497. def _get_1_channel_tensor_various_types():
  498. img_data_float = torch.Tensor(1, 4, 4).uniform_()
  499. expected_output = img_data_float.mul(255).int().float().div(255).numpy()
  500. yield img_data_float, expected_output, "L"
  501. img_data_byte = torch.ByteTensor(1, 4, 4).random_(0, 255)
  502. expected_output = img_data_byte.float().div(255.0).numpy()
  503. yield img_data_byte, expected_output, "L"
  504. img_data_short = torch.ShortTensor(1, 4, 4).random_()
  505. expected_output = img_data_short.numpy()
  506. yield img_data_short, expected_output, "I;16"
  507. img_data_int = torch.IntTensor(1, 4, 4).random_()
  508. expected_output = img_data_int.numpy()
  509. yield img_data_int, expected_output, "I"
  510. def _get_2d_tensor_various_types():
  511. img_data_float = torch.Tensor(4, 4).uniform_()
  512. expected_output = img_data_float.mul(255).int().float().div(255).numpy()
  513. yield img_data_float, expected_output, "L"
  514. img_data_byte = torch.ByteTensor(4, 4).random_(0, 255)
  515. expected_output = img_data_byte.float().div(255.0).numpy()
  516. yield img_data_byte, expected_output, "L"
  517. img_data_short = torch.ShortTensor(4, 4).random_()
  518. expected_output = img_data_short.numpy()
  519. yield img_data_short, expected_output, "I;16"
  520. img_data_int = torch.IntTensor(4, 4).random_()
  521. expected_output = img_data_int.numpy()
  522. yield img_data_int, expected_output, "I"
  523. @pytest.mark.parametrize("with_mode", [False, True])
  524. @pytest.mark.parametrize("img_data, expected_output, expected_mode", _get_1_channel_tensor_various_types())
  525. def test_1_channel_tensor_to_pil_image(self, with_mode, img_data, expected_output, expected_mode):
  526. transform = transforms.ToPILImage(mode=expected_mode) if with_mode else transforms.ToPILImage()
  527. to_tensor = transforms.ToTensor()
  528. img = transform(img_data)
  529. assert img.mode == expected_mode
  530. torch.testing.assert_close(expected_output, to_tensor(img).numpy())
  531. def test_1_channel_float_tensor_to_pil_image(self):
  532. img_data = torch.Tensor(1, 4, 4).uniform_()
  533. # 'F' mode for torch.FloatTensor
  534. img_F_mode = transforms.ToPILImage(mode="F")(img_data)
  535. assert img_F_mode.mode == "F"
  536. torch.testing.assert_close(
  537. np.array(Image.fromarray(img_data.squeeze(0).numpy(), mode="F")), np.array(img_F_mode)
  538. )
  539. @pytest.mark.parametrize("with_mode", [False, True])
  540. @pytest.mark.parametrize(
  541. "img_data, expected_mode",
  542. [
  543. (torch.Tensor(4, 4, 1).uniform_().numpy(), "F"),
  544. (torch.ByteTensor(4, 4, 1).random_(0, 255).numpy(), "L"),
  545. (torch.ShortTensor(4, 4, 1).random_().numpy(), "I;16"),
  546. (torch.IntTensor(4, 4, 1).random_().numpy(), "I"),
  547. ],
  548. )
  549. def test_1_channel_ndarray_to_pil_image(self, with_mode, img_data, expected_mode):
  550. transform = transforms.ToPILImage(mode=expected_mode) if with_mode else transforms.ToPILImage()
  551. img = transform(img_data)
  552. assert img.mode == expected_mode
  553. # note: we explicitly convert img's dtype because pytorch doesn't support uint16
  554. # and otherwise assert_close wouldn't be able to construct a tensor from the uint16 array
  555. torch.testing.assert_close(img_data[:, :, 0], np.asarray(img).astype(img_data.dtype))
  556. @pytest.mark.parametrize("expected_mode", [None, "LA"])
  557. def test_2_channel_ndarray_to_pil_image(self, expected_mode):
  558. img_data = torch.ByteTensor(4, 4, 2).random_(0, 255).numpy()
  559. if expected_mode is None:
  560. img = transforms.ToPILImage()(img_data)
  561. assert img.mode == "LA" # default should assume LA
  562. else:
  563. img = transforms.ToPILImage(mode=expected_mode)(img_data)
  564. assert img.mode == expected_mode
  565. split = img.split()
  566. for i in range(2):
  567. torch.testing.assert_close(img_data[:, :, i], np.asarray(split[i]))
  568. def test_2_channel_ndarray_to_pil_image_error(self):
  569. img_data = torch.ByteTensor(4, 4, 2).random_(0, 255).numpy()
  570. transforms.ToPILImage().__repr__()
  571. # should raise if we try a mode for 4 or 1 or 3 channel images
  572. with pytest.raises(ValueError, match=r"Only modes \['LA'\] are supported for 2D inputs"):
  573. transforms.ToPILImage(mode="RGBA")(img_data)
  574. with pytest.raises(ValueError, match=r"Only modes \['LA'\] are supported for 2D inputs"):
  575. transforms.ToPILImage(mode="P")(img_data)
  576. with pytest.raises(ValueError, match=r"Only modes \['LA'\] are supported for 2D inputs"):
  577. transforms.ToPILImage(mode="RGB")(img_data)
  578. @pytest.mark.parametrize("expected_mode", [None, "LA"])
  579. def test_2_channel_tensor_to_pil_image(self, expected_mode):
  580. img_data = torch.Tensor(2, 4, 4).uniform_()
  581. expected_output = img_data.mul(255).int().float().div(255)
  582. if expected_mode is None:
  583. img = transforms.ToPILImage()(img_data)
  584. assert img.mode == "LA" # default should assume LA
  585. else:
  586. img = transforms.ToPILImage(mode=expected_mode)(img_data)
  587. assert img.mode == expected_mode
  588. split = img.split()
  589. for i in range(2):
  590. torch.testing.assert_close(expected_output[i].numpy(), F.to_tensor(split[i]).squeeze(0).numpy())
  591. def test_2_channel_tensor_to_pil_image_error(self):
  592. img_data = torch.Tensor(2, 4, 4).uniform_()
  593. # should raise if we try a mode for 4 or 1 or 3 channel images
  594. with pytest.raises(ValueError, match=r"Only modes \['LA'\] are supported for 2D inputs"):
  595. transforms.ToPILImage(mode="RGBA")(img_data)
  596. with pytest.raises(ValueError, match=r"Only modes \['LA'\] are supported for 2D inputs"):
  597. transforms.ToPILImage(mode="P")(img_data)
  598. with pytest.raises(ValueError, match=r"Only modes \['LA'\] are supported for 2D inputs"):
  599. transforms.ToPILImage(mode="RGB")(img_data)
  600. @pytest.mark.parametrize("with_mode", [False, True])
  601. @pytest.mark.parametrize("img_data, expected_output, expected_mode", _get_2d_tensor_various_types())
  602. def test_2d_tensor_to_pil_image(self, with_mode, img_data, expected_output, expected_mode):
  603. transform = transforms.ToPILImage(mode=expected_mode) if with_mode else transforms.ToPILImage()
  604. to_tensor = transforms.ToTensor()
  605. img = transform(img_data)
  606. assert img.mode == expected_mode
  607. torch.testing.assert_close(expected_output, to_tensor(img).numpy()[0])
  608. @pytest.mark.parametrize("with_mode", [False, True])
  609. @pytest.mark.parametrize(
  610. "img_data, expected_mode",
  611. [
  612. (torch.Tensor(4, 4).uniform_().numpy(), "F"),
  613. (torch.ByteTensor(4, 4).random_(0, 255).numpy(), "L"),
  614. (torch.ShortTensor(4, 4).random_().numpy(), "I;16"),
  615. (torch.IntTensor(4, 4).random_().numpy(), "I"),
  616. ],
  617. )
  618. def test_2d_ndarray_to_pil_image(self, with_mode, img_data, expected_mode):
  619. transform = transforms.ToPILImage(mode=expected_mode) if with_mode else transforms.ToPILImage()
  620. img = transform(img_data)
  621. assert img.mode == expected_mode
  622. np.testing.assert_allclose(img_data, img)
  623. @pytest.mark.parametrize("expected_mode", [None, "RGB", "HSV", "YCbCr"])
  624. def test_3_channel_tensor_to_pil_image(self, expected_mode):
  625. img_data = torch.Tensor(3, 4, 4).uniform_()
  626. expected_output = img_data.mul(255).int().float().div(255)
  627. if expected_mode is None:
  628. img = transforms.ToPILImage()(img_data)
  629. assert img.mode == "RGB" # default should assume RGB
  630. else:
  631. img = transforms.ToPILImage(mode=expected_mode)(img_data)
  632. assert img.mode == expected_mode
  633. split = img.split()
  634. for i in range(3):
  635. torch.testing.assert_close(expected_output[i].numpy(), F.to_tensor(split[i]).squeeze(0).numpy())
  636. def test_3_channel_tensor_to_pil_image_error(self):
  637. img_data = torch.Tensor(3, 4, 4).uniform_()
  638. error_message_3d = r"Only modes \['RGB', 'YCbCr', 'HSV'\] are supported for 3D inputs"
  639. # should raise if we try a mode for 4 or 1 or 2 channel images
  640. with pytest.raises(ValueError, match=error_message_3d):
  641. transforms.ToPILImage(mode="RGBA")(img_data)
  642. with pytest.raises(ValueError, match=error_message_3d):
  643. transforms.ToPILImage(mode="P")(img_data)
  644. with pytest.raises(ValueError, match=error_message_3d):
  645. transforms.ToPILImage(mode="LA")(img_data)
  646. with pytest.raises(ValueError, match=r"pic should be 2/3 dimensional. Got \d+ dimensions."):
  647. transforms.ToPILImage()(torch.Tensor(1, 3, 4, 4).uniform_())
  648. @pytest.mark.parametrize("expected_mode", [None, "RGB", "HSV", "YCbCr"])
  649. def test_3_channel_ndarray_to_pil_image(self, expected_mode):
  650. img_data = torch.ByteTensor(4, 4, 3).random_(0, 255).numpy()
  651. if expected_mode is None:
  652. img = transforms.ToPILImage()(img_data)
  653. assert img.mode == "RGB" # default should assume RGB
  654. else:
  655. img = transforms.ToPILImage(mode=expected_mode)(img_data)
  656. assert img.mode == expected_mode
  657. split = img.split()
  658. for i in range(3):
  659. torch.testing.assert_close(img_data[:, :, i], np.asarray(split[i]))
  660. def test_3_channel_ndarray_to_pil_image_error(self):
  661. img_data = torch.ByteTensor(4, 4, 3).random_(0, 255).numpy()
  662. # Checking if ToPILImage can be printed as string
  663. transforms.ToPILImage().__repr__()
  664. error_message_3d = r"Only modes \['RGB', 'YCbCr', 'HSV'\] are supported for 3D inputs"
  665. # should raise if we try a mode for 4 or 1 or 2 channel images
  666. with pytest.raises(ValueError, match=error_message_3d):
  667. transforms.ToPILImage(mode="RGBA")(img_data)
  668. with pytest.raises(ValueError, match=error_message_3d):
  669. transforms.ToPILImage(mode="P")(img_data)
  670. with pytest.raises(ValueError, match=error_message_3d):
  671. transforms.ToPILImage(mode="LA")(img_data)
  672. @pytest.mark.parametrize("expected_mode", [None, "RGBA", "CMYK", "RGBX"])
  673. def test_4_channel_tensor_to_pil_image(self, expected_mode):
  674. img_data = torch.Tensor(4, 4, 4).uniform_()
  675. expected_output = img_data.mul(255).int().float().div(255)
  676. if expected_mode is None:
  677. img = transforms.ToPILImage()(img_data)
  678. assert img.mode == "RGBA" # default should assume RGBA
  679. else:
  680. img = transforms.ToPILImage(mode=expected_mode)(img_data)
  681. assert img.mode == expected_mode
  682. split = img.split()
  683. for i in range(4):
  684. torch.testing.assert_close(expected_output[i].numpy(), F.to_tensor(split[i]).squeeze(0).numpy())
  685. def test_4_channel_tensor_to_pil_image_error(self):
  686. img_data = torch.Tensor(4, 4, 4).uniform_()
  687. error_message_4d = r"Only modes \['RGBA', 'CMYK', 'RGBX'\] are supported for 4D inputs"
  688. # should raise if we try a mode for 3 or 1 or 2 channel images
  689. with pytest.raises(ValueError, match=error_message_4d):
  690. transforms.ToPILImage(mode="RGB")(img_data)
  691. with pytest.raises(ValueError, match=error_message_4d):
  692. transforms.ToPILImage(mode="P")(img_data)
  693. with pytest.raises(ValueError, match=error_message_4d):
  694. transforms.ToPILImage(mode="LA")(img_data)
  695. @pytest.mark.parametrize("expected_mode", [None, "RGBA", "CMYK", "RGBX"])
  696. def test_4_channel_ndarray_to_pil_image(self, expected_mode):
  697. img_data = torch.ByteTensor(4, 4, 4).random_(0, 255).numpy()
  698. if expected_mode is None:
  699. img = transforms.ToPILImage()(img_data)
  700. assert img.mode == "RGBA" # default should assume RGBA
  701. else:
  702. img = transforms.ToPILImage(mode=expected_mode)(img_data)
  703. assert img.mode == expected_mode
  704. split = img.split()
  705. for i in range(4):
  706. torch.testing.assert_close(img_data[:, :, i], np.asarray(split[i]))
  707. def test_4_channel_ndarray_to_pil_image_error(self):
  708. img_data = torch.ByteTensor(4, 4, 4).random_(0, 255).numpy()
  709. error_message_4d = r"Only modes \['RGBA', 'CMYK', 'RGBX'\] are supported for 4D inputs"
  710. # should raise if we try a mode for 3 or 1 or 2 channel images
  711. with pytest.raises(ValueError, match=error_message_4d):
  712. transforms.ToPILImage(mode="RGB")(img_data)
  713. with pytest.raises(ValueError, match=error_message_4d):
  714. transforms.ToPILImage(mode="P")(img_data)
  715. with pytest.raises(ValueError, match=error_message_4d):
  716. transforms.ToPILImage(mode="LA")(img_data)
  717. def test_ndarray_bad_types_to_pil_image(self):
  718. trans = transforms.ToPILImage()
  719. reg_msg = r"Input type \w+ is not supported"
  720. with pytest.raises(TypeError, match=reg_msg):
  721. trans(np.ones([4, 4, 1], np.int64))
  722. with pytest.raises(TypeError, match=reg_msg):
  723. trans(np.ones([4, 4, 1], np.uint16))
  724. with pytest.raises(TypeError, match=reg_msg):
  725. trans(np.ones([4, 4, 1], np.uint32))
  726. with pytest.raises(TypeError, match=reg_msg):
  727. trans(np.ones([4, 4, 1], np.float64))
  728. with pytest.raises(ValueError, match=r"pic should be 2/3 dimensional. Got \d+ dimensions."):
  729. transforms.ToPILImage()(np.ones([1, 4, 4, 3]))
  730. with pytest.raises(ValueError, match=r"pic should not have > 4 channels. Got \d+ channels."):
  731. transforms.ToPILImage()(np.ones([4, 4, 6]))
  732. def test_tensor_bad_types_to_pil_image(self):
  733. with pytest.raises(ValueError, match=r"pic should be 2/3 dimensional. Got \d+ dimensions."):
  734. transforms.ToPILImage()(torch.ones(1, 3, 4, 4))
  735. with pytest.raises(ValueError, match=r"pic should not have > 4 channels. Got \d+ channels."):
  736. transforms.ToPILImage()(torch.ones(6, 4, 4))
  737. def test_adjust_brightness():
  738. x_shape = [2, 2, 3]
  739. x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
  740. x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape)
  741. x_pil = Image.fromarray(x_np, mode="RGB")
  742. # test 0
  743. y_pil = F.adjust_brightness(x_pil, 1)
  744. y_np = np.array(y_pil)
  745. torch.testing.assert_close(y_np, x_np)
  746. # test 1
  747. y_pil = F.adjust_brightness(x_pil, 0.5)
  748. y_np = np.array(y_pil)
  749. y_ans = [0, 2, 6, 27, 67, 113, 18, 4, 117, 45, 127, 0]
  750. y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape)
  751. torch.testing.assert_close(y_np, y_ans)
  752. # test 2
  753. y_pil = F.adjust_brightness(x_pil, 2)
  754. y_np = np.array(y_pil)
  755. y_ans = [0, 10, 26, 108, 255, 255, 74, 16, 255, 180, 255, 2]
  756. y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape)
  757. torch.testing.assert_close(y_np, y_ans)
  758. def test_adjust_contrast():
  759. x_shape = [2, 2, 3]
  760. x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
  761. x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape)
  762. x_pil = Image.fromarray(x_np, mode="RGB")
  763. # test 0
  764. y_pil = F.adjust_contrast(x_pil, 1)
  765. y_np = np.array(y_pil)
  766. torch.testing.assert_close(y_np, x_np)
  767. # test 1
  768. y_pil = F.adjust_contrast(x_pil, 0.5)
  769. y_np = np.array(y_pil)
  770. y_ans = [43, 45, 49, 70, 110, 156, 61, 47, 160, 88, 170, 43]
  771. y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape)
  772. torch.testing.assert_close(y_np, y_ans)
  773. # test 2
  774. y_pil = F.adjust_contrast(x_pil, 2)
  775. y_np = np.array(y_pil)
  776. y_ans = [0, 0, 0, 22, 184, 255, 0, 0, 255, 94, 255, 0]
  777. y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape)
  778. torch.testing.assert_close(y_np, y_ans)
  779. def test_adjust_hue():
  780. x_shape = [2, 2, 3]
  781. x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
  782. x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape)
  783. x_pil = Image.fromarray(x_np, mode="RGB")
  784. with pytest.raises(ValueError):
  785. F.adjust_hue(x_pil, -0.7)
  786. F.adjust_hue(x_pil, 1)
  787. # test 0: almost same as x_data but not exact.
  788. # probably because hsv <-> rgb floating point ops
  789. y_pil = F.adjust_hue(x_pil, 0)
  790. y_np = np.array(y_pil)
  791. y_ans = [0, 5, 13, 54, 139, 226, 35, 8, 234, 91, 255, 1]
  792. y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape)
  793. torch.testing.assert_close(y_np, y_ans)
  794. # test 1
  795. y_pil = F.adjust_hue(x_pil, 0.25)
  796. y_np = np.array(y_pil)
  797. y_ans = [13, 0, 12, 224, 54, 226, 234, 8, 99, 1, 222, 255]
  798. y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape)
  799. torch.testing.assert_close(y_np, y_ans)
  800. # test 2
  801. y_pil = F.adjust_hue(x_pil, -0.25)
  802. y_np = np.array(y_pil)
  803. y_ans = [0, 13, 2, 54, 226, 58, 8, 234, 152, 255, 43, 1]
  804. y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape)
  805. torch.testing.assert_close(y_np, y_ans)
  806. def test_adjust_sharpness():
  807. x_shape = [4, 4, 3]
  808. x_data = [
  809. 75,
  810. 121,
  811. 114,
  812. 105,
  813. 97,
  814. 107,
  815. 105,
  816. 32,
  817. 66,
  818. 111,
  819. 117,
  820. 114,
  821. 99,
  822. 104,
  823. 97,
  824. 0,
  825. 0,
  826. 65,
  827. 108,
  828. 101,
  829. 120,
  830. 97,
  831. 110,
  832. 100,
  833. 101,
  834. 114,
  835. 32,
  836. 86,
  837. 114,
  838. 121,
  839. 110,
  840. 105,
  841. 111,
  842. 116,
  843. 105,
  844. 115,
  845. 0,
  846. 0,
  847. 73,
  848. 32,
  849. 108,
  850. 111,
  851. 118,
  852. 101,
  853. 32,
  854. 121,
  855. 111,
  856. 117,
  857. ]
  858. x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape)
  859. x_pil = Image.fromarray(x_np, mode="RGB")
  860. # test 0
  861. y_pil = F.adjust_sharpness(x_pil, 1)
  862. y_np = np.array(y_pil)
  863. torch.testing.assert_close(y_np, x_np)
  864. # test 1
  865. y_pil = F.adjust_sharpness(x_pil, 0.5)
  866. y_np = np.array(y_pil)
  867. y_ans = [
  868. 75,
  869. 121,
  870. 114,
  871. 105,
  872. 97,
  873. 107,
  874. 105,
  875. 32,
  876. 66,
  877. 111,
  878. 117,
  879. 114,
  880. 99,
  881. 104,
  882. 97,
  883. 30,
  884. 30,
  885. 74,
  886. 103,
  887. 96,
  888. 114,
  889. 97,
  890. 110,
  891. 100,
  892. 101,
  893. 114,
  894. 32,
  895. 81,
  896. 103,
  897. 108,
  898. 102,
  899. 101,
  900. 107,
  901. 116,
  902. 105,
  903. 115,
  904. 0,
  905. 0,
  906. 73,
  907. 32,
  908. 108,
  909. 111,
  910. 118,
  911. 101,
  912. 32,
  913. 121,
  914. 111,
  915. 117,
  916. ]
  917. y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape)
  918. torch.testing.assert_close(y_np, y_ans)
  919. # test 2
  920. y_pil = F.adjust_sharpness(x_pil, 2)
  921. y_np = np.array(y_pil)
  922. y_ans = [
  923. 75,
  924. 121,
  925. 114,
  926. 105,
  927. 97,
  928. 107,
  929. 105,
  930. 32,
  931. 66,
  932. 111,
  933. 117,
  934. 114,
  935. 99,
  936. 104,
  937. 97,
  938. 0,
  939. 0,
  940. 46,
  941. 118,
  942. 111,
  943. 132,
  944. 97,
  945. 110,
  946. 100,
  947. 101,
  948. 114,
  949. 32,
  950. 95,
  951. 135,
  952. 146,
  953. 126,
  954. 112,
  955. 119,
  956. 116,
  957. 105,
  958. 115,
  959. 0,
  960. 0,
  961. 73,
  962. 32,
  963. 108,
  964. 111,
  965. 118,
  966. 101,
  967. 32,
  968. 121,
  969. 111,
  970. 117,
  971. ]
  972. y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape)
  973. torch.testing.assert_close(y_np, y_ans)
  974. # test 3
  975. x_shape = [2, 2, 3]
  976. x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
  977. x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape)
  978. x_pil = Image.fromarray(x_np, mode="RGB")
  979. x_th = torch.tensor(x_np.transpose(2, 0, 1))
  980. y_pil = F.adjust_sharpness(x_pil, 2)
  981. y_np = np.array(y_pil).transpose(2, 0, 1)
  982. y_th = F.adjust_sharpness(x_th, 2)
  983. torch.testing.assert_close(y_np, y_th.numpy())
  984. def test_adjust_gamma():
  985. x_shape = [2, 2, 3]
  986. x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
  987. x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape)
  988. x_pil = Image.fromarray(x_np, mode="RGB")
  989. # test 0
  990. y_pil = F.adjust_gamma(x_pil, 1)
  991. y_np = np.array(y_pil)
  992. torch.testing.assert_close(y_np, x_np)
  993. # test 1
  994. y_pil = F.adjust_gamma(x_pil, 0.5)
  995. y_np = np.array(y_pil)
  996. y_ans = [0, 35, 57, 117, 186, 241, 97, 45, 245, 152, 255, 16]
  997. y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape)
  998. torch.testing.assert_close(y_np, y_ans)
  999. # test 2
  1000. y_pil = F.adjust_gamma(x_pil, 2)
  1001. y_np = np.array(y_pil)
  1002. y_ans = [0, 0, 0, 11, 71, 201, 5, 0, 215, 31, 255, 0]
  1003. y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape)
  1004. torch.testing.assert_close(y_np, y_ans)
  1005. def test_adjusts_L_mode():
  1006. x_shape = [2, 2, 3]
  1007. x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
  1008. x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape)
  1009. x_rgb = Image.fromarray(x_np, mode="RGB")
  1010. x_l = x_rgb.convert("L")
  1011. assert F.adjust_brightness(x_l, 2).mode == "L"
  1012. assert F.adjust_saturation(x_l, 2).mode == "L"
  1013. assert F.adjust_contrast(x_l, 2).mode == "L"
  1014. assert F.adjust_hue(x_l, 0.4).mode == "L"
  1015. assert F.adjust_sharpness(x_l, 2).mode == "L"
  1016. assert F.adjust_gamma(x_l, 0.5).mode == "L"
  1017. def test_rotate():
  1018. x = np.zeros((100, 100, 3), dtype=np.uint8)
  1019. x[40, 40] = [255, 255, 255]
  1020. with pytest.raises(TypeError, match=r"img should be PIL Image"):
  1021. F.rotate(x, 10)
  1022. img = F.to_pil_image(x)
  1023. result = F.rotate(img, 45)
  1024. assert result.size == (100, 100)
  1025. r, c, ch = np.where(result)
  1026. assert all(x in r for x in [49, 50])
  1027. assert all(x in c for x in [36])
  1028. assert all(x in ch for x in [0, 1, 2])
  1029. result = F.rotate(img, 45, expand=True)
  1030. assert result.size == (142, 142)
  1031. r, c, ch = np.where(result)
  1032. assert all(x in r for x in [70, 71])
  1033. assert all(x in c for x in [57])
  1034. assert all(x in ch for x in [0, 1, 2])
  1035. result = F.rotate(img, 45, center=(40, 40))
  1036. assert result.size == (100, 100)
  1037. r, c, ch = np.where(result)
  1038. assert all(x in r for x in [40])
  1039. assert all(x in c for x in [40])
  1040. assert all(x in ch for x in [0, 1, 2])
  1041. result_a = F.rotate(img, 90)
  1042. result_b = F.rotate(img, -270)
  1043. assert_equal(np.array(result_a), np.array(result_b))
  1044. @pytest.mark.parametrize("mode", ["L", "RGB", "F"])
  1045. def test_rotate_fill(mode):
  1046. img = F.to_pil_image(np.ones((100, 100, 3), dtype=np.uint8) * 255, "RGB")
  1047. num_bands = len(mode)
  1048. wrong_num_bands = num_bands + 1
  1049. fill = 127
  1050. img_conv = img.convert(mode)
  1051. img_rot = F.rotate(img_conv, 45.0, fill=fill)
  1052. pixel = img_rot.getpixel((0, 0))
  1053. if not isinstance(pixel, tuple):
  1054. pixel = (pixel,)
  1055. assert pixel == tuple([fill] * num_bands)
  1056. with pytest.raises(ValueError):
  1057. F.rotate(img_conv, 45.0, fill=tuple([fill] * wrong_num_bands))
  1058. def test_gaussian_blur_asserts():
  1059. np_img = np.ones((100, 100, 3), dtype=np.uint8) * 255
  1060. img = F.to_pil_image(np_img, "RGB")
  1061. with pytest.raises(ValueError, match=r"If kernel_size is a sequence its length should be 2"):
  1062. F.gaussian_blur(img, [3])
  1063. with pytest.raises(ValueError, match=r"If kernel_size is a sequence its length should be 2"):
  1064. F.gaussian_blur(img, [3, 3, 3])
  1065. with pytest.raises(ValueError, match=r"Kernel size should be a tuple/list of two integers"):
  1066. transforms.GaussianBlur([3, 3, 3])
  1067. with pytest.raises(ValueError, match=r"kernel_size should have odd and positive integers"):
  1068. F.gaussian_blur(img, [4, 4])
  1069. with pytest.raises(ValueError, match=r"Kernel size value should be an odd and positive number"):
  1070. transforms.GaussianBlur([4, 4])
  1071. with pytest.raises(ValueError, match=r"kernel_size should have odd and positive integers"):
  1072. F.gaussian_blur(img, [-3, -3])
  1073. with pytest.raises(ValueError, match=r"Kernel size value should be an odd and positive number"):
  1074. transforms.GaussianBlur([-3, -3])
  1075. with pytest.raises(ValueError, match=r"If sigma is a sequence, its length should be 2"):
  1076. F.gaussian_blur(img, 3, [1, 1, 1])
  1077. with pytest.raises(ValueError, match=r"sigma should be a single number or a list/tuple with length 2"):
  1078. transforms.GaussianBlur(3, [1, 1, 1])
  1079. with pytest.raises(ValueError, match=r"sigma should have positive values"):
  1080. F.gaussian_blur(img, 3, -1.0)
  1081. with pytest.raises(ValueError, match=r"If sigma is a single number, it must be positive"):
  1082. transforms.GaussianBlur(3, -1.0)
  1083. with pytest.raises(TypeError, match=r"kernel_size should be int or a sequence of integers"):
  1084. F.gaussian_blur(img, "kernel_size_string")
  1085. with pytest.raises(ValueError, match=r"Kernel size should be a tuple/list of two integers"):
  1086. transforms.GaussianBlur("kernel_size_string")
  1087. with pytest.raises(TypeError, match=r"sigma should be either float or sequence of floats"):
  1088. F.gaussian_blur(img, 3, "sigma_string")
  1089. with pytest.raises(ValueError, match=r"sigma should be a single number or a list/tuple with length 2"):
  1090. transforms.GaussianBlur(3, "sigma_string")
  1091. def test_lambda():
  1092. trans = transforms.Lambda(lambda x: x.add(10))
  1093. x = torch.randn(10)
  1094. y = trans(x)
  1095. assert_equal(y, torch.add(x, 10))
  1096. trans = transforms.Lambda(lambda x: x.add_(10))
  1097. x = torch.randn(10)
  1098. y = trans(x)
  1099. assert_equal(y, x)
  1100. # Checking if Lambda can be printed as string
  1101. trans.__repr__()
  1102. def test_to_grayscale():
  1103. """Unit tests for grayscale transform"""
  1104. x_shape = [2, 2, 3]
  1105. x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
  1106. x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape)
  1107. x_pil = Image.fromarray(x_np, mode="RGB")
  1108. x_pil_2 = x_pil.convert("L")
  1109. gray_np = np.array(x_pil_2)
  1110. # Test Set: Grayscale an image with desired number of output channels
  1111. # Case 1: RGB -> 1 channel grayscale
  1112. trans1 = transforms.Grayscale(num_output_channels=1)
  1113. gray_pil_1 = trans1(x_pil)
  1114. gray_np_1 = np.array(gray_pil_1)
  1115. assert gray_pil_1.mode == "L", "mode should be L"
  1116. assert gray_np_1.shape == tuple(x_shape[0:2]), "should be 1 channel"
  1117. assert_equal(gray_np, gray_np_1)
  1118. # Case 2: RGB -> 3 channel grayscale
  1119. trans2 = transforms.Grayscale(num_output_channels=3)
  1120. gray_pil_2 = trans2(x_pil)
  1121. gray_np_2 = np.array(gray_pil_2)
  1122. assert gray_pil_2.mode == "RGB", "mode should be RGB"
  1123. assert gray_np_2.shape == tuple(x_shape), "should be 3 channel"
  1124. assert_equal(gray_np_2[:, :, 0], gray_np_2[:, :, 1])
  1125. assert_equal(gray_np_2[:, :, 1], gray_np_2[:, :, 2])
  1126. assert_equal(gray_np, gray_np_2[:, :, 0])
  1127. # Case 3: 1 channel grayscale -> 1 channel grayscale
  1128. trans3 = transforms.Grayscale(num_output_channels=1)
  1129. gray_pil_3 = trans3(x_pil_2)
  1130. gray_np_3 = np.array(gray_pil_3)
  1131. assert gray_pil_3.mode == "L", "mode should be L"
  1132. assert gray_np_3.shape == tuple(x_shape[0:2]), "should be 1 channel"
  1133. assert_equal(gray_np, gray_np_3)
  1134. # Case 4: 1 channel grayscale -> 3 channel grayscale
  1135. trans4 = transforms.Grayscale(num_output_channels=3)
  1136. gray_pil_4 = trans4(x_pil_2)
  1137. gray_np_4 = np.array(gray_pil_4)
  1138. assert gray_pil_4.mode == "RGB", "mode should be RGB"
  1139. assert gray_np_4.shape == tuple(x_shape), "should be 3 channel"
  1140. assert_equal(gray_np_4[:, :, 0], gray_np_4[:, :, 1])
  1141. assert_equal(gray_np_4[:, :, 1], gray_np_4[:, :, 2])
  1142. assert_equal(gray_np, gray_np_4[:, :, 0])
  1143. # Checking if Grayscale can be printed as string
  1144. trans4.__repr__()
  1145. @pytest.mark.parametrize("seed", range(10))
  1146. @pytest.mark.parametrize("p", (0, 1))
  1147. def test_random_apply(p, seed):
  1148. torch.manual_seed(seed)
  1149. random_apply_transform = transforms.RandomApply([transforms.RandomRotation((45, 50))], p=p)
  1150. img = transforms.ToPILImage()(torch.rand(3, 30, 40))
  1151. out = random_apply_transform(img)
  1152. if p == 0:
  1153. assert out == img
  1154. elif p == 1:
  1155. assert out != img
  1156. # Checking if RandomApply can be printed as string
  1157. random_apply_transform.__repr__()
  1158. @pytest.mark.parametrize("seed", range(10))
  1159. @pytest.mark.parametrize("proba_passthrough", (0, 1))
  1160. def test_random_choice(proba_passthrough, seed):
  1161. random.seed(seed) # RandomChoice relies on python builtin random.choice, not pytorch
  1162. random_choice_transform = transforms.RandomChoice(
  1163. [
  1164. lambda x: x, # passthrough
  1165. transforms.RandomRotation((45, 50)),
  1166. ],
  1167. p=[proba_passthrough, 1 - proba_passthrough],
  1168. )
  1169. img = transforms.ToPILImage()(torch.rand(3, 30, 40))
  1170. out = random_choice_transform(img)
  1171. if proba_passthrough == 1:
  1172. assert out == img
  1173. elif proba_passthrough == 0:
  1174. assert out != img
  1175. # Checking if RandomChoice can be printed as string
  1176. random_choice_transform.__repr__()
  1177. @pytest.mark.skipif(stats is None, reason="scipy.stats not available")
  1178. def test_random_order():
  1179. random_state = random.getstate()
  1180. random.seed(42)
  1181. random_order_transform = transforms.RandomOrder([transforms.Resize(20, antialias=True), transforms.CenterCrop(10)])
  1182. img = transforms.ToPILImage()(torch.rand(3, 25, 25))
  1183. num_samples = 250
  1184. num_normal_order = 0
  1185. resize_crop_out = transforms.CenterCrop(10)(transforms.Resize(20, antialias=True)(img))
  1186. for _ in range(num_samples):
  1187. out = random_order_transform(img)
  1188. if out == resize_crop_out:
  1189. num_normal_order += 1
  1190. p_value = stats.binomtest(num_normal_order, num_samples, p=0.5).pvalue
  1191. random.setstate(random_state)
  1192. assert p_value > 0.0001
  1193. # Checking if RandomOrder can be printed as string
  1194. random_order_transform.__repr__()
  1195. def test_linear_transformation():
  1196. num_samples = 1000
  1197. x = torch.randn(num_samples, 3, 10, 10)
  1198. flat_x = x.view(x.size(0), x.size(1) * x.size(2) * x.size(3))
  1199. # compute principal components
  1200. sigma = torch.mm(flat_x.t(), flat_x) / flat_x.size(0)
  1201. u, s, _ = np.linalg.svd(sigma.numpy())
  1202. zca_epsilon = 1e-10 # avoid division by 0
  1203. d = torch.Tensor(np.diag(1.0 / np.sqrt(s + zca_epsilon)))
  1204. u = torch.Tensor(u)
  1205. principal_components = torch.mm(torch.mm(u, d), u.t())
  1206. mean_vector = torch.sum(flat_x, dim=0) / flat_x.size(0)
  1207. # initialize whitening matrix
  1208. whitening = transforms.LinearTransformation(principal_components, mean_vector)
  1209. # estimate covariance and mean using weak law of large number
  1210. num_features = flat_x.size(1)
  1211. cov = 0.0
  1212. mean = 0.0
  1213. for i in x:
  1214. xwhite = whitening(i)
  1215. xwhite = xwhite.view(1, -1).numpy()
  1216. cov += np.dot(xwhite, xwhite.T) / num_features
  1217. mean += np.sum(xwhite) / num_features
  1218. # if rtol for std = 1e-3 then rtol for cov = 2e-3 as std**2 = cov
  1219. torch.testing.assert_close(
  1220. cov / num_samples, np.identity(1), rtol=2e-3, atol=1e-8, check_dtype=False, msg="cov not close to 1"
  1221. )
  1222. torch.testing.assert_close(
  1223. mean / num_samples, 0, rtol=1e-3, atol=1e-8, check_dtype=False, msg="mean not close to 0"
  1224. )
  1225. # Checking if LinearTransformation can be printed as string
  1226. whitening.__repr__()
  1227. @pytest.mark.parametrize("dtype", int_dtypes())
  1228. def test_max_value(dtype):
  1229. assert F_t._max_value(dtype) == torch.iinfo(dtype).max
  1230. # remove float testing as it can lead to errors such as
  1231. # runtime error: 5.7896e+76 is outside the range of representable values of type 'float'
  1232. # for dtype in float_dtypes():
  1233. # self.assertGreater(F_t._max_value(dtype), torch.finfo(dtype).max)
  1234. @pytest.mark.xfail(
  1235. reason="torch.iinfo() is not supported by torchscript. See https://github.com/pytorch/pytorch/issues/41492."
  1236. )
  1237. def test_max_value_iinfo():
  1238. @torch.jit.script
  1239. def max_value(image: torch.Tensor) -> int:
  1240. return 1 if image.is_floating_point() else torch.iinfo(image.dtype).max
  1241. @pytest.mark.parametrize("should_vflip", [True, False])
  1242. @pytest.mark.parametrize("single_dim", [True, False])
  1243. def test_ten_crop(should_vflip, single_dim):
  1244. to_pil_image = transforms.ToPILImage()
  1245. h = random.randint(5, 25)
  1246. w = random.randint(5, 25)
  1247. crop_h = random.randint(1, h)
  1248. crop_w = random.randint(1, w)
  1249. if single_dim:
  1250. crop_h = min(crop_h, crop_w)
  1251. crop_w = crop_h
  1252. transform = transforms.TenCrop(crop_h, vertical_flip=should_vflip)
  1253. five_crop = transforms.FiveCrop(crop_h)
  1254. else:
  1255. transform = transforms.TenCrop((crop_h, crop_w), vertical_flip=should_vflip)
  1256. five_crop = transforms.FiveCrop((crop_h, crop_w))
  1257. img = to_pil_image(torch.FloatTensor(3, h, w).uniform_())
  1258. results = transform(img)
  1259. expected_output = five_crop(img)
  1260. # Checking if FiveCrop and TenCrop can be printed as string
  1261. transform.__repr__()
  1262. five_crop.__repr__()
  1263. if should_vflip:
  1264. vflipped_img = img.transpose(Image.FLIP_TOP_BOTTOM)
  1265. expected_output += five_crop(vflipped_img)
  1266. else:
  1267. hflipped_img = img.transpose(Image.FLIP_LEFT_RIGHT)
  1268. expected_output += five_crop(hflipped_img)
  1269. assert len(results) == 10
  1270. assert results == expected_output
  1271. @pytest.mark.parametrize("single_dim", [True, False])
  1272. def test_five_crop(single_dim):
  1273. to_pil_image = transforms.ToPILImage()
  1274. h = random.randint(5, 25)
  1275. w = random.randint(5, 25)
  1276. crop_h = random.randint(1, h)
  1277. crop_w = random.randint(1, w)
  1278. if single_dim:
  1279. crop_h = min(crop_h, crop_w)
  1280. crop_w = crop_h
  1281. transform = transforms.FiveCrop(crop_h)
  1282. else:
  1283. transform = transforms.FiveCrop((crop_h, crop_w))
  1284. img = torch.FloatTensor(3, h, w).uniform_()
  1285. results = transform(to_pil_image(img))
  1286. assert len(results) == 5
  1287. for crop in results:
  1288. assert crop.size == (crop_w, crop_h)
  1289. to_pil_image = transforms.ToPILImage()
  1290. tl = to_pil_image(img[:, 0:crop_h, 0:crop_w])
  1291. tr = to_pil_image(img[:, 0:crop_h, w - crop_w :])
  1292. bl = to_pil_image(img[:, h - crop_h :, 0:crop_w])
  1293. br = to_pil_image(img[:, h - crop_h :, w - crop_w :])
  1294. center = transforms.CenterCrop((crop_h, crop_w))(to_pil_image(img))
  1295. expected_output = (tl, tr, bl, br, center)
  1296. assert results == expected_output
  1297. @pytest.mark.parametrize("policy", transforms.AutoAugmentPolicy)
  1298. @pytest.mark.parametrize("fill", [None, 85, (128, 128, 128)])
  1299. @pytest.mark.parametrize("grayscale", [True, False])
  1300. def test_autoaugment(policy, fill, grayscale):
  1301. random.seed(42)
  1302. img = Image.open(GRACE_HOPPER)
  1303. if grayscale:
  1304. img, fill = _get_grayscale_test_image(img, fill)
  1305. transform = transforms.AutoAugment(policy=policy, fill=fill)
  1306. for _ in range(100):
  1307. img = transform(img)
  1308. transform.__repr__()
  1309. @pytest.mark.parametrize("num_ops", [1, 2, 3])
  1310. @pytest.mark.parametrize("magnitude", [7, 9, 11])
  1311. @pytest.mark.parametrize("fill", [None, 85, (128, 128, 128)])
  1312. @pytest.mark.parametrize("grayscale", [True, False])
  1313. def test_randaugment(num_ops, magnitude, fill, grayscale):
  1314. random.seed(42)
  1315. img = Image.open(GRACE_HOPPER)
  1316. if grayscale:
  1317. img, fill = _get_grayscale_test_image(img, fill)
  1318. transform = transforms.RandAugment(num_ops=num_ops, magnitude=magnitude, fill=fill)
  1319. for _ in range(100):
  1320. img = transform(img)
  1321. transform.__repr__()
  1322. @pytest.mark.parametrize("fill", [None, 85, (128, 128, 128)])
  1323. @pytest.mark.parametrize("num_magnitude_bins", [10, 13, 30])
  1324. @pytest.mark.parametrize("grayscale", [True, False])
  1325. def test_trivialaugmentwide(fill, num_magnitude_bins, grayscale):
  1326. random.seed(42)
  1327. img = Image.open(GRACE_HOPPER)
  1328. if grayscale:
  1329. img, fill = _get_grayscale_test_image(img, fill)
  1330. transform = transforms.TrivialAugmentWide(fill=fill, num_magnitude_bins=num_magnitude_bins)
  1331. for _ in range(100):
  1332. img = transform(img)
  1333. transform.__repr__()
  1334. @pytest.mark.parametrize("fill", [None, 85, (128, 128, 128)])
  1335. @pytest.mark.parametrize("severity", [1, 10])
  1336. @pytest.mark.parametrize("mixture_width", [1, 2])
  1337. @pytest.mark.parametrize("chain_depth", [-1, 2])
  1338. @pytest.mark.parametrize("all_ops", [True, False])
  1339. @pytest.mark.parametrize("grayscale", [True, False])
  1340. def test_augmix(fill, severity, mixture_width, chain_depth, all_ops, grayscale):
  1341. random.seed(42)
  1342. img = Image.open(GRACE_HOPPER)
  1343. if grayscale:
  1344. img, fill = _get_grayscale_test_image(img, fill)
  1345. transform = transforms.AugMix(
  1346. fill=fill, severity=severity, mixture_width=mixture_width, chain_depth=chain_depth, all_ops=all_ops
  1347. )
  1348. for _ in range(100):
  1349. img = transform(img)
  1350. transform.__repr__()
  1351. def test_random_crop():
  1352. height = random.randint(10, 32) * 2
  1353. width = random.randint(10, 32) * 2
  1354. oheight = random.randint(5, (height - 2) / 2) * 2
  1355. owidth = random.randint(5, (width - 2) / 2) * 2
  1356. img = torch.ones(3, height, width, dtype=torch.uint8)
  1357. result = transforms.Compose(
  1358. [
  1359. transforms.ToPILImage(),
  1360. transforms.RandomCrop((oheight, owidth)),
  1361. transforms.PILToTensor(),
  1362. ]
  1363. )(img)
  1364. assert result.size(1) == oheight
  1365. assert result.size(2) == owidth
  1366. padding = random.randint(1, 20)
  1367. result = transforms.Compose(
  1368. [
  1369. transforms.ToPILImage(),
  1370. transforms.RandomCrop((oheight, owidth), padding=padding),
  1371. transforms.PILToTensor(),
  1372. ]
  1373. )(img)
  1374. assert result.size(1) == oheight
  1375. assert result.size(2) == owidth
  1376. result = transforms.Compose(
  1377. [transforms.ToPILImage(), transforms.RandomCrop((height, width)), transforms.PILToTensor()]
  1378. )(img)
  1379. assert result.size(1) == height
  1380. assert result.size(2) == width
  1381. torch.testing.assert_close(result, img)
  1382. result = transforms.Compose(
  1383. [
  1384. transforms.ToPILImage(),
  1385. transforms.RandomCrop((height + 1, width + 1), pad_if_needed=True),
  1386. transforms.PILToTensor(),
  1387. ]
  1388. )(img)
  1389. assert result.size(1) == height + 1
  1390. assert result.size(2) == width + 1
  1391. t = transforms.RandomCrop(33)
  1392. img = torch.ones(3, 32, 32)
  1393. with pytest.raises(ValueError, match=r"Required crop size .+ is larger than input image size .+"):
  1394. t(img)
  1395. def test_center_crop():
  1396. height = random.randint(10, 32) * 2
  1397. width = random.randint(10, 32) * 2
  1398. oheight = random.randint(5, (height - 2) / 2) * 2
  1399. owidth = random.randint(5, (width - 2) / 2) * 2
  1400. img = torch.ones(3, height, width, dtype=torch.uint8)
  1401. oh1 = (height - oheight) // 2
  1402. ow1 = (width - owidth) // 2
  1403. imgnarrow = img[:, oh1 : oh1 + oheight, ow1 : ow1 + owidth]
  1404. imgnarrow.fill_(0)
  1405. result = transforms.Compose(
  1406. [
  1407. transforms.ToPILImage(),
  1408. transforms.CenterCrop((oheight, owidth)),
  1409. transforms.PILToTensor(),
  1410. ]
  1411. )(img)
  1412. assert result.sum() == 0
  1413. oheight += 1
  1414. owidth += 1
  1415. result = transforms.Compose(
  1416. [
  1417. transforms.ToPILImage(),
  1418. transforms.CenterCrop((oheight, owidth)),
  1419. transforms.PILToTensor(),
  1420. ]
  1421. )(img)
  1422. sum1 = result.sum()
  1423. assert sum1 > 1
  1424. oheight += 1
  1425. owidth += 1
  1426. result = transforms.Compose(
  1427. [
  1428. transforms.ToPILImage(),
  1429. transforms.CenterCrop((oheight, owidth)),
  1430. transforms.PILToTensor(),
  1431. ]
  1432. )(img)
  1433. sum2 = result.sum()
  1434. assert sum2 > 0
  1435. assert sum2 > sum1
  1436. @pytest.mark.parametrize("odd_image_size", (True, False))
  1437. @pytest.mark.parametrize("delta", (1, 3, 5))
  1438. @pytest.mark.parametrize("delta_width", (-2, -1, 0, 1, 2))
  1439. @pytest.mark.parametrize("delta_height", (-2, -1, 0, 1, 2))
  1440. def test_center_crop_2(odd_image_size, delta, delta_width, delta_height):
  1441. """Tests when center crop size is larger than image size, along any dimension"""
  1442. # Since height is independent of width, we can ignore images with odd height and even width and vice-versa.
  1443. input_image_size = (random.randint(10, 32) * 2, random.randint(10, 32) * 2)
  1444. if odd_image_size:
  1445. input_image_size = (input_image_size[0] + 1, input_image_size[1] + 1)
  1446. delta_height *= delta
  1447. delta_width *= delta
  1448. img = torch.ones(3, *input_image_size, dtype=torch.uint8)
  1449. crop_size = (input_image_size[0] + delta_height, input_image_size[1] + delta_width)
  1450. # Test both transforms, one with PIL input and one with tensor
  1451. output_pil = transforms.Compose(
  1452. [transforms.ToPILImage(), transforms.CenterCrop(crop_size), transforms.PILToTensor()],
  1453. )(img)
  1454. assert output_pil.size()[1:3] == crop_size
  1455. output_tensor = transforms.CenterCrop(crop_size)(img)
  1456. assert output_tensor.size()[1:3] == crop_size
  1457. # Ensure output for PIL and Tensor are equal
  1458. assert_equal(
  1459. output_tensor,
  1460. output_pil,
  1461. msg=f"image_size: {input_image_size} crop_size: {crop_size}",
  1462. )
  1463. # Check if content in center of both image and cropped output is same.
  1464. center_size = (min(crop_size[0], input_image_size[0]), min(crop_size[1], input_image_size[1]))
  1465. crop_center_tl, input_center_tl = [0, 0], [0, 0]
  1466. for index in range(2):
  1467. if crop_size[index] > input_image_size[index]:
  1468. crop_center_tl[index] = (crop_size[index] - input_image_size[index]) // 2
  1469. else:
  1470. input_center_tl[index] = (input_image_size[index] - crop_size[index]) // 2
  1471. output_center = output_pil[
  1472. :,
  1473. crop_center_tl[0] : crop_center_tl[0] + center_size[0],
  1474. crop_center_tl[1] : crop_center_tl[1] + center_size[1],
  1475. ]
  1476. img_center = img[
  1477. :,
  1478. input_center_tl[0] : input_center_tl[0] + center_size[0],
  1479. input_center_tl[1] : input_center_tl[1] + center_size[1],
  1480. ]
  1481. assert_equal(output_center, img_center)
  1482. def test_color_jitter():
  1483. color_jitter = transforms.ColorJitter(2, 2, 2, 0.1)
  1484. x_shape = [2, 2, 3]
  1485. x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
  1486. x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape)
  1487. x_pil = Image.fromarray(x_np, mode="RGB")
  1488. x_pil_2 = x_pil.convert("L")
  1489. for _ in range(10):
  1490. y_pil = color_jitter(x_pil)
  1491. assert y_pil.mode == x_pil.mode
  1492. y_pil_2 = color_jitter(x_pil_2)
  1493. assert y_pil_2.mode == x_pil_2.mode
  1494. # Checking if ColorJitter can be printed as string
  1495. color_jitter.__repr__()
  1496. @pytest.mark.parametrize("hue", [1, (-1, 1)])
  1497. def test_color_jitter_hue_out_of_bounds(hue):
  1498. with pytest.raises(ValueError, match=re.escape("hue values should be between (-0.5, 0.5)")):
  1499. transforms.ColorJitter(hue=hue)
  1500. @pytest.mark.parametrize("seed", range(10))
  1501. @pytest.mark.skipif(stats is None, reason="scipy.stats not available")
  1502. def test_random_erasing(seed):
  1503. torch.random.manual_seed(seed)
  1504. img = torch.ones(3, 128, 128)
  1505. t = transforms.RandomErasing(scale=(0.1, 0.1), ratio=(1 / 3, 3.0))
  1506. y, x, h, w, v = t.get_params(
  1507. img,
  1508. t.scale,
  1509. t.ratio,
  1510. [
  1511. t.value,
  1512. ],
  1513. )
  1514. aspect_ratio = h / w
  1515. # Add some tolerance due to the rounding and int conversion used in the transform
  1516. tol = 0.05
  1517. assert 1 / 3 - tol <= aspect_ratio <= 3 + tol
  1518. # Make sure that h > w and h < w are equally likely (log-scale sampling)
  1519. aspect_ratios = []
  1520. random.seed(42)
  1521. trial = 1000
  1522. for _ in range(trial):
  1523. y, x, h, w, v = t.get_params(
  1524. img,
  1525. t.scale,
  1526. t.ratio,
  1527. [
  1528. t.value,
  1529. ],
  1530. )
  1531. aspect_ratios.append(h / w)
  1532. count_bigger_then_ones = len([1 for aspect_ratio in aspect_ratios if aspect_ratio > 1])
  1533. p_value = stats.binomtest(count_bigger_then_ones, trial, p=0.5).pvalue
  1534. assert p_value > 0.0001
  1535. # Checking if RandomErasing can be printed as string
  1536. t.__repr__()
  1537. def test_random_rotation():
  1538. with pytest.raises(ValueError):
  1539. transforms.RandomRotation(-0.7)
  1540. with pytest.raises(ValueError):
  1541. transforms.RandomRotation([-0.7])
  1542. with pytest.raises(ValueError):
  1543. transforms.RandomRotation([-0.7, 0, 0.7])
  1544. t = transforms.RandomRotation(0, fill=None)
  1545. assert t.fill == 0
  1546. t = transforms.RandomRotation(10)
  1547. angle = t.get_params(t.degrees)
  1548. assert angle > -10 and angle < 10
  1549. t = transforms.RandomRotation((-10, 10))
  1550. angle = t.get_params(t.degrees)
  1551. assert -10 < angle < 10
  1552. # Checking if RandomRotation can be printed as string
  1553. t.__repr__()
  1554. t = transforms.RandomRotation((-10, 10), interpolation=Image.BILINEAR)
  1555. assert t.interpolation == transforms.InterpolationMode.BILINEAR
  1556. def test_random_rotation_error():
  1557. # assert fill being either a Sequence or a Number
  1558. with pytest.raises(TypeError):
  1559. transforms.RandomRotation(0, fill={})
  1560. def test_randomperspective():
  1561. for _ in range(10):
  1562. height = random.randint(24, 32) * 2
  1563. width = random.randint(24, 32) * 2
  1564. img = torch.ones(3, height, width)
  1565. to_pil_image = transforms.ToPILImage()
  1566. img = to_pil_image(img)
  1567. perp = transforms.RandomPerspective()
  1568. startpoints, endpoints = perp.get_params(width, height, 0.5)
  1569. tr_img = F.perspective(img, startpoints, endpoints)
  1570. tr_img2 = F.convert_image_dtype(F.pil_to_tensor(F.perspective(tr_img, endpoints, startpoints)))
  1571. tr_img = F.convert_image_dtype(F.pil_to_tensor(tr_img))
  1572. assert img.size[0] == width
  1573. assert img.size[1] == height
  1574. assert torch.nn.functional.mse_loss(
  1575. tr_img, F.convert_image_dtype(F.pil_to_tensor(img))
  1576. ) + 0.3 > torch.nn.functional.mse_loss(tr_img2, F.convert_image_dtype(F.pil_to_tensor(img)))
  1577. @pytest.mark.parametrize("seed", range(10))
  1578. @pytest.mark.parametrize("mode", ["L", "RGB", "F"])
  1579. def test_randomperspective_fill(mode, seed):
  1580. torch.random.manual_seed(seed)
  1581. # assert fill being either a Sequence or a Number
  1582. with pytest.raises(TypeError):
  1583. transforms.RandomPerspective(fill={})
  1584. t = transforms.RandomPerspective(fill=None)
  1585. assert t.fill == 0
  1586. height = 100
  1587. width = 100
  1588. img = torch.ones(3, height, width)
  1589. to_pil_image = transforms.ToPILImage()
  1590. img = to_pil_image(img)
  1591. fill = 127
  1592. num_bands = len(mode)
  1593. img_conv = img.convert(mode)
  1594. perspective = transforms.RandomPerspective(p=1, fill=fill)
  1595. tr_img = perspective(img_conv)
  1596. pixel = tr_img.getpixel((0, 0))
  1597. if not isinstance(pixel, tuple):
  1598. pixel = (pixel,)
  1599. assert pixel == tuple([fill] * num_bands)
  1600. startpoints, endpoints = transforms.RandomPerspective.get_params(width, height, 0.5)
  1601. tr_img = F.perspective(img_conv, startpoints, endpoints, fill=fill)
  1602. pixel = tr_img.getpixel((0, 0))
  1603. if not isinstance(pixel, tuple):
  1604. pixel = (pixel,)
  1605. assert pixel == tuple([fill] * num_bands)
  1606. wrong_num_bands = num_bands + 1
  1607. with pytest.raises(ValueError):
  1608. F.perspective(img_conv, startpoints, endpoints, fill=tuple([fill] * wrong_num_bands))
  1609. @pytest.mark.skipif(stats is None, reason="scipy.stats not available")
  1610. def test_normalize():
  1611. def samples_from_standard_normal(tensor):
  1612. p_value = stats.kstest(list(tensor.view(-1)), "norm", args=(0, 1)).pvalue
  1613. return p_value > 0.0001
  1614. random_state = random.getstate()
  1615. random.seed(42)
  1616. for channels in [1, 3]:
  1617. img = torch.rand(channels, 10, 10)
  1618. mean = [img[c].mean() for c in range(channels)]
  1619. std = [img[c].std() for c in range(channels)]
  1620. normalized = transforms.Normalize(mean, std)(img)
  1621. assert samples_from_standard_normal(normalized)
  1622. random.setstate(random_state)
  1623. # Checking if Normalize can be printed as string
  1624. transforms.Normalize(mean, std).__repr__()
  1625. # Checking the optional in-place behaviour
  1626. tensor = torch.rand((1, 16, 16))
  1627. tensor_inplace = transforms.Normalize((0.5,), (0.5,), inplace=True)(tensor)
  1628. assert_equal(tensor, tensor_inplace)
  1629. @pytest.mark.parametrize("dtype1", [torch.float32, torch.float64])
  1630. @pytest.mark.parametrize("dtype2", [torch.int64, torch.float32, torch.float64])
  1631. def test_normalize_different_dtype(dtype1, dtype2):
  1632. img = torch.rand(3, 10, 10, dtype=dtype1)
  1633. mean = torch.tensor([1, 2, 3], dtype=dtype2)
  1634. std = torch.tensor([1, 2, 1], dtype=dtype2)
  1635. # checks that it doesn't crash
  1636. transforms.functional.normalize(img, mean, std)
  1637. def test_normalize_3d_tensor():
  1638. torch.manual_seed(28)
  1639. n_channels = 3
  1640. img_size = 10
  1641. mean = torch.rand(n_channels)
  1642. std = torch.rand(n_channels)
  1643. img = torch.rand(n_channels, img_size, img_size)
  1644. target = F.normalize(img, mean, std)
  1645. mean_unsqueezed = mean.view(-1, 1, 1)
  1646. std_unsqueezed = std.view(-1, 1, 1)
  1647. result1 = F.normalize(img, mean_unsqueezed, std_unsqueezed)
  1648. result2 = F.normalize(
  1649. img, mean_unsqueezed.repeat(1, img_size, img_size), std_unsqueezed.repeat(1, img_size, img_size)
  1650. )
  1651. torch.testing.assert_close(target, result1)
  1652. torch.testing.assert_close(target, result2)
  1653. class TestAffine:
  1654. @pytest.fixture(scope="class")
  1655. def input_img(self):
  1656. input_img = np.zeros((40, 40, 3), dtype=np.uint8)
  1657. for pt in [(16, 16), (20, 16), (20, 20)]:
  1658. for i in range(-5, 5):
  1659. for j in range(-5, 5):
  1660. input_img[pt[0] + i, pt[1] + j, :] = [255, 155, 55]
  1661. return input_img
  1662. def test_affine_translate_seq(self, input_img):
  1663. with pytest.raises(TypeError, match=r"Argument translate should be a sequence"):
  1664. F.affine(input_img, 10, translate=0, scale=1, shear=1)
  1665. @pytest.fixture(scope="class")
  1666. def pil_image(self, input_img):
  1667. return F.to_pil_image(input_img)
  1668. def _to_3x3_inv(self, inv_result_matrix):
  1669. result_matrix = np.zeros((3, 3))
  1670. result_matrix[:2, :] = np.array(inv_result_matrix).reshape((2, 3))
  1671. result_matrix[2, 2] = 1
  1672. return np.linalg.inv(result_matrix)
  1673. def _test_transformation(self, angle, translate, scale, shear, pil_image, input_img, center=None):
  1674. a_rad = math.radians(angle)
  1675. s_rad = [math.radians(sh_) for sh_ in shear]
  1676. cnt = [20, 20] if center is None else center
  1677. cx, cy = cnt
  1678. tx, ty = translate
  1679. sx, sy = s_rad
  1680. rot = a_rad
  1681. # 1) Check transformation matrix:
  1682. C = np.array([[1, 0, cx], [0, 1, cy], [0, 0, 1]])
  1683. T = np.array([[1, 0, tx], [0, 1, ty], [0, 0, 1]])
  1684. Cinv = np.linalg.inv(C)
  1685. RS = np.array(
  1686. [
  1687. [scale * math.cos(rot), -scale * math.sin(rot), 0],
  1688. [scale * math.sin(rot), scale * math.cos(rot), 0],
  1689. [0, 0, 1],
  1690. ]
  1691. )
  1692. SHx = np.array([[1, -math.tan(sx), 0], [0, 1, 0], [0, 0, 1]])
  1693. SHy = np.array([[1, 0, 0], [-math.tan(sy), 1, 0], [0, 0, 1]])
  1694. RSS = np.matmul(RS, np.matmul(SHy, SHx))
  1695. true_matrix = np.matmul(T, np.matmul(C, np.matmul(RSS, Cinv)))
  1696. result_matrix = self._to_3x3_inv(
  1697. F._get_inverse_affine_matrix(center=cnt, angle=angle, translate=translate, scale=scale, shear=shear)
  1698. )
  1699. assert np.sum(np.abs(true_matrix - result_matrix)) < 1e-10
  1700. # 2) Perform inverse mapping:
  1701. true_result = np.zeros((40, 40, 3), dtype=np.uint8)
  1702. inv_true_matrix = np.linalg.inv(true_matrix)
  1703. for y in range(true_result.shape[0]):
  1704. for x in range(true_result.shape[1]):
  1705. # Same as for PIL:
  1706. # https://github.com/python-pillow/Pillow/blob/71f8ec6a0cfc1008076a023c0756542539d057ab/
  1707. # src/libImaging/Geometry.c#L1060
  1708. input_pt = np.array([x + 0.5, y + 0.5, 1.0])
  1709. res = np.floor(np.dot(inv_true_matrix, input_pt)).astype(int)
  1710. _x, _y = res[:2]
  1711. if 0 <= _x < input_img.shape[1] and 0 <= _y < input_img.shape[0]:
  1712. true_result[y, x, :] = input_img[_y, _x, :]
  1713. result = F.affine(pil_image, angle=angle, translate=translate, scale=scale, shear=shear, center=center)
  1714. assert result.size == pil_image.size
  1715. # Compute number of different pixels:
  1716. np_result = np.array(result)
  1717. n_diff_pixels = np.sum(np_result != true_result) / 3
  1718. # Accept 3 wrong pixels
  1719. error_msg = (
  1720. f"angle={angle}, translate={translate}, scale={scale}, shear={shear}\nn diff pixels={n_diff_pixels}\n"
  1721. )
  1722. assert n_diff_pixels < 3, error_msg
  1723. def test_transformation_discrete(self, pil_image, input_img):
  1724. # Test rotation
  1725. angle = 45
  1726. self._test_transformation(
  1727. angle=angle, translate=(0, 0), scale=1.0, shear=(0.0, 0.0), pil_image=pil_image, input_img=input_img
  1728. )
  1729. # Test rotation
  1730. angle = 45
  1731. self._test_transformation(
  1732. angle=angle,
  1733. translate=(0, 0),
  1734. scale=1.0,
  1735. shear=(0.0, 0.0),
  1736. pil_image=pil_image,
  1737. input_img=input_img,
  1738. center=[0, 0],
  1739. )
  1740. # Test translation
  1741. translate = [10, 15]
  1742. self._test_transformation(
  1743. angle=0.0, translate=translate, scale=1.0, shear=(0.0, 0.0), pil_image=pil_image, input_img=input_img
  1744. )
  1745. # Test scale
  1746. scale = 1.2
  1747. self._test_transformation(
  1748. angle=0.0, translate=(0.0, 0.0), scale=scale, shear=(0.0, 0.0), pil_image=pil_image, input_img=input_img
  1749. )
  1750. # Test shear
  1751. shear = [45.0, 25.0]
  1752. self._test_transformation(
  1753. angle=0.0, translate=(0.0, 0.0), scale=1.0, shear=shear, pil_image=pil_image, input_img=input_img
  1754. )
  1755. # Test shear with top-left as center
  1756. shear = [45.0, 25.0]
  1757. self._test_transformation(
  1758. angle=0.0,
  1759. translate=(0.0, 0.0),
  1760. scale=1.0,
  1761. shear=shear,
  1762. pil_image=pil_image,
  1763. input_img=input_img,
  1764. center=[0, 0],
  1765. )
  1766. @pytest.mark.parametrize("angle", range(-90, 90, 36))
  1767. @pytest.mark.parametrize("translate", range(-10, 10, 5))
  1768. @pytest.mark.parametrize("scale", [0.77, 1.0, 1.27])
  1769. @pytest.mark.parametrize("shear", range(-15, 15, 5))
  1770. def test_transformation_range(self, angle, translate, scale, shear, pil_image, input_img):
  1771. self._test_transformation(
  1772. angle=angle,
  1773. translate=(translate, translate),
  1774. scale=scale,
  1775. shear=(shear, shear),
  1776. pil_image=pil_image,
  1777. input_img=input_img,
  1778. )
  1779. def test_random_affine():
  1780. with pytest.raises(ValueError):
  1781. transforms.RandomAffine(-0.7)
  1782. with pytest.raises(ValueError):
  1783. transforms.RandomAffine([-0.7])
  1784. with pytest.raises(ValueError):
  1785. transforms.RandomAffine([-0.7, 0, 0.7])
  1786. with pytest.raises(TypeError):
  1787. transforms.RandomAffine([-90, 90], translate=2.0)
  1788. with pytest.raises(ValueError):
  1789. transforms.RandomAffine([-90, 90], translate=[-1.0, 1.0])
  1790. with pytest.raises(ValueError):
  1791. transforms.RandomAffine([-90, 90], translate=[-1.0, 0.0, 1.0])
  1792. with pytest.raises(ValueError):
  1793. transforms.RandomAffine([-90, 90], translate=[0.2, 0.2], scale=[0.0])
  1794. with pytest.raises(ValueError):
  1795. transforms.RandomAffine([-90, 90], translate=[0.2, 0.2], scale=[-1.0, 1.0])
  1796. with pytest.raises(ValueError):
  1797. transforms.RandomAffine([-90, 90], translate=[0.2, 0.2], scale=[0.5, -0.5])
  1798. with pytest.raises(ValueError):
  1799. transforms.RandomAffine([-90, 90], translate=[0.2, 0.2], scale=[0.5, 3.0, -0.5])
  1800. with pytest.raises(ValueError):
  1801. transforms.RandomAffine([-90, 90], translate=[0.2, 0.2], scale=[0.5, 0.5], shear=-7)
  1802. with pytest.raises(ValueError):
  1803. transforms.RandomAffine([-90, 90], translate=[0.2, 0.2], scale=[0.5, 0.5], shear=[-10])
  1804. with pytest.raises(ValueError):
  1805. transforms.RandomAffine([-90, 90], translate=[0.2, 0.2], scale=[0.5, 0.5], shear=[-10, 0, 10])
  1806. with pytest.raises(ValueError):
  1807. transforms.RandomAffine([-90, 90], translate=[0.2, 0.2], scale=[0.5, 0.5], shear=[-10, 0, 10, 0, 10])
  1808. # assert fill being either a Sequence or a Number
  1809. with pytest.raises(TypeError):
  1810. transforms.RandomAffine(0, fill={})
  1811. t = transforms.RandomAffine(0, fill=None)
  1812. assert t.fill == 0
  1813. x = np.zeros((100, 100, 3), dtype=np.uint8)
  1814. img = F.to_pil_image(x)
  1815. t = transforms.RandomAffine(10, translate=[0.5, 0.3], scale=[0.7, 1.3], shear=[-10, 10, 20, 40])
  1816. for _ in range(100):
  1817. angle, translations, scale, shear = t.get_params(t.degrees, t.translate, t.scale, t.shear, img_size=img.size)
  1818. assert -10 < angle < 10
  1819. assert -img.size[0] * 0.5 <= translations[0] <= img.size[0] * 0.5
  1820. assert -img.size[1] * 0.5 <= translations[1] <= img.size[1] * 0.5
  1821. assert 0.7 < scale < 1.3
  1822. assert -10 < shear[0] < 10
  1823. assert -20 < shear[1] < 40
  1824. # Checking if RandomAffine can be printed as string
  1825. t.__repr__()
  1826. t = transforms.RandomAffine(10, interpolation=transforms.InterpolationMode.BILINEAR)
  1827. assert "bilinear" in t.__repr__()
  1828. t = transforms.RandomAffine(10, interpolation=Image.BILINEAR)
  1829. assert t.interpolation == transforms.InterpolationMode.BILINEAR
  1830. def test_elastic_transformation():
  1831. with pytest.raises(TypeError, match=r"alpha should be float or a sequence of floats"):
  1832. transforms.ElasticTransform(alpha=True, sigma=2.0)
  1833. with pytest.raises(TypeError, match=r"alpha should be a sequence of floats"):
  1834. transforms.ElasticTransform(alpha=[1.0, True], sigma=2.0)
  1835. with pytest.raises(ValueError, match=r"alpha is a sequence its length should be 2"):
  1836. transforms.ElasticTransform(alpha=[1.0, 0.0, 1.0], sigma=2.0)
  1837. with pytest.raises(TypeError, match=r"sigma should be float or a sequence of floats"):
  1838. transforms.ElasticTransform(alpha=2.0, sigma=True)
  1839. with pytest.raises(TypeError, match=r"sigma should be a sequence of floats"):
  1840. transforms.ElasticTransform(alpha=2.0, sigma=[1.0, True])
  1841. with pytest.raises(ValueError, match=r"sigma is a sequence its length should be 2"):
  1842. transforms.ElasticTransform(alpha=2.0, sigma=[1.0, 0.0, 1.0])
  1843. t = transforms.transforms.ElasticTransform(alpha=2.0, sigma=2.0, interpolation=Image.BILINEAR)
  1844. assert t.interpolation == transforms.InterpolationMode.BILINEAR
  1845. with pytest.raises(TypeError, match=r"fill should be int or float"):
  1846. transforms.ElasticTransform(alpha=1.0, sigma=1.0, fill={})
  1847. x = torch.randint(0, 256, (3, 32, 32), dtype=torch.uint8)
  1848. img = F.to_pil_image(x)
  1849. t = transforms.ElasticTransform(alpha=0.0, sigma=0.0)
  1850. transformed_img = t(img)
  1851. assert transformed_img == img
  1852. # Smoke test on PIL images
  1853. t = transforms.ElasticTransform(alpha=0.5, sigma=0.23)
  1854. transformed_img = t(img)
  1855. assert isinstance(transformed_img, Image.Image)
  1856. # Checking if ElasticTransform can be printed as string
  1857. t.__repr__()
  1858. def test_random_grayscale_with_grayscale_input():
  1859. transform = transforms.RandomGrayscale(p=1.0)
  1860. image_tensor = torch.randint(0, 256, (1, 16, 16), dtype=torch.uint8)
  1861. output_tensor = transform(image_tensor)
  1862. torch.testing.assert_close(output_tensor, image_tensor)
  1863. image_pil = F.to_pil_image(image_tensor)
  1864. output_pil = transform(image_pil)
  1865. torch.testing.assert_close(F.pil_to_tensor(output_pil), image_tensor)
  1866. # TODO: remove in 0.17 when we can delete functional_pil.py and functional_tensor.py
  1867. @pytest.mark.parametrize(
  1868. "import_statement",
  1869. (
  1870. "from torchvision.transforms import functional_pil",
  1871. "from torchvision.transforms import functional_tensor",
  1872. "from torchvision.transforms.functional_tensor import resize",
  1873. "from torchvision.transforms.functional_pil import resize",
  1874. ),
  1875. )
  1876. @pytest.mark.parametrize("from_private", (True, False))
  1877. def test_functional_deprecation_warning(import_statement, from_private):
  1878. if from_private:
  1879. import_statement = import_statement.replace("functional", "_functional")
  1880. source = f"""
  1881. import warnings
  1882. with warnings.catch_warnings():
  1883. warnings.simplefilter("error")
  1884. {import_statement}
  1885. """
  1886. else:
  1887. source = f"""
  1888. import pytest
  1889. with pytest.warns(UserWarning, match="removed in 0.17"):
  1890. {import_statement}
  1891. """
  1892. assert_run_python_script(textwrap.dedent(source))
  1893. if __name__ == "__main__":
  1894. pytest.main([__file__])