transforms_v2_dispatcher_infos.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325
  1. import pytest
  2. import torchvision.transforms.v2.functional as F
  3. from torchvision import tv_tensors
  4. from transforms_v2_kernel_infos import KERNEL_INFOS, pad_xfail_jit_fill_condition
  5. from transforms_v2_legacy_utils import InfoBase, TestMark
  6. __all__ = ["DispatcherInfo", "DISPATCHER_INFOS"]
  7. class PILKernelInfo(InfoBase):
  8. def __init__(
  9. self,
  10. kernel,
  11. *,
  12. # Defaults to `kernel.__name__`. Should be set if the function is exposed under a different name
  13. # TODO: This can probably be removed after roll-out since we shouldn't have any aliasing then
  14. kernel_name=None,
  15. ):
  16. super().__init__(id=kernel_name or kernel.__name__)
  17. self.kernel = kernel
  18. class DispatcherInfo(InfoBase):
  19. _KERNEL_INFO_MAP = {info.kernel: info for info in KERNEL_INFOS}
  20. def __init__(
  21. self,
  22. dispatcher,
  23. *,
  24. # Dictionary of types that map to the kernel the dispatcher dispatches to.
  25. kernels,
  26. # If omitted, no PIL dispatch test will be performed.
  27. pil_kernel_info=None,
  28. # See InfoBase
  29. test_marks=None,
  30. # See InfoBase
  31. closeness_kwargs=None,
  32. ):
  33. super().__init__(id=dispatcher.__name__, test_marks=test_marks, closeness_kwargs=closeness_kwargs)
  34. self.dispatcher = dispatcher
  35. self.kernels = kernels
  36. self.pil_kernel_info = pil_kernel_info
  37. kernel_infos = {}
  38. for tv_tensor_type, kernel in self.kernels.items():
  39. kernel_info = self._KERNEL_INFO_MAP.get(kernel)
  40. if not kernel_info:
  41. raise pytest.UsageError(
  42. f"Can't register {kernel.__name__} for type {tv_tensor_type} since there is no `KernelInfo` for it. "
  43. f"Please add a `KernelInfo` for it in `transforms_v2_kernel_infos.py`."
  44. )
  45. kernel_infos[tv_tensor_type] = kernel_info
  46. self.kernel_infos = kernel_infos
  47. def sample_inputs(self, *tv_tensor_types, filter_metadata=True):
  48. for tv_tensor_type in tv_tensor_types or self.kernel_infos.keys():
  49. kernel_info = self.kernel_infos.get(tv_tensor_type)
  50. if not kernel_info:
  51. raise pytest.UsageError(f"There is no kernel registered for type {type.__name__}")
  52. sample_inputs = kernel_info.sample_inputs_fn()
  53. if not filter_metadata:
  54. yield from sample_inputs
  55. return
  56. import itertools
  57. for args_kwargs in sample_inputs:
  58. if hasattr(tv_tensor_type, "__annotations__"):
  59. for name in itertools.chain(
  60. tv_tensor_type.__annotations__.keys(),
  61. # FIXME: this seems ok for conversion dispatchers, but we should probably handle this on a
  62. # per-dispatcher level. However, so far there is no option for that.
  63. (f"old_{name}" for name in tv_tensor_type.__annotations__.keys()),
  64. ):
  65. if name in args_kwargs.kwargs:
  66. del args_kwargs.kwargs[name]
  67. yield args_kwargs
  68. def xfail_jit(reason, *, condition=None):
  69. return TestMark(
  70. ("TestDispatchers", "test_scripted_smoke"),
  71. pytest.mark.xfail(reason=reason),
  72. condition=condition,
  73. )
  74. def xfail_jit_python_scalar_arg(name, *, reason=None):
  75. return xfail_jit(
  76. reason or f"Python scalar int or float for `{name}` is not supported when scripting",
  77. condition=lambda args_kwargs: isinstance(args_kwargs.kwargs.get(name), (int, float)),
  78. )
  79. skip_dispatch_tv_tensor = TestMark(
  80. ("TestDispatchers", "test_dispatch_tv_tensor"),
  81. pytest.mark.skip(reason="Dispatcher doesn't support arbitrary tv_tensor dispatch."),
  82. )
  83. multi_crop_skips = [
  84. TestMark(
  85. ("TestDispatchers", test_name),
  86. pytest.mark.skip(reason="Multi-crop dispatchers return a sequence of items rather than a single one."),
  87. )
  88. for test_name in ["test_pure_tensor_output_type", "test_pil_output_type", "test_tv_tensor_output_type"]
  89. ]
  90. multi_crop_skips.append(skip_dispatch_tv_tensor)
  91. DISPATCHER_INFOS = [
  92. DispatcherInfo(
  93. F.resized_crop,
  94. kernels={
  95. tv_tensors.Image: F.resized_crop_image,
  96. tv_tensors.Video: F.resized_crop_video,
  97. tv_tensors.BoundingBoxes: F.resized_crop_bounding_boxes,
  98. tv_tensors.Mask: F.resized_crop_mask,
  99. },
  100. pil_kernel_info=PILKernelInfo(F._resized_crop_image_pil),
  101. ),
  102. DispatcherInfo(
  103. F.pad,
  104. kernels={
  105. tv_tensors.Image: F.pad_image,
  106. tv_tensors.Video: F.pad_video,
  107. tv_tensors.BoundingBoxes: F.pad_bounding_boxes,
  108. tv_tensors.Mask: F.pad_mask,
  109. },
  110. pil_kernel_info=PILKernelInfo(F._pad_image_pil, kernel_name="pad_image_pil"),
  111. test_marks=[
  112. xfail_jit("F.pad only supports vector fills for list of floats", condition=pad_xfail_jit_fill_condition),
  113. xfail_jit_python_scalar_arg("padding"),
  114. ],
  115. ),
  116. DispatcherInfo(
  117. F.perspective,
  118. kernels={
  119. tv_tensors.Image: F.perspective_image,
  120. tv_tensors.Video: F.perspective_video,
  121. tv_tensors.BoundingBoxes: F.perspective_bounding_boxes,
  122. tv_tensors.Mask: F.perspective_mask,
  123. },
  124. pil_kernel_info=PILKernelInfo(F._perspective_image_pil),
  125. test_marks=[
  126. xfail_jit_python_scalar_arg("fill"),
  127. ],
  128. ),
  129. DispatcherInfo(
  130. F.elastic,
  131. kernels={
  132. tv_tensors.Image: F.elastic_image,
  133. tv_tensors.Video: F.elastic_video,
  134. tv_tensors.BoundingBoxes: F.elastic_bounding_boxes,
  135. tv_tensors.Mask: F.elastic_mask,
  136. },
  137. pil_kernel_info=PILKernelInfo(F._elastic_image_pil),
  138. test_marks=[xfail_jit_python_scalar_arg("fill")],
  139. ),
  140. DispatcherInfo(
  141. F.center_crop,
  142. kernels={
  143. tv_tensors.Image: F.center_crop_image,
  144. tv_tensors.Video: F.center_crop_video,
  145. tv_tensors.BoundingBoxes: F.center_crop_bounding_boxes,
  146. tv_tensors.Mask: F.center_crop_mask,
  147. },
  148. pil_kernel_info=PILKernelInfo(F._center_crop_image_pil),
  149. test_marks=[
  150. xfail_jit_python_scalar_arg("output_size"),
  151. ],
  152. ),
  153. DispatcherInfo(
  154. F.gaussian_blur,
  155. kernels={
  156. tv_tensors.Image: F.gaussian_blur_image,
  157. tv_tensors.Video: F.gaussian_blur_video,
  158. },
  159. pil_kernel_info=PILKernelInfo(F._gaussian_blur_image_pil),
  160. test_marks=[
  161. xfail_jit_python_scalar_arg("kernel_size"),
  162. xfail_jit_python_scalar_arg("sigma"),
  163. ],
  164. ),
  165. DispatcherInfo(
  166. F.equalize,
  167. kernels={
  168. tv_tensors.Image: F.equalize_image,
  169. tv_tensors.Video: F.equalize_video,
  170. },
  171. pil_kernel_info=PILKernelInfo(F._equalize_image_pil, kernel_name="equalize_image_pil"),
  172. ),
  173. DispatcherInfo(
  174. F.invert,
  175. kernels={
  176. tv_tensors.Image: F.invert_image,
  177. tv_tensors.Video: F.invert_video,
  178. },
  179. pil_kernel_info=PILKernelInfo(F._invert_image_pil, kernel_name="invert_image_pil"),
  180. ),
  181. DispatcherInfo(
  182. F.posterize,
  183. kernels={
  184. tv_tensors.Image: F.posterize_image,
  185. tv_tensors.Video: F.posterize_video,
  186. },
  187. pil_kernel_info=PILKernelInfo(F._posterize_image_pil, kernel_name="posterize_image_pil"),
  188. ),
  189. DispatcherInfo(
  190. F.solarize,
  191. kernels={
  192. tv_tensors.Image: F.solarize_image,
  193. tv_tensors.Video: F.solarize_video,
  194. },
  195. pil_kernel_info=PILKernelInfo(F._solarize_image_pil, kernel_name="solarize_image_pil"),
  196. ),
  197. DispatcherInfo(
  198. F.autocontrast,
  199. kernels={
  200. tv_tensors.Image: F.autocontrast_image,
  201. tv_tensors.Video: F.autocontrast_video,
  202. },
  203. pil_kernel_info=PILKernelInfo(F._autocontrast_image_pil, kernel_name="autocontrast_image_pil"),
  204. ),
  205. DispatcherInfo(
  206. F.adjust_sharpness,
  207. kernels={
  208. tv_tensors.Image: F.adjust_sharpness_image,
  209. tv_tensors.Video: F.adjust_sharpness_video,
  210. },
  211. pil_kernel_info=PILKernelInfo(F._adjust_sharpness_image_pil, kernel_name="adjust_sharpness_image_pil"),
  212. ),
  213. DispatcherInfo(
  214. F.adjust_contrast,
  215. kernels={
  216. tv_tensors.Image: F.adjust_contrast_image,
  217. tv_tensors.Video: F.adjust_contrast_video,
  218. },
  219. pil_kernel_info=PILKernelInfo(F._adjust_contrast_image_pil, kernel_name="adjust_contrast_image_pil"),
  220. ),
  221. DispatcherInfo(
  222. F.adjust_gamma,
  223. kernels={
  224. tv_tensors.Image: F.adjust_gamma_image,
  225. tv_tensors.Video: F.adjust_gamma_video,
  226. },
  227. pil_kernel_info=PILKernelInfo(F._adjust_gamma_image_pil, kernel_name="adjust_gamma_image_pil"),
  228. ),
  229. DispatcherInfo(
  230. F.adjust_hue,
  231. kernels={
  232. tv_tensors.Image: F.adjust_hue_image,
  233. tv_tensors.Video: F.adjust_hue_video,
  234. },
  235. pil_kernel_info=PILKernelInfo(F._adjust_hue_image_pil, kernel_name="adjust_hue_image_pil"),
  236. ),
  237. DispatcherInfo(
  238. F.adjust_saturation,
  239. kernels={
  240. tv_tensors.Image: F.adjust_saturation_image,
  241. tv_tensors.Video: F.adjust_saturation_video,
  242. },
  243. pil_kernel_info=PILKernelInfo(F._adjust_saturation_image_pil, kernel_name="adjust_saturation_image_pil"),
  244. ),
  245. DispatcherInfo(
  246. F.five_crop,
  247. kernels={
  248. tv_tensors.Image: F.five_crop_image,
  249. tv_tensors.Video: F.five_crop_video,
  250. },
  251. pil_kernel_info=PILKernelInfo(F._five_crop_image_pil),
  252. test_marks=[
  253. xfail_jit_python_scalar_arg("size"),
  254. *multi_crop_skips,
  255. ],
  256. ),
  257. DispatcherInfo(
  258. F.ten_crop,
  259. kernels={
  260. tv_tensors.Image: F.ten_crop_image,
  261. tv_tensors.Video: F.ten_crop_video,
  262. },
  263. test_marks=[
  264. xfail_jit_python_scalar_arg("size"),
  265. *multi_crop_skips,
  266. ],
  267. pil_kernel_info=PILKernelInfo(F._ten_crop_image_pil),
  268. ),
  269. DispatcherInfo(
  270. F.normalize,
  271. kernels={
  272. tv_tensors.Image: F.normalize_image,
  273. tv_tensors.Video: F.normalize_video,
  274. },
  275. test_marks=[
  276. xfail_jit_python_scalar_arg("mean"),
  277. xfail_jit_python_scalar_arg("std"),
  278. ],
  279. ),
  280. DispatcherInfo(
  281. F.uniform_temporal_subsample,
  282. kernels={
  283. tv_tensors.Video: F.uniform_temporal_subsample_video,
  284. },
  285. test_marks=[
  286. skip_dispatch_tv_tensor,
  287. ],
  288. ),
  289. DispatcherInfo(
  290. F.clamp_bounding_boxes,
  291. kernels={tv_tensors.BoundingBoxes: F.clamp_bounding_boxes},
  292. test_marks=[
  293. skip_dispatch_tv_tensor,
  294. ],
  295. ),
  296. DispatcherInfo(
  297. F.convert_bounding_box_format,
  298. kernels={tv_tensors.BoundingBoxes: F.convert_bounding_box_format},
  299. test_marks=[
  300. skip_dispatch_tv_tensor,
  301. ],
  302. ),
  303. ]