plotting.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592
  1. # Ultralytics YOLO 🚀, AGPL-3.0 license
  2. import contextlib
  3. import math
  4. import warnings
  5. from pathlib import Path
  6. import cv2
  7. import matplotlib.pyplot as plt
  8. import numpy as np
  9. import torch
  10. from PIL import Image, ImageDraw, ImageFont
  11. from PIL import __version__ as pil_version
  12. from ultralytics.utils import LOGGER, TryExcept, plt_settings, threaded
  13. from .checks import check_font, check_version, is_ascii
  14. from .files import increment_path
  15. from .ops import clip_boxes, scale_image, xywh2xyxy, xyxy2xywh
  16. class Colors:
  17. """
  18. Ultralytics default color palette https://ultralytics.com/.
  19. This class provides methods to work with the Ultralytics color palette, including converting hex color codes to
  20. RGB values.
  21. Attributes:
  22. palette (list of tuple): List of RGB color values.
  23. n (int): The number of colors in the palette.
  24. pose_palette (np.array): A specific color palette array with dtype np.uint8.
  25. """
  26. def __init__(self):
  27. """Initialize colors as hex = matplotlib.colors.TABLEAU_COLORS.values()."""
  28. hexs = ('FF3838', 'FF9D97', 'FF701F', 'FFB21D', 'CFD231', '48F90A', '92CC17', '3DDB86', '1A9334', '00D4BB',
  29. '2C99A8', '00C2FF', '344593', '6473FF', '0018EC', '8438FF', '520085', 'CB38FF', 'FF95C8', 'FF37C7')
  30. self.palette = [self.hex2rgb(f'#{c}') for c in hexs]
  31. self.n = len(self.palette)
  32. self.pose_palette = np.array([[255, 128, 0], [255, 153, 51], [255, 178, 102], [230, 230, 0], [255, 153, 255],
  33. [153, 204, 255], [255, 102, 255], [255, 51, 255], [102, 178, 255], [51, 153, 255],
  34. [255, 153, 153], [255, 102, 102], [255, 51, 51], [153, 255, 153], [102, 255, 102],
  35. [51, 255, 51], [0, 255, 0], [0, 0, 255], [255, 0, 0], [255, 255, 255]],
  36. dtype=np.uint8)
  37. def __call__(self, i, bgr=False):
  38. """Converts hex color codes to RGB values."""
  39. c = self.palette[int(i) % self.n]
  40. return (c[2], c[1], c[0]) if bgr else c
  41. @staticmethod
  42. def hex2rgb(h):
  43. """Converts hex color codes to RGB values (i.e. default PIL order)."""
  44. return tuple(int(h[1 + i:1 + i + 2], 16) for i in (0, 2, 4))
  45. colors = Colors() # create instance for 'from utils.plots import colors'
  46. class Annotator:
  47. """
  48. Ultralytics Annotator for train/val mosaics and JPGs and predictions annotations.
  49. Attributes:
  50. im (Image.Image or numpy array): The image to annotate.
  51. pil (bool): Whether to use PIL or cv2 for drawing annotations.
  52. font (ImageFont.truetype or ImageFont.load_default): Font used for text annotations.
  53. lw (float): Line width for drawing.
  54. skeleton (List[List[int]]): Skeleton structure for keypoints.
  55. limb_color (List[int]): Color palette for limbs.
  56. kpt_color (List[int]): Color palette for keypoints.
  57. """
  58. def __init__(self, im, line_width=None, font_size=None, font='Arial.ttf', pil=False, example='abc'):
  59. """Initialize the Annotator class with image and line width along with color palette for keypoints and limbs."""
  60. assert im.data.contiguous, 'Image not contiguous. Apply np.ascontiguousarray(im) to Annotator() input images.'
  61. non_ascii = not is_ascii(example) # non-latin labels, i.e. asian, arabic, cyrillic
  62. self.pil = pil or non_ascii
  63. if self.pil: # use PIL
  64. self.im = im if isinstance(im, Image.Image) else Image.fromarray(im)
  65. self.draw = ImageDraw.Draw(self.im)
  66. try:
  67. font = check_font('Arial.Unicode.ttf' if non_ascii else font)
  68. size = font_size or max(round(sum(self.im.size) / 2 * 0.035), 12)
  69. self.font = ImageFont.truetype(str(font), size)
  70. except Exception:
  71. self.font = ImageFont.load_default()
  72. # Deprecation fix for w, h = getsize(string) -> _, _, w, h = getbox(string)
  73. if check_version(pil_version, '9.2.0'):
  74. self.font.getsize = lambda x: self.font.getbbox(x)[2:4] # text width, height
  75. else: # use cv2
  76. self.im = im
  77. self.lw = line_width or max(round(sum(im.shape) / 2 * 0.003), 2) # line width
  78. # Pose
  79. self.skeleton = [[16, 14], [14, 12], [17, 15], [15, 13], [12, 13], [6, 12], [7, 13], [6, 7], [6, 8], [7, 9],
  80. [8, 10], [9, 11], [2, 3], [1, 2], [1, 3], [2, 4], [3, 5], [4, 6], [5, 7]]
  81. self.limb_color = colors.pose_palette[[9, 9, 9, 9, 7, 7, 7, 0, 0, 0, 0, 0, 16, 16, 16, 16, 16, 16, 16]]
  82. self.kpt_color = colors.pose_palette[[16, 16, 16, 16, 16, 0, 0, 0, 0, 0, 0, 9, 9, 9, 9, 9, 9]]
  83. def box_label(self, box, label='', color=(128, 128, 128), txt_color=(255, 255, 255)):
  84. """Add one xyxy box to image with label."""
  85. if isinstance(box, torch.Tensor):
  86. box = box.tolist()
  87. if self.pil or not is_ascii(label):
  88. self.draw.rectangle(box, width=self.lw, outline=color) # box
  89. if label:
  90. w, h = self.font.getsize(label) # text width, height
  91. outside = box[1] - h >= 0 # label fits outside box
  92. self.draw.rectangle(
  93. (box[0], box[1] - h if outside else box[1], box[0] + w + 1,
  94. box[1] + 1 if outside else box[1] + h + 1),
  95. fill=color,
  96. )
  97. # self.draw.text((box[0], box[1]), label, fill=txt_color, font=self.font, anchor='ls') # for PIL>8.0
  98. self.draw.text((box[0], box[1] - h if outside else box[1]), label, fill=txt_color, font=self.font)
  99. else: # cv2
  100. p1, p2 = (int(box[0]), int(box[1])), (int(box[2]), int(box[3]))
  101. cv2.rectangle(self.im, p1, p2, color, thickness=self.lw, lineType=cv2.LINE_AA)
  102. if label:
  103. tf = max(self.lw - 1, 1) # font thickness
  104. w, h = cv2.getTextSize(label, 0, fontScale=self.lw / 3, thickness=tf)[0] # text width, height
  105. outside = p1[1] - h >= 3
  106. p2 = p1[0] + w, p1[1] - h - 3 if outside else p1[1] + h + 3
  107. cv2.rectangle(self.im, p1, p2, color, -1, cv2.LINE_AA) # filled
  108. cv2.putText(self.im,
  109. label, (p1[0], p1[1] - 2 if outside else p1[1] + h + 2),
  110. 0,
  111. self.lw / 3,
  112. txt_color,
  113. thickness=tf,
  114. lineType=cv2.LINE_AA)
  115. def masks(self, masks, colors, im_gpu, alpha=0.5, retina_masks=False):
  116. """
  117. Plot masks on image.
  118. Args:
  119. masks (tensor): Predicted masks on cuda, shape: [n, h, w]
  120. colors (List[List[Int]]): Colors for predicted masks, [[r, g, b] * n]
  121. im_gpu (tensor): Image is in cuda, shape: [3, h, w], range: [0, 1]
  122. alpha (float): Mask transparency: 0.0 fully transparent, 1.0 opaque
  123. retina_masks (bool): Whether to use high resolution masks or not. Defaults to False.
  124. """
  125. if self.pil:
  126. # Convert to numpy first
  127. self.im = np.asarray(self.im).copy()
  128. if len(masks) == 0:
  129. self.im[:] = im_gpu.permute(1, 2, 0).contiguous().cpu().numpy() * 255
  130. if im_gpu.device != masks.device:
  131. im_gpu = im_gpu.to(masks.device)
  132. colors = torch.tensor(colors, device=masks.device, dtype=torch.float32) / 255.0 # shape(n,3)
  133. colors = colors[:, None, None] # shape(n,1,1,3)
  134. masks = masks.unsqueeze(3) # shape(n,h,w,1)
  135. masks_color = masks * (colors * alpha) # shape(n,h,w,3)
  136. inv_alph_masks = (1 - masks * alpha).cumprod(0) # shape(n,h,w,1)
  137. mcs = masks_color.max(dim=0).values # shape(n,h,w,3)
  138. im_gpu = im_gpu.flip(dims=[0]) # flip channel
  139. im_gpu = im_gpu.permute(1, 2, 0).contiguous() # shape(h,w,3)
  140. im_gpu = im_gpu * inv_alph_masks[-1] + mcs
  141. im_mask = (im_gpu * 255)
  142. im_mask_np = im_mask.byte().cpu().numpy()
  143. self.im[:] = im_mask_np if retina_masks else scale_image(im_mask_np, self.im.shape)
  144. if self.pil:
  145. # Convert im back to PIL and update draw
  146. self.fromarray(self.im)
  147. def kpts(self, kpts, shape=(640, 640), radius=5, kpt_line=True):
  148. """
  149. Plot keypoints on the image.
  150. Args:
  151. kpts (tensor): Predicted keypoints with shape [17, 3]. Each keypoint has (x, y, confidence).
  152. shape (tuple): Image shape as a tuple (h, w), where h is the height and w is the width.
  153. radius (int, optional): Radius of the drawn keypoints. Default is 5.
  154. kpt_line (bool, optional): If True, the function will draw lines connecting keypoints
  155. for human pose. Default is True.
  156. Note: `kpt_line=True` currently only supports human pose plotting.
  157. """
  158. if self.pil:
  159. # Convert to numpy first
  160. self.im = np.asarray(self.im).copy()
  161. nkpt, ndim = kpts.shape
  162. is_pose = nkpt == 17 and ndim == 3
  163. kpt_line &= is_pose # `kpt_line=True` for now only supports human pose plotting
  164. for i, k in enumerate(kpts):
  165. color_k = [int(x) for x in self.kpt_color[i]] if is_pose else colors(i)
  166. x_coord, y_coord = k[0], k[1]
  167. if x_coord % shape[1] != 0 and y_coord % shape[0] != 0:
  168. if len(k) == 3:
  169. conf = k[2]
  170. if conf < 0.5:
  171. continue
  172. cv2.circle(self.im, (int(x_coord), int(y_coord)), radius, color_k, -1, lineType=cv2.LINE_AA)
  173. if kpt_line:
  174. ndim = kpts.shape[-1]
  175. for i, sk in enumerate(self.skeleton):
  176. pos1 = (int(kpts[(sk[0] - 1), 0]), int(kpts[(sk[0] - 1), 1]))
  177. pos2 = (int(kpts[(sk[1] - 1), 0]), int(kpts[(sk[1] - 1), 1]))
  178. if ndim == 3:
  179. conf1 = kpts[(sk[0] - 1), 2]
  180. conf2 = kpts[(sk[1] - 1), 2]
  181. if conf1 < 0.5 or conf2 < 0.5:
  182. continue
  183. if pos1[0] % shape[1] == 0 or pos1[1] % shape[0] == 0 or pos1[0] < 0 or pos1[1] < 0:
  184. continue
  185. if pos2[0] % shape[1] == 0 or pos2[1] % shape[0] == 0 or pos2[0] < 0 or pos2[1] < 0:
  186. continue
  187. cv2.line(self.im, pos1, pos2, [int(x) for x in self.limb_color[i]], thickness=2, lineType=cv2.LINE_AA)
  188. if self.pil:
  189. # Convert im back to PIL and update draw
  190. self.fromarray(self.im)
  191. def rectangle(self, xy, fill=None, outline=None, width=1):
  192. """Add rectangle to image (PIL-only)."""
  193. self.draw.rectangle(xy, fill, outline, width)
  194. def text(self, xy, text, txt_color=(255, 255, 255), anchor='top', box_style=False):
  195. """Adds text to an image using PIL or cv2."""
  196. if anchor == 'bottom': # start y from font bottom
  197. w, h = self.font.getsize(text) # text width, height
  198. xy[1] += 1 - h
  199. if self.pil:
  200. if box_style:
  201. w, h = self.font.getsize(text)
  202. self.draw.rectangle((xy[0], xy[1], xy[0] + w + 1, xy[1] + h + 1), fill=txt_color)
  203. # Using `txt_color` for background and draw fg with white color
  204. txt_color = (255, 255, 255)
  205. if '\n' in text:
  206. lines = text.split('\n')
  207. _, h = self.font.getsize(text)
  208. for line in lines:
  209. self.draw.text(xy, line, fill=txt_color, font=self.font)
  210. xy[1] += h
  211. else:
  212. self.draw.text(xy, text, fill=txt_color, font=self.font)
  213. else:
  214. if box_style:
  215. tf = max(self.lw - 1, 1) # font thickness
  216. w, h = cv2.getTextSize(text, 0, fontScale=self.lw / 3, thickness=tf)[0] # text width, height
  217. outside = xy[1] - h >= 3
  218. p2 = xy[0] + w, xy[1] - h - 3 if outside else xy[1] + h + 3
  219. cv2.rectangle(self.im, xy, p2, txt_color, -1, cv2.LINE_AA) # filled
  220. # Using `txt_color` for background and draw fg with white color
  221. txt_color = (255, 255, 255)
  222. tf = max(self.lw - 1, 1) # font thickness
  223. cv2.putText(self.im, text, xy, 0, self.lw / 3, txt_color, thickness=tf, lineType=cv2.LINE_AA)
  224. def fromarray(self, im):
  225. """Update self.im from a numpy array."""
  226. self.im = im if isinstance(im, Image.Image) else Image.fromarray(im)
  227. self.draw = ImageDraw.Draw(self.im)
  228. def result(self):
  229. """Return annotated image as array."""
  230. return np.asarray(self.im)
  231. @TryExcept() # known issue https://github.com/ultralytics/yolov5/issues/5395
  232. @plt_settings()
  233. def plot_labels(boxes, cls, names=(), save_dir=Path(''), on_plot=None):
  234. """Plot training labels including class histograms and box statistics."""
  235. import pandas as pd
  236. import seaborn as sn
  237. # Filter matplotlib>=3.7.2 warning
  238. warnings.filterwarnings('ignore', category=UserWarning, message='The figure layout has changed to tight')
  239. # Plot dataset labels
  240. LOGGER.info(f"Plotting labels to {save_dir / 'labels.jpg'}... ")
  241. nc = int(cls.max() + 1) # number of classes
  242. boxes = boxes[:1000000] # limit to 1M boxes
  243. x = pd.DataFrame(boxes, columns=['x', 'y', 'width', 'height'])
  244. # Seaborn correlogram
  245. sn.pairplot(x, corner=True, diag_kind='auto', kind='hist', diag_kws=dict(bins=50), plot_kws=dict(pmax=0.9))
  246. plt.savefig(save_dir / 'labels_correlogram.jpg', dpi=200)
  247. plt.close()
  248. # Matplotlib labels
  249. ax = plt.subplots(2, 2, figsize=(8, 8), tight_layout=True)[1].ravel()
  250. y = ax[0].hist(cls, bins=np.linspace(0, nc, nc + 1) - 0.5, rwidth=0.8)
  251. with contextlib.suppress(Exception): # color histogram bars by class
  252. [y[2].patches[i].set_color([x / 255 for x in colors(i)]) for i in range(nc)] # known issue #3195
  253. ax[0].set_ylabel('instances')
  254. if 0 < len(names) < 30:
  255. ax[0].set_xticks(range(len(names)))
  256. ax[0].set_xticklabels(list(names.values()), rotation=90, fontsize=10)
  257. else:
  258. ax[0].set_xlabel('classes')
  259. sn.histplot(x, x='x', y='y', ax=ax[2], bins=50, pmax=0.9)
  260. sn.histplot(x, x='width', y='height', ax=ax[3], bins=50, pmax=0.9)
  261. # Rectangles
  262. boxes[:, 0:2] = 0.5 # center
  263. boxes = xywh2xyxy(boxes) * 1000
  264. img = Image.fromarray(np.ones((1000, 1000, 3), dtype=np.uint8) * 255)
  265. for cls, box in zip(cls[:500], boxes[:500]):
  266. ImageDraw.Draw(img).rectangle(box, width=1, outline=colors(cls)) # plot
  267. ax[1].imshow(img)
  268. ax[1].axis('off')
  269. for a in [0, 1, 2, 3]:
  270. for s in ['top', 'right', 'left', 'bottom']:
  271. ax[a].spines[s].set_visible(False)
  272. fname = save_dir / 'labels.jpg'
  273. plt.savefig(fname, dpi=200)
  274. plt.close()
  275. if on_plot:
  276. on_plot(fname)
  277. def save_one_box(xyxy, im, file=Path('im.jpg'), gain=1.02, pad=10, square=False, BGR=False, save=True):
  278. """Save image crop as {file} with crop size multiple {gain} and {pad} pixels. Save and/or return crop.
  279. This function takes a bounding box and an image, and then saves a cropped portion of the image according
  280. to the bounding box. Optionally, the crop can be squared, and the function allows for gain and padding
  281. adjustments to the bounding box.
  282. Args:
  283. xyxy (torch.Tensor or list): A tensor or list representing the bounding box in xyxy format.
  284. im (numpy.ndarray): The input image.
  285. file (Path, optional): The path where the cropped image will be saved. Defaults to 'im.jpg'.
  286. gain (float, optional): A multiplicative factor to increase the size of the bounding box. Defaults to 1.02.
  287. pad (int, optional): The number of pixels to add to the width and height of the bounding box. Defaults to 10.
  288. square (bool, optional): If True, the bounding box will be transformed into a square. Defaults to False.
  289. BGR (bool, optional): If True, the image will be saved in BGR format, otherwise in RGB. Defaults to False.
  290. save (bool, optional): If True, the cropped image will be saved to disk. Defaults to True.
  291. Returns:
  292. (numpy.ndarray): The cropped image.
  293. Example:
  294. ```python
  295. from ultralytics.utils.plotting import save_one_box
  296. xyxy = [50, 50, 150, 150]
  297. im = cv2.imread('image.jpg')
  298. cropped_im = save_one_box(xyxy, im, file='cropped.jpg', square=True)
  299. ```
  300. """
  301. if not isinstance(xyxy, torch.Tensor): # may be list
  302. xyxy = torch.stack(xyxy)
  303. b = xyxy2xywh(xyxy.view(-1, 4)) # boxes
  304. if square:
  305. b[:, 2:] = b[:, 2:].max(1)[0].unsqueeze(1) # attempt rectangle to square
  306. b[:, 2:] = b[:, 2:] * gain + pad # box wh * gain + pad
  307. xyxy = xywh2xyxy(b).long()
  308. clip_boxes(xyxy, im.shape)
  309. crop = im[int(xyxy[0, 1]):int(xyxy[0, 3]), int(xyxy[0, 0]):int(xyxy[0, 2]), ::(1 if BGR else -1)]
  310. if save:
  311. file.parent.mkdir(parents=True, exist_ok=True) # make directory
  312. f = str(increment_path(file).with_suffix('.jpg'))
  313. # cv2.imwrite(f, crop) # save BGR, https://github.com/ultralytics/yolov5/issues/7007 chroma subsampling issue
  314. Image.fromarray(crop[..., ::-1]).save(f, quality=95, subsampling=0) # save RGB
  315. return crop
  316. @threaded
  317. def plot_images(images,
  318. batch_idx,
  319. cls,
  320. bboxes=np.zeros(0, dtype=np.float32),
  321. masks=np.zeros(0, dtype=np.uint8),
  322. kpts=np.zeros((0, 51), dtype=np.float32),
  323. paths=None,
  324. fname='images.jpg',
  325. names=None,
  326. on_plot=None):
  327. """Plot image grid with labels."""
  328. if isinstance(images, torch.Tensor):
  329. images = images.cpu().float().numpy()
  330. if isinstance(cls, torch.Tensor):
  331. cls = cls.cpu().numpy()
  332. if isinstance(bboxes, torch.Tensor):
  333. bboxes = bboxes.cpu().numpy()
  334. if isinstance(masks, torch.Tensor):
  335. masks = masks.cpu().numpy().astype(int)
  336. if isinstance(kpts, torch.Tensor):
  337. kpts = kpts.cpu().numpy()
  338. if isinstance(batch_idx, torch.Tensor):
  339. batch_idx = batch_idx.cpu().numpy()
  340. max_size = 1920 # max image size
  341. max_subplots = 16 # max image subplots, i.e. 4x4
  342. bs, _, h, w = images.shape # batch size, _, height, width
  343. bs = min(bs, max_subplots) # limit plot images
  344. ns = np.ceil(bs ** 0.5) # number of subplots (square)
  345. if np.max(images[0]) <= 1:
  346. images *= 255 # de-normalise (optional)
  347. # Build Image
  348. mosaic = np.full((int(ns * h), int(ns * w), 3), 255, dtype=np.uint8) # init
  349. for i, im in enumerate(images):
  350. if i == max_subplots: # if last batch has fewer images than we expect
  351. break
  352. x, y = int(w * (i // ns)), int(h * (i % ns)) # block origin
  353. im = im.transpose(1, 2, 0)
  354. mosaic[y:y + h, x:x + w, :] = im
  355. # Resize (optional)
  356. scale = max_size / ns / max(h, w)
  357. if scale < 1:
  358. h = math.ceil(scale * h)
  359. w = math.ceil(scale * w)
  360. mosaic = cv2.resize(mosaic, tuple(int(x * ns) for x in (w, h)))
  361. # Annotate
  362. fs = int((h + w) * ns * 0.01) # font size
  363. annotator = Annotator(mosaic, line_width=round(fs / 10), font_size=fs, pil=True, example=names)
  364. for i in range(i + 1):
  365. x, y = int(w * (i // ns)), int(h * (i % ns)) # block origin
  366. annotator.rectangle([x, y, x + w, y + h], None, (255, 255, 255), width=2) # borders
  367. if paths:
  368. annotator.text((x + 5, y + 5), text=Path(paths[i]).name[:40], txt_color=(220, 220, 220)) # filenames
  369. if len(cls) > 0:
  370. idx = batch_idx == i
  371. classes = cls[idx].astype('int')
  372. if len(bboxes):
  373. boxes = xywh2xyxy(bboxes[idx, :4]).T
  374. labels = bboxes.shape[1] == 4 # labels if no conf column
  375. conf = None if labels else bboxes[idx, 4] # check for confidence presence (label vs pred)
  376. if boxes.shape[1]:
  377. if boxes.max() <= 1.01: # if normalized with tolerance 0.01
  378. boxes[[0, 2]] *= w # scale to pixels
  379. boxes[[1, 3]] *= h
  380. elif scale < 1: # absolute coords need scale if image scales
  381. boxes *= scale
  382. boxes[[0, 2]] += x
  383. boxes[[1, 3]] += y
  384. for j, box in enumerate(boxes.T.tolist()):
  385. c = classes[j]
  386. color = colors(c)
  387. c = names.get(c, c) if names else c
  388. if labels or conf[j] > 0.25: # 0.25 conf thresh
  389. label = f'{c}' if labels else f'{c} {conf[j]:.1f}'
  390. annotator.box_label(box, label, color=color)
  391. elif len(classes):
  392. for c in classes:
  393. color = colors(c)
  394. c = names.get(c, c) if names else c
  395. annotator.text((x, y), f'{c}', txt_color=color, box_style=True)
  396. # Plot keypoints
  397. if len(kpts):
  398. kpts_ = kpts[idx].copy()
  399. if len(kpts_):
  400. if kpts_[..., 0].max() <= 1.01 or kpts_[..., 1].max() <= 1.01: # if normalized with tolerance .01
  401. kpts_[..., 0] *= w # scale to pixels
  402. kpts_[..., 1] *= h
  403. elif scale < 1: # absolute coords need scale if image scales
  404. kpts_ *= scale
  405. kpts_[..., 0] += x
  406. kpts_[..., 1] += y
  407. for j in range(len(kpts_)):
  408. if labels or conf[j] > 0.25: # 0.25 conf thresh
  409. annotator.kpts(kpts_[j])
  410. # Plot masks
  411. if len(masks):
  412. if idx.shape[0] == masks.shape[0]: # overlap_masks=False
  413. image_masks = masks[idx]
  414. else: # overlap_masks=True
  415. image_masks = masks[[i]] # (1, 640, 640)
  416. nl = idx.sum()
  417. index = np.arange(nl).reshape((nl, 1, 1)) + 1
  418. image_masks = np.repeat(image_masks, nl, axis=0)
  419. image_masks = np.where(image_masks == index, 1.0, 0.0)
  420. im = np.asarray(annotator.im).copy()
  421. for j, box in enumerate(boxes.T.tolist()):
  422. if labels or conf[j] > 0.25: # 0.25 conf thresh
  423. color = colors(classes[j])
  424. mh, mw = image_masks[j].shape
  425. if mh != h or mw != w:
  426. mask = image_masks[j].astype(np.uint8)
  427. mask = cv2.resize(mask, (w, h))
  428. mask = mask.astype(bool)
  429. else:
  430. mask = image_masks[j].astype(bool)
  431. with contextlib.suppress(Exception):
  432. im[y:y + h, x:x + w, :][mask] = im[y:y + h, x:x + w, :][mask] * 0.4 + np.array(color) * 0.6
  433. annotator.fromarray(im)
  434. annotator.im.save(fname) # save
  435. if on_plot:
  436. on_plot(fname)
  437. @plt_settings()
  438. def plot_results(file='path/to/results.csv', dir='', segment=False, pose=False, classify=False, on_plot=None):
  439. """
  440. Plot training results from results CSV file.
  441. Example:
  442. ```python
  443. from ultralytics.utils.plotting import plot_results
  444. plot_results('path/to/results.csv')
  445. ```
  446. """
  447. import pandas as pd
  448. from scipy.ndimage import gaussian_filter1d
  449. save_dir = Path(file).parent if file else Path(dir)
  450. if classify:
  451. fig, ax = plt.subplots(2, 2, figsize=(6, 6), tight_layout=True)
  452. index = [1, 4, 2, 3]
  453. elif segment:
  454. fig, ax = plt.subplots(2, 8, figsize=(18, 6), tight_layout=True)
  455. index = [1, 2, 3, 4, 5, 6, 9, 10, 13, 14, 15, 16, 7, 8, 11, 12]
  456. elif pose:
  457. fig, ax = plt.subplots(2, 9, figsize=(21, 6), tight_layout=True)
  458. index = [1, 2, 3, 4, 5, 6, 7, 10, 11, 14, 15, 16, 17, 18, 8, 9, 12, 13]
  459. else:
  460. fig, ax = plt.subplots(2, 5, figsize=(12, 6), tight_layout=True)
  461. index = [1, 2, 3, 4, 5, 8, 9, 10, 6, 7]
  462. ax = ax.ravel()
  463. files = list(save_dir.glob('results*.csv'))
  464. assert len(files), f'No results.csv files found in {save_dir.resolve()}, nothing to plot.'
  465. for f in files:
  466. try:
  467. data = pd.read_csv(f)
  468. s = [x.strip() for x in data.columns]
  469. x = data.values[:, 0]
  470. for i, j in enumerate(index):
  471. y = data.values[:, j].astype('float')
  472. # y[y == 0] = np.nan # don't show zero values
  473. ax[i].plot(x, y, marker='.', label=f.stem, linewidth=2, markersize=8) # actual results
  474. ax[i].plot(x, gaussian_filter1d(y, sigma=3), ':', label='smooth', linewidth=2) # smoothing line
  475. ax[i].set_title(s[j], fontsize=12)
  476. # if j in [8, 9, 10]: # share train and val loss y axes
  477. # ax[i].get_shared_y_axes().join(ax[i], ax[i - 5])
  478. except Exception as e:
  479. LOGGER.warning(f'WARNING: Plotting error for {f}: {e}')
  480. ax[1].legend()
  481. fname = save_dir / 'results.png'
  482. fig.savefig(fname, dpi=200)
  483. plt.close()
  484. if on_plot:
  485. on_plot(fname)
  486. def output_to_target(output, max_det=300):
  487. """Convert model output to target format [batch_id, class_id, x, y, w, h, conf] for plotting."""
  488. targets = []
  489. for i, o in enumerate(output):
  490. box, conf, cls = o[:max_det, :6].cpu().split((4, 1, 1), 1)
  491. j = torch.full((conf.shape[0], 1), i)
  492. targets.append(torch.cat((j, cls, xyxy2xywh(box), conf), 1))
  493. targets = torch.cat(targets, 0).numpy()
  494. return targets[:, 0], targets[:, 1], targets[:, 2:]
  495. def feature_visualization(x, module_type, stage, n=32, save_dir=Path('runs/detect/exp')):
  496. """
  497. Visualize feature maps of a given model module during inference.
  498. Args:
  499. x (torch.Tensor): Features to be visualized.
  500. module_type (str): Module type.
  501. stage (int): Module stage within the model.
  502. n (int, optional): Maximum number of feature maps to plot. Defaults to 32.
  503. save_dir (Path, optional): Directory to save results. Defaults to Path('runs/detect/exp').
  504. """
  505. for m in ['Detect', 'Pose', 'Segment']:
  506. if m in module_type:
  507. return
  508. batch, channels, height, width = x.shape # batch, channels, height, width
  509. if height > 1 and width > 1:
  510. f = save_dir / f"stage{stage}_{module_type.split('.')[-1]}_features.png" # filename
  511. blocks = torch.chunk(x[0].cpu(), channels, dim=0) # select batch index 0, block by channels
  512. n = min(n, channels) # number of plots
  513. fig, ax = plt.subplots(math.ceil(n / 8), 8, tight_layout=True) # 8 rows x n/8 cols
  514. ax = ax.ravel()
  515. plt.subplots_adjust(wspace=0.05, hspace=0.05)
  516. for i in range(n):
  517. ax[i].imshow(blocks[i].squeeze()) # cmap='gray'
  518. ax[i].axis('off')
  519. LOGGER.info(f'Saving {f}... ({n}/{channels})')
  520. plt.savefig(f, dpi=300, bbox_inches='tight')
  521. plt.close()
  522. np.save(str(f.with_suffix('.npy')), x[0].cpu().numpy()) # npy save