plot_custom_tv_tensors.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. """
  2. ====================================
  3. How to write your own TVTensor class
  4. ====================================
  5. .. note::
  6. Try on `collab <https://colab.research.google.com/github/pytorch/vision/blob/gh-pages/main/_generated_ipynb_notebooks/plot_custom_tv_tensors.ipynb>`_
  7. or :ref:`go to the end <sphx_glr_download_auto_examples_transforms_plot_custom_tv_tensors.py>` to download the full example code.
  8. This guide is intended for advanced users and downstream library maintainers. We explain how to
  9. write your own TVTensor class, and how to make it compatible with the built-in
  10. Torchvision v2 transforms. Before continuing, make sure you have read
  11. :ref:`sphx_glr_auto_examples_transforms_plot_tv_tensors.py`.
  12. """
  13. # %%
  14. import torch
  15. from torchvision import tv_tensors
  16. from torchvision.transforms import v2
  17. # %%
  18. # We will create a very simple class that just inherits from the base
  19. # :class:`~torchvision.tv_tensors.TVTensor` class. It will be enough to cover
  20. # what you need to know to implement your more elaborate uses-cases. If you need
  21. # to create a class that carries meta-data, take a look at how the
  22. # :class:`~torchvision.tv_tensors.BoundingBoxes` class is `implemented
  23. # <https://github.com/pytorch/vision/blob/main/torchvision/tv_tensors/_bounding_box.py>`_.
  24. class MyTVTensor(tv_tensors.TVTensor):
  25. pass
  26. my_dp = MyTVTensor([1, 2, 3])
  27. my_dp
  28. # %%
  29. # Now that we have defined our custom TVTensor class, we want it to be
  30. # compatible with the built-in torchvision transforms, and the functional API.
  31. # For that, we need to implement a kernel which performs the core of the
  32. # transformation, and then "hook" it to the functional that we want to support
  33. # via :func:`~torchvision.transforms.v2.functional.register_kernel`.
  34. #
  35. # We illustrate this process below: we create a kernel for the "horizontal flip"
  36. # operation of our MyTVTensor class, and register it to the functional API.
  37. from torchvision.transforms.v2 import functional as F
  38. @F.register_kernel(functional="hflip", tv_tensor_cls=MyTVTensor)
  39. def hflip_my_tv_tensor(my_dp, *args, **kwargs):
  40. print("Flipping!")
  41. out = my_dp.flip(-1)
  42. return tv_tensors.wrap(out, like=my_dp)
  43. # %%
  44. # To understand why :func:`~torchvision.tv_tensors.wrap` is used, see
  45. # :ref:`tv_tensor_unwrapping_behaviour`. Ignore the ``*args, **kwargs`` for now,
  46. # we will explain it below in :ref:`param_forwarding`.
  47. #
  48. # .. note::
  49. #
  50. # In our call to ``register_kernel`` above we used a string
  51. # ``functional="hflip"`` to refer to the functional we want to hook into. We
  52. # could also have used the functional *itself*, i.e.
  53. # ``@register_kernel(functional=F.hflip, ...)``.
  54. #
  55. # Now that we have registered our kernel, we can call the functional API on a
  56. # ``MyTVTensor`` instance:
  57. my_dp = MyTVTensor(torch.rand(3, 256, 256))
  58. _ = F.hflip(my_dp)
  59. # %%
  60. # And we can also use the
  61. # :class:`~torchvision.transforms.v2.RandomHorizontalFlip` transform, since it relies on :func:`~torchvision.transforms.v2.functional.hflip` internally:
  62. t = v2.RandomHorizontalFlip(p=1)
  63. _ = t(my_dp)
  64. # %%
  65. # .. note::
  66. #
  67. # We cannot register a kernel for a transform class, we can only register a
  68. # kernel for a **functional**. The reason we can't register a transform
  69. # class is because one transform may internally rely on more than one
  70. # functional, so in general we can't register a single kernel for a given
  71. # class.
  72. #
  73. # .. _param_forwarding:
  74. #
  75. # Parameter forwarding, and ensuring future compatibility of your kernels
  76. # -----------------------------------------------------------------------
  77. #
  78. # The functional API that you're hooking into is public and therefore
  79. # **backward** compatible: we guarantee that the parameters of these functionals
  80. # won't be removed or renamed without a proper deprecation cycle. However, we
  81. # don't guarantee **forward** compatibility, and we may add new parameters in
  82. # the future.
  83. #
  84. # Imagine that in a future version, Torchvision adds a new ``inplace`` parameter
  85. # to its :func:`~torchvision.transforms.v2.functional.hflip` functional. If you
  86. # already defined and registered your own kernel as
  87. def hflip_my_tv_tensor(my_dp): # noqa
  88. print("Flipping!")
  89. out = my_dp.flip(-1)
  90. return tv_tensors.wrap(out, like=my_dp)
  91. # %%
  92. # then calling ``F.hflip(my_dp)`` will **fail**, because ``hflip`` will try to
  93. # pass the new ``inplace`` parameter to your kernel, but your kernel doesn't
  94. # accept it.
  95. #
  96. # For this reason, we recommend to always define your kernels with
  97. # ``*args, **kwargs`` in their signature, as done above. This way, your kernel
  98. # will be able to accept any new parameter that we may add in the future.
  99. # (Technically, adding `**kwargs` only should be enough).