1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950 |
- import matplotlib.pyplot as plt
- import torch
- from torchvision.utils import draw_bounding_boxes, draw_segmentation_masks
- from torchvision import tv_tensors
- from torchvision.transforms.v2 import functional as F
- def plot(imgs, row_title=None, **imshow_kwargs):
- if not isinstance(imgs[0], list):
- # Make a 2d grid even if there's just 1 row
- imgs = [imgs]
- num_rows = len(imgs)
- num_cols = len(imgs[0])
- _, axs = plt.subplots(nrows=num_rows, ncols=num_cols, squeeze=False)
- for row_idx, row in enumerate(imgs):
- for col_idx, img in enumerate(row):
- boxes = None
- masks = None
- if isinstance(img, tuple):
- img, target = img
- if isinstance(target, dict):
- boxes = target.get("boxes")
- masks = target.get("masks")
- elif isinstance(target, tv_tensors.BoundingBoxes):
- boxes = target
- else:
- raise ValueError(f"Unexpected target type: {type(target)}")
- img = F.to_image(img)
- if img.dtype.is_floating_point and img.min() < 0:
- # Poor man's re-normalization for the colors to be OK-ish. This
- # is useful for images coming out of Normalize()
- img -= img.min()
- img /= img.max()
- img = F.to_dtype(img, torch.uint8, scale=True)
- if boxes is not None:
- img = draw_bounding_boxes(img, boxes, colors="yellow", width=3)
- if masks is not None:
- img = draw_segmentation_masks(img, masks.to(torch.bool), colors=["green"] * masks.shape[0], alpha=.65)
- ax = axs[row_idx, col_idx]
- ax.imshow(img.permute(1, 2, 0).numpy(), **imshow_kwargs)
- ax.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
- if row_title is not None:
- for row_idx in range(num_rows):
- axs[row_idx, 0].set(ylabel=row_title[row_idx])
- plt.tight_layout()
|