plot_repurposing_annotations.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211
  1. """
  2. =====================================
  3. Repurposing masks into bounding boxes
  4. =====================================
  5. .. note::
  6. Try on `collab <https://colab.research.google.com/github/pytorch/vision/blob/gh-pages/main/_generated_ipynb_notebooks/plot_repurposing_annotations.ipynb>`_
  7. or :ref:`go to the end <sphx_glr_download_auto_examples_others_plot_repurposing_annotations.py>` to download the full example code.
  8. The following example illustrates the operations available
  9. the :ref:`torchvision.ops <ops>` module for repurposing
  10. segmentation masks into object localization annotations for different tasks
  11. (e.g. transforming masks used by instance and panoptic segmentation
  12. methods into bounding boxes used by object detection methods).
  13. """
  14. # sphinx_gallery_thumbnail_path = "../../gallery/assets/repurposing_annotations_thumbnail.png"
  15. import os
  16. import numpy as np
  17. import torch
  18. import matplotlib.pyplot as plt
  19. import torchvision.transforms.functional as F
  20. ASSETS_DIRECTORY = "../assets"
  21. plt.rcParams["savefig.bbox"] = "tight"
  22. def show(imgs):
  23. if not isinstance(imgs, list):
  24. imgs = [imgs]
  25. fix, axs = plt.subplots(ncols=len(imgs), squeeze=False)
  26. for i, img in enumerate(imgs):
  27. img = img.detach()
  28. img = F.to_pil_image(img)
  29. axs[0, i].imshow(np.asarray(img))
  30. axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
  31. # %%
  32. # Masks
  33. # -----
  34. # In tasks like instance and panoptic segmentation, masks are commonly defined, and are defined by this package,
  35. # as a multi-dimensional array (e.g. a NumPy array or a PyTorch tensor) with the following shape:
  36. #
  37. # (num_objects, height, width)
  38. #
  39. # Where num_objects is the number of annotated objects in the image. Each (height, width) object corresponds to exactly
  40. # one object. For example, if your input image has the dimensions 224 x 224 and has four annotated objects the shape
  41. # of your masks annotation has the following shape:
  42. #
  43. # (4, 224, 224).
  44. #
  45. # A nice property of masks is that they can be easily repurposed to be used in methods to solve a variety of object
  46. # localization tasks.
  47. # %%
  48. # Converting Masks to Bounding Boxes
  49. # -----------------------------------------------
  50. # For example, the :func:`~torchvision.ops.masks_to_boxes` operation can be used to
  51. # transform masks into bounding boxes that can be
  52. # used as input to detection models such as FasterRCNN and RetinaNet.
  53. # We will take images and masks from the `PenFudan Dataset <https://www.cis.upenn.edu/~jshi/ped_html/>`_.
  54. from torchvision.io import read_image
  55. img_path = os.path.join(ASSETS_DIRECTORY, "FudanPed00054.png")
  56. mask_path = os.path.join(ASSETS_DIRECTORY, "FudanPed00054_mask.png")
  57. img = read_image(img_path)
  58. mask = read_image(mask_path)
  59. # %%
  60. # Here the masks are represented as a PNG Image, with floating point values.
  61. # Each pixel is encoded as different colors, with 0 being background.
  62. # Notice that the spatial dimensions of image and mask match.
  63. print(mask.size())
  64. print(img.size())
  65. print(mask)
  66. # %%
  67. # We get the unique colors, as these would be the object ids.
  68. obj_ids = torch.unique(mask)
  69. # first id is the background, so remove it.
  70. obj_ids = obj_ids[1:]
  71. # split the color-encoded mask into a set of boolean masks.
  72. # Note that this snippet would work as well if the masks were float values instead of ints.
  73. masks = mask == obj_ids[:, None, None]
  74. # %%
  75. # Now the masks are a boolean tensor.
  76. # The first dimension in this case 3 and denotes the number of instances: there are 3 people in the image.
  77. # The other two dimensions are height and width, which are equal to the dimensions of the image.
  78. # For each instance, the boolean tensors represent if the particular pixel
  79. # belongs to the segmentation mask of the image.
  80. print(masks.size())
  81. print(masks)
  82. # %%
  83. # Let us visualize an image and plot its corresponding segmentation masks.
  84. # We will use the :func:`~torchvision.utils.draw_segmentation_masks` to draw the segmentation masks.
  85. from torchvision.utils import draw_segmentation_masks
  86. drawn_masks = []
  87. for mask in masks:
  88. drawn_masks.append(draw_segmentation_masks(img, mask, alpha=0.8, colors="blue"))
  89. show(drawn_masks)
  90. # %%
  91. # To convert the boolean masks into bounding boxes.
  92. # We will use the :func:`~torchvision.ops.masks_to_boxes` from the torchvision.ops module
  93. # It returns the boxes in ``(xmin, ymin, xmax, ymax)`` format.
  94. from torchvision.ops import masks_to_boxes
  95. boxes = masks_to_boxes(masks)
  96. print(boxes.size())
  97. print(boxes)
  98. # %%
  99. # As the shape denotes, there are 3 boxes and in ``(xmin, ymin, xmax, ymax)`` format.
  100. # These can be visualized very easily with :func:`~torchvision.utils.draw_bounding_boxes` utility
  101. # provided in :ref:`torchvision.utils <utils>`.
  102. from torchvision.utils import draw_bounding_boxes
  103. drawn_boxes = draw_bounding_boxes(img, boxes, colors="red")
  104. show(drawn_boxes)
  105. # %%
  106. # These boxes can now directly be used by detection models in torchvision.
  107. # Here is demo with a Faster R-CNN model loaded from
  108. # :func:`~torchvision.models.detection.fasterrcnn_resnet50_fpn`
  109. from torchvision.models.detection import fasterrcnn_resnet50_fpn, FasterRCNN_ResNet50_FPN_Weights
  110. weights = FasterRCNN_ResNet50_FPN_Weights.DEFAULT
  111. model = fasterrcnn_resnet50_fpn(weights=weights, progress=False)
  112. print(img.size())
  113. tranforms = weights.transforms()
  114. img = tranforms(img)
  115. target = {}
  116. target["boxes"] = boxes
  117. target["labels"] = labels = torch.ones((masks.size(0),), dtype=torch.int64)
  118. detection_outputs = model(img.unsqueeze(0), [target])
  119. # %%
  120. # Converting Segmentation Dataset to Detection Dataset
  121. # ----------------------------------------------------
  122. #
  123. # With this utility it becomes very simple to convert a segmentation dataset to a detection dataset.
  124. # With this we can now use a segmentation dataset to train a detection model.
  125. # One can similarly convert panoptic dataset to detection dataset.
  126. # Here is an example where we re-purpose the dataset from the
  127. # `PenFudan Detection Tutorial <https://pytorch.org/tutorials/intermediate/torchvision_tutorial.html>`_.
  128. class SegmentationToDetectionDataset(torch.utils.data.Dataset):
  129. def __init__(self, root, transforms):
  130. self.root = root
  131. self.transforms = transforms
  132. # load all image files, sorting them to
  133. # ensure that they are aligned
  134. self.imgs = list(sorted(os.listdir(os.path.join(root, "PNGImages"))))
  135. self.masks = list(sorted(os.listdir(os.path.join(root, "PedMasks"))))
  136. def __getitem__(self, idx):
  137. # load images and masks
  138. img_path = os.path.join(self.root, "PNGImages", self.imgs[idx])
  139. mask_path = os.path.join(self.root, "PedMasks", self.masks[idx])
  140. img = read_image(img_path)
  141. mask = read_image(mask_path)
  142. img = F.convert_image_dtype(img, dtype=torch.float)
  143. mask = F.convert_image_dtype(mask, dtype=torch.float)
  144. # We get the unique colors, as these would be the object ids.
  145. obj_ids = torch.unique(mask)
  146. # first id is the background, so remove it.
  147. obj_ids = obj_ids[1:]
  148. # split the color-encoded mask into a set of boolean masks.
  149. masks = mask == obj_ids[:, None, None]
  150. boxes = masks_to_boxes(masks)
  151. # there is only one class
  152. labels = torch.ones((masks.shape[0],), dtype=torch.int64)
  153. target = {}
  154. target["boxes"] = boxes
  155. target["labels"] = labels
  156. if self.transforms is not None:
  157. img, target = self.transforms(img, target)
  158. return img, target