_utils.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141
  1. import functools
  2. from typing import Any, Callable, Dict, List, Optional, Sequence, Type, Union
  3. import torch
  4. from torchvision import tv_tensors
  5. _FillType = Union[int, float, Sequence[int], Sequence[float], None]
  6. _FillTypeJIT = Optional[List[float]]
  7. def is_pure_tensor(inpt: Any) -> bool:
  8. return isinstance(inpt, torch.Tensor) and not isinstance(inpt, tv_tensors.TVTensor)
  9. # {functional: {input_type: type_specific_kernel}}
  10. _KERNEL_REGISTRY: Dict[Callable, Dict[Type, Callable]] = {}
  11. def _kernel_tv_tensor_wrapper(kernel):
  12. @functools.wraps(kernel)
  13. def wrapper(inpt, *args, **kwargs):
  14. # If you're wondering whether we could / should get rid of this wrapper,
  15. # the answer is no: we want to pass pure Tensors to avoid the overhead
  16. # of the __torch_function__ machinery. Note that this is always valid,
  17. # regardless of whether we override __torch_function__ in our base class
  18. # or not.
  19. # Also, even if we didn't call `as_subclass` here, we would still need
  20. # this wrapper to call wrap(), because the TVTensor type would be
  21. # lost after the first operation due to our own __torch_function__
  22. # logic.
  23. output = kernel(inpt.as_subclass(torch.Tensor), *args, **kwargs)
  24. return tv_tensors.wrap(output, like=inpt)
  25. return wrapper
  26. def _register_kernel_internal(functional, input_type, *, tv_tensor_wrapper=True):
  27. registry = _KERNEL_REGISTRY.setdefault(functional, {})
  28. if input_type in registry:
  29. raise ValueError(f"Functional {functional} already has a kernel registered for type {input_type}.")
  30. def decorator(kernel):
  31. registry[input_type] = (
  32. _kernel_tv_tensor_wrapper(kernel)
  33. if issubclass(input_type, tv_tensors.TVTensor) and tv_tensor_wrapper
  34. else kernel
  35. )
  36. return kernel
  37. return decorator
  38. def _name_to_functional(name):
  39. import torchvision.transforms.v2.functional # noqa
  40. try:
  41. return getattr(torchvision.transforms.v2.functional, name)
  42. except AttributeError:
  43. raise ValueError(
  44. f"Could not find functional with name '{name}' in torchvision.transforms.v2.functional."
  45. ) from None
  46. _BUILTIN_DATAPOINT_TYPES = {
  47. obj for obj in tv_tensors.__dict__.values() if isinstance(obj, type) and issubclass(obj, tv_tensors.TVTensor)
  48. }
  49. def register_kernel(functional, tv_tensor_cls):
  50. """[BETA] Decorate a kernel to register it for a functional and a (custom) tv_tensor type.
  51. See :ref:`sphx_glr_auto_examples_transforms_plot_custom_tv_tensors.py` for usage
  52. details.
  53. """
  54. if isinstance(functional, str):
  55. functional = _name_to_functional(name=functional)
  56. elif not (
  57. callable(functional)
  58. and getattr(functional, "__module__", "").startswith("torchvision.transforms.v2.functional")
  59. ):
  60. raise ValueError(
  61. f"Kernels can only be registered on functionals from the torchvision.transforms.v2.functional namespace, "
  62. f"but got {functional}."
  63. )
  64. if not (isinstance(tv_tensor_cls, type) and issubclass(tv_tensor_cls, tv_tensors.TVTensor)):
  65. raise ValueError(
  66. f"Kernels can only be registered for subclasses of torchvision.tv_tensors.TVTensor, "
  67. f"but got {tv_tensor_cls}."
  68. )
  69. if tv_tensor_cls in _BUILTIN_DATAPOINT_TYPES:
  70. raise ValueError(f"Kernels cannot be registered for the builtin tv_tensor classes, but got {tv_tensor_cls}")
  71. return _register_kernel_internal(functional, tv_tensor_cls, tv_tensor_wrapper=False)
  72. def _get_kernel(functional, input_type, *, allow_passthrough=False):
  73. registry = _KERNEL_REGISTRY.get(functional)
  74. if not registry:
  75. raise ValueError(f"No kernel registered for functional {functional.__name__}.")
  76. for cls in input_type.__mro__:
  77. if cls in registry:
  78. return registry[cls]
  79. elif cls is tv_tensors.TVTensor:
  80. # We don't want user-defined tv_tensors to dispatch to the pure Tensor kernels, so we explicit stop the
  81. # MRO traversal before hitting torch.Tensor. We can even stop at tv_tensors.TVTensor, since we don't
  82. # allow kernels to be registered for tv_tensors.TVTensor anyway.
  83. break
  84. if allow_passthrough:
  85. return lambda inpt, *args, **kwargs: inpt
  86. raise TypeError(
  87. f"Functional F.{functional.__name__} supports inputs of type {registry.keys()}, "
  88. f"but got {input_type} instead."
  89. )
  90. # This basically replicates _register_kernel_internal, but with a specialized wrapper for five_crop / ten_crop
  91. # We could get rid of this by letting _register_kernel_internal take arbitrary functionals rather than wrap_kernel: bool
  92. def _register_five_ten_crop_kernel_internal(functional, input_type):
  93. registry = _KERNEL_REGISTRY.setdefault(functional, {})
  94. if input_type in registry:
  95. raise TypeError(f"Functional '{functional}' already has a kernel registered for type '{input_type}'.")
  96. def wrap(kernel):
  97. @functools.wraps(kernel)
  98. def wrapper(inpt, *args, **kwargs):
  99. output = kernel(inpt, *args, **kwargs)
  100. container_type = type(output)
  101. return container_type(tv_tensors.wrap(o, like=inpt) for o in output)
  102. return wrapper
  103. def decorator(kernel):
  104. registry[input_type] = wrap(kernel) if issubclass(input_type, tv_tensors.TVTensor) else kernel
  105. return kernel
  106. return decorator