plot_transforms_e2e.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181
  1. """
  2. ===============================================================
  3. Transforms v2: End-to-end object detection/segmentation example
  4. ===============================================================
  5. .. note::
  6. Try on `collab <https://colab.research.google.com/github/pytorch/vision/blob/gh-pages/main/_generated_ipynb_notebooks/plot_transforms_e2e.ipynb>`_
  7. or :ref:`go to the end <sphx_glr_download_auto_examples_transforms_plot_transforms_e2e.py>` to download the full example code.
  8. Object detection and segmentation tasks are natively supported:
  9. ``torchvision.transforms.v2`` enables jointly transforming images, videos,
  10. bounding boxes, and masks.
  11. This example showcases an end-to-end instance segmentation training case using
  12. Torchvision utils from ``torchvision.datasets``, ``torchvision.models`` and
  13. ``torchvision.transforms.v2``. Everything covered here can be applied similarly
  14. to object detection or semantic segmentation tasks.
  15. """
  16. # %%
  17. import pathlib
  18. import torch
  19. import torch.utils.data
  20. from torchvision import models, datasets, tv_tensors
  21. from torchvision.transforms import v2
  22. torch.manual_seed(0)
  23. # This loads fake data for illustration purposes of this example. In practice, you'll have
  24. # to replace this with the proper data.
  25. # If you're trying to run that on collab, you can download the assets and the
  26. # helpers from https://github.com/pytorch/vision/tree/main/gallery/
  27. ROOT = pathlib.Path("../assets") / "coco"
  28. IMAGES_PATH = str(ROOT / "images")
  29. ANNOTATIONS_PATH = str(ROOT / "instances.json")
  30. from helpers import plot
  31. # %%
  32. # Dataset preparation
  33. # -------------------
  34. #
  35. # We start off by loading the :class:`~torchvision.datasets.CocoDetection` dataset to have a look at what it currently
  36. # returns.
  37. dataset = datasets.CocoDetection(IMAGES_PATH, ANNOTATIONS_PATH)
  38. sample = dataset[0]
  39. img, target = sample
  40. print(f"{type(img) = }\n{type(target) = }\n{type(target[0]) = }\n{target[0].keys() = }")
  41. # %%
  42. # Torchvision datasets preserve the data structure and types as it was intended
  43. # by the datasets authors. So by default, the output structure may not always be
  44. # compatible with the models or the transforms.
  45. #
  46. # To overcome that, we can use the
  47. # :func:`~torchvision.datasets.wrap_dataset_for_transforms_v2` function. For
  48. # :class:`~torchvision.datasets.CocoDetection`, this changes the target
  49. # structure to a single dictionary of lists:
  50. dataset = datasets.wrap_dataset_for_transforms_v2(dataset, target_keys=("boxes", "labels", "masks"))
  51. sample = dataset[0]
  52. img, target = sample
  53. print(f"{type(img) = }\n{type(target) = }\n{target.keys() = }")
  54. print(f"{type(target['boxes']) = }\n{type(target['labels']) = }\n{type(target['masks']) = }")
  55. # %%
  56. # We used the ``target_keys`` parameter to specify the kind of output we're
  57. # interested in. Our dataset now returns a target which is dict where the values
  58. # are :ref:`TVTensors <what_are_tv_tensors>` (all are :class:`torch.Tensor`
  59. # subclasses). We're dropped all unncessary keys from the previous output, but
  60. # if you need any of the original keys e.g. "image_id", you can still ask for
  61. # it.
  62. #
  63. # .. note::
  64. #
  65. # If you just want to do detection, you don't need and shouldn't pass
  66. # "masks" in ``target_keys``: if masks are present in the sample, they will
  67. # be transformed, slowing down your transformations unnecessarily.
  68. #
  69. # As baseline, let's have a look at a sample without transformations:
  70. plot([dataset[0], dataset[1]])
  71. # %%
  72. # Transforms
  73. # ----------
  74. #
  75. # Let's now define our pre-processing transforms. All the transforms know how
  76. # to handle images, bouding boxes and masks when relevant.
  77. #
  78. # Transforms are typically passed as the ``transforms`` parameter of the
  79. # dataset so that they can leverage multi-processing from the
  80. # :class:`torch.utils.data.DataLoader`.
  81. transforms = v2.Compose(
  82. [
  83. v2.ToImage(),
  84. v2.RandomPhotometricDistort(p=1),
  85. v2.RandomZoomOut(fill={tv_tensors.Image: (123, 117, 104), "others": 0}),
  86. v2.RandomIoUCrop(),
  87. v2.RandomHorizontalFlip(p=1),
  88. v2.SanitizeBoundingBoxes(),
  89. v2.ToDtype(torch.float32, scale=True),
  90. ]
  91. )
  92. dataset = datasets.CocoDetection(IMAGES_PATH, ANNOTATIONS_PATH, transforms=transforms)
  93. dataset = datasets.wrap_dataset_for_transforms_v2(dataset, target_keys=["boxes", "labels", "masks"])
  94. # %%
  95. # A few things are worth noting here:
  96. #
  97. # - We're converting the PIL image into a
  98. # :class:`~torchvision.transforms.v2.Image` object. This isn't strictly
  99. # necessary, but relying on Tensors (here: a Tensor subclass) will
  100. # :ref:`generally be faster <transforms_perf>`.
  101. # - We are calling :class:`~torchvision.transforms.v2.SanitizeBoundingBoxes` to
  102. # make sure we remove degenerate bounding boxes, as well as their
  103. # corresponding labels and masks.
  104. # :class:`~torchvision.transforms.v2.SanitizeBoundingBoxes` should be placed
  105. # at least once at the end of a detection pipeline; it is particularly
  106. # critical if :class:`~torchvision.transforms.v2.RandomIoUCrop` was used.
  107. #
  108. # Let's look how the sample looks like with our augmentation pipeline in place:
  109. # sphinx_gallery_thumbnail_number = 2
  110. plot([dataset[0], dataset[1]])
  111. # %%
  112. # We can see that the color of the images were distorted, zoomed in or out, and flipped.
  113. # The bounding boxes and the masks were transformed accordingly. And without any further ado, we can start training.
  114. #
  115. # Data loading and training loop
  116. # ------------------------------
  117. #
  118. # Below we're using Mask-RCNN which is an instance segmentation model, but
  119. # everything we've covered in this tutorial also applies to object detection and
  120. # semantic segmentation tasks.
  121. data_loader = torch.utils.data.DataLoader(
  122. dataset,
  123. batch_size=2,
  124. # We need a custom collation function here, since the object detection
  125. # models expect a sequence of images and target dictionaries. The default
  126. # collation function tries to torch.stack() the individual elements,
  127. # which fails in general for object detection, because the number of bouding
  128. # boxes varies between the images of a same batch.
  129. collate_fn=lambda batch: tuple(zip(*batch)),
  130. )
  131. model = models.get_model("maskrcnn_resnet50_fpn_v2", weights=None, weights_backbone=None).train()
  132. for imgs, targets in data_loader:
  133. loss_dict = model(imgs, targets)
  134. # Put your training logic here
  135. print(f"{[img.shape for img in imgs] = }")
  136. print(f"{[type(target) for target in targets] = }")
  137. for name, loss_val in loss_dict.items():
  138. print(f"{name:<20}{loss_val:.3f}")
  139. # %%
  140. # Training References
  141. # -------------------
  142. #
  143. # From there, you can check out the `torchvision references
  144. # <https://github.com/pytorch/vision/tree/main/references>`_ where you'll find
  145. # the actual training scripts we use to train our models.
  146. #
  147. # **Disclaimer** The code in our references is more complex than what you'll
  148. # need for your own use-cases: this is because we're supporting different
  149. # backends (PIL, tensors, TVTensors) and different transforms namespaces (v1 and
  150. # v2). So don't be afraid to simplify and only keep what you need.