plot_visualization_utils.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462
  1. """
  2. =======================
  3. Visualization utilities
  4. =======================
  5. .. note::
  6. Try on `collab <https://colab.research.google.com/github/pytorch/vision/blob/gh-pages/main/_generated_ipynb_notebooks/plot_visualization_utils.ipynb>`_
  7. or :ref:`go to the end <sphx_glr_download_auto_examples_others_plot_visualization_utils.py>` to download the full example code.
  8. This example illustrates some of the utilities that torchvision offers for
  9. visualizing images, bounding boxes, segmentation masks and keypoints.
  10. """
  11. # sphinx_gallery_thumbnail_path = "../../gallery/assets/visualization_utils_thumbnail2.png"
  12. import torch
  13. import numpy as np
  14. import matplotlib.pyplot as plt
  15. import torchvision.transforms.functional as F
  16. plt.rcParams["savefig.bbox"] = 'tight'
  17. def show(imgs):
  18. if not isinstance(imgs, list):
  19. imgs = [imgs]
  20. fig, axs = plt.subplots(ncols=len(imgs), squeeze=False)
  21. for i, img in enumerate(imgs):
  22. img = img.detach()
  23. img = F.to_pil_image(img)
  24. axs[0, i].imshow(np.asarray(img))
  25. axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
  26. # %%
  27. # Visualizing a grid of images
  28. # ----------------------------
  29. # The :func:`~torchvision.utils.make_grid` function can be used to create a
  30. # tensor that represents multiple images in a grid. This util requires a single
  31. # image of dtype ``uint8`` as input.
  32. from torchvision.utils import make_grid
  33. from torchvision.io import read_image
  34. from pathlib import Path
  35. dog1_int = read_image(str(Path('../assets') / 'dog1.jpg'))
  36. dog2_int = read_image(str(Path('../assets') / 'dog2.jpg'))
  37. dog_list = [dog1_int, dog2_int]
  38. grid = make_grid(dog_list)
  39. show(grid)
  40. # %%
  41. # Visualizing bounding boxes
  42. # --------------------------
  43. # We can use :func:`~torchvision.utils.draw_bounding_boxes` to draw boxes on an
  44. # image. We can set the colors, labels, width as well as font and font size.
  45. # The boxes are in ``(xmin, ymin, xmax, ymax)`` format.
  46. from torchvision.utils import draw_bounding_boxes
  47. boxes = torch.tensor([[50, 50, 100, 200], [210, 150, 350, 430]], dtype=torch.float)
  48. colors = ["blue", "yellow"]
  49. result = draw_bounding_boxes(dog1_int, boxes, colors=colors, width=5)
  50. show(result)
  51. # %%
  52. # Naturally, we can also plot bounding boxes produced by torchvision detection
  53. # models. Here is a demo with a Faster R-CNN model loaded from
  54. # :func:`~torchvision.models.detection.fasterrcnn_resnet50_fpn`
  55. # model. For more details on the output of such models, you may
  56. # refer to :ref:`instance_seg_output`.
  57. from torchvision.models.detection import fasterrcnn_resnet50_fpn, FasterRCNN_ResNet50_FPN_Weights
  58. weights = FasterRCNN_ResNet50_FPN_Weights.DEFAULT
  59. transforms = weights.transforms()
  60. images = [transforms(d) for d in dog_list]
  61. model = fasterrcnn_resnet50_fpn(weights=weights, progress=False)
  62. model = model.eval()
  63. outputs = model(images)
  64. print(outputs)
  65. # %%
  66. # Let's plot the boxes detected by our model. We will only plot the boxes with a
  67. # score greater than a given threshold.
  68. score_threshold = .8
  69. dogs_with_boxes = [
  70. draw_bounding_boxes(dog_int, boxes=output['boxes'][output['scores'] > score_threshold], width=4)
  71. for dog_int, output in zip(dog_list, outputs)
  72. ]
  73. show(dogs_with_boxes)
  74. # %%
  75. # Visualizing segmentation masks
  76. # ------------------------------
  77. # The :func:`~torchvision.utils.draw_segmentation_masks` function can be used to
  78. # draw segmentation masks on images. Semantic segmentation and instance
  79. # segmentation models have different outputs, so we will treat each
  80. # independently.
  81. #
  82. # .. _semantic_seg_output:
  83. #
  84. # Semantic segmentation models
  85. # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  86. #
  87. # We will see how to use it with torchvision's FCN Resnet-50, loaded with
  88. # :func:`~torchvision.models.segmentation.fcn_resnet50`. Let's start by looking
  89. # at the output of the model.
  90. from torchvision.models.segmentation import fcn_resnet50, FCN_ResNet50_Weights
  91. weights = FCN_ResNet50_Weights.DEFAULT
  92. transforms = weights.transforms(resize_size=None)
  93. model = fcn_resnet50(weights=weights, progress=False)
  94. model = model.eval()
  95. batch = torch.stack([transforms(d) for d in dog_list])
  96. output = model(batch)['out']
  97. print(output.shape, output.min().item(), output.max().item())
  98. # %%
  99. # As we can see above, the output of the segmentation model is a tensor of shape
  100. # ``(batch_size, num_classes, H, W)``. Each value is a non-normalized score, and
  101. # we can normalize them into ``[0, 1]`` by using a softmax. After the softmax,
  102. # we can interpret each value as a probability indicating how likely a given
  103. # pixel is to belong to a given class.
  104. #
  105. # Let's plot the masks that have been detected for the dog class and for the
  106. # boat class:
  107. sem_class_to_idx = {cls: idx for (idx, cls) in enumerate(weights.meta["categories"])}
  108. normalized_masks = torch.nn.functional.softmax(output, dim=1)
  109. dog_and_boat_masks = [
  110. normalized_masks[img_idx, sem_class_to_idx[cls]]
  111. for img_idx in range(len(dog_list))
  112. for cls in ('dog', 'boat')
  113. ]
  114. show(dog_and_boat_masks)
  115. # %%
  116. # As expected, the model is confident about the dog class, but not so much for
  117. # the boat class.
  118. #
  119. # The :func:`~torchvision.utils.draw_segmentation_masks` function can be used to
  120. # plots those masks on top of the original image. This function expects the
  121. # masks to be boolean masks, but our masks above contain probabilities in ``[0,
  122. # 1]``. To get boolean masks, we can do the following:
  123. class_dim = 1
  124. boolean_dog_masks = (normalized_masks.argmax(class_dim) == sem_class_to_idx['dog'])
  125. print(f"shape = {boolean_dog_masks.shape}, dtype = {boolean_dog_masks.dtype}")
  126. show([m.float() for m in boolean_dog_masks])
  127. # %%
  128. # The line above where we define ``boolean_dog_masks`` is a bit cryptic, but you
  129. # can read it as the following query: "For which pixels is 'dog' the most likely
  130. # class?"
  131. #
  132. # .. note::
  133. # While we're using the ``normalized_masks`` here, we would have
  134. # gotten the same result by using the non-normalized scores of the model
  135. # directly (as the softmax operation preserves the order).
  136. #
  137. # Now that we have boolean masks, we can use them with
  138. # :func:`~torchvision.utils.draw_segmentation_masks` to plot them on top of the
  139. # original images:
  140. from torchvision.utils import draw_segmentation_masks
  141. dogs_with_masks = [
  142. draw_segmentation_masks(img, masks=mask, alpha=0.7)
  143. for img, mask in zip(dog_list, boolean_dog_masks)
  144. ]
  145. show(dogs_with_masks)
  146. # %%
  147. # We can plot more than one mask per image! Remember that the model returned as
  148. # many masks as there are classes. Let's ask the same query as above, but this
  149. # time for *all* classes, not just the dog class: "For each pixel and each class
  150. # C, is class C the most likely class?"
  151. #
  152. # This one is a bit more involved, so we'll first show how to do it with a
  153. # single image, and then we'll generalize to the batch
  154. num_classes = normalized_masks.shape[1]
  155. dog1_masks = normalized_masks[0]
  156. class_dim = 0
  157. dog1_all_classes_masks = dog1_masks.argmax(class_dim) == torch.arange(num_classes)[:, None, None]
  158. print(f"dog1_masks shape = {dog1_masks.shape}, dtype = {dog1_masks.dtype}")
  159. print(f"dog1_all_classes_masks = {dog1_all_classes_masks.shape}, dtype = {dog1_all_classes_masks.dtype}")
  160. dog_with_all_masks = draw_segmentation_masks(dog1_int, masks=dog1_all_classes_masks, alpha=.6)
  161. show(dog_with_all_masks)
  162. # %%
  163. # We can see in the image above that only 2 masks were drawn: the mask for the
  164. # background and the mask for the dog. This is because the model thinks that
  165. # only these 2 classes are the most likely ones across all the pixels. If the
  166. # model had detected another class as the most likely among other pixels, we
  167. # would have seen its mask above.
  168. #
  169. # Removing the background mask is as simple as passing
  170. # ``masks=dog1_all_classes_masks[1:]``, because the background class is the
  171. # class with index 0.
  172. #
  173. # Let's now do the same but for an entire batch of images. The code is similar
  174. # but involves a bit more juggling with the dimensions.
  175. class_dim = 1
  176. all_classes_masks = normalized_masks.argmax(class_dim) == torch.arange(num_classes)[:, None, None, None]
  177. print(f"shape = {all_classes_masks.shape}, dtype = {all_classes_masks.dtype}")
  178. # The first dimension is the classes now, so we need to swap it
  179. all_classes_masks = all_classes_masks.swapaxes(0, 1)
  180. dogs_with_masks = [
  181. draw_segmentation_masks(img, masks=mask, alpha=.6)
  182. for img, mask in zip(dog_list, all_classes_masks)
  183. ]
  184. show(dogs_with_masks)
  185. # %%
  186. # .. _instance_seg_output:
  187. #
  188. # Instance segmentation models
  189. # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  190. #
  191. # Instance segmentation models have a significantly different output from the
  192. # semantic segmentation models. We will see here how to plot the masks for such
  193. # models. Let's start by analyzing the output of a Mask-RCNN model. Note that
  194. # these models don't require the images to be normalized, so we don't need to
  195. # use the normalized batch.
  196. #
  197. # .. note::
  198. #
  199. # We will here describe the output of a Mask-RCNN model. The models in
  200. # :ref:`object_det_inst_seg_pers_keypoint_det` all have a similar output
  201. # format, but some of them may have extra info like keypoints for
  202. # :func:`~torchvision.models.detection.keypointrcnn_resnet50_fpn`, and some
  203. # of them may not have masks, like
  204. # :func:`~torchvision.models.detection.fasterrcnn_resnet50_fpn`.
  205. from torchvision.models.detection import maskrcnn_resnet50_fpn, MaskRCNN_ResNet50_FPN_Weights
  206. weights = MaskRCNN_ResNet50_FPN_Weights.DEFAULT
  207. transforms = weights.transforms()
  208. images = [transforms(d) for d in dog_list]
  209. model = maskrcnn_resnet50_fpn(weights=weights, progress=False)
  210. model = model.eval()
  211. output = model(images)
  212. print(output)
  213. # %%
  214. # Let's break this down. For each image in the batch, the model outputs some
  215. # detections (or instances). The number of detections varies for each input
  216. # image. Each instance is described by its bounding box, its label, its score
  217. # and its mask.
  218. #
  219. # The way the output is organized is as follows: the output is a list of length
  220. # ``batch_size``. Each entry in the list corresponds to an input image, and it
  221. # is a dict with keys 'boxes', 'labels', 'scores', and 'masks'. Each value
  222. # associated to those keys has ``num_instances`` elements in it. In our case
  223. # above there are 3 instances detected in the first image, and 2 instances in
  224. # the second one.
  225. #
  226. # The boxes can be plotted with :func:`~torchvision.utils.draw_bounding_boxes`
  227. # as above, but here we're more interested in the masks. These masks are quite
  228. # different from the masks that we saw above for the semantic segmentation
  229. # models.
  230. dog1_output = output[0]
  231. dog1_masks = dog1_output['masks']
  232. print(f"shape = {dog1_masks.shape}, dtype = {dog1_masks.dtype}, "
  233. f"min = {dog1_masks.min()}, max = {dog1_masks.max()}")
  234. # %%
  235. # Here the masks correspond to probabilities indicating, for each pixel, how
  236. # likely it is to belong to the predicted label of that instance. Those
  237. # predicted labels correspond to the 'labels' element in the same output dict.
  238. # Let's see which labels were predicted for the instances of the first image.
  239. print("For the first dog, the following instances were detected:")
  240. print([weights.meta["categories"][label] for label in dog1_output['labels']])
  241. # %%
  242. # Interestingly, the model detects two persons in the image. Let's go ahead and
  243. # plot those masks. Since :func:`~torchvision.utils.draw_segmentation_masks`
  244. # expects boolean masks, we need to convert those probabilities into boolean
  245. # values. Remember that the semantic of those masks is "How likely is this pixel
  246. # to belong to the predicted class?". As a result, a natural way of converting
  247. # those masks into boolean values is to threshold them with the 0.5 probability
  248. # (one could also choose a different threshold).
  249. proba_threshold = 0.5
  250. dog1_bool_masks = dog1_output['masks'] > proba_threshold
  251. print(f"shape = {dog1_bool_masks.shape}, dtype = {dog1_bool_masks.dtype}")
  252. # There's an extra dimension (1) to the masks. We need to remove it
  253. dog1_bool_masks = dog1_bool_masks.squeeze(1)
  254. show(draw_segmentation_masks(dog1_int, dog1_bool_masks, alpha=0.9))
  255. # %%
  256. # The model seems to have properly detected the dog, but it also confused trees
  257. # with people. Looking more closely at the scores will help us plot more
  258. # relevant masks:
  259. print(dog1_output['scores'])
  260. # %%
  261. # Clearly the model is more confident about the dog detection than it is about
  262. # the people detections. That's good news. When plotting the masks, we can ask
  263. # for only those that have a good score. Let's use a score threshold of .75
  264. # here, and also plot the masks of the second dog.
  265. score_threshold = .75
  266. boolean_masks = [
  267. out['masks'][out['scores'] > score_threshold] > proba_threshold
  268. for out in output
  269. ]
  270. dogs_with_masks = [
  271. draw_segmentation_masks(img, mask.squeeze(1))
  272. for img, mask in zip(dog_list, boolean_masks)
  273. ]
  274. show(dogs_with_masks)
  275. # %%
  276. # The two 'people' masks in the first image where not selected because they have
  277. # a lower score than the score threshold. Similarly, in the second image, the
  278. # instance with class 15 (which corresponds to 'bench') was not selected.
  279. # %%
  280. # .. _keypoint_output:
  281. #
  282. # Visualizing keypoints
  283. # ------------------------------
  284. # The :func:`~torchvision.utils.draw_keypoints` function can be used to
  285. # draw keypoints on images. We will see how to use it with
  286. # torchvision's KeypointRCNN loaded with :func:`~torchvision.models.detection.keypointrcnn_resnet50_fpn`.
  287. # We will first have a look at output of the model.
  288. #
  289. from torchvision.models.detection import keypointrcnn_resnet50_fpn, KeypointRCNN_ResNet50_FPN_Weights
  290. from torchvision.io import read_image
  291. person_int = read_image(str(Path("../assets") / "person1.jpg"))
  292. weights = KeypointRCNN_ResNet50_FPN_Weights.DEFAULT
  293. transforms = weights.transforms()
  294. person_float = transforms(person_int)
  295. model = keypointrcnn_resnet50_fpn(weights=weights, progress=False)
  296. model = model.eval()
  297. outputs = model([person_float])
  298. print(outputs)
  299. # %%
  300. # As we see the output contains a list of dictionaries.
  301. # The output list is of length batch_size.
  302. # We currently have just a single image so length of list is 1.
  303. # Each entry in the list corresponds to an input image,
  304. # and it is a dict with keys `boxes`, `labels`, `scores`, `keypoints` and `keypoint_scores`.
  305. # Each value associated to those keys has `num_instances` elements in it.
  306. # In our case above there are 2 instances detected in the image.
  307. kpts = outputs[0]['keypoints']
  308. scores = outputs[0]['scores']
  309. print(kpts)
  310. print(scores)
  311. # %%
  312. # The KeypointRCNN model detects there are two instances in the image.
  313. # If you plot the boxes by using :func:`~draw_bounding_boxes`
  314. # you would recognize they are the person and the surfboard.
  315. # If we look at the scores, we will realize that the model is much more confident about the person than surfboard.
  316. # We could now set a threshold confidence and plot instances which we are confident enough.
  317. # Let us set a threshold of 0.75 and filter out the keypoints corresponding to the person.
  318. detect_threshold = 0.75
  319. idx = torch.where(scores > detect_threshold)
  320. keypoints = kpts[idx]
  321. print(keypoints)
  322. # %%
  323. # Great, now we have the keypoints corresponding to the person.
  324. # Each keypoint is represented by x, y coordinates and the visibility.
  325. # We can now use the :func:`~torchvision.utils.draw_keypoints` function to draw keypoints.
  326. # Note that the utility expects uint8 images.
  327. from torchvision.utils import draw_keypoints
  328. res = draw_keypoints(person_int, keypoints, colors="blue", radius=3)
  329. show(res)
  330. # %%
  331. # As we see the keypoints appear as colored circles over the image.
  332. # The coco keypoints for a person are ordered and represent the following list.\
  333. coco_keypoints = [
  334. "nose", "left_eye", "right_eye", "left_ear", "right_ear",
  335. "left_shoulder", "right_shoulder", "left_elbow", "right_elbow",
  336. "left_wrist", "right_wrist", "left_hip", "right_hip",
  337. "left_knee", "right_knee", "left_ankle", "right_ankle",
  338. ]
  339. # %%
  340. # What if we are interested in joining the keypoints?
  341. # This is especially useful in creating pose detection or action recognition.
  342. # We can join the keypoints easily using the `connectivity` parameter.
  343. # A close observation would reveal that we would need to join the points in below
  344. # order to construct human skeleton.
  345. #
  346. # nose -> left_eye -> left_ear. (0, 1), (1, 3)
  347. #
  348. # nose -> right_eye -> right_ear. (0, 2), (2, 4)
  349. #
  350. # nose -> left_shoulder -> left_elbow -> left_wrist. (0, 5), (5, 7), (7, 9)
  351. #
  352. # nose -> right_shoulder -> right_elbow -> right_wrist. (0, 6), (6, 8), (8, 10)
  353. #
  354. # left_shoulder -> left_hip -> left_knee -> left_ankle. (5, 11), (11, 13), (13, 15)
  355. #
  356. # right_shoulder -> right_hip -> right_knee -> right_ankle. (6, 12), (12, 14), (14, 16)
  357. #
  358. # We will create a list containing these keypoint ids to be connected.
  359. connect_skeleton = [
  360. (0, 1), (0, 2), (1, 3), (2, 4), (0, 5), (0, 6), (5, 7), (6, 8),
  361. (7, 9), (8, 10), (5, 11), (6, 12), (11, 13), (12, 14), (13, 15), (14, 16)
  362. ]
  363. # %%
  364. # We pass the above list to the connectivity parameter to connect the keypoints.
  365. #
  366. res = draw_keypoints(person_int, keypoints, connectivity=connect_skeleton, colors="blue", radius=4, width=3)
  367. show(res)