plot_transforms_getting_started.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266
  1. """
  2. ==================================
  3. Getting started with transforms v2
  4. ==================================
  5. .. note::
  6. Try on `collab <https://colab.research.google.com/github/pytorch/vision/blob/gh-pages/main/_generated_ipynb_notebooks/plot_transforms_getting_started.ipynb>`_
  7. or :ref:`go to the end <sphx_glr_download_auto_examples_transforms_plot_transforms_getting_started.py>` to download the full example code.
  8. This example illustrates all of what you need to know to get started with the
  9. new :mod:`torchvision.transforms.v2` API. We'll cover simple tasks like
  10. image classification, and more advanced ones like object detection /
  11. segmentation.
  12. """
  13. # %%
  14. # First, a bit of setup
  15. from pathlib import Path
  16. import torch
  17. import matplotlib.pyplot as plt
  18. plt.rcParams["savefig.bbox"] = 'tight'
  19. from torchvision.transforms import v2
  20. from torchvision.io import read_image
  21. torch.manual_seed(1)
  22. # If you're trying to run that on collab, you can download the assets and the
  23. # helpers from https://github.com/pytorch/vision/tree/main/gallery/
  24. from helpers import plot
  25. img = read_image(str(Path('../assets') / 'astronaut.jpg'))
  26. print(f"{type(img) = }, {img.dtype = }, {img.shape = }")
  27. # %%
  28. # The basics
  29. # ----------
  30. #
  31. # The Torchvision transforms behave like a regular :class:`torch.nn.Module` (in
  32. # fact, most of them are): instantiate a transform, pass an input, get a
  33. # transformed output:
  34. transform = v2.RandomCrop(size=(224, 224))
  35. out = transform(img)
  36. plot([img, out])
  37. # %%
  38. # I just want to do image classification
  39. # --------------------------------------
  40. #
  41. # If you just care about image classification, things are very simple. A basic
  42. # classification pipeline may look like this:
  43. transforms = v2.Compose([
  44. v2.RandomResizedCrop(size=(224, 224), antialias=True),
  45. v2.RandomHorizontalFlip(p=0.5),
  46. v2.ToDtype(torch.float32, scale=True),
  47. v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
  48. ])
  49. out = transforms(img)
  50. plot([img, out])
  51. # %%
  52. # Such transformation pipeline is typically passed as the ``transform`` argument
  53. # to the :ref:`Datasets <datasets>`, e.g. ``ImageNet(...,
  54. # transform=transforms)``.
  55. #
  56. # That's pretty much all there is. From there, read through our :ref:`main docs
  57. # <transforms>` to learn more about recommended practices and conventions, or
  58. # explore more :ref:`examples <transforms_gallery>` e.g. how to use augmentation
  59. # transforms like :ref:`CutMix and MixUp
  60. # <sphx_glr_auto_examples_transforms_plot_cutmix_mixup.py>`.
  61. #
  62. # .. note::
  63. #
  64. # If you're already relying on the ``torchvision.transforms`` v1 API,
  65. # we recommend to :ref:`switch to the new v2 transforms<v1_or_v2>`. It's
  66. # very easy: the v2 transforms are fully compatible with the v1 API, so you
  67. # only need to change the import!
  68. #
  69. # Detection, Segmentation, Videos
  70. # -------------------------------
  71. #
  72. # The new Torchvision transforms in the ``torchvision.transforms.v2`` namespace
  73. # support tasks beyond image classification: they can also transform bounding
  74. # boxes, segmentation / detection masks, or videos.
  75. #
  76. # Let's briefly look at a detection example with bounding boxes.
  77. from torchvision import tv_tensors # we'll describe this a bit later, bare with us
  78. boxes = tv_tensors.BoundingBoxes(
  79. [
  80. [15, 10, 370, 510],
  81. [275, 340, 510, 510],
  82. [130, 345, 210, 425]
  83. ],
  84. format="XYXY", canvas_size=img.shape[-2:])
  85. transforms = v2.Compose([
  86. v2.RandomResizedCrop(size=(224, 224), antialias=True),
  87. v2.RandomPhotometricDistort(p=1),
  88. v2.RandomHorizontalFlip(p=1),
  89. ])
  90. out_img, out_boxes = transforms(img, boxes)
  91. print(type(boxes), type(out_boxes))
  92. plot([(img, boxes), (out_img, out_boxes)])
  93. # %%
  94. #
  95. # The example above focuses on object detection. But if we had masks
  96. # (:class:`torchvision.tv_tensors.Mask`) for object segmentation or semantic
  97. # segmentation, or videos (:class:`torchvision.tv_tensors.Video`), we could have
  98. # passed them to the transforms in exactly the same way.
  99. #
  100. # By now you likely have a few questions: what are these TVTensors, how do we
  101. # use them, and what is the expected input/output of those transforms? We'll
  102. # answer these in the next sections.
  103. # %%
  104. #
  105. # .. _what_are_tv_tensors:
  106. #
  107. # What are TVTensors?
  108. # --------------------
  109. #
  110. # TVTensors are :class:`torch.Tensor` subclasses. The available TVTensors are
  111. # :class:`~torchvision.tv_tensors.Image`,
  112. # :class:`~torchvision.tv_tensors.BoundingBoxes`,
  113. # :class:`~torchvision.tv_tensors.Mask`, and
  114. # :class:`~torchvision.tv_tensors.Video`.
  115. #
  116. # TVTensors look and feel just like regular tensors - they **are** tensors.
  117. # Everything that is supported on a plain :class:`torch.Tensor` like ``.sum()``
  118. # or any ``torch.*`` operator will also work on a TVTensor:
  119. img_dp = tv_tensors.Image(torch.randint(0, 256, (3, 256, 256), dtype=torch.uint8))
  120. print(f"{isinstance(img_dp, torch.Tensor) = }")
  121. print(f"{img_dp.dtype = }, {img_dp.shape = }, {img_dp.sum() = }")
  122. # %%
  123. # These TVTensor classes are at the core of the transforms: in order to
  124. # transform a given input, the transforms first look at the **class** of the
  125. # object, and dispatch to the appropriate implementation accordingly.
  126. #
  127. # You don't need to know much more about TVTensors at this point, but advanced
  128. # users who want to learn more can refer to
  129. # :ref:`sphx_glr_auto_examples_transforms_plot_tv_tensors.py`.
  130. #
  131. # What do I pass as input?
  132. # ------------------------
  133. #
  134. # Above, we've seen two examples: one where we passed a single image as input
  135. # i.e. ``out = transforms(img)``, and one where we passed both an image and
  136. # bounding boxes, i.e. ``out_img, out_boxes = transforms(img, boxes)``.
  137. #
  138. # In fact, transforms support **arbitrary input structures**. The input can be a
  139. # single image, a tuple, an arbitrarily nested dictionary... pretty much
  140. # anything. The same structure will be returned as output. Below, we use the
  141. # same detection transforms, but pass a tuple (image, target_dict) as input and
  142. # we're getting the same structure as output:
  143. target = {
  144. "boxes": boxes,
  145. "labels": torch.arange(boxes.shape[0]),
  146. "this_is_ignored": ("arbitrary", {"structure": "!"})
  147. }
  148. # Re-using the transforms and definitions from above.
  149. out_img, out_target = transforms(img, target)
  150. # sphinx_gallery_thumbnail_number = 4
  151. plot([(img, target["boxes"]), (out_img, out_target["boxes"])])
  152. print(f"{out_target['this_is_ignored']}")
  153. # %%
  154. # We passed a tuple so we get a tuple back, and the second element is the
  155. # tranformed target dict. Transforms don't really care about the structure of
  156. # the input; as mentioned above, they only care about the **type** of the
  157. # objects and transforms them accordingly.
  158. #
  159. # *Foreign* objects like strings or ints are simply passed-through. This can be
  160. # useful e.g. if you want to associate a path with every single sample when
  161. # debugging!
  162. #
  163. # .. _passthrough_heuristic:
  164. #
  165. # .. note::
  166. #
  167. # **Disclaimer** This note is slightly advanced and can be safely skipped on
  168. # a first read.
  169. #
  170. # Pure :class:`torch.Tensor` objects are, in general, treated as images (or
  171. # as videos for video-specific transforms). Indeed, you may have noticed
  172. # that in the code above we haven't used the
  173. # :class:`~torchvision.tv_tensors.Image` class at all, and yet our images
  174. # got transformed properly. Transforms follow the following logic to
  175. # determine whether a pure Tensor should be treated as an image (or video),
  176. # or just ignored:
  177. #
  178. # * If there is an :class:`~torchvision.tv_tensors.Image`,
  179. # :class:`~torchvision.tv_tensors.Video`,
  180. # or :class:`PIL.Image.Image` instance in the input, all other pure
  181. # tensors are passed-through.
  182. # * If there is no :class:`~torchvision.tv_tensors.Image` or
  183. # :class:`~torchvision.tv_tensors.Video` instance, only the first pure
  184. # :class:`torch.Tensor` will be transformed as image or video, while all
  185. # others will be passed-through. Here "first" means "first in a depth-wise
  186. # traversal".
  187. #
  188. # This is what happened in the detection example above: the first pure
  189. # tensor was the image so it got transformed properly, and all other pure
  190. # tensor instances like the ``labels`` were passed-through (although labels
  191. # can still be transformed by some transforms like
  192. # :class:`~torchvision.transforms.v2.SanitizeBoundingBoxes`!).
  193. #
  194. # .. _transforms_datasets_intercompatibility:
  195. #
  196. # Transforms and Datasets intercompatibility
  197. # ------------------------------------------
  198. #
  199. # Roughly speaking, the output of the datasets must correspond to the input of
  200. # the transforms. How to do that depends on whether you're using the torchvision
  201. # :ref:`built-in datatsets <datasets>`, or your own custom datasets.
  202. #
  203. # Using built-in datasets
  204. # ^^^^^^^^^^^^^^^^^^^^^^^
  205. #
  206. # If you're just doing image classification, you don't need to do anything. Just
  207. # use ``transform`` argument of the dataset e.g. ``ImageNet(...,
  208. # transform=transforms)`` and you're good to go.
  209. #
  210. # Torchvision also supports datasets for object detection or segmentation like
  211. # :class:`torchvision.datasets.CocoDetection`. Those datasets predate
  212. # the existence of the :mod:`torchvision.transforms.v2` module and of the
  213. # TVTensors, so they don't return TVTensors out of the box.
  214. #
  215. # An easy way to force those datasets to return TVTensors and to make them
  216. # compatible with v2 transforms is to use the
  217. # :func:`torchvision.datasets.wrap_dataset_for_transforms_v2` function:
  218. #
  219. # .. code-block:: python
  220. #
  221. # from torchvision.datasets import CocoDetection, wrap_dataset_for_transforms_v2
  222. #
  223. # dataset = CocoDetection(..., transforms=my_transforms)
  224. # dataset = wrap_dataset_for_transforms_v2(dataset)
  225. # # Now the dataset returns TVTensors!
  226. #
  227. # Using your own datasets
  228. # ^^^^^^^^^^^^^^^^^^^^^^^
  229. #
  230. # If you have a custom dataset, then you'll need to convert your objects into
  231. # the appropriate TVTensor classes. Creating TVTensor instances is very easy,
  232. # refer to :ref:`tv_tensor_creation` for more details.
  233. #
  234. # There are two main places where you can implement that conversion logic:
  235. #
  236. # - At the end of the datasets's ``__getitem__`` method, before returning the
  237. # sample (or by sub-classing the dataset).
  238. # - As the very first step of your transforms pipeline
  239. #
  240. # Either way, the logic will depend on your specific dataset.