test_ops.py 76 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867
  1. import math
  2. import os
  3. from abc import ABC, abstractmethod
  4. from functools import lru_cache
  5. from itertools import product
  6. from typing import Callable, List, Tuple
  7. import numpy as np
  8. import pytest
  9. import torch
  10. import torch.fx
  11. import torch.nn.functional as F
  12. from common_utils import assert_equal, cpu_and_cuda, cpu_and_cuda_and_mps, needs_cuda, needs_mps
  13. from PIL import Image
  14. from torch import nn, Tensor
  15. from torch.autograd import gradcheck
  16. from torch.nn.modules.utils import _pair
  17. from torchvision import models, ops
  18. from torchvision.models.feature_extraction import get_graph_node_names
  19. # Context manager for setting deterministic flag and automatically
  20. # resetting it to its original value
  21. class DeterministicGuard:
  22. def __init__(self, deterministic, *, warn_only=False):
  23. self.deterministic = deterministic
  24. self.warn_only = warn_only
  25. def __enter__(self):
  26. self.deterministic_restore = torch.are_deterministic_algorithms_enabled()
  27. self.warn_only_restore = torch.is_deterministic_algorithms_warn_only_enabled()
  28. torch.use_deterministic_algorithms(self.deterministic, warn_only=self.warn_only)
  29. def __exit__(self, exception_type, exception_value, traceback):
  30. torch.use_deterministic_algorithms(self.deterministic_restore, warn_only=self.warn_only_restore)
  31. class RoIOpTesterModuleWrapper(nn.Module):
  32. def __init__(self, obj):
  33. super().__init__()
  34. self.layer = obj
  35. self.n_inputs = 2
  36. def forward(self, a, b):
  37. self.layer(a, b)
  38. class MultiScaleRoIAlignModuleWrapper(nn.Module):
  39. def __init__(self, obj):
  40. super().__init__()
  41. self.layer = obj
  42. self.n_inputs = 3
  43. def forward(self, a, b, c):
  44. self.layer(a, b, c)
  45. class DeformConvModuleWrapper(nn.Module):
  46. def __init__(self, obj):
  47. super().__init__()
  48. self.layer = obj
  49. self.n_inputs = 3
  50. def forward(self, a, b, c):
  51. self.layer(a, b, c)
  52. class StochasticDepthWrapper(nn.Module):
  53. def __init__(self, obj):
  54. super().__init__()
  55. self.layer = obj
  56. self.n_inputs = 1
  57. def forward(self, a):
  58. self.layer(a)
  59. class DropBlockWrapper(nn.Module):
  60. def __init__(self, obj):
  61. super().__init__()
  62. self.layer = obj
  63. self.n_inputs = 1
  64. def forward(self, a):
  65. self.layer(a)
  66. class PoolWrapper(nn.Module):
  67. def __init__(self, pool: nn.Module):
  68. super().__init__()
  69. self.pool = pool
  70. def forward(self, imgs: Tensor, boxes: List[Tensor]) -> Tensor:
  71. return self.pool(imgs, boxes)
  72. class RoIOpTester(ABC):
  73. dtype = torch.float64
  74. mps_dtype = torch.float32
  75. mps_backward_atol = 2e-2
  76. @pytest.mark.parametrize("device", cpu_and_cuda_and_mps())
  77. @pytest.mark.parametrize("contiguous", (True, False))
  78. @pytest.mark.parametrize(
  79. "x_dtype",
  80. (
  81. torch.float16,
  82. torch.float32,
  83. torch.float64,
  84. ),
  85. ids=str,
  86. )
  87. def test_forward(self, device, contiguous, x_dtype, rois_dtype=None, deterministic=False, **kwargs):
  88. if device == "mps" and x_dtype is torch.float64:
  89. pytest.skip("MPS does not support float64")
  90. rois_dtype = x_dtype if rois_dtype is None else rois_dtype
  91. tol = 1e-5
  92. if x_dtype is torch.half:
  93. if device == "mps":
  94. tol = 5e-3
  95. else:
  96. tol = 4e-3
  97. pool_size = 5
  98. # n_channels % (pool_size ** 2) == 0 required for PS operations.
  99. n_channels = 2 * (pool_size**2)
  100. x = torch.rand(2, n_channels, 10, 10, dtype=x_dtype, device=device)
  101. if not contiguous:
  102. x = x.permute(0, 1, 3, 2)
  103. rois = torch.tensor(
  104. [[0, 0, 0, 9, 9], [0, 0, 5, 4, 9], [0, 5, 5, 9, 9], [1, 0, 0, 9, 9]], # format is (xyxy)
  105. dtype=rois_dtype,
  106. device=device,
  107. )
  108. pool_h, pool_w = pool_size, pool_size
  109. with DeterministicGuard(deterministic):
  110. y = self.fn(x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwargs)
  111. # the following should be true whether we're running an autocast test or not.
  112. assert y.dtype == x.dtype
  113. gt_y = self.expected_fn(
  114. x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, device=device, dtype=x_dtype, **kwargs
  115. )
  116. torch.testing.assert_close(gt_y.to(y), y, rtol=tol, atol=tol)
  117. @pytest.mark.parametrize("device", cpu_and_cuda())
  118. def test_is_leaf_node(self, device):
  119. op_obj = self.make_obj(wrap=True).to(device=device)
  120. graph_node_names = get_graph_node_names(op_obj)
  121. assert len(graph_node_names) == 2
  122. assert len(graph_node_names[0]) == len(graph_node_names[1])
  123. assert len(graph_node_names[0]) == 1 + op_obj.n_inputs
  124. @pytest.mark.parametrize("device", cpu_and_cuda())
  125. def test_torch_fx_trace(self, device, x_dtype=torch.float, rois_dtype=torch.float):
  126. op_obj = self.make_obj().to(device=device)
  127. graph_module = torch.fx.symbolic_trace(op_obj)
  128. pool_size = 5
  129. n_channels = 2 * (pool_size**2)
  130. x = torch.rand(2, n_channels, 5, 5, dtype=x_dtype, device=device)
  131. rois = torch.tensor(
  132. [[0, 0, 0, 9, 9], [0, 0, 5, 4, 9], [0, 5, 5, 9, 9], [1, 0, 0, 9, 9]], # format is (xyxy)
  133. dtype=rois_dtype,
  134. device=device,
  135. )
  136. output_gt = op_obj(x, rois)
  137. assert output_gt.dtype == x.dtype
  138. output_fx = graph_module(x, rois)
  139. assert output_fx.dtype == x.dtype
  140. tol = 1e-5
  141. torch.testing.assert_close(output_gt, output_fx, rtol=tol, atol=tol)
  142. @pytest.mark.parametrize("seed", range(10))
  143. @pytest.mark.parametrize("device", cpu_and_cuda_and_mps())
  144. @pytest.mark.parametrize("contiguous", (True, False))
  145. def test_backward(self, seed, device, contiguous, deterministic=False):
  146. atol = self.mps_backward_atol if device == "mps" else 1e-05
  147. dtype = self.mps_dtype if device == "mps" else self.dtype
  148. torch.random.manual_seed(seed)
  149. pool_size = 2
  150. x = torch.rand(1, 2 * (pool_size**2), 5, 5, dtype=dtype, device=device, requires_grad=True)
  151. if not contiguous:
  152. x = x.permute(0, 1, 3, 2)
  153. rois = torch.tensor(
  154. [[0, 0, 0, 4, 4], [0, 0, 2, 3, 4], [0, 2, 2, 4, 4]], dtype=dtype, device=device # format is (xyxy)
  155. )
  156. def func(z):
  157. return self.fn(z, rois, pool_size, pool_size, spatial_scale=1, sampling_ratio=1)
  158. script_func = self.get_script_fn(rois, pool_size)
  159. with DeterministicGuard(deterministic):
  160. gradcheck(func, (x,), atol=atol)
  161. gradcheck(script_func, (x,), atol=atol)
  162. @needs_mps
  163. def test_mps_error_inputs(self):
  164. pool_size = 2
  165. x = torch.rand(1, 2 * (pool_size**2), 5, 5, dtype=torch.float16, device="mps", requires_grad=True)
  166. rois = torch.tensor(
  167. [[0, 0, 0, 4, 4], [0, 0, 2, 3, 4], [0, 2, 2, 4, 4]], dtype=torch.float16, device="mps" # format is (xyxy)
  168. )
  169. def func(z):
  170. return self.fn(z, rois, pool_size, pool_size, spatial_scale=1, sampling_ratio=1)
  171. with pytest.raises(
  172. RuntimeError, match="MPS does not support (?:ps_)?roi_(?:align|pool)? backward with float16 inputs."
  173. ):
  174. gradcheck(func, (x,))
  175. @needs_cuda
  176. @pytest.mark.parametrize("x_dtype", (torch.float, torch.half))
  177. @pytest.mark.parametrize("rois_dtype", (torch.float, torch.half))
  178. def test_autocast(self, x_dtype, rois_dtype):
  179. with torch.cuda.amp.autocast():
  180. self.test_forward(torch.device("cuda"), contiguous=False, x_dtype=x_dtype, rois_dtype=rois_dtype)
  181. def _helper_boxes_shape(self, func):
  182. # test boxes as Tensor[N, 5]
  183. with pytest.raises(AssertionError):
  184. a = torch.linspace(1, 8 * 8, 8 * 8).reshape(1, 1, 8, 8)
  185. boxes = torch.tensor([[0, 0, 3, 3]], dtype=a.dtype)
  186. func(a, boxes, output_size=(2, 2))
  187. # test boxes as List[Tensor[N, 4]]
  188. with pytest.raises(AssertionError):
  189. a = torch.linspace(1, 8 * 8, 8 * 8).reshape(1, 1, 8, 8)
  190. boxes = torch.tensor([[0, 0, 3]], dtype=a.dtype)
  191. ops.roi_pool(a, [boxes], output_size=(2, 2))
  192. def _helper_jit_boxes_list(self, model):
  193. x = torch.rand(2, 1, 10, 10)
  194. roi = torch.tensor([[0, 0, 0, 9, 9], [0, 0, 5, 4, 9], [0, 5, 5, 9, 9], [1, 0, 0, 9, 9]], dtype=torch.float).t()
  195. rois = [roi, roi]
  196. scriped = torch.jit.script(model)
  197. y = scriped(x, rois)
  198. assert y.shape == (10, 1, 3, 3)
  199. @abstractmethod
  200. def fn(*args, **kwargs):
  201. pass
  202. @abstractmethod
  203. def make_obj(*args, **kwargs):
  204. pass
  205. @abstractmethod
  206. def get_script_fn(*args, **kwargs):
  207. pass
  208. @abstractmethod
  209. def expected_fn(*args, **kwargs):
  210. pass
  211. class TestRoiPool(RoIOpTester):
  212. def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwargs):
  213. return ops.RoIPool((pool_h, pool_w), spatial_scale)(x, rois)
  214. def make_obj(self, pool_h=5, pool_w=5, spatial_scale=1, wrap=False):
  215. obj = ops.RoIPool((pool_h, pool_w), spatial_scale)
  216. return RoIOpTesterModuleWrapper(obj) if wrap else obj
  217. def get_script_fn(self, rois, pool_size):
  218. scriped = torch.jit.script(ops.roi_pool)
  219. return lambda x: scriped(x, rois, pool_size)
  220. def expected_fn(
  221. self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, device=None, dtype=torch.float64
  222. ):
  223. if device is None:
  224. device = torch.device("cpu")
  225. n_channels = x.size(1)
  226. y = torch.zeros(rois.size(0), n_channels, pool_h, pool_w, dtype=dtype, device=device)
  227. def get_slice(k, block):
  228. return slice(int(np.floor(k * block)), int(np.ceil((k + 1) * block)))
  229. for roi_idx, roi in enumerate(rois):
  230. batch_idx = int(roi[0])
  231. j_begin, i_begin, j_end, i_end = (int(round(x.item() * spatial_scale)) for x in roi[1:])
  232. roi_x = x[batch_idx, :, i_begin : i_end + 1, j_begin : j_end + 1]
  233. roi_h, roi_w = roi_x.shape[-2:]
  234. bin_h = roi_h / pool_h
  235. bin_w = roi_w / pool_w
  236. for i in range(0, pool_h):
  237. for j in range(0, pool_w):
  238. bin_x = roi_x[:, get_slice(i, bin_h), get_slice(j, bin_w)]
  239. if bin_x.numel() > 0:
  240. y[roi_idx, :, i, j] = bin_x.reshape(n_channels, -1).max(dim=1)[0]
  241. return y
  242. def test_boxes_shape(self):
  243. self._helper_boxes_shape(ops.roi_pool)
  244. def test_jit_boxes_list(self):
  245. model = PoolWrapper(ops.RoIPool(output_size=[3, 3], spatial_scale=1.0))
  246. self._helper_jit_boxes_list(model)
  247. class TestPSRoIPool(RoIOpTester):
  248. mps_backward_atol = 5e-2
  249. def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwargs):
  250. return ops.PSRoIPool((pool_h, pool_w), 1)(x, rois)
  251. def make_obj(self, pool_h=5, pool_w=5, spatial_scale=1, wrap=False):
  252. obj = ops.PSRoIPool((pool_h, pool_w), spatial_scale)
  253. return RoIOpTesterModuleWrapper(obj) if wrap else obj
  254. def get_script_fn(self, rois, pool_size):
  255. scriped = torch.jit.script(ops.ps_roi_pool)
  256. return lambda x: scriped(x, rois, pool_size)
  257. def expected_fn(
  258. self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, device=None, dtype=torch.float64
  259. ):
  260. if device is None:
  261. device = torch.device("cpu")
  262. n_input_channels = x.size(1)
  263. assert n_input_channels % (pool_h * pool_w) == 0, "input channels must be divisible by ph * pw"
  264. n_output_channels = int(n_input_channels / (pool_h * pool_w))
  265. y = torch.zeros(rois.size(0), n_output_channels, pool_h, pool_w, dtype=dtype, device=device)
  266. def get_slice(k, block):
  267. return slice(int(np.floor(k * block)), int(np.ceil((k + 1) * block)))
  268. for roi_idx, roi in enumerate(rois):
  269. batch_idx = int(roi[0])
  270. j_begin, i_begin, j_end, i_end = (int(round(x.item() * spatial_scale)) for x in roi[1:])
  271. roi_x = x[batch_idx, :, i_begin : i_end + 1, j_begin : j_end + 1]
  272. roi_height = max(i_end - i_begin, 1)
  273. roi_width = max(j_end - j_begin, 1)
  274. bin_h, bin_w = roi_height / float(pool_h), roi_width / float(pool_w)
  275. for i in range(0, pool_h):
  276. for j in range(0, pool_w):
  277. bin_x = roi_x[:, get_slice(i, bin_h), get_slice(j, bin_w)]
  278. if bin_x.numel() > 0:
  279. area = bin_x.size(-2) * bin_x.size(-1)
  280. for c_out in range(0, n_output_channels):
  281. c_in = c_out * (pool_h * pool_w) + pool_w * i + j
  282. t = torch.sum(bin_x[c_in, :, :])
  283. y[roi_idx, c_out, i, j] = t / area
  284. return y
  285. def test_boxes_shape(self):
  286. self._helper_boxes_shape(ops.ps_roi_pool)
  287. def bilinear_interpolate(data, y, x, snap_border=False):
  288. height, width = data.shape
  289. if snap_border:
  290. if -1 < y <= 0:
  291. y = 0
  292. elif height - 1 <= y < height:
  293. y = height - 1
  294. if -1 < x <= 0:
  295. x = 0
  296. elif width - 1 <= x < width:
  297. x = width - 1
  298. y_low = int(math.floor(y))
  299. x_low = int(math.floor(x))
  300. y_high = y_low + 1
  301. x_high = x_low + 1
  302. wy_h = y - y_low
  303. wx_h = x - x_low
  304. wy_l = 1 - wy_h
  305. wx_l = 1 - wx_h
  306. val = 0
  307. for wx, xp in zip((wx_l, wx_h), (x_low, x_high)):
  308. for wy, yp in zip((wy_l, wy_h), (y_low, y_high)):
  309. if 0 <= yp < height and 0 <= xp < width:
  310. val += wx * wy * data[yp, xp]
  311. return val
  312. class TestRoIAlign(RoIOpTester):
  313. mps_backward_atol = 6e-2
  314. def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, aligned=False, **kwargs):
  315. return ops.RoIAlign(
  316. (pool_h, pool_w), spatial_scale=spatial_scale, sampling_ratio=sampling_ratio, aligned=aligned
  317. )(x, rois)
  318. def make_obj(self, pool_h=5, pool_w=5, spatial_scale=1, sampling_ratio=-1, aligned=False, wrap=False):
  319. obj = ops.RoIAlign(
  320. (pool_h, pool_w), spatial_scale=spatial_scale, sampling_ratio=sampling_ratio, aligned=aligned
  321. )
  322. return RoIOpTesterModuleWrapper(obj) if wrap else obj
  323. def get_script_fn(self, rois, pool_size):
  324. scriped = torch.jit.script(ops.roi_align)
  325. return lambda x: scriped(x, rois, pool_size)
  326. def expected_fn(
  327. self,
  328. in_data,
  329. rois,
  330. pool_h,
  331. pool_w,
  332. spatial_scale=1,
  333. sampling_ratio=-1,
  334. aligned=False,
  335. device=None,
  336. dtype=torch.float64,
  337. ):
  338. if device is None:
  339. device = torch.device("cpu")
  340. n_channels = in_data.size(1)
  341. out_data = torch.zeros(rois.size(0), n_channels, pool_h, pool_w, dtype=dtype, device=device)
  342. offset = 0.5 if aligned else 0.0
  343. for r, roi in enumerate(rois):
  344. batch_idx = int(roi[0])
  345. j_begin, i_begin, j_end, i_end = (x.item() * spatial_scale - offset for x in roi[1:])
  346. roi_h = i_end - i_begin
  347. roi_w = j_end - j_begin
  348. bin_h = roi_h / pool_h
  349. bin_w = roi_w / pool_w
  350. for i in range(0, pool_h):
  351. start_h = i_begin + i * bin_h
  352. grid_h = sampling_ratio if sampling_ratio > 0 else int(np.ceil(bin_h))
  353. for j in range(0, pool_w):
  354. start_w = j_begin + j * bin_w
  355. grid_w = sampling_ratio if sampling_ratio > 0 else int(np.ceil(bin_w))
  356. for channel in range(0, n_channels):
  357. val = 0
  358. for iy in range(0, grid_h):
  359. y = start_h + (iy + 0.5) * bin_h / grid_h
  360. for ix in range(0, grid_w):
  361. x = start_w + (ix + 0.5) * bin_w / grid_w
  362. val += bilinear_interpolate(in_data[batch_idx, channel, :, :], y, x, snap_border=True)
  363. val /= grid_h * grid_w
  364. out_data[r, channel, i, j] = val
  365. return out_data
  366. def test_boxes_shape(self):
  367. self._helper_boxes_shape(ops.roi_align)
  368. @pytest.mark.parametrize("aligned", (True, False))
  369. @pytest.mark.parametrize("device", cpu_and_cuda_and_mps())
  370. @pytest.mark.parametrize("x_dtype", (torch.float16, torch.float32, torch.float64), ids=str)
  371. @pytest.mark.parametrize("contiguous", (True, False))
  372. @pytest.mark.parametrize("deterministic", (True, False))
  373. def test_forward(self, device, contiguous, deterministic, aligned, x_dtype, rois_dtype=None):
  374. if deterministic and device == "cpu":
  375. pytest.skip("cpu is always deterministic, don't retest")
  376. super().test_forward(
  377. device=device,
  378. contiguous=contiguous,
  379. deterministic=deterministic,
  380. x_dtype=x_dtype,
  381. rois_dtype=rois_dtype,
  382. aligned=aligned,
  383. )
  384. @needs_cuda
  385. @pytest.mark.parametrize("aligned", (True, False))
  386. @pytest.mark.parametrize("deterministic", (True, False))
  387. @pytest.mark.parametrize("x_dtype", (torch.float, torch.half))
  388. @pytest.mark.parametrize("rois_dtype", (torch.float, torch.half))
  389. def test_autocast(self, aligned, deterministic, x_dtype, rois_dtype):
  390. with torch.cuda.amp.autocast():
  391. self.test_forward(
  392. torch.device("cuda"),
  393. contiguous=False,
  394. deterministic=deterministic,
  395. aligned=aligned,
  396. x_dtype=x_dtype,
  397. rois_dtype=rois_dtype,
  398. )
  399. @pytest.mark.parametrize("seed", range(10))
  400. @pytest.mark.parametrize("device", cpu_and_cuda_and_mps())
  401. @pytest.mark.parametrize("contiguous", (True, False))
  402. @pytest.mark.parametrize("deterministic", (True, False))
  403. def test_backward(self, seed, device, contiguous, deterministic):
  404. if deterministic and device == "cpu":
  405. pytest.skip("cpu is always deterministic, don't retest")
  406. super().test_backward(seed, device, contiguous, deterministic)
  407. def _make_rois(self, img_size, num_imgs, dtype, num_rois=1000):
  408. rois = torch.randint(0, img_size // 2, size=(num_rois, 5)).to(dtype)
  409. rois[:, 0] = torch.randint(0, num_imgs, size=(num_rois,)) # set batch index
  410. rois[:, 3:] += rois[:, 1:3] # make sure boxes aren't degenerate
  411. return rois
  412. @pytest.mark.parametrize("aligned", (True, False))
  413. @pytest.mark.parametrize("scale, zero_point", ((1, 0), (2, 10), (0.1, 50)))
  414. @pytest.mark.parametrize("qdtype", (torch.qint8, torch.quint8, torch.qint32))
  415. def test_qroialign(self, aligned, scale, zero_point, qdtype):
  416. """Make sure quantized version of RoIAlign is close to float version"""
  417. pool_size = 5
  418. img_size = 10
  419. n_channels = 2
  420. num_imgs = 1
  421. dtype = torch.float
  422. x = torch.randint(50, 100, size=(num_imgs, n_channels, img_size, img_size)).to(dtype)
  423. qx = torch.quantize_per_tensor(x, scale=scale, zero_point=zero_point, dtype=qdtype)
  424. rois = self._make_rois(img_size, num_imgs, dtype)
  425. qrois = torch.quantize_per_tensor(rois, scale=scale, zero_point=zero_point, dtype=qdtype)
  426. x, rois = qx.dequantize(), qrois.dequantize() # we want to pass the same inputs
  427. y = ops.roi_align(
  428. x,
  429. rois,
  430. output_size=pool_size,
  431. spatial_scale=1,
  432. sampling_ratio=-1,
  433. aligned=aligned,
  434. )
  435. qy = ops.roi_align(
  436. qx,
  437. qrois,
  438. output_size=pool_size,
  439. spatial_scale=1,
  440. sampling_ratio=-1,
  441. aligned=aligned,
  442. )
  443. # The output qy is itself a quantized tensor and there might have been a loss of info when it was
  444. # quantized. For a fair comparison we need to quantize y as well
  445. quantized_float_y = torch.quantize_per_tensor(y, scale=scale, zero_point=zero_point, dtype=qdtype)
  446. try:
  447. # Ideally, we would assert this, which passes with (scale, zero) == (1, 0)
  448. assert (qy == quantized_float_y).all()
  449. except AssertionError:
  450. # But because the computation aren't exactly the same between the 2 RoIAlign procedures, some
  451. # rounding error may lead to a difference of 2 in the output.
  452. # For example with (scale, zero) = (2, 10), 45.00000... will be quantized to 44
  453. # but 45.00000001 will be rounded to 46. We make sure below that:
  454. # - such discrepancies between qy and quantized_float_y are very rare (less then 5%)
  455. # - any difference between qy and quantized_float_y is == scale
  456. diff_idx = torch.where(qy != quantized_float_y)
  457. num_diff = diff_idx[0].numel()
  458. assert num_diff / qy.numel() < 0.05
  459. abs_diff = torch.abs(qy[diff_idx].dequantize() - quantized_float_y[diff_idx].dequantize())
  460. t_scale = torch.full_like(abs_diff, fill_value=scale)
  461. torch.testing.assert_close(abs_diff, t_scale, rtol=1e-5, atol=1e-5)
  462. def test_qroi_align_multiple_images(self):
  463. dtype = torch.float
  464. x = torch.randint(50, 100, size=(2, 3, 10, 10)).to(dtype)
  465. qx = torch.quantize_per_tensor(x, scale=1, zero_point=0, dtype=torch.qint8)
  466. rois = self._make_rois(img_size=10, num_imgs=2, dtype=dtype, num_rois=10)
  467. qrois = torch.quantize_per_tensor(rois, scale=1, zero_point=0, dtype=torch.qint8)
  468. with pytest.raises(RuntimeError, match="Only one image per batch is allowed"):
  469. ops.roi_align(qx, qrois, output_size=5)
  470. def test_jit_boxes_list(self):
  471. model = PoolWrapper(ops.RoIAlign(output_size=[3, 3], spatial_scale=1.0, sampling_ratio=-1))
  472. self._helper_jit_boxes_list(model)
  473. class TestPSRoIAlign(RoIOpTester):
  474. mps_backward_atol = 5e-2
  475. def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwargs):
  476. return ops.PSRoIAlign((pool_h, pool_w), spatial_scale=spatial_scale, sampling_ratio=sampling_ratio)(x, rois)
  477. def make_obj(self, pool_h=5, pool_w=5, spatial_scale=1, sampling_ratio=-1, wrap=False):
  478. obj = ops.PSRoIAlign((pool_h, pool_w), spatial_scale=spatial_scale, sampling_ratio=sampling_ratio)
  479. return RoIOpTesterModuleWrapper(obj) if wrap else obj
  480. def get_script_fn(self, rois, pool_size):
  481. scriped = torch.jit.script(ops.ps_roi_align)
  482. return lambda x: scriped(x, rois, pool_size)
  483. def expected_fn(
  484. self, in_data, rois, pool_h, pool_w, device, spatial_scale=1, sampling_ratio=-1, dtype=torch.float64
  485. ):
  486. if device is None:
  487. device = torch.device("cpu")
  488. n_input_channels = in_data.size(1)
  489. assert n_input_channels % (pool_h * pool_w) == 0, "input channels must be divisible by ph * pw"
  490. n_output_channels = int(n_input_channels / (pool_h * pool_w))
  491. out_data = torch.zeros(rois.size(0), n_output_channels, pool_h, pool_w, dtype=dtype, device=device)
  492. for r, roi in enumerate(rois):
  493. batch_idx = int(roi[0])
  494. j_begin, i_begin, j_end, i_end = (x.item() * spatial_scale - 0.5 for x in roi[1:])
  495. roi_h = i_end - i_begin
  496. roi_w = j_end - j_begin
  497. bin_h = roi_h / pool_h
  498. bin_w = roi_w / pool_w
  499. for i in range(0, pool_h):
  500. start_h = i_begin + i * bin_h
  501. grid_h = sampling_ratio if sampling_ratio > 0 else int(np.ceil(bin_h))
  502. for j in range(0, pool_w):
  503. start_w = j_begin + j * bin_w
  504. grid_w = sampling_ratio if sampling_ratio > 0 else int(np.ceil(bin_w))
  505. for c_out in range(0, n_output_channels):
  506. c_in = c_out * (pool_h * pool_w) + pool_w * i + j
  507. val = 0
  508. for iy in range(0, grid_h):
  509. y = start_h + (iy + 0.5) * bin_h / grid_h
  510. for ix in range(0, grid_w):
  511. x = start_w + (ix + 0.5) * bin_w / grid_w
  512. val += bilinear_interpolate(in_data[batch_idx, c_in, :, :], y, x, snap_border=True)
  513. val /= grid_h * grid_w
  514. out_data[r, c_out, i, j] = val
  515. return out_data
  516. def test_boxes_shape(self):
  517. self._helper_boxes_shape(ops.ps_roi_align)
  518. class TestMultiScaleRoIAlign:
  519. def make_obj(self, fmap_names=None, output_size=(7, 7), sampling_ratio=2, wrap=False):
  520. if fmap_names is None:
  521. fmap_names = ["0"]
  522. obj = ops.poolers.MultiScaleRoIAlign(fmap_names, output_size, sampling_ratio)
  523. return MultiScaleRoIAlignModuleWrapper(obj) if wrap else obj
  524. def test_msroialign_repr(self):
  525. fmap_names = ["0"]
  526. output_size = (7, 7)
  527. sampling_ratio = 2
  528. # Pass mock feature map names
  529. t = self.make_obj(fmap_names, output_size, sampling_ratio, wrap=False)
  530. # Check integrity of object __repr__ attribute
  531. expected_string = (
  532. f"MultiScaleRoIAlign(featmap_names={fmap_names}, output_size={output_size}, "
  533. f"sampling_ratio={sampling_ratio})"
  534. )
  535. assert repr(t) == expected_string
  536. @pytest.mark.parametrize("device", cpu_and_cuda())
  537. def test_is_leaf_node(self, device):
  538. op_obj = self.make_obj(wrap=True).to(device=device)
  539. graph_node_names = get_graph_node_names(op_obj)
  540. assert len(graph_node_names) == 2
  541. assert len(graph_node_names[0]) == len(graph_node_names[1])
  542. assert len(graph_node_names[0]) == 1 + op_obj.n_inputs
  543. class TestNMS:
  544. def _reference_nms(self, boxes, scores, iou_threshold):
  545. """
  546. Args:
  547. boxes: boxes in corner-form
  548. scores: probabilities
  549. iou_threshold: intersection over union threshold
  550. Returns:
  551. picked: a list of indexes of the kept boxes
  552. """
  553. picked = []
  554. _, indexes = scores.sort(descending=True)
  555. while len(indexes) > 0:
  556. current = indexes[0]
  557. picked.append(current.item())
  558. if len(indexes) == 1:
  559. break
  560. current_box = boxes[current, :]
  561. indexes = indexes[1:]
  562. rest_boxes = boxes[indexes, :]
  563. iou = ops.box_iou(rest_boxes, current_box.unsqueeze(0)).squeeze(1)
  564. indexes = indexes[iou <= iou_threshold]
  565. return torch.as_tensor(picked)
  566. def _create_tensors_with_iou(self, N, iou_thresh):
  567. # force last box to have a pre-defined iou with the first box
  568. # let b0 be [x0, y0, x1, y1], and b1 be [x0, y0, x1 + d, y1],
  569. # then, in order to satisfy ops.iou(b0, b1) == iou_thresh,
  570. # we need to have d = (x1 - x0) * (1 - iou_thresh) / iou_thresh
  571. # Adjust the threshold upward a bit with the intent of creating
  572. # at least one box that exceeds (barely) the threshold and so
  573. # should be suppressed.
  574. boxes = torch.rand(N, 4) * 100
  575. boxes[:, 2:] += boxes[:, :2]
  576. boxes[-1, :] = boxes[0, :]
  577. x0, y0, x1, y1 = boxes[-1].tolist()
  578. iou_thresh += 1e-5
  579. boxes[-1, 2] += (x1 - x0) * (1 - iou_thresh) / iou_thresh
  580. scores = torch.rand(N)
  581. return boxes, scores
  582. @pytest.mark.parametrize("iou", (0.2, 0.5, 0.8))
  583. @pytest.mark.parametrize("seed", range(10))
  584. def test_nms_ref(self, iou, seed):
  585. torch.random.manual_seed(seed)
  586. err_msg = "NMS incompatible between CPU and reference implementation for IoU={}"
  587. boxes, scores = self._create_tensors_with_iou(1000, iou)
  588. keep_ref = self._reference_nms(boxes, scores, iou)
  589. keep = ops.nms(boxes, scores, iou)
  590. torch.testing.assert_close(keep, keep_ref, msg=err_msg.format(iou))
  591. def test_nms_input_errors(self):
  592. with pytest.raises(RuntimeError):
  593. ops.nms(torch.rand(4), torch.rand(3), 0.5)
  594. with pytest.raises(RuntimeError):
  595. ops.nms(torch.rand(3, 5), torch.rand(3), 0.5)
  596. with pytest.raises(RuntimeError):
  597. ops.nms(torch.rand(3, 4), torch.rand(3, 2), 0.5)
  598. with pytest.raises(RuntimeError):
  599. ops.nms(torch.rand(3, 4), torch.rand(4), 0.5)
  600. @pytest.mark.parametrize("iou", (0.2, 0.5, 0.8))
  601. @pytest.mark.parametrize("scale, zero_point", ((1, 0), (2, 50), (3, 10)))
  602. def test_qnms(self, iou, scale, zero_point):
  603. # Note: we compare qnms vs nms instead of qnms vs reference implementation.
  604. # This is because with the int conversion, the trick used in _create_tensors_with_iou
  605. # doesn't really work (in fact, nms vs reference implem will also fail with ints)
  606. err_msg = "NMS and QNMS give different results for IoU={}"
  607. boxes, scores = self._create_tensors_with_iou(1000, iou)
  608. scores *= 100 # otherwise most scores would be 0 or 1 after int conversion
  609. qboxes = torch.quantize_per_tensor(boxes, scale=scale, zero_point=zero_point, dtype=torch.quint8)
  610. qscores = torch.quantize_per_tensor(scores, scale=scale, zero_point=zero_point, dtype=torch.quint8)
  611. boxes = qboxes.dequantize()
  612. scores = qscores.dequantize()
  613. keep = ops.nms(boxes, scores, iou)
  614. qkeep = ops.nms(qboxes, qscores, iou)
  615. torch.testing.assert_close(qkeep, keep, msg=err_msg.format(iou))
  616. @pytest.mark.parametrize(
  617. "device",
  618. (
  619. pytest.param("cuda", marks=pytest.mark.needs_cuda),
  620. pytest.param("mps", marks=pytest.mark.needs_mps),
  621. ),
  622. )
  623. @pytest.mark.parametrize("iou", (0.2, 0.5, 0.8))
  624. def test_nms_gpu(self, iou, device, dtype=torch.float64):
  625. dtype = torch.float32 if device == "mps" else dtype
  626. tol = 1e-3 if dtype is torch.half else 1e-5
  627. err_msg = "NMS incompatible between CPU and CUDA for IoU={}"
  628. boxes, scores = self._create_tensors_with_iou(1000, iou)
  629. r_cpu = ops.nms(boxes, scores, iou)
  630. r_gpu = ops.nms(boxes.to(device), scores.to(device), iou)
  631. is_eq = torch.allclose(r_cpu, r_gpu.cpu())
  632. if not is_eq:
  633. # if the indices are not the same, ensure that it's because the scores
  634. # are duplicate
  635. is_eq = torch.allclose(scores[r_cpu], scores[r_gpu.cpu()], rtol=tol, atol=tol)
  636. assert is_eq, err_msg.format(iou)
  637. @needs_cuda
  638. @pytest.mark.parametrize("iou", (0.2, 0.5, 0.8))
  639. @pytest.mark.parametrize("dtype", (torch.float, torch.half))
  640. def test_autocast(self, iou, dtype):
  641. with torch.cuda.amp.autocast():
  642. self.test_nms_gpu(iou=iou, dtype=dtype, device="cuda")
  643. @pytest.mark.parametrize(
  644. "device",
  645. (
  646. pytest.param("cuda", marks=pytest.mark.needs_cuda),
  647. pytest.param("mps", marks=pytest.mark.needs_mps),
  648. ),
  649. )
  650. def test_nms_float16(self, device):
  651. boxes = torch.tensor(
  652. [
  653. [285.3538, 185.5758, 1193.5110, 851.4551],
  654. [285.1472, 188.7374, 1192.4984, 851.0669],
  655. [279.2440, 197.9812, 1189.4746, 849.2019],
  656. ]
  657. ).to(device)
  658. scores = torch.tensor([0.6370, 0.7569, 0.3966]).to(device)
  659. iou_thres = 0.2
  660. keep32 = ops.nms(boxes, scores, iou_thres)
  661. keep16 = ops.nms(boxes.to(torch.float16), scores.to(torch.float16), iou_thres)
  662. assert_equal(keep32, keep16)
  663. @pytest.mark.parametrize("seed", range(10))
  664. def test_batched_nms_implementations(self, seed):
  665. """Make sure that both implementations of batched_nms yield identical results"""
  666. torch.random.manual_seed(seed)
  667. num_boxes = 1000
  668. iou_threshold = 0.9
  669. boxes = torch.cat((torch.rand(num_boxes, 2), torch.rand(num_boxes, 2) + 10), dim=1)
  670. assert max(boxes[:, 0]) < min(boxes[:, 2]) # x1 < x2
  671. assert max(boxes[:, 1]) < min(boxes[:, 3]) # y1 < y2
  672. scores = torch.rand(num_boxes)
  673. idxs = torch.randint(0, 4, size=(num_boxes,))
  674. keep_vanilla = ops.boxes._batched_nms_vanilla(boxes, scores, idxs, iou_threshold)
  675. keep_trick = ops.boxes._batched_nms_coordinate_trick(boxes, scores, idxs, iou_threshold)
  676. torch.testing.assert_close(
  677. keep_vanilla, keep_trick, msg="The vanilla and the trick implementation yield different nms outputs."
  678. )
  679. # Also make sure an empty tensor is returned if boxes is empty
  680. empty = torch.empty((0,), dtype=torch.int64)
  681. torch.testing.assert_close(empty, ops.batched_nms(empty, None, None, None))
  682. class TestDeformConv:
  683. dtype = torch.float64
  684. def expected_fn(self, x, weight, offset, mask, bias, stride=1, padding=0, dilation=1):
  685. stride_h, stride_w = _pair(stride)
  686. pad_h, pad_w = _pair(padding)
  687. dil_h, dil_w = _pair(dilation)
  688. weight_h, weight_w = weight.shape[-2:]
  689. n_batches, n_in_channels, in_h, in_w = x.shape
  690. n_out_channels = weight.shape[0]
  691. out_h = (in_h + 2 * pad_h - (dil_h * (weight_h - 1) + 1)) // stride_h + 1
  692. out_w = (in_w + 2 * pad_w - (dil_w * (weight_w - 1) + 1)) // stride_w + 1
  693. n_offset_grps = offset.shape[1] // (2 * weight_h * weight_w)
  694. in_c_per_offset_grp = n_in_channels // n_offset_grps
  695. n_weight_grps = n_in_channels // weight.shape[1]
  696. in_c_per_weight_grp = weight.shape[1]
  697. out_c_per_weight_grp = n_out_channels // n_weight_grps
  698. out = torch.zeros(n_batches, n_out_channels, out_h, out_w, device=x.device, dtype=x.dtype)
  699. for b in range(n_batches):
  700. for c_out in range(n_out_channels):
  701. for i in range(out_h):
  702. for j in range(out_w):
  703. for di in range(weight_h):
  704. for dj in range(weight_w):
  705. for c in range(in_c_per_weight_grp):
  706. weight_grp = c_out // out_c_per_weight_grp
  707. c_in = weight_grp * in_c_per_weight_grp + c
  708. offset_grp = c_in // in_c_per_offset_grp
  709. mask_idx = offset_grp * (weight_h * weight_w) + di * weight_w + dj
  710. offset_idx = 2 * mask_idx
  711. pi = stride_h * i - pad_h + dil_h * di + offset[b, offset_idx, i, j]
  712. pj = stride_w * j - pad_w + dil_w * dj + offset[b, offset_idx + 1, i, j]
  713. mask_value = 1.0
  714. if mask is not None:
  715. mask_value = mask[b, mask_idx, i, j]
  716. out[b, c_out, i, j] += (
  717. mask_value
  718. * weight[c_out, c, di, dj]
  719. * bilinear_interpolate(x[b, c_in, :, :], pi, pj)
  720. )
  721. out += bias.view(1, n_out_channels, 1, 1)
  722. return out
  723. @lru_cache(maxsize=None)
  724. def get_fn_args(self, device, contiguous, batch_sz, dtype):
  725. n_in_channels = 6
  726. n_out_channels = 2
  727. n_weight_grps = 2
  728. n_offset_grps = 3
  729. stride = (2, 1)
  730. pad = (1, 0)
  731. dilation = (2, 1)
  732. stride_h, stride_w = stride
  733. pad_h, pad_w = pad
  734. dil_h, dil_w = dilation
  735. weight_h, weight_w = (3, 2)
  736. in_h, in_w = (5, 4)
  737. out_h = (in_h + 2 * pad_h - (dil_h * (weight_h - 1) + 1)) // stride_h + 1
  738. out_w = (in_w + 2 * pad_w - (dil_w * (weight_w - 1) + 1)) // stride_w + 1
  739. x = torch.rand(batch_sz, n_in_channels, in_h, in_w, device=device, dtype=dtype, requires_grad=True)
  740. offset = torch.randn(
  741. batch_sz,
  742. n_offset_grps * 2 * weight_h * weight_w,
  743. out_h,
  744. out_w,
  745. device=device,
  746. dtype=dtype,
  747. requires_grad=True,
  748. )
  749. mask = torch.randn(
  750. batch_sz, n_offset_grps * weight_h * weight_w, out_h, out_w, device=device, dtype=dtype, requires_grad=True
  751. )
  752. weight = torch.randn(
  753. n_out_channels,
  754. n_in_channels // n_weight_grps,
  755. weight_h,
  756. weight_w,
  757. device=device,
  758. dtype=dtype,
  759. requires_grad=True,
  760. )
  761. bias = torch.randn(n_out_channels, device=device, dtype=dtype, requires_grad=True)
  762. if not contiguous:
  763. x = x.permute(0, 1, 3, 2).contiguous().permute(0, 1, 3, 2)
  764. offset = offset.permute(1, 3, 0, 2).contiguous().permute(2, 0, 3, 1)
  765. mask = mask.permute(1, 3, 0, 2).contiguous().permute(2, 0, 3, 1)
  766. weight = weight.permute(3, 2, 0, 1).contiguous().permute(2, 3, 1, 0)
  767. return x, weight, offset, mask, bias, stride, pad, dilation
  768. def make_obj(self, in_channels=6, out_channels=2, kernel_size=(3, 2), groups=2, wrap=False):
  769. obj = ops.DeformConv2d(
  770. in_channels, out_channels, kernel_size, stride=(2, 1), padding=(1, 0), dilation=(2, 1), groups=groups
  771. )
  772. return DeformConvModuleWrapper(obj) if wrap else obj
  773. @pytest.mark.parametrize("device", cpu_and_cuda())
  774. def test_is_leaf_node(self, device):
  775. op_obj = self.make_obj(wrap=True).to(device=device)
  776. graph_node_names = get_graph_node_names(op_obj)
  777. assert len(graph_node_names) == 2
  778. assert len(graph_node_names[0]) == len(graph_node_names[1])
  779. assert len(graph_node_names[0]) == 1 + op_obj.n_inputs
  780. @pytest.mark.parametrize("device", cpu_and_cuda())
  781. @pytest.mark.parametrize("contiguous", (True, False))
  782. @pytest.mark.parametrize("batch_sz", (0, 33))
  783. def test_forward(self, device, contiguous, batch_sz, dtype=None):
  784. dtype = dtype or self.dtype
  785. x, _, offset, mask, _, stride, padding, dilation = self.get_fn_args(device, contiguous, batch_sz, dtype)
  786. in_channels = 6
  787. out_channels = 2
  788. kernel_size = (3, 2)
  789. groups = 2
  790. tol = 2e-3 if dtype is torch.half else 1e-5
  791. layer = self.make_obj(in_channels, out_channels, kernel_size, groups, wrap=False).to(
  792. device=x.device, dtype=dtype
  793. )
  794. res = layer(x, offset, mask)
  795. weight = layer.weight.data
  796. bias = layer.bias.data
  797. expected = self.expected_fn(x, weight, offset, mask, bias, stride=stride, padding=padding, dilation=dilation)
  798. torch.testing.assert_close(
  799. res.to(expected), expected, rtol=tol, atol=tol, msg=f"\nres:\n{res}\nexpected:\n{expected}"
  800. )
  801. # no modulation test
  802. res = layer(x, offset)
  803. expected = self.expected_fn(x, weight, offset, None, bias, stride=stride, padding=padding, dilation=dilation)
  804. torch.testing.assert_close(
  805. res.to(expected), expected, rtol=tol, atol=tol, msg=f"\nres:\n{res}\nexpected:\n{expected}"
  806. )
  807. def test_wrong_sizes(self):
  808. in_channels = 6
  809. out_channels = 2
  810. kernel_size = (3, 2)
  811. groups = 2
  812. x, _, offset, mask, _, stride, padding, dilation = self.get_fn_args(
  813. "cpu", contiguous=True, batch_sz=10, dtype=self.dtype
  814. )
  815. layer = ops.DeformConv2d(
  816. in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups
  817. )
  818. with pytest.raises(RuntimeError, match="the shape of the offset"):
  819. wrong_offset = torch.rand_like(offset[:, :2])
  820. layer(x, wrong_offset)
  821. with pytest.raises(RuntimeError, match=r"mask.shape\[1\] is not valid"):
  822. wrong_mask = torch.rand_like(mask[:, :2])
  823. layer(x, offset, wrong_mask)
  824. @pytest.mark.parametrize("device", cpu_and_cuda())
  825. @pytest.mark.parametrize("contiguous", (True, False))
  826. @pytest.mark.parametrize("batch_sz", (0, 33))
  827. def test_backward(self, device, contiguous, batch_sz):
  828. x, weight, offset, mask, bias, stride, padding, dilation = self.get_fn_args(
  829. device, contiguous, batch_sz, self.dtype
  830. )
  831. def func(x_, offset_, mask_, weight_, bias_):
  832. return ops.deform_conv2d(
  833. x_, offset_, weight_, bias_, stride=stride, padding=padding, dilation=dilation, mask=mask_
  834. )
  835. gradcheck(func, (x, offset, mask, weight, bias), nondet_tol=1e-5, fast_mode=True)
  836. def func_no_mask(x_, offset_, weight_, bias_):
  837. return ops.deform_conv2d(
  838. x_, offset_, weight_, bias_, stride=stride, padding=padding, dilation=dilation, mask=None
  839. )
  840. gradcheck(func_no_mask, (x, offset, weight, bias), nondet_tol=1e-5, fast_mode=True)
  841. @torch.jit.script
  842. def script_func(x_, offset_, mask_, weight_, bias_, stride_, pad_, dilation_):
  843. # type:(Tensor, Tensor, Tensor, Tensor, Tensor, Tuple[int, int], Tuple[int, int], Tuple[int, int])->Tensor
  844. return ops.deform_conv2d(
  845. x_, offset_, weight_, bias_, stride=stride_, padding=pad_, dilation=dilation_, mask=mask_
  846. )
  847. gradcheck(
  848. lambda z, off, msk, wei, bi: script_func(z, off, msk, wei, bi, stride, padding, dilation),
  849. (x, offset, mask, weight, bias),
  850. nondet_tol=1e-5,
  851. fast_mode=True,
  852. )
  853. @torch.jit.script
  854. def script_func_no_mask(x_, offset_, weight_, bias_, stride_, pad_, dilation_):
  855. # type:(Tensor, Tensor, Tensor, Tensor, Tuple[int, int], Tuple[int, int], Tuple[int, int])->Tensor
  856. return ops.deform_conv2d(
  857. x_, offset_, weight_, bias_, stride=stride_, padding=pad_, dilation=dilation_, mask=None
  858. )
  859. gradcheck(
  860. lambda z, off, wei, bi: script_func_no_mask(z, off, wei, bi, stride, padding, dilation),
  861. (x, offset, weight, bias),
  862. nondet_tol=1e-5,
  863. fast_mode=True,
  864. )
  865. @needs_cuda
  866. @pytest.mark.parametrize("contiguous", (True, False))
  867. def test_compare_cpu_cuda_grads(self, contiguous):
  868. # Test from https://github.com/pytorch/vision/issues/2598
  869. # Run on CUDA only
  870. # compare grads computed on CUDA with grads computed on CPU
  871. true_cpu_grads = None
  872. init_weight = torch.randn(9, 9, 3, 3, requires_grad=True)
  873. img = torch.randn(8, 9, 1000, 110)
  874. offset = torch.rand(8, 2 * 3 * 3, 1000, 110)
  875. mask = torch.rand(8, 3 * 3, 1000, 110)
  876. if not contiguous:
  877. img = img.permute(0, 1, 3, 2).contiguous().permute(0, 1, 3, 2)
  878. offset = offset.permute(1, 3, 0, 2).contiguous().permute(2, 0, 3, 1)
  879. mask = mask.permute(1, 3, 0, 2).contiguous().permute(2, 0, 3, 1)
  880. weight = init_weight.permute(3, 2, 0, 1).contiguous().permute(2, 3, 1, 0)
  881. else:
  882. weight = init_weight
  883. for d in ["cpu", "cuda"]:
  884. out = ops.deform_conv2d(img.to(d), offset.to(d), weight.to(d), padding=1, mask=mask.to(d))
  885. out.mean().backward()
  886. if true_cpu_grads is None:
  887. true_cpu_grads = init_weight.grad
  888. assert true_cpu_grads is not None
  889. else:
  890. assert init_weight.grad is not None
  891. res_grads = init_weight.grad.to("cpu")
  892. torch.testing.assert_close(true_cpu_grads, res_grads)
  893. @needs_cuda
  894. @pytest.mark.parametrize("batch_sz", (0, 33))
  895. @pytest.mark.parametrize("dtype", (torch.float, torch.half))
  896. def test_autocast(self, batch_sz, dtype):
  897. with torch.cuda.amp.autocast():
  898. self.test_forward(torch.device("cuda"), contiguous=False, batch_sz=batch_sz, dtype=dtype)
  899. def test_forward_scriptability(self):
  900. # Non-regression test for https://github.com/pytorch/vision/issues/4078
  901. torch.jit.script(ops.DeformConv2d(in_channels=8, out_channels=8, kernel_size=3))
  902. class TestFrozenBNT:
  903. def test_frozenbatchnorm2d_repr(self):
  904. num_features = 32
  905. eps = 1e-5
  906. t = ops.misc.FrozenBatchNorm2d(num_features, eps=eps)
  907. # Check integrity of object __repr__ attribute
  908. expected_string = f"FrozenBatchNorm2d({num_features}, eps={eps})"
  909. assert repr(t) == expected_string
  910. @pytest.mark.parametrize("seed", range(10))
  911. def test_frozenbatchnorm2d_eps(self, seed):
  912. torch.random.manual_seed(seed)
  913. sample_size = (4, 32, 28, 28)
  914. x = torch.rand(sample_size)
  915. state_dict = dict(
  916. weight=torch.rand(sample_size[1]),
  917. bias=torch.rand(sample_size[1]),
  918. running_mean=torch.rand(sample_size[1]),
  919. running_var=torch.rand(sample_size[1]),
  920. num_batches_tracked=torch.tensor(100),
  921. )
  922. # Check that default eps is equal to the one of BN
  923. fbn = ops.misc.FrozenBatchNorm2d(sample_size[1])
  924. fbn.load_state_dict(state_dict, strict=False)
  925. bn = torch.nn.BatchNorm2d(sample_size[1]).eval()
  926. bn.load_state_dict(state_dict)
  927. # Difference is expected to fall in an acceptable range
  928. torch.testing.assert_close(fbn(x), bn(x), rtol=1e-5, atol=1e-6)
  929. # Check computation for eps > 0
  930. fbn = ops.misc.FrozenBatchNorm2d(sample_size[1], eps=1e-5)
  931. fbn.load_state_dict(state_dict, strict=False)
  932. bn = torch.nn.BatchNorm2d(sample_size[1], eps=1e-5).eval()
  933. bn.load_state_dict(state_dict)
  934. torch.testing.assert_close(fbn(x), bn(x), rtol=1e-5, atol=1e-6)
  935. class TestBoxConversionToRoi:
  936. def _get_box_sequences():
  937. # Define here the argument type of `boxes` supported by region pooling operations
  938. box_tensor = torch.tensor([[0, 0, 0, 100, 100], [1, 0, 0, 100, 100]], dtype=torch.float)
  939. box_list = [
  940. torch.tensor([[0, 0, 100, 100]], dtype=torch.float),
  941. torch.tensor([[0, 0, 100, 100]], dtype=torch.float),
  942. ]
  943. box_tuple = tuple(box_list)
  944. return box_tensor, box_list, box_tuple
  945. @pytest.mark.parametrize("box_sequence", _get_box_sequences())
  946. def test_check_roi_boxes_shape(self, box_sequence):
  947. # Ensure common sequences of tensors are supported
  948. ops._utils.check_roi_boxes_shape(box_sequence)
  949. @pytest.mark.parametrize("box_sequence", _get_box_sequences())
  950. def test_convert_boxes_to_roi_format(self, box_sequence):
  951. # Ensure common sequences of tensors yield the same result
  952. ref_tensor = None
  953. if ref_tensor is None:
  954. ref_tensor = box_sequence
  955. else:
  956. assert_equal(ref_tensor, ops._utils.convert_boxes_to_roi_format(box_sequence))
  957. class TestBoxConvert:
  958. def test_bbox_same(self):
  959. box_tensor = torch.tensor(
  960. [[0, 0, 100, 100], [0, 0, 0, 0], [10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float
  961. )
  962. exp_xyxy = torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0], [10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float)
  963. assert exp_xyxy.size() == torch.Size([4, 4])
  964. assert_equal(ops.box_convert(box_tensor, in_fmt="xyxy", out_fmt="xyxy"), exp_xyxy)
  965. assert_equal(ops.box_convert(box_tensor, in_fmt="xywh", out_fmt="xywh"), exp_xyxy)
  966. assert_equal(ops.box_convert(box_tensor, in_fmt="cxcywh", out_fmt="cxcywh"), exp_xyxy)
  967. def test_bbox_xyxy_xywh(self):
  968. # Simple test convert boxes to xywh and back. Make sure they are same.
  969. # box_tensor is in x1 y1 x2 y2 format.
  970. box_tensor = torch.tensor(
  971. [[0, 0, 100, 100], [0, 0, 0, 0], [10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float
  972. )
  973. exp_xywh = torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0], [10, 15, 20, 20], [23, 35, 70, 60]], dtype=torch.float)
  974. assert exp_xywh.size() == torch.Size([4, 4])
  975. box_xywh = ops.box_convert(box_tensor, in_fmt="xyxy", out_fmt="xywh")
  976. assert_equal(box_xywh, exp_xywh)
  977. # Reverse conversion
  978. box_xyxy = ops.box_convert(box_xywh, in_fmt="xywh", out_fmt="xyxy")
  979. assert_equal(box_xyxy, box_tensor)
  980. def test_bbox_xyxy_cxcywh(self):
  981. # Simple test convert boxes to cxcywh and back. Make sure they are same.
  982. # box_tensor is in x1 y1 x2 y2 format.
  983. box_tensor = torch.tensor(
  984. [[0, 0, 100, 100], [0, 0, 0, 0], [10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float
  985. )
  986. exp_cxcywh = torch.tensor(
  987. [[50, 50, 100, 100], [0, 0, 0, 0], [20, 25, 20, 20], [58, 65, 70, 60]], dtype=torch.float
  988. )
  989. assert exp_cxcywh.size() == torch.Size([4, 4])
  990. box_cxcywh = ops.box_convert(box_tensor, in_fmt="xyxy", out_fmt="cxcywh")
  991. assert_equal(box_cxcywh, exp_cxcywh)
  992. # Reverse conversion
  993. box_xyxy = ops.box_convert(box_cxcywh, in_fmt="cxcywh", out_fmt="xyxy")
  994. assert_equal(box_xyxy, box_tensor)
  995. def test_bbox_xywh_cxcywh(self):
  996. box_tensor = torch.tensor(
  997. [[0, 0, 100, 100], [0, 0, 0, 0], [10, 15, 20, 20], [23, 35, 70, 60]], dtype=torch.float
  998. )
  999. exp_cxcywh = torch.tensor(
  1000. [[50, 50, 100, 100], [0, 0, 0, 0], [20, 25, 20, 20], [58, 65, 70, 60]], dtype=torch.float
  1001. )
  1002. assert exp_cxcywh.size() == torch.Size([4, 4])
  1003. box_cxcywh = ops.box_convert(box_tensor, in_fmt="xywh", out_fmt="cxcywh")
  1004. assert_equal(box_cxcywh, exp_cxcywh)
  1005. # Reverse conversion
  1006. box_xywh = ops.box_convert(box_cxcywh, in_fmt="cxcywh", out_fmt="xywh")
  1007. assert_equal(box_xywh, box_tensor)
  1008. @pytest.mark.parametrize("inv_infmt", ["xwyh", "cxwyh"])
  1009. @pytest.mark.parametrize("inv_outfmt", ["xwcx", "xhwcy"])
  1010. def test_bbox_invalid(self, inv_infmt, inv_outfmt):
  1011. box_tensor = torch.tensor(
  1012. [[0, 0, 100, 100], [0, 0, 0, 0], [10, 15, 20, 20], [23, 35, 70, 60]], dtype=torch.float
  1013. )
  1014. with pytest.raises(ValueError):
  1015. ops.box_convert(box_tensor, inv_infmt, inv_outfmt)
  1016. def test_bbox_convert_jit(self):
  1017. box_tensor = torch.tensor(
  1018. [[0, 0, 100, 100], [0, 0, 0, 0], [10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float
  1019. )
  1020. scripted_fn = torch.jit.script(ops.box_convert)
  1021. box_xywh = ops.box_convert(box_tensor, in_fmt="xyxy", out_fmt="xywh")
  1022. scripted_xywh = scripted_fn(box_tensor, "xyxy", "xywh")
  1023. torch.testing.assert_close(scripted_xywh, box_xywh)
  1024. box_cxcywh = ops.box_convert(box_tensor, in_fmt="xyxy", out_fmt="cxcywh")
  1025. scripted_cxcywh = scripted_fn(box_tensor, "xyxy", "cxcywh")
  1026. torch.testing.assert_close(scripted_cxcywh, box_cxcywh)
  1027. class TestBoxArea:
  1028. def area_check(self, box, expected, atol=1e-4):
  1029. out = ops.box_area(box)
  1030. torch.testing.assert_close(out, expected, rtol=0.0, check_dtype=False, atol=atol)
  1031. @pytest.mark.parametrize("dtype", [torch.int8, torch.int16, torch.int32, torch.int64])
  1032. def test_int_boxes(self, dtype):
  1033. box_tensor = torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0]], dtype=dtype)
  1034. expected = torch.tensor([10000, 0], dtype=torch.int32)
  1035. self.area_check(box_tensor, expected)
  1036. @pytest.mark.parametrize("dtype", [torch.float32, torch.float64])
  1037. def test_float_boxes(self, dtype):
  1038. box_tensor = torch.tensor(FLOAT_BOXES, dtype=dtype)
  1039. expected = torch.tensor([604723.0806, 600965.4666, 592761.0085], dtype=dtype)
  1040. self.area_check(box_tensor, expected)
  1041. def test_float16_box(self):
  1042. box_tensor = torch.tensor(
  1043. [[2.825, 1.8625, 3.90, 4.85], [2.825, 4.875, 19.20, 5.10], [2.925, 1.80, 8.90, 4.90]], dtype=torch.float16
  1044. )
  1045. expected = torch.tensor([3.2170, 3.7108, 18.5071], dtype=torch.float16)
  1046. self.area_check(box_tensor, expected, atol=0.01)
  1047. def test_box_area_jit(self):
  1048. box_tensor = torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0]], dtype=torch.float)
  1049. expected = ops.box_area(box_tensor)
  1050. scripted_fn = torch.jit.script(ops.box_area)
  1051. scripted_area = scripted_fn(box_tensor)
  1052. torch.testing.assert_close(scripted_area, expected)
  1053. INT_BOXES = [[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300], [0, 0, 25, 25]]
  1054. INT_BOXES2 = [[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]]
  1055. FLOAT_BOXES = [
  1056. [285.3538, 185.5758, 1193.5110, 851.4551],
  1057. [285.1472, 188.7374, 1192.4984, 851.0669],
  1058. [279.2440, 197.9812, 1189.4746, 849.2019],
  1059. ]
  1060. def gen_box(size, dtype=torch.float):
  1061. xy1 = torch.rand((size, 2), dtype=dtype)
  1062. xy2 = xy1 + torch.rand((size, 2), dtype=dtype)
  1063. return torch.cat([xy1, xy2], axis=-1)
  1064. class TestIouBase:
  1065. @staticmethod
  1066. def _run_test(target_fn: Callable, actual_box1, actual_box2, dtypes, atol, expected):
  1067. for dtype in dtypes:
  1068. actual_box1 = torch.tensor(actual_box1, dtype=dtype)
  1069. actual_box2 = torch.tensor(actual_box2, dtype=dtype)
  1070. expected_box = torch.tensor(expected)
  1071. out = target_fn(actual_box1, actual_box2)
  1072. torch.testing.assert_close(out, expected_box, rtol=0.0, check_dtype=False, atol=atol)
  1073. @staticmethod
  1074. def _run_jit_test(target_fn: Callable, actual_box: List):
  1075. box_tensor = torch.tensor(actual_box, dtype=torch.float)
  1076. expected = target_fn(box_tensor, box_tensor)
  1077. scripted_fn = torch.jit.script(target_fn)
  1078. scripted_out = scripted_fn(box_tensor, box_tensor)
  1079. torch.testing.assert_close(scripted_out, expected)
  1080. @staticmethod
  1081. def _cartesian_product(boxes1, boxes2, target_fn: Callable):
  1082. N = boxes1.size(0)
  1083. M = boxes2.size(0)
  1084. result = torch.zeros((N, M))
  1085. for i in range(N):
  1086. for j in range(M):
  1087. result[i, j] = target_fn(boxes1[i].unsqueeze(0), boxes2[j].unsqueeze(0))
  1088. return result
  1089. @staticmethod
  1090. def _run_cartesian_test(target_fn: Callable):
  1091. boxes1 = gen_box(5)
  1092. boxes2 = gen_box(7)
  1093. a = TestIouBase._cartesian_product(boxes1, boxes2, target_fn)
  1094. b = target_fn(boxes1, boxes2)
  1095. torch.testing.assert_close(a, b)
  1096. class TestBoxIou(TestIouBase):
  1097. int_expected = [[1.0, 0.25, 0.0], [0.25, 1.0, 0.0], [0.0, 0.0, 1.0], [0.0625, 0.25, 0.0]]
  1098. float_expected = [[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]]
  1099. @pytest.mark.parametrize(
  1100. "actual_box1, actual_box2, dtypes, atol, expected",
  1101. [
  1102. pytest.param(INT_BOXES, INT_BOXES2, [torch.int16, torch.int32, torch.int64], 1e-4, int_expected),
  1103. pytest.param(FLOAT_BOXES, FLOAT_BOXES, [torch.float16], 0.002, float_expected),
  1104. pytest.param(FLOAT_BOXES, FLOAT_BOXES, [torch.float32, torch.float64], 1e-3, float_expected),
  1105. ],
  1106. )
  1107. def test_iou(self, actual_box1, actual_box2, dtypes, atol, expected):
  1108. self._run_test(ops.box_iou, actual_box1, actual_box2, dtypes, atol, expected)
  1109. def test_iou_jit(self):
  1110. self._run_jit_test(ops.box_iou, INT_BOXES)
  1111. def test_iou_cartesian(self):
  1112. self._run_cartesian_test(ops.box_iou)
  1113. class TestGeneralizedBoxIou(TestIouBase):
  1114. int_expected = [[1.0, 0.25, -0.7778], [0.25, 1.0, -0.8611], [-0.7778, -0.8611, 1.0], [0.0625, 0.25, -0.8819]]
  1115. float_expected = [[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]]
  1116. @pytest.mark.parametrize(
  1117. "actual_box1, actual_box2, dtypes, atol, expected",
  1118. [
  1119. pytest.param(INT_BOXES, INT_BOXES2, [torch.int16, torch.int32, torch.int64], 1e-4, int_expected),
  1120. pytest.param(FLOAT_BOXES, FLOAT_BOXES, [torch.float16], 0.002, float_expected),
  1121. pytest.param(FLOAT_BOXES, FLOAT_BOXES, [torch.float32, torch.float64], 1e-3, float_expected),
  1122. ],
  1123. )
  1124. def test_iou(self, actual_box1, actual_box2, dtypes, atol, expected):
  1125. self._run_test(ops.generalized_box_iou, actual_box1, actual_box2, dtypes, atol, expected)
  1126. def test_iou_jit(self):
  1127. self._run_jit_test(ops.generalized_box_iou, INT_BOXES)
  1128. def test_iou_cartesian(self):
  1129. self._run_cartesian_test(ops.generalized_box_iou)
  1130. class TestDistanceBoxIoU(TestIouBase):
  1131. int_expected = [
  1132. [1.0000, 0.1875, -0.4444],
  1133. [0.1875, 1.0000, -0.5625],
  1134. [-0.4444, -0.5625, 1.0000],
  1135. [-0.0781, 0.1875, -0.6267],
  1136. ]
  1137. float_expected = [[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]]
  1138. @pytest.mark.parametrize(
  1139. "actual_box1, actual_box2, dtypes, atol, expected",
  1140. [
  1141. pytest.param(INT_BOXES, INT_BOXES2, [torch.int16, torch.int32, torch.int64], 1e-4, int_expected),
  1142. pytest.param(FLOAT_BOXES, FLOAT_BOXES, [torch.float16], 0.002, float_expected),
  1143. pytest.param(FLOAT_BOXES, FLOAT_BOXES, [torch.float32, torch.float64], 1e-3, float_expected),
  1144. ],
  1145. )
  1146. def test_iou(self, actual_box1, actual_box2, dtypes, atol, expected):
  1147. self._run_test(ops.distance_box_iou, actual_box1, actual_box2, dtypes, atol, expected)
  1148. def test_iou_jit(self):
  1149. self._run_jit_test(ops.distance_box_iou, INT_BOXES)
  1150. def test_iou_cartesian(self):
  1151. self._run_cartesian_test(ops.distance_box_iou)
  1152. class TestCompleteBoxIou(TestIouBase):
  1153. int_expected = [
  1154. [1.0000, 0.1875, -0.4444],
  1155. [0.1875, 1.0000, -0.5625],
  1156. [-0.4444, -0.5625, 1.0000],
  1157. [-0.0781, 0.1875, -0.6267],
  1158. ]
  1159. float_expected = [[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]]
  1160. @pytest.mark.parametrize(
  1161. "actual_box1, actual_box2, dtypes, atol, expected",
  1162. [
  1163. pytest.param(INT_BOXES, INT_BOXES2, [torch.int16, torch.int32, torch.int64], 1e-4, int_expected),
  1164. pytest.param(FLOAT_BOXES, FLOAT_BOXES, [torch.float16], 0.002, float_expected),
  1165. pytest.param(FLOAT_BOXES, FLOAT_BOXES, [torch.float32, torch.float64], 1e-3, float_expected),
  1166. ],
  1167. )
  1168. def test_iou(self, actual_box1, actual_box2, dtypes, atol, expected):
  1169. self._run_test(ops.complete_box_iou, actual_box1, actual_box2, dtypes, atol, expected)
  1170. def test_iou_jit(self):
  1171. self._run_jit_test(ops.complete_box_iou, INT_BOXES)
  1172. def test_iou_cartesian(self):
  1173. self._run_cartesian_test(ops.complete_box_iou)
  1174. def get_boxes(dtype, device):
  1175. box1 = torch.tensor([-1, -1, 1, 1], dtype=dtype, device=device)
  1176. box2 = torch.tensor([0, 0, 1, 1], dtype=dtype, device=device)
  1177. box3 = torch.tensor([0, 1, 1, 2], dtype=dtype, device=device)
  1178. box4 = torch.tensor([1, 1, 2, 2], dtype=dtype, device=device)
  1179. box1s = torch.stack([box2, box2], dim=0)
  1180. box2s = torch.stack([box3, box4], dim=0)
  1181. return box1, box2, box3, box4, box1s, box2s
  1182. def assert_iou_loss(iou_fn, box1, box2, expected_loss, device, reduction="none"):
  1183. computed_loss = iou_fn(box1, box2, reduction=reduction)
  1184. expected_loss = torch.tensor(expected_loss, device=device)
  1185. torch.testing.assert_close(computed_loss, expected_loss)
  1186. def assert_empty_loss(iou_fn, dtype, device):
  1187. box1 = torch.randn([0, 4], dtype=dtype, device=device).requires_grad_()
  1188. box2 = torch.randn([0, 4], dtype=dtype, device=device).requires_grad_()
  1189. loss = iou_fn(box1, box2, reduction="mean")
  1190. loss.backward()
  1191. torch.testing.assert_close(loss, torch.tensor(0.0, device=device))
  1192. assert box1.grad is not None, "box1.grad should not be None after backward is called"
  1193. assert box2.grad is not None, "box2.grad should not be None after backward is called"
  1194. loss = iou_fn(box1, box2, reduction="none")
  1195. assert loss.numel() == 0, f"{str(iou_fn)} for two empty box should be empty"
  1196. class TestGeneralizedBoxIouLoss:
  1197. # We refer to original test: https://github.com/facebookresearch/fvcore/blob/main/tests/test_giou_loss.py
  1198. @pytest.mark.parametrize("device", cpu_and_cuda())
  1199. @pytest.mark.parametrize("dtype", [torch.float32, torch.half])
  1200. def test_giou_loss(self, dtype, device):
  1201. box1, box2, box3, box4, box1s, box2s = get_boxes(dtype, device)
  1202. # Identical boxes should have loss of 0
  1203. assert_iou_loss(ops.generalized_box_iou_loss, box1, box1, 0.0, device=device)
  1204. # quarter size box inside other box = IoU of 0.25
  1205. assert_iou_loss(ops.generalized_box_iou_loss, box1, box2, 0.75, device=device)
  1206. # Two side by side boxes, area=union
  1207. # IoU=0 and GIoU=0 (loss 1.0)
  1208. assert_iou_loss(ops.generalized_box_iou_loss, box2, box3, 1.0, device=device)
  1209. # Two diagonally adjacent boxes, area=2*union
  1210. # IoU=0 and GIoU=-0.5 (loss 1.5)
  1211. assert_iou_loss(ops.generalized_box_iou_loss, box2, box4, 1.5, device=device)
  1212. # Test batched loss and reductions
  1213. assert_iou_loss(ops.generalized_box_iou_loss, box1s, box2s, 2.5, device=device, reduction="sum")
  1214. assert_iou_loss(ops.generalized_box_iou_loss, box1s, box2s, 1.25, device=device, reduction="mean")
  1215. # Test reduction value
  1216. # reduction value other than ["none", "mean", "sum"] should raise a ValueError
  1217. with pytest.raises(ValueError, match="Invalid"):
  1218. ops.generalized_box_iou_loss(box1s, box2s, reduction="xyz")
  1219. @pytest.mark.parametrize("device", cpu_and_cuda())
  1220. @pytest.mark.parametrize("dtype", [torch.float32, torch.half])
  1221. def test_empty_inputs(self, dtype, device):
  1222. assert_empty_loss(ops.generalized_box_iou_loss, dtype, device)
  1223. class TestCompleteBoxIouLoss:
  1224. @pytest.mark.parametrize("dtype", [torch.float32, torch.half])
  1225. @pytest.mark.parametrize("device", cpu_and_cuda())
  1226. def test_ciou_loss(self, dtype, device):
  1227. box1, box2, box3, box4, box1s, box2s = get_boxes(dtype, device)
  1228. assert_iou_loss(ops.complete_box_iou_loss, box1, box1, 0.0, device=device)
  1229. assert_iou_loss(ops.complete_box_iou_loss, box1, box2, 0.8125, device=device)
  1230. assert_iou_loss(ops.complete_box_iou_loss, box1, box3, 1.1923, device=device)
  1231. assert_iou_loss(ops.complete_box_iou_loss, box1, box4, 1.2500, device=device)
  1232. assert_iou_loss(ops.complete_box_iou_loss, box1s, box2s, 1.2250, device=device, reduction="mean")
  1233. assert_iou_loss(ops.complete_box_iou_loss, box1s, box2s, 2.4500, device=device, reduction="sum")
  1234. with pytest.raises(ValueError, match="Invalid"):
  1235. ops.complete_box_iou_loss(box1s, box2s, reduction="xyz")
  1236. @pytest.mark.parametrize("device", cpu_and_cuda())
  1237. @pytest.mark.parametrize("dtype", [torch.float32, torch.half])
  1238. def test_empty_inputs(self, dtype, device):
  1239. assert_empty_loss(ops.complete_box_iou_loss, dtype, device)
  1240. class TestDistanceBoxIouLoss:
  1241. @pytest.mark.parametrize("device", cpu_and_cuda())
  1242. @pytest.mark.parametrize("dtype", [torch.float32, torch.half])
  1243. def test_distance_iou_loss(self, dtype, device):
  1244. box1, box2, box3, box4, box1s, box2s = get_boxes(dtype, device)
  1245. assert_iou_loss(ops.distance_box_iou_loss, box1, box1, 0.0, device=device)
  1246. assert_iou_loss(ops.distance_box_iou_loss, box1, box2, 0.8125, device=device)
  1247. assert_iou_loss(ops.distance_box_iou_loss, box1, box3, 1.1923, device=device)
  1248. assert_iou_loss(ops.distance_box_iou_loss, box1, box4, 1.2500, device=device)
  1249. assert_iou_loss(ops.distance_box_iou_loss, box1s, box2s, 1.2250, device=device, reduction="mean")
  1250. assert_iou_loss(ops.distance_box_iou_loss, box1s, box2s, 2.4500, device=device, reduction="sum")
  1251. with pytest.raises(ValueError, match="Invalid"):
  1252. ops.distance_box_iou_loss(box1s, box2s, reduction="xyz")
  1253. @pytest.mark.parametrize("device", cpu_and_cuda())
  1254. @pytest.mark.parametrize("dtype", [torch.float32, torch.half])
  1255. def test_empty_distance_iou_inputs(self, dtype, device):
  1256. assert_empty_loss(ops.distance_box_iou_loss, dtype, device)
  1257. class TestFocalLoss:
  1258. def _generate_diverse_input_target_pair(self, shape=(5, 2), **kwargs):
  1259. def logit(p):
  1260. return torch.log(p / (1 - p))
  1261. def generate_tensor_with_range_type(shape, range_type, **kwargs):
  1262. if range_type != "random_binary":
  1263. low, high = {
  1264. "small": (0.0, 0.2),
  1265. "big": (0.8, 1.0),
  1266. "zeros": (0.0, 0.0),
  1267. "ones": (1.0, 1.0),
  1268. "random": (0.0, 1.0),
  1269. }[range_type]
  1270. return torch.testing.make_tensor(shape, low=low, high=high, **kwargs)
  1271. else:
  1272. return torch.randint(0, 2, shape, **kwargs)
  1273. # This function will return inputs and targets with shape: (shape[0]*9, shape[1])
  1274. inputs = []
  1275. targets = []
  1276. for input_range_type, target_range_type in [
  1277. ("small", "zeros"),
  1278. ("small", "ones"),
  1279. ("small", "random_binary"),
  1280. ("big", "zeros"),
  1281. ("big", "ones"),
  1282. ("big", "random_binary"),
  1283. ("random", "zeros"),
  1284. ("random", "ones"),
  1285. ("random", "random_binary"),
  1286. ]:
  1287. inputs.append(logit(generate_tensor_with_range_type(shape, input_range_type, **kwargs)))
  1288. targets.append(generate_tensor_with_range_type(shape, target_range_type, **kwargs))
  1289. return torch.cat(inputs), torch.cat(targets)
  1290. @pytest.mark.parametrize("alpha", [-1.0, 0.0, 0.58, 1.0])
  1291. @pytest.mark.parametrize("gamma", [0, 2])
  1292. @pytest.mark.parametrize("device", cpu_and_cuda())
  1293. @pytest.mark.parametrize("dtype", [torch.float32, torch.half])
  1294. @pytest.mark.parametrize("seed", [0, 1])
  1295. def test_correct_ratio(self, alpha, gamma, device, dtype, seed):
  1296. if device == "cpu" and dtype is torch.half:
  1297. pytest.skip("Currently torch.half is not fully supported on cpu")
  1298. # For testing the ratio with manual calculation, we require the reduction to be "none"
  1299. reduction = "none"
  1300. torch.random.manual_seed(seed)
  1301. inputs, targets = self._generate_diverse_input_target_pair(dtype=dtype, device=device)
  1302. focal_loss = ops.sigmoid_focal_loss(inputs, targets, gamma=gamma, alpha=alpha, reduction=reduction)
  1303. ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction=reduction)
  1304. assert torch.all(
  1305. focal_loss <= ce_loss
  1306. ), "focal loss must be less or equal to cross entropy loss with same input"
  1307. loss_ratio = (focal_loss / ce_loss).squeeze()
  1308. prob = torch.sigmoid(inputs)
  1309. p_t = prob * targets + (1 - prob) * (1 - targets)
  1310. correct_ratio = (1.0 - p_t) ** gamma
  1311. if alpha >= 0:
  1312. alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
  1313. correct_ratio = correct_ratio * alpha_t
  1314. tol = 1e-3 if dtype is torch.half else 1e-5
  1315. torch.testing.assert_close(correct_ratio, loss_ratio, atol=tol, rtol=tol)
  1316. @pytest.mark.parametrize("reduction", ["mean", "sum"])
  1317. @pytest.mark.parametrize("device", cpu_and_cuda())
  1318. @pytest.mark.parametrize("dtype", [torch.float32, torch.half])
  1319. @pytest.mark.parametrize("seed", [2, 3])
  1320. def test_equal_ce_loss(self, reduction, device, dtype, seed):
  1321. if device == "cpu" and dtype is torch.half:
  1322. pytest.skip("Currently torch.half is not fully supported on cpu")
  1323. # focal loss should be equal ce_loss if alpha=-1 and gamma=0
  1324. alpha = -1
  1325. gamma = 0
  1326. torch.random.manual_seed(seed)
  1327. inputs, targets = self._generate_diverse_input_target_pair(dtype=dtype, device=device)
  1328. inputs_fl = inputs.clone().requires_grad_()
  1329. targets_fl = targets.clone()
  1330. inputs_ce = inputs.clone().requires_grad_()
  1331. targets_ce = targets.clone()
  1332. focal_loss = ops.sigmoid_focal_loss(inputs_fl, targets_fl, gamma=gamma, alpha=alpha, reduction=reduction)
  1333. ce_loss = F.binary_cross_entropy_with_logits(inputs_ce, targets_ce, reduction=reduction)
  1334. torch.testing.assert_close(focal_loss, ce_loss)
  1335. focal_loss.backward()
  1336. ce_loss.backward()
  1337. torch.testing.assert_close(inputs_fl.grad, inputs_ce.grad)
  1338. @pytest.mark.parametrize("alpha", [-1.0, 0.0, 0.58, 1.0])
  1339. @pytest.mark.parametrize("gamma", [0, 2])
  1340. @pytest.mark.parametrize("reduction", ["none", "mean", "sum"])
  1341. @pytest.mark.parametrize("device", cpu_and_cuda())
  1342. @pytest.mark.parametrize("dtype", [torch.float32, torch.half])
  1343. @pytest.mark.parametrize("seed", [4, 5])
  1344. def test_jit(self, alpha, gamma, reduction, device, dtype, seed):
  1345. if device == "cpu" and dtype is torch.half:
  1346. pytest.skip("Currently torch.half is not fully supported on cpu")
  1347. script_fn = torch.jit.script(ops.sigmoid_focal_loss)
  1348. torch.random.manual_seed(seed)
  1349. inputs, targets = self._generate_diverse_input_target_pair(dtype=dtype, device=device)
  1350. focal_loss = ops.sigmoid_focal_loss(inputs, targets, gamma=gamma, alpha=alpha, reduction=reduction)
  1351. scripted_focal_loss = script_fn(inputs, targets, gamma=gamma, alpha=alpha, reduction=reduction)
  1352. tol = 1e-3 if dtype is torch.half else 1e-5
  1353. torch.testing.assert_close(focal_loss, scripted_focal_loss, rtol=tol, atol=tol)
  1354. # Raise ValueError for anonymous reduction mode
  1355. @pytest.mark.parametrize("device", cpu_and_cuda())
  1356. @pytest.mark.parametrize("dtype", [torch.float32, torch.half])
  1357. def test_reduction_mode(self, device, dtype, reduction="xyz"):
  1358. if device == "cpu" and dtype is torch.half:
  1359. pytest.skip("Currently torch.half is not fully supported on cpu")
  1360. torch.random.manual_seed(0)
  1361. inputs, targets = self._generate_diverse_input_target_pair(device=device, dtype=dtype)
  1362. with pytest.raises(ValueError, match="Invalid"):
  1363. ops.sigmoid_focal_loss(inputs, targets, 0.25, 2, reduction)
  1364. class TestMasksToBoxes:
  1365. def test_masks_box(self):
  1366. def masks_box_check(masks, expected, atol=1e-4):
  1367. out = ops.masks_to_boxes(masks)
  1368. assert out.dtype == torch.float
  1369. torch.testing.assert_close(out, expected, rtol=0.0, check_dtype=True, atol=atol)
  1370. # Check for int type boxes.
  1371. def _get_image():
  1372. assets_directory = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets")
  1373. mask_path = os.path.join(assets_directory, "masks.tiff")
  1374. image = Image.open(mask_path)
  1375. return image
  1376. def _create_masks(image, masks):
  1377. for index in range(image.n_frames):
  1378. image.seek(index)
  1379. frame = np.array(image)
  1380. masks[index] = torch.tensor(frame)
  1381. return masks
  1382. expected = torch.tensor(
  1383. [
  1384. [127, 2, 165, 40],
  1385. [2, 50, 44, 92],
  1386. [56, 63, 98, 100],
  1387. [139, 68, 175, 104],
  1388. [160, 112, 198, 145],
  1389. [49, 138, 99, 182],
  1390. [108, 148, 152, 213],
  1391. ],
  1392. dtype=torch.float,
  1393. )
  1394. image = _get_image()
  1395. for dtype in [torch.float16, torch.float32, torch.float64]:
  1396. masks = torch.zeros((image.n_frames, image.height, image.width), dtype=dtype)
  1397. masks = _create_masks(image, masks)
  1398. masks_box_check(masks, expected)
  1399. class TestStochasticDepth:
  1400. @pytest.mark.parametrize("seed", range(10))
  1401. @pytest.mark.parametrize("p", [0.2, 0.5, 0.8])
  1402. @pytest.mark.parametrize("mode", ["batch", "row"])
  1403. def test_stochastic_depth_random(self, seed, mode, p):
  1404. torch.manual_seed(seed)
  1405. stats = pytest.importorskip("scipy.stats")
  1406. batch_size = 5
  1407. x = torch.ones(size=(batch_size, 3, 4, 4))
  1408. layer = ops.StochasticDepth(p=p, mode=mode)
  1409. layer.__repr__()
  1410. trials = 250
  1411. num_samples = 0
  1412. counts = 0
  1413. for _ in range(trials):
  1414. out = layer(x)
  1415. non_zero_count = out.sum(dim=(1, 2, 3)).nonzero().size(0)
  1416. if mode == "batch":
  1417. if non_zero_count == 0:
  1418. counts += 1
  1419. num_samples += 1
  1420. elif mode == "row":
  1421. counts += batch_size - non_zero_count
  1422. num_samples += batch_size
  1423. p_value = stats.binomtest(counts, num_samples, p=p).pvalue
  1424. assert p_value > 0.01
  1425. @pytest.mark.parametrize("seed", range(10))
  1426. @pytest.mark.parametrize("p", (0, 1))
  1427. @pytest.mark.parametrize("mode", ["batch", "row"])
  1428. def test_stochastic_depth(self, seed, mode, p):
  1429. torch.manual_seed(seed)
  1430. batch_size = 5
  1431. x = torch.ones(size=(batch_size, 3, 4, 4))
  1432. layer = ops.StochasticDepth(p=p, mode=mode)
  1433. out = layer(x)
  1434. if p == 0:
  1435. assert out.equal(x)
  1436. elif p == 1:
  1437. assert out.equal(torch.zeros_like(x))
  1438. def make_obj(self, p, mode, wrap=False):
  1439. obj = ops.StochasticDepth(p, mode)
  1440. return StochasticDepthWrapper(obj) if wrap else obj
  1441. @pytest.mark.parametrize("p", (0, 1))
  1442. @pytest.mark.parametrize("mode", ["batch", "row"])
  1443. def test_is_leaf_node(self, p, mode):
  1444. op_obj = self.make_obj(p, mode, wrap=True)
  1445. graph_node_names = get_graph_node_names(op_obj)
  1446. assert len(graph_node_names) == 2
  1447. assert len(graph_node_names[0]) == len(graph_node_names[1])
  1448. assert len(graph_node_names[0]) == 1 + op_obj.n_inputs
  1449. class TestUtils:
  1450. @pytest.mark.parametrize("norm_layer", [None, nn.BatchNorm2d, nn.LayerNorm])
  1451. def test_split_normalization_params(self, norm_layer):
  1452. model = models.mobilenet_v3_large(norm_layer=norm_layer)
  1453. params = ops._utils.split_normalization_params(model, None if norm_layer is None else [norm_layer])
  1454. assert len(params[0]) == 92
  1455. assert len(params[1]) == 82
  1456. class TestDropBlock:
  1457. @pytest.mark.parametrize("seed", range(10))
  1458. @pytest.mark.parametrize("dim", [2, 3])
  1459. @pytest.mark.parametrize("p", [0, 0.5])
  1460. @pytest.mark.parametrize("block_size", [5, 11])
  1461. @pytest.mark.parametrize("inplace", [True, False])
  1462. def test_drop_block(self, seed, dim, p, block_size, inplace):
  1463. torch.manual_seed(seed)
  1464. batch_size = 5
  1465. channels = 3
  1466. height = 11
  1467. width = height
  1468. depth = height
  1469. if dim == 2:
  1470. x = torch.ones(size=(batch_size, channels, height, width))
  1471. layer = ops.DropBlock2d(p=p, block_size=block_size, inplace=inplace)
  1472. feature_size = height * width
  1473. elif dim == 3:
  1474. x = torch.ones(size=(batch_size, channels, depth, height, width))
  1475. layer = ops.DropBlock3d(p=p, block_size=block_size, inplace=inplace)
  1476. feature_size = depth * height * width
  1477. layer.__repr__()
  1478. out = layer(x)
  1479. if p == 0:
  1480. assert out.equal(x)
  1481. if block_size == height:
  1482. for b, c in product(range(batch_size), range(channels)):
  1483. assert out[b, c].count_nonzero() in (0, feature_size)
  1484. @pytest.mark.parametrize("seed", range(10))
  1485. @pytest.mark.parametrize("dim", [2, 3])
  1486. @pytest.mark.parametrize("p", [0.1, 0.2])
  1487. @pytest.mark.parametrize("block_size", [3])
  1488. @pytest.mark.parametrize("inplace", [False])
  1489. def test_drop_block_random(self, seed, dim, p, block_size, inplace):
  1490. torch.manual_seed(seed)
  1491. batch_size = 5
  1492. channels = 3
  1493. height = 11
  1494. width = height
  1495. depth = height
  1496. if dim == 2:
  1497. x = torch.ones(size=(batch_size, channels, height, width))
  1498. layer = ops.DropBlock2d(p=p, block_size=block_size, inplace=inplace)
  1499. elif dim == 3:
  1500. x = torch.ones(size=(batch_size, channels, depth, height, width))
  1501. layer = ops.DropBlock3d(p=p, block_size=block_size, inplace=inplace)
  1502. trials = 250
  1503. num_samples = 0
  1504. counts = 0
  1505. cell_numel = torch.tensor(x.shape).prod()
  1506. for _ in range(trials):
  1507. with torch.no_grad():
  1508. out = layer(x)
  1509. non_zero_count = out.nonzero().size(0)
  1510. counts += cell_numel - non_zero_count
  1511. num_samples += cell_numel
  1512. assert abs(p - counts / num_samples) / p < 0.15
  1513. def make_obj(self, dim, p, block_size, inplace, wrap=False):
  1514. if dim == 2:
  1515. obj = ops.DropBlock2d(p, block_size, inplace)
  1516. elif dim == 3:
  1517. obj = ops.DropBlock3d(p, block_size, inplace)
  1518. return DropBlockWrapper(obj) if wrap else obj
  1519. @pytest.mark.parametrize("dim", (2, 3))
  1520. @pytest.mark.parametrize("p", [0, 1])
  1521. @pytest.mark.parametrize("block_size", [5, 7])
  1522. @pytest.mark.parametrize("inplace", [True, False])
  1523. def test_is_leaf_node(self, dim, p, block_size, inplace):
  1524. op_obj = self.make_obj(dim, p, block_size, inplace, wrap=True)
  1525. graph_node_names = get_graph_node_names(op_obj)
  1526. assert len(graph_node_names) == 2
  1527. assert len(graph_node_names[0]) == len(graph_node_names[1])
  1528. assert len(graph_node_names[0]) == 1 + op_obj.n_inputs
  1529. if __name__ == "__main__":
  1530. pytest.main([__file__])