123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325 |
- import pytest
- import torchvision.transforms.v2.functional as F
- from torchvision import tv_tensors
- from transforms_v2_kernel_infos import KERNEL_INFOS, pad_xfail_jit_fill_condition
- from transforms_v2_legacy_utils import InfoBase, TestMark
- __all__ = ["DispatcherInfo", "DISPATCHER_INFOS"]
- class PILKernelInfo(InfoBase):
- def __init__(
- self,
- kernel,
- *,
- # Defaults to `kernel.__name__`. Should be set if the function is exposed under a different name
- # TODO: This can probably be removed after roll-out since we shouldn't have any aliasing then
- kernel_name=None,
- ):
- super().__init__(id=kernel_name or kernel.__name__)
- self.kernel = kernel
- class DispatcherInfo(InfoBase):
- _KERNEL_INFO_MAP = {info.kernel: info for info in KERNEL_INFOS}
- def __init__(
- self,
- dispatcher,
- *,
- # Dictionary of types that map to the kernel the dispatcher dispatches to.
- kernels,
- # If omitted, no PIL dispatch test will be performed.
- pil_kernel_info=None,
- # See InfoBase
- test_marks=None,
- # See InfoBase
- closeness_kwargs=None,
- ):
- super().__init__(id=dispatcher.__name__, test_marks=test_marks, closeness_kwargs=closeness_kwargs)
- self.dispatcher = dispatcher
- self.kernels = kernels
- self.pil_kernel_info = pil_kernel_info
- kernel_infos = {}
- for tv_tensor_type, kernel in self.kernels.items():
- kernel_info = self._KERNEL_INFO_MAP.get(kernel)
- if not kernel_info:
- raise pytest.UsageError(
- f"Can't register {kernel.__name__} for type {tv_tensor_type} since there is no `KernelInfo` for it. "
- f"Please add a `KernelInfo` for it in `transforms_v2_kernel_infos.py`."
- )
- kernel_infos[tv_tensor_type] = kernel_info
- self.kernel_infos = kernel_infos
- def sample_inputs(self, *tv_tensor_types, filter_metadata=True):
- for tv_tensor_type in tv_tensor_types or self.kernel_infos.keys():
- kernel_info = self.kernel_infos.get(tv_tensor_type)
- if not kernel_info:
- raise pytest.UsageError(f"There is no kernel registered for type {type.__name__}")
- sample_inputs = kernel_info.sample_inputs_fn()
- if not filter_metadata:
- yield from sample_inputs
- return
- import itertools
- for args_kwargs in sample_inputs:
- if hasattr(tv_tensor_type, "__annotations__"):
- for name in itertools.chain(
- tv_tensor_type.__annotations__.keys(),
- # FIXME: this seems ok for conversion dispatchers, but we should probably handle this on a
- # per-dispatcher level. However, so far there is no option for that.
- (f"old_{name}" for name in tv_tensor_type.__annotations__.keys()),
- ):
- if name in args_kwargs.kwargs:
- del args_kwargs.kwargs[name]
- yield args_kwargs
- def xfail_jit(reason, *, condition=None):
- return TestMark(
- ("TestDispatchers", "test_scripted_smoke"),
- pytest.mark.xfail(reason=reason),
- condition=condition,
- )
- def xfail_jit_python_scalar_arg(name, *, reason=None):
- return xfail_jit(
- reason or f"Python scalar int or float for `{name}` is not supported when scripting",
- condition=lambda args_kwargs: isinstance(args_kwargs.kwargs.get(name), (int, float)),
- )
- skip_dispatch_tv_tensor = TestMark(
- ("TestDispatchers", "test_dispatch_tv_tensor"),
- pytest.mark.skip(reason="Dispatcher doesn't support arbitrary tv_tensor dispatch."),
- )
- multi_crop_skips = [
- TestMark(
- ("TestDispatchers", test_name),
- pytest.mark.skip(reason="Multi-crop dispatchers return a sequence of items rather than a single one."),
- )
- for test_name in ["test_pure_tensor_output_type", "test_pil_output_type", "test_tv_tensor_output_type"]
- ]
- multi_crop_skips.append(skip_dispatch_tv_tensor)
- DISPATCHER_INFOS = [
- DispatcherInfo(
- F.resized_crop,
- kernels={
- tv_tensors.Image: F.resized_crop_image,
- tv_tensors.Video: F.resized_crop_video,
- tv_tensors.BoundingBoxes: F.resized_crop_bounding_boxes,
- tv_tensors.Mask: F.resized_crop_mask,
- },
- pil_kernel_info=PILKernelInfo(F._resized_crop_image_pil),
- ),
- DispatcherInfo(
- F.pad,
- kernels={
- tv_tensors.Image: F.pad_image,
- tv_tensors.Video: F.pad_video,
- tv_tensors.BoundingBoxes: F.pad_bounding_boxes,
- tv_tensors.Mask: F.pad_mask,
- },
- pil_kernel_info=PILKernelInfo(F._pad_image_pil, kernel_name="pad_image_pil"),
- test_marks=[
- xfail_jit("F.pad only supports vector fills for list of floats", condition=pad_xfail_jit_fill_condition),
- xfail_jit_python_scalar_arg("padding"),
- ],
- ),
- DispatcherInfo(
- F.perspective,
- kernels={
- tv_tensors.Image: F.perspective_image,
- tv_tensors.Video: F.perspective_video,
- tv_tensors.BoundingBoxes: F.perspective_bounding_boxes,
- tv_tensors.Mask: F.perspective_mask,
- },
- pil_kernel_info=PILKernelInfo(F._perspective_image_pil),
- test_marks=[
- xfail_jit_python_scalar_arg("fill"),
- ],
- ),
- DispatcherInfo(
- F.elastic,
- kernels={
- tv_tensors.Image: F.elastic_image,
- tv_tensors.Video: F.elastic_video,
- tv_tensors.BoundingBoxes: F.elastic_bounding_boxes,
- tv_tensors.Mask: F.elastic_mask,
- },
- pil_kernel_info=PILKernelInfo(F._elastic_image_pil),
- test_marks=[xfail_jit_python_scalar_arg("fill")],
- ),
- DispatcherInfo(
- F.center_crop,
- kernels={
- tv_tensors.Image: F.center_crop_image,
- tv_tensors.Video: F.center_crop_video,
- tv_tensors.BoundingBoxes: F.center_crop_bounding_boxes,
- tv_tensors.Mask: F.center_crop_mask,
- },
- pil_kernel_info=PILKernelInfo(F._center_crop_image_pil),
- test_marks=[
- xfail_jit_python_scalar_arg("output_size"),
- ],
- ),
- DispatcherInfo(
- F.gaussian_blur,
- kernels={
- tv_tensors.Image: F.gaussian_blur_image,
- tv_tensors.Video: F.gaussian_blur_video,
- },
- pil_kernel_info=PILKernelInfo(F._gaussian_blur_image_pil),
- test_marks=[
- xfail_jit_python_scalar_arg("kernel_size"),
- xfail_jit_python_scalar_arg("sigma"),
- ],
- ),
- DispatcherInfo(
- F.equalize,
- kernels={
- tv_tensors.Image: F.equalize_image,
- tv_tensors.Video: F.equalize_video,
- },
- pil_kernel_info=PILKernelInfo(F._equalize_image_pil, kernel_name="equalize_image_pil"),
- ),
- DispatcherInfo(
- F.invert,
- kernels={
- tv_tensors.Image: F.invert_image,
- tv_tensors.Video: F.invert_video,
- },
- pil_kernel_info=PILKernelInfo(F._invert_image_pil, kernel_name="invert_image_pil"),
- ),
- DispatcherInfo(
- F.posterize,
- kernels={
- tv_tensors.Image: F.posterize_image,
- tv_tensors.Video: F.posterize_video,
- },
- pil_kernel_info=PILKernelInfo(F._posterize_image_pil, kernel_name="posterize_image_pil"),
- ),
- DispatcherInfo(
- F.solarize,
- kernels={
- tv_tensors.Image: F.solarize_image,
- tv_tensors.Video: F.solarize_video,
- },
- pil_kernel_info=PILKernelInfo(F._solarize_image_pil, kernel_name="solarize_image_pil"),
- ),
- DispatcherInfo(
- F.autocontrast,
- kernels={
- tv_tensors.Image: F.autocontrast_image,
- tv_tensors.Video: F.autocontrast_video,
- },
- pil_kernel_info=PILKernelInfo(F._autocontrast_image_pil, kernel_name="autocontrast_image_pil"),
- ),
- DispatcherInfo(
- F.adjust_sharpness,
- kernels={
- tv_tensors.Image: F.adjust_sharpness_image,
- tv_tensors.Video: F.adjust_sharpness_video,
- },
- pil_kernel_info=PILKernelInfo(F._adjust_sharpness_image_pil, kernel_name="adjust_sharpness_image_pil"),
- ),
- DispatcherInfo(
- F.adjust_contrast,
- kernels={
- tv_tensors.Image: F.adjust_contrast_image,
- tv_tensors.Video: F.adjust_contrast_video,
- },
- pil_kernel_info=PILKernelInfo(F._adjust_contrast_image_pil, kernel_name="adjust_contrast_image_pil"),
- ),
- DispatcherInfo(
- F.adjust_gamma,
- kernels={
- tv_tensors.Image: F.adjust_gamma_image,
- tv_tensors.Video: F.adjust_gamma_video,
- },
- pil_kernel_info=PILKernelInfo(F._adjust_gamma_image_pil, kernel_name="adjust_gamma_image_pil"),
- ),
- DispatcherInfo(
- F.adjust_hue,
- kernels={
- tv_tensors.Image: F.adjust_hue_image,
- tv_tensors.Video: F.adjust_hue_video,
- },
- pil_kernel_info=PILKernelInfo(F._adjust_hue_image_pil, kernel_name="adjust_hue_image_pil"),
- ),
- DispatcherInfo(
- F.adjust_saturation,
- kernels={
- tv_tensors.Image: F.adjust_saturation_image,
- tv_tensors.Video: F.adjust_saturation_video,
- },
- pil_kernel_info=PILKernelInfo(F._adjust_saturation_image_pil, kernel_name="adjust_saturation_image_pil"),
- ),
- DispatcherInfo(
- F.five_crop,
- kernels={
- tv_tensors.Image: F.five_crop_image,
- tv_tensors.Video: F.five_crop_video,
- },
- pil_kernel_info=PILKernelInfo(F._five_crop_image_pil),
- test_marks=[
- xfail_jit_python_scalar_arg("size"),
- *multi_crop_skips,
- ],
- ),
- DispatcherInfo(
- F.ten_crop,
- kernels={
- tv_tensors.Image: F.ten_crop_image,
- tv_tensors.Video: F.ten_crop_video,
- },
- test_marks=[
- xfail_jit_python_scalar_arg("size"),
- *multi_crop_skips,
- ],
- pil_kernel_info=PILKernelInfo(F._ten_crop_image_pil),
- ),
- DispatcherInfo(
- F.normalize,
- kernels={
- tv_tensors.Image: F.normalize_image,
- tv_tensors.Video: F.normalize_video,
- },
- test_marks=[
- xfail_jit_python_scalar_arg("mean"),
- xfail_jit_python_scalar_arg("std"),
- ],
- ),
- DispatcherInfo(
- F.uniform_temporal_subsample,
- kernels={
- tv_tensors.Video: F.uniform_temporal_subsample_video,
- },
- test_marks=[
- skip_dispatch_tv_tensor,
- ],
- ),
- DispatcherInfo(
- F.clamp_bounding_boxes,
- kernels={tv_tensors.BoundingBoxes: F.clamp_bounding_boxes},
- test_marks=[
- skip_dispatch_tv_tensor,
- ],
- ),
- DispatcherInfo(
- F.convert_bounding_box_format,
- kernels={tv_tensors.BoundingBoxes: F.convert_bounding_box_format},
- test_marks=[
- skip_dispatch_tv_tensor,
- ],
- ),
- ]
|