plots.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446
  1. # Plotting utils
  2. import glob
  3. import math
  4. import os
  5. import random
  6. from copy import copy
  7. from pathlib import Path
  8. import cv2
  9. import matplotlib
  10. import matplotlib.pyplot as plt
  11. import numpy as np
  12. import pandas as pd
  13. import seaborn as sns
  14. import torch
  15. import yaml
  16. from PIL import Image, ImageDraw, ImageFont
  17. from utils.general import xywh2xyxy, xyxy2xywh
  18. from utils.metrics import fitness
  19. # Settings
  20. matplotlib.rc('font', **{'size': 11})
  21. matplotlib.use('Agg') # for writing to files only
  22. class Colors:
  23. # Ultralytics color palette https://ultralytics.com/
  24. def __init__(self):
  25. # hex = matplotlib.colors.TABLEAU_COLORS.values()
  26. hex = ('FF3838', 'FF9D97', 'FF701F', 'FFB21D', 'CFD231', '48F90A', '92CC17', '3DDB86', '1A9334', '00D4BB',
  27. '2C99A8', '00C2FF', '344593', '6473FF', '0018EC', '8438FF', '520085', 'CB38FF', 'FF95C8', 'FF37C7')
  28. self.palette = [self.hex2rgb('#' + c) for c in hex]
  29. self.n = len(self.palette)
  30. def __call__(self, i, bgr=False):
  31. c = self.palette[int(i) % self.n]
  32. return (c[2], c[1], c[0]) if bgr else c
  33. @staticmethod
  34. def hex2rgb(h): # rgb order (PIL)
  35. return tuple(int(h[1 + i:1 + i + 2], 16) for i in (0, 2, 4))
  36. colors = Colors() # create instance for 'from utils.plots import colors'
  37. def hist2d(x, y, n=100):
  38. # 2d histogram used in labels.png and evolve.png
  39. xedges, yedges = np.linspace(x.min(), x.max(), n), np.linspace(y.min(), y.max(), n)
  40. hist, xedges, yedges = np.histogram2d(x, y, (xedges, yedges))
  41. xidx = np.clip(np.digitize(x, xedges) - 1, 0, hist.shape[0] - 1)
  42. yidx = np.clip(np.digitize(y, yedges) - 1, 0, hist.shape[1] - 1)
  43. return np.log(hist[xidx, yidx])
  44. def butter_lowpass_filtfilt(data, cutoff=1500, fs=50000, order=5):
  45. from scipy.signal import butter, filtfilt
  46. # https://stackoverflow.com/questions/28536191/how-to-filter-smooth-with-scipy-numpy
  47. def butter_lowpass(cutoff, fs, order):
  48. nyq = 0.5 * fs
  49. normal_cutoff = cutoff / nyq
  50. return butter(order, normal_cutoff, btype='low', analog=False)
  51. b, a = butter_lowpass(cutoff, fs, order=order)
  52. return filtfilt(b, a, data) # forward-backward filter
  53. def plot_one_box(x, im, color=(128, 128, 128), label=None, line_thickness=3):
  54. # Plots one bounding box on image 'im' using OpenCV
  55. assert im.data.contiguous, 'Image not contiguous. Apply np.ascontiguousarray(im) to plot_on_box() input image.'
  56. tl = line_thickness or round(0.002 * (im.shape[0] + im.shape[1]) / 2) + 1 # line/font thickness
  57. c1, c2 = (int(x[0]), int(x[1])), (int(x[2]), int(x[3]))
  58. cv2.rectangle(im, c1, c2, color, thickness=tl, lineType=cv2.LINE_AA)
  59. if label:
  60. tf = max(tl - 1, 1) # font thickness
  61. t_size = cv2.getTextSize(label, 0, fontScale=tl / 3, thickness=tf)[0]
  62. c2 = c1[0] + t_size[0], c1[1] - t_size[1] - 3
  63. cv2.rectangle(im, c1, c2, color, -1, cv2.LINE_AA) # filled
  64. cv2.putText(im, label, (c1[0], c1[1] - 2), 0, tl / 3, [225, 255, 255], thickness=tf, lineType=cv2.LINE_AA)
  65. def plot_one_box_PIL(box, im, color=(128, 128, 128), label=None, line_thickness=None):
  66. # Plots one bounding box on image 'im' using PIL
  67. im = Image.fromarray(im)
  68. draw = ImageDraw.Draw(im)
  69. line_thickness = line_thickness or max(int(min(im.size) / 200), 2)
  70. draw.rectangle(box, width=line_thickness, outline=color) # plot
  71. if label:
  72. font = ImageFont.truetype("Arial.ttf", size=max(round(max(im.size) / 40), 12))
  73. txt_width, txt_height = font.getsize(label)
  74. draw.rectangle([box[0], box[1] - txt_height + 4, box[0] + txt_width, box[1]], fill=color)
  75. draw.text((box[0], box[1] - txt_height + 1), label, fill=(255, 255, 255), font=font)
  76. return np.asarray(im)
  77. def plot_wh_methods(): # from utils.plots import *; plot_wh_methods()
  78. # Compares the two methods for width-height anchor multiplication
  79. # https://github.com/ultralytics/yolov3/issues/168
  80. x = np.arange(-4.0, 4.0, .1)
  81. ya = np.exp(x)
  82. yb = torch.sigmoid(torch.from_numpy(x)).numpy() * 2
  83. fig = plt.figure(figsize=(6, 3), tight_layout=True)
  84. plt.plot(x, ya, '.-', label='YOLOv3')
  85. plt.plot(x, yb ** 2, '.-', label='YOLOv5 ^2')
  86. plt.plot(x, yb ** 1.6, '.-', label='YOLOv5 ^1.6')
  87. plt.xlim(left=-4, right=4)
  88. plt.ylim(bottom=0, top=6)
  89. plt.xlabel('input')
  90. plt.ylabel('output')
  91. plt.grid()
  92. plt.legend()
  93. fig.savefig('comparison.png', dpi=200)
  94. def output_to_target(output):
  95. # Convert model output to target format [batch_id, class_id, x, y, w, h, conf]
  96. targets = []
  97. for i, o in enumerate(output):
  98. for *box, conf, cls in o.cpu().numpy():
  99. targets.append([i, cls, *list(*xyxy2xywh(np.array(box)[None])), conf])
  100. return np.array(targets)
  101. def plot_images(images, targets, paths=None, fname='images.jpg', names=None, max_size=640, max_subplots=16):
  102. # Plot image grid with labels
  103. if isinstance(images, torch.Tensor):
  104. images = images.cpu().float().numpy()
  105. if isinstance(targets, torch.Tensor):
  106. targets = targets.cpu().numpy()
  107. # un-normalise
  108. if np.max(images[0]) <= 1:
  109. images *= 255
  110. tl = 3 # line thickness
  111. tf = max(tl - 1, 1) # font thickness
  112. bs, _, h, w = images.shape # batch size, _, height, width
  113. bs = min(bs, max_subplots) # limit plot images
  114. ns = np.ceil(bs ** 0.5) # number of subplots (square)
  115. # Check if we should resize
  116. scale_factor = max_size / max(h, w)
  117. if scale_factor < 1:
  118. h = math.ceil(scale_factor * h)
  119. w = math.ceil(scale_factor * w)
  120. mosaic = np.full((int(ns * h), int(ns * w), 3), 255, dtype=np.uint8) # init
  121. for i, img in enumerate(images):
  122. if i == max_subplots: # if last batch has fewer images than we expect
  123. break
  124. block_x = int(w * (i // ns))
  125. block_y = int(h * (i % ns))
  126. img = img.transpose(1, 2, 0)
  127. if scale_factor < 1:
  128. img = cv2.resize(img, (w, h))
  129. mosaic[block_y:block_y + h, block_x:block_x + w, :] = img
  130. if len(targets) > 0:
  131. image_targets = targets[targets[:, 0] == i]
  132. boxes = xywh2xyxy(image_targets[:, 2:6]).T
  133. classes = image_targets[:, 1].astype('int')
  134. labels = image_targets.shape[1] == 6 # labels if no conf column
  135. conf = None if labels else image_targets[:, 6] # check for confidence presence (label vs pred)
  136. if boxes.shape[1]:
  137. if boxes.max() <= 1.01: # if normalized with tolerance 0.01
  138. boxes[[0, 2]] *= w # scale to pixels
  139. boxes[[1, 3]] *= h
  140. elif scale_factor < 1: # absolute coords need scale if image scales
  141. boxes *= scale_factor
  142. boxes[[0, 2]] += block_x
  143. boxes[[1, 3]] += block_y
  144. for j, box in enumerate(boxes.T):
  145. cls = int(classes[j])
  146. color = colors(cls)
  147. cls = names[cls] if names else cls
  148. if labels or conf[j] > 0.25: # 0.25 conf thresh
  149. label = '%s' % cls if labels else '%s %.1f' % (cls, conf[j])
  150. plot_one_box(box, mosaic, label=label, color=color, line_thickness=tl)
  151. # Draw image filename labels
  152. if paths:
  153. label = Path(paths[i]).name[:40] # trim to 40 char
  154. t_size = cv2.getTextSize(label, 0, fontScale=tl / 3, thickness=tf)[0]
  155. cv2.putText(mosaic, label, (block_x + 5, block_y + t_size[1] + 5), 0, tl / 3, [220, 220, 220], thickness=tf,
  156. lineType=cv2.LINE_AA)
  157. # Image border
  158. cv2.rectangle(mosaic, (block_x, block_y), (block_x + w, block_y + h), (255, 255, 255), thickness=3)
  159. if fname:
  160. r = min(1280. / max(h, w) / ns, 1.0) # ratio to limit image size
  161. mosaic = cv2.resize(mosaic, (int(ns * w * r), int(ns * h * r)), interpolation=cv2.INTER_AREA)
  162. # cv2.imwrite(fname, cv2.cvtColor(mosaic, cv2.COLOR_BGR2RGB)) # cv2 save
  163. Image.fromarray(mosaic).save(fname) # PIL save
  164. return mosaic
  165. def plot_lr_scheduler(optimizer, scheduler, epochs=300, save_dir=''):
  166. # Plot LR simulating training for full epochs
  167. optimizer, scheduler = copy(optimizer), copy(scheduler) # do not modify originals
  168. y = []
  169. for _ in range(epochs):
  170. scheduler.step()
  171. y.append(optimizer.param_groups[0]['lr'])
  172. plt.plot(y, '.-', label='LR')
  173. plt.xlabel('epoch')
  174. plt.ylabel('LR')
  175. plt.grid()
  176. plt.xlim(0, epochs)
  177. plt.ylim(0)
  178. plt.savefig(Path(save_dir) / 'LR.png', dpi=200)
  179. plt.close()
  180. def plot_test_txt(): # from utils.plots import *; plot_test()
  181. # Plot test.txt histograms
  182. x = np.loadtxt('test.txt', dtype=np.float32)
  183. box = xyxy2xywh(x[:, :4])
  184. cx, cy = box[:, 0], box[:, 1]
  185. fig, ax = plt.subplots(1, 1, figsize=(6, 6), tight_layout=True)
  186. ax.hist2d(cx, cy, bins=600, cmax=10, cmin=0)
  187. ax.set_aspect('equal')
  188. plt.savefig('hist2d.png', dpi=300)
  189. fig, ax = plt.subplots(1, 2, figsize=(12, 6), tight_layout=True)
  190. ax[0].hist(cx, bins=600)
  191. ax[1].hist(cy, bins=600)
  192. plt.savefig('hist1d.png', dpi=200)
  193. def plot_targets_txt(): # from utils.plots import *; plot_targets_txt()
  194. # Plot targets.txt histograms
  195. x = np.loadtxt('targets.txt', dtype=np.float32).T
  196. s = ['x targets', 'y targets', 'width targets', 'height targets']
  197. fig, ax = plt.subplots(2, 2, figsize=(8, 8), tight_layout=True)
  198. ax = ax.ravel()
  199. for i in range(4):
  200. ax[i].hist(x[i], bins=100, label='%.3g +/- %.3g' % (x[i].mean(), x[i].std()))
  201. ax[i].legend()
  202. ax[i].set_title(s[i])
  203. plt.savefig('targets.jpg', dpi=200)
  204. def plot_study_txt(path='', x=None): # from utils.plots import *; plot_study_txt()
  205. # Plot study.txt generated by test.py
  206. fig, ax = plt.subplots(2, 4, figsize=(10, 6), tight_layout=True)
  207. # ax = ax.ravel()
  208. fig2, ax2 = plt.subplots(1, 1, figsize=(8, 4), tight_layout=True)
  209. # for f in [Path(path) / f'study_coco_{x}.txt' for x in ['yolov5s6', 'yolov5m6', 'yolov5l6', 'yolov5x6']]:
  210. for f in sorted(Path(path).glob('study*.txt')):
  211. y = np.loadtxt(f, dtype=np.float32, usecols=[0, 1, 2, 3, 7, 8, 9], ndmin=2).T
  212. x = np.arange(y.shape[1]) if x is None else np.array(x)
  213. s = ['P', 'R', 'mAP@.5', 'mAP@.5:.95', 't_inference (ms/img)', 't_NMS (ms/img)', 't_total (ms/img)']
  214. # for i in range(7):
  215. # ax[i].plot(x, y[i], '.-', linewidth=2, markersize=8)
  216. # ax[i].set_title(s[i])
  217. j = y[3].argmax() + 1
  218. ax2.plot(y[6, 1:j], y[3, 1:j] * 1E2, '.-', linewidth=2, markersize=8,
  219. label=f.stem.replace('study_coco_', '').replace('yolo', 'YOLO'))
  220. ax2.plot(1E3 / np.array([209, 140, 97, 58, 35, 18]), [34.6, 40.5, 43.0, 47.5, 49.7, 51.5],
  221. 'k.-', linewidth=2, markersize=8, alpha=.25, label='EfficientDet')
  222. ax2.grid(alpha=0.2)
  223. ax2.set_yticks(np.arange(20, 60, 5))
  224. ax2.set_xlim(0, 57)
  225. ax2.set_ylim(30, 55)
  226. ax2.set_xlabel('GPU Speed (ms/img)')
  227. ax2.set_ylabel('COCO AP val')
  228. ax2.legend(loc='lower right')
  229. plt.savefig(str(Path(path).name) + '.png', dpi=300)
  230. def plot_labels(labels, names=(), save_dir=Path(''), loggers=None):
  231. # plot dataset labels
  232. print('Plotting labels... ')
  233. c, b = labels[:, 0], labels[:, 1:].transpose() # classes, boxes
  234. nc = int(c.max() + 1) # number of classes
  235. x = pd.DataFrame(b.transpose(), columns=['x', 'y', 'width', 'height'])
  236. # seaborn correlogram
  237. sns.pairplot(x, corner=True, diag_kind='auto', kind='hist', diag_kws=dict(bins=50), plot_kws=dict(pmax=0.9))
  238. plt.savefig(save_dir / 'labels_correlogram.jpg', dpi=200)
  239. plt.close()
  240. # matplotlib labels
  241. matplotlib.use('svg') # faster
  242. ax = plt.subplots(2, 2, figsize=(8, 8), tight_layout=True)[1].ravel()
  243. y = ax[0].hist(c, bins=np.linspace(0, nc, nc + 1) - 0.5, rwidth=0.8)
  244. # [y[2].patches[i].set_color([x / 255 for x in colors(i)]) for i in range(nc)] # update colors bug #3195
  245. ax[0].set_ylabel('instances')
  246. if 0 < len(names) < 30:
  247. ax[0].set_xticks(range(len(names)))
  248. ax[0].set_xticklabels(names, rotation=90, fontsize=10)
  249. else:
  250. ax[0].set_xlabel('classes')
  251. sns.histplot(x, x='x', y='y', ax=ax[2], bins=50, pmax=0.9)
  252. sns.histplot(x, x='width', y='height', ax=ax[3], bins=50, pmax=0.9)
  253. # rectangles
  254. labels[:, 1:3] = 0.5 # center
  255. labels[:, 1:] = xywh2xyxy(labels[:, 1:]) * 2000
  256. img = Image.fromarray(np.ones((2000, 2000, 3), dtype=np.uint8) * 255)
  257. for cls, *box in labels[:1000]:
  258. ImageDraw.Draw(img).rectangle(box, width=1, outline=colors(cls)) # plot
  259. ax[1].imshow(img)
  260. ax[1].axis('off')
  261. for a in [0, 1, 2, 3]:
  262. for s in ['top', 'right', 'left', 'bottom']:
  263. ax[a].spines[s].set_visible(False)
  264. plt.savefig(save_dir / 'labels.jpg', dpi=200)
  265. matplotlib.use('Agg')
  266. plt.close()
  267. # loggers
  268. for k, v in loggers.items() or {}:
  269. if k == 'wandb' and v:
  270. v.log({"Labels": [v.Image(str(x), caption=x.name) for x in save_dir.glob('*labels*.jpg')]}, commit=False)
  271. def plot_evolution(yaml_file='data/hyp.finetune.yaml'): # from utils.plots import *; plot_evolution()
  272. # Plot hyperparameter evolution results in evolve.txt
  273. with open(yaml_file) as f:
  274. hyp = yaml.safe_load(f)
  275. x = np.loadtxt('evolve.txt', ndmin=2)
  276. f = fitness(x)
  277. # weights = (f - f.min()) ** 2 # for weighted results
  278. plt.figure(figsize=(10, 12), tight_layout=True)
  279. matplotlib.rc('font', **{'size': 8})
  280. for i, (k, v) in enumerate(hyp.items()):
  281. y = x[:, i + 7]
  282. # mu = (y * weights).sum() / weights.sum() # best weighted result
  283. mu = y[f.argmax()] # best single result
  284. plt.subplot(6, 5, i + 1)
  285. plt.scatter(y, f, c=hist2d(y, f, 20), cmap='viridis', alpha=.8, edgecolors='none')
  286. plt.plot(mu, f.max(), 'k+', markersize=15)
  287. plt.title('%s = %.3g' % (k, mu), fontdict={'size': 9}) # limit to 40 characters
  288. if i % 5 != 0:
  289. plt.yticks([])
  290. print('%15s: %.3g' % (k, mu))
  291. plt.savefig('evolve.png', dpi=200)
  292. print('\nPlot saved as evolve.png')
  293. def profile_idetection(start=0, stop=0, labels=(), save_dir=''):
  294. # Plot iDetection '*.txt' per-image logs. from utils.plots import *; profile_idetection()
  295. ax = plt.subplots(2, 4, figsize=(12, 6), tight_layout=True)[1].ravel()
  296. s = ['Images', 'Free Storage (GB)', 'RAM Usage (GB)', 'Battery', 'dt_raw (ms)', 'dt_smooth (ms)', 'real-world FPS']
  297. files = list(Path(save_dir).glob('frames*.txt'))
  298. for fi, f in enumerate(files):
  299. try:
  300. results = np.loadtxt(f, ndmin=2).T[:, 90:-30] # clip first and last rows
  301. n = results.shape[1] # number of rows
  302. x = np.arange(start, min(stop, n) if stop else n)
  303. results = results[:, x]
  304. t = (results[0] - results[0].min()) # set t0=0s
  305. results[0] = x
  306. for i, a in enumerate(ax):
  307. if i < len(results):
  308. label = labels[fi] if len(labels) else f.stem.replace('frames_', '')
  309. a.plot(t, results[i], marker='.', label=label, linewidth=1, markersize=5)
  310. a.set_title(s[i])
  311. a.set_xlabel('time (s)')
  312. # if fi == len(files) - 1:
  313. # a.set_ylim(bottom=0)
  314. for side in ['top', 'right']:
  315. a.spines[side].set_visible(False)
  316. else:
  317. a.remove()
  318. except Exception as e:
  319. print('Warning: Plotting error for %s; %s' % (f, e))
  320. ax[1].legend()
  321. plt.savefig(Path(save_dir) / 'idetection_profile.png', dpi=200)
  322. def plot_results_overlay(start=0, stop=0): # from utils.plots import *; plot_results_overlay()
  323. # Plot training 'results*.txt', overlaying train and val losses
  324. s = ['train', 'train', 'train', 'Precision', 'mAP@0.5', 'val', 'val', 'val', 'Recall', 'mAP@0.5:0.95'] # legends
  325. t = ['Box', 'Objectness', 'Classification', 'P-R', 'mAP-F1'] # titles
  326. for f in sorted(glob.glob('results*.txt') + glob.glob('../../Downloads/results*.txt')):
  327. results = np.loadtxt(f, usecols=[2, 3, 4, 8, 9, 12, 13, 14, 10, 11], ndmin=2).T
  328. n = results.shape[1] # number of rows
  329. x = range(start, min(stop, n) if stop else n)
  330. fig, ax = plt.subplots(1, 5, figsize=(14, 3.5), tight_layout=True)
  331. ax = ax.ravel()
  332. for i in range(5):
  333. for j in [i, i + 5]:
  334. y = results[j, x]
  335. ax[i].plot(x, y, marker='.', label=s[j])
  336. # y_smooth = butter_lowpass_filtfilt(y)
  337. # ax[i].plot(x, np.gradient(y_smooth), marker='.', label=s[j])
  338. ax[i].set_title(t[i])
  339. ax[i].legend()
  340. ax[i].set_ylabel(f) if i == 0 else None # add filename
  341. fig.savefig(f.replace('.txt', '.png'), dpi=200)
  342. def plot_results(start=0, stop=0, bucket='', id=(), labels=(), save_dir=''):
  343. # Plot training 'results*.txt'. from utils.plots import *; plot_results(save_dir='runs/train/exp')
  344. fig, ax = plt.subplots(2, 5, figsize=(12, 6), tight_layout=True)
  345. ax = ax.ravel()
  346. s = ['Box', 'Objectness', 'Classification', 'Precision', 'Recall',
  347. 'val Box', 'val Objectness', 'val Classification', 'mAP@0.5', 'mAP@0.5:0.95']
  348. if bucket:
  349. # files = ['https://storage.googleapis.com/%s/results%g.txt' % (bucket, x) for x in id]
  350. files = ['results%g.txt' % x for x in id]
  351. c = ('gsutil cp ' + '%s ' * len(files) + '.') % tuple('gs://%s/results%g.txt' % (bucket, x) for x in id)
  352. os.system(c)
  353. else:
  354. files = list(Path(save_dir).glob('results*.txt'))
  355. assert len(files), 'No results.txt files found in %s, nothing to plot.' % os.path.abspath(save_dir)
  356. for fi, f in enumerate(files):
  357. try:
  358. results = np.loadtxt(f, usecols=[2, 3, 4, 8, 9, 12, 13, 14, 10, 11], ndmin=2).T
  359. n = results.shape[1] # number of rows
  360. x = range(start, min(stop, n) if stop else n)
  361. for i in range(10):
  362. y = results[i, x]
  363. if i in [0, 1, 2, 5, 6, 7]:
  364. y[y == 0] = np.nan # don't show zero loss values
  365. # y /= y[0] # normalize
  366. label = labels[fi] if len(labels) else f.stem
  367. ax[i].plot(x, y, marker='.', label=label, linewidth=2, markersize=8)
  368. ax[i].set_title(s[i])
  369. # if i in [5, 6, 7]: # share train and val loss y axes
  370. # ax[i].get_shared_y_axes().join(ax[i], ax[i - 5])
  371. except Exception as e:
  372. print('Warning: Plotting error for %s; %s' % (f, e))
  373. ax[1].legend()
  374. fig.savefig(Path(save_dir) / 'results.png', dpi=200)