plot_tv_tensors.py 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224
  1. """
  2. =============
  3. TVTensors FAQ
  4. =============
  5. .. note::
  6. Try on `collab <https://colab.research.google.com/github/pytorch/vision/blob/gh-pages/main/_generated_ipynb_notebooks/plot_tv_tensors.ipynb>`_
  7. or :ref:`go to the end <sphx_glr_download_auto_examples_transforms_plot_tv_tensors.py>` to download the full example code.
  8. TVTensors are Tensor subclasses introduced together with
  9. ``torchvision.transforms.v2``. This example showcases what these TVTensors are
  10. and how they behave.
  11. .. warning::
  12. **Intended Audience** Unless you're writing your own transforms or your own TVTensors, you
  13. probably do not need to read this guide. This is a fairly low-level topic
  14. that most users will not need to worry about: you do not need to understand
  15. the internals of TVTensors to efficiently rely on
  16. ``torchvision.transforms.v2``. It may however be useful for advanced users
  17. trying to implement their own datasets, transforms, or work directly with
  18. the TVTensors.
  19. """
  20. # %%
  21. import PIL.Image
  22. import torch
  23. from torchvision import tv_tensors
  24. # %%
  25. # What are TVTensors?
  26. # -------------------
  27. #
  28. # TVTensors are zero-copy tensor subclasses:
  29. tensor = torch.rand(3, 256, 256)
  30. image = tv_tensors.Image(tensor)
  31. assert isinstance(image, torch.Tensor)
  32. assert image.data_ptr() == tensor.data_ptr()
  33. # %%
  34. # Under the hood, they are needed in :mod:`torchvision.transforms.v2` to correctly dispatch to the appropriate function
  35. # for the input data.
  36. #
  37. # :mod:`torchvision.tv_tensors` supports four types of TVTensors:
  38. #
  39. # * :class:`~torchvision.tv_tensors.Image`
  40. # * :class:`~torchvision.tv_tensors.Video`
  41. # * :class:`~torchvision.tv_tensors.BoundingBoxes`
  42. # * :class:`~torchvision.tv_tensors.Mask`
  43. #
  44. # What can I do with a TVTensor?
  45. # ------------------------------
  46. #
  47. # TVTensors look and feel just like regular tensors - they **are** tensors.
  48. # Everything that is supported on a plain :class:`torch.Tensor` like ``.sum()`` or
  49. # any ``torch.*`` operator will also work on TVTensors. See
  50. # :ref:`tv_tensor_unwrapping_behaviour` for a few gotchas.
  51. # %%
  52. # .. _tv_tensor_creation:
  53. #
  54. # How do I construct a TVTensor?
  55. # ------------------------------
  56. #
  57. # Using the constructor
  58. # ^^^^^^^^^^^^^^^^^^^^^
  59. #
  60. # Each TVTensor class takes any tensor-like data that can be turned into a :class:`~torch.Tensor`
  61. image = tv_tensors.Image([[[[0, 1], [1, 0]]]])
  62. print(image)
  63. # %%
  64. # Similar to other PyTorch creations ops, the constructor also takes the ``dtype``, ``device``, and ``requires_grad``
  65. # parameters.
  66. float_image = tv_tensors.Image([[[0, 1], [1, 0]]], dtype=torch.float32, requires_grad=True)
  67. print(float_image)
  68. # %%
  69. # In addition, :class:`~torchvision.tv_tensors.Image` and :class:`~torchvision.tv_tensors.Mask` can also take a
  70. # :class:`PIL.Image.Image` directly:
  71. image = tv_tensors.Image(PIL.Image.open("../assets/astronaut.jpg"))
  72. print(image.shape, image.dtype)
  73. # %%
  74. # Some TVTensors require additional metadata to be passed in ordered to be constructed. For example,
  75. # :class:`~torchvision.tv_tensors.BoundingBoxes` requires the coordinate format as well as the size of the
  76. # corresponding image (``canvas_size``) alongside the actual values. These
  77. # metadata are required to properly transform the bounding boxes.
  78. bboxes = tv_tensors.BoundingBoxes(
  79. [[17, 16, 344, 495], [0, 10, 0, 10]],
  80. format=tv_tensors.BoundingBoxFormat.XYXY,
  81. canvas_size=image.shape[-2:]
  82. )
  83. print(bboxes)
  84. # %%
  85. # Using ``tv_tensors.wrap()``
  86. # ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  87. #
  88. # You can also use the :func:`~torchvision.tv_tensors.wrap` function to wrap a tensor object
  89. # into a TVTensor. This is useful when you already have an object of the
  90. # desired type, which typically happens when writing transforms: you just want
  91. # to wrap the output like the input.
  92. new_bboxes = torch.tensor([0, 20, 30, 40])
  93. new_bboxes = tv_tensors.wrap(new_bboxes, like=bboxes)
  94. assert isinstance(new_bboxes, tv_tensors.BoundingBoxes)
  95. assert new_bboxes.canvas_size == bboxes.canvas_size
  96. # %%
  97. # The metadata of ``new_bboxes`` is the same as ``bboxes``, but you could pass
  98. # it as a parameter to override it.
  99. #
  100. # .. _tv_tensor_unwrapping_behaviour:
  101. #
  102. # I had a TVTensor but now I have a Tensor. Help!
  103. # -----------------------------------------------
  104. #
  105. # By default, operations on :class:`~torchvision.tv_tensors.TVTensor` objects
  106. # will return a pure Tensor:
  107. assert isinstance(bboxes, tv_tensors.BoundingBoxes)
  108. # Shift bboxes by 3 pixels in both H and W
  109. new_bboxes = bboxes + 3
  110. assert isinstance(new_bboxes, torch.Tensor)
  111. assert not isinstance(new_bboxes, tv_tensors.BoundingBoxes)
  112. # %%
  113. # .. note::
  114. #
  115. # This behavior only affects native ``torch`` operations. If you are using
  116. # the built-in ``torchvision`` transforms or functionals, you will always get
  117. # as output the same type that you passed as input (pure ``Tensor`` or
  118. # ``TVTensor``).
  119. # %%
  120. # But I want a TVTensor back!
  121. # ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  122. #
  123. # You can re-wrap a pure tensor into a TVTensor by just calling the TVTensor
  124. # constructor, or by using the :func:`~torchvision.tv_tensors.wrap` function
  125. # (see more details above in :ref:`tv_tensor_creation`):
  126. new_bboxes = bboxes + 3
  127. new_bboxes = tv_tensors.wrap(new_bboxes, like=bboxes)
  128. assert isinstance(new_bboxes, tv_tensors.BoundingBoxes)
  129. # %%
  130. # Alternatively, you can use the :func:`~torchvision.tv_tensors.set_return_type`
  131. # as a global config setting for the whole program, or as a context manager
  132. # (read its docs to learn more about caveats):
  133. with tv_tensors.set_return_type("TVTensor"):
  134. new_bboxes = bboxes + 3
  135. assert isinstance(new_bboxes, tv_tensors.BoundingBoxes)
  136. # %%
  137. # Why is this happening?
  138. # ^^^^^^^^^^^^^^^^^^^^^^
  139. #
  140. # **For performance reasons**. :class:`~torchvision.tv_tensors.TVTensor`
  141. # classes are Tensor subclasses, so any operation involving a
  142. # :class:`~torchvision.tv_tensors.TVTensor` object will go through the
  143. # `__torch_function__
  144. # <https://pytorch.org/docs/stable/notes/extending.html#extending-torch>`_
  145. # protocol. This induces a small overhead, which we want to avoid when possible.
  146. # This doesn't matter for built-in ``torchvision`` transforms because we can
  147. # avoid the overhead there, but it could be a problem in your model's
  148. # ``forward``.
  149. #
  150. # **The alternative isn't much better anyway.** For every operation where
  151. # preserving the :class:`~torchvision.tv_tensors.TVTensor` type makes
  152. # sense, there are just as many operations where returning a pure Tensor is
  153. # preferable: for example, is ``img.sum()`` still an :class:`~torchvision.tv_tensors.Image`?
  154. # If we were to preserve :class:`~torchvision.tv_tensors.TVTensor` types all
  155. # the way, even model's logits or the output of the loss function would end up
  156. # being of type :class:`~torchvision.tv_tensors.Image`, and surely that's not
  157. # desirable.
  158. #
  159. # .. note::
  160. #
  161. # This behaviour is something we're actively seeking feedback on. If you find this surprising or if you
  162. # have any suggestions on how to better support your use-cases, please reach out to us via this issue:
  163. # https://github.com/pytorch/vision/issues/7319
  164. #
  165. # Exceptions
  166. # ^^^^^^^^^^
  167. #
  168. # There are a few exceptions to this "unwrapping" rule:
  169. # :meth:`~torch.Tensor.clone`, :meth:`~torch.Tensor.to`,
  170. # :meth:`torch.Tensor.detach`, and :meth:`~torch.Tensor.requires_grad_` retain
  171. # the TVTensor type.
  172. #
  173. # Inplace operations on TVTensors like ``obj.add_()`` will preserve the type of
  174. # ``obj``. However, the **returned** value of inplace operations will be a pure
  175. # tensor:
  176. image = tv_tensors.Image([[[0, 1], [1, 0]]])
  177. new_image = image.add_(1).mul_(2)
  178. # image got transformed in-place and is still a TVTensor Image, but new_image
  179. # is a Tensor. They share the same underlying data and they're equal, just
  180. # different classes.
  181. assert isinstance(image, tv_tensors.Image)
  182. print(image)
  183. assert isinstance(new_image, torch.Tensor) and not isinstance(new_image, tv_tensors.Image)
  184. assert (new_image == image).all()
  185. assert new_image.data_ptr() == image.data_ptr()