helpers.py 2.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950
  1. import matplotlib.pyplot as plt
  2. import torch
  3. from torchvision.utils import draw_bounding_boxes, draw_segmentation_masks
  4. from torchvision import tv_tensors
  5. from torchvision.transforms.v2 import functional as F
  6. def plot(imgs, row_title=None, **imshow_kwargs):
  7. if not isinstance(imgs[0], list):
  8. # Make a 2d grid even if there's just 1 row
  9. imgs = [imgs]
  10. num_rows = len(imgs)
  11. num_cols = len(imgs[0])
  12. _, axs = plt.subplots(nrows=num_rows, ncols=num_cols, squeeze=False)
  13. for row_idx, row in enumerate(imgs):
  14. for col_idx, img in enumerate(row):
  15. boxes = None
  16. masks = None
  17. if isinstance(img, tuple):
  18. img, target = img
  19. if isinstance(target, dict):
  20. boxes = target.get("boxes")
  21. masks = target.get("masks")
  22. elif isinstance(target, tv_tensors.BoundingBoxes):
  23. boxes = target
  24. else:
  25. raise ValueError(f"Unexpected target type: {type(target)}")
  26. img = F.to_image(img)
  27. if img.dtype.is_floating_point and img.min() < 0:
  28. # Poor man's re-normalization for the colors to be OK-ish. This
  29. # is useful for images coming out of Normalize()
  30. img -= img.min()
  31. img /= img.max()
  32. img = F.to_dtype(img, torch.uint8, scale=True)
  33. if boxes is not None:
  34. img = draw_bounding_boxes(img, boxes, colors="yellow", width=3)
  35. if masks is not None:
  36. img = draw_segmentation_masks(img, masks.to(torch.bool), colors=["green"] * masks.shape[0], alpha=.65)
  37. ax = axs[row_idx, col_idx]
  38. ax.imshow(img.permute(1, 2, 0).numpy(), **imshow_kwargs)
  39. ax.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
  40. if row_title is not None:
  41. for row_idx in range(num_rows):
  42. axs[row_idx, 0].set(ylabel=row_title[row_idx])
  43. plt.tight_layout()