general.py 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692
  1. # YOLOv5 general utils
  2. import glob
  3. import logging
  4. import math
  5. import os
  6. import platform
  7. import random
  8. import re
  9. import subprocess
  10. import time
  11. from itertools import repeat
  12. from multiprocessing.pool import ThreadPool
  13. from pathlib import Path
  14. import cv2
  15. import numpy as np
  16. import pandas as pd
  17. import pkg_resources as pkg
  18. import torch
  19. import torchvision
  20. import yaml
  21. from utils.google_utils import gsutil_getsize
  22. from utils.metrics import fitness
  23. from utils.torch_utils import init_torch_seeds
  24. # Settings
  25. torch.set_printoptions(linewidth=320, precision=5, profile='long')
  26. np.set_printoptions(linewidth=320, formatter={'float_kind': '{:11.5g}'.format}) # format short g, %precision=5
  27. pd.options.display.max_columns = 10
  28. cv2.setNumThreads(0) # prevent OpenCV from multithreading (incompatible with PyTorch DataLoader)
  29. os.environ['NUMEXPR_MAX_THREADS'] = str(min(os.cpu_count(), 8)) # NumExpr max threads
  30. def set_logging(rank=-1, verbose=True):
  31. logging.basicConfig(
  32. format="%(message)s",
  33. level=logging.INFO if (verbose and rank in [-1, 0]) else logging.WARN)
  34. def init_seeds(seed=0):
  35. # Initialize random number generator (RNG) seeds
  36. random.seed(seed)
  37. np.random.seed(seed)
  38. init_torch_seeds(seed)
  39. def get_latest_run(search_dir='.'):
  40. # Return path to most recent 'last.pt' in /runs (i.e. to --resume from)
  41. last_list = glob.glob(f'{search_dir}/**/last*.pt', recursive=True)
  42. return max(last_list, key=os.path.getctime) if last_list else ''
  43. def is_docker():
  44. # Is environment a Docker container
  45. return Path('/workspace').exists() # or Path('/.dockerenv').exists()
  46. def is_colab():
  47. # Is environment a Google Colab instance
  48. try:
  49. import google.colab
  50. return True
  51. except Exception as e:
  52. return False
  53. def emojis(str=''):
  54. # Return platform-dependent emoji-safe version of string
  55. return str.encode().decode('ascii', 'ignore') if platform.system() == 'Windows' else str
  56. def file_size(file):
  57. # Return file size in MB
  58. return Path(file).stat().st_size / 1e6
  59. def check_online():
  60. # Check internet connectivity
  61. import socket
  62. try:
  63. socket.create_connection(("1.1.1.1", 443), 5) # check host accesability
  64. return True
  65. except OSError:
  66. return False
  67. def check_git_status():
  68. # Recommend 'git pull' if code is out of date
  69. print(colorstr('github: '), end='')
  70. try:
  71. assert Path('.git').exists(), 'skipping check (not a git repository)'
  72. assert not is_docker(), 'skipping check (Docker image)'
  73. assert check_online(), 'skipping check (offline)'
  74. cmd = 'git fetch && git config --get remote.origin.url'
  75. url = subprocess.check_output(cmd, shell=True).decode().strip().rstrip('.git') # github repo url
  76. branch = subprocess.check_output('git rev-parse --abbrev-ref HEAD', shell=True).decode().strip() # checked out
  77. n = int(subprocess.check_output(f'git rev-list {branch}..origin/master --count', shell=True)) # commits behind
  78. if n > 0:
  79. s = f"⚠️ WARNING: code is out of date by {n} commit{'s' * (n > 1)}. " \
  80. f"Use 'git pull' to update or 'git clone {url}' to download latest."
  81. else:
  82. s = f'up to date with {url} ✅'
  83. print(emojis(s)) # emoji-safe
  84. except Exception as e:
  85. print(e)
  86. def check_python(minimum='3.7.0', required=True):
  87. # Check current python version vs. required python version
  88. current = platform.python_version()
  89. result = pkg.parse_version(current) >= pkg.parse_version(minimum)
  90. if required:
  91. assert result, f'Python {minimum} required by YOLOv5, but Python {current} is currently installed'
  92. return result
  93. def check_requirements(requirements='requirements.txt', exclude=()):
  94. # Check installed dependencies meet requirements (pass *.txt file or list of packages)
  95. prefix = colorstr('red', 'bold', 'requirements:')
  96. check_python() # check python version
  97. if isinstance(requirements, (str, Path)): # requirements.txt file
  98. file = Path(requirements)
  99. if not file.exists():
  100. print(f"{prefix} {file.resolve()} not found, check failed.")
  101. return
  102. requirements = [f'{x.name}{x.specifier}' for x in pkg.parse_requirements(file.open()) if x.name not in exclude]
  103. else: # list or tuple of packages
  104. requirements = [x for x in requirements if x not in exclude]
  105. n = 0 # number of packages updates
  106. for r in requirements:
  107. try:
  108. pkg.require(r)
  109. except Exception as e: # DistributionNotFound or VersionConflict if requirements not met
  110. n += 1
  111. print(f"{prefix} {r} not found and is required by YOLOv5, attempting auto-update...")
  112. try:
  113. print(subprocess.check_output(f"pip install '{r}'", shell=True).decode())
  114. except Exception as e:
  115. print(f'{prefix} {e}')
  116. if n: # if packages updated
  117. source = file.resolve() if 'file' in locals() else requirements
  118. s = f"{prefix} {n} package{'s' * (n > 1)} updated per {source}\n" \
  119. f"{prefix} ⚠️ {colorstr('bold', 'Restart runtime or rerun command for updates to take effect')}\n"
  120. print(emojis(s)) # emoji-safe
  121. def check_img_size(img_size, s=32):
  122. # Verify img_size is a multiple of stride s
  123. new_size = make_divisible(img_size, int(s)) # ceil gs-multiple
  124. if new_size != img_size:
  125. print('WARNING: --img-size %g must be multiple of max stride %g, updating to %g' % (img_size, s, new_size))
  126. return new_size
  127. def check_imshow():
  128. # Check if environment supports image displays
  129. try:
  130. assert not is_docker(), 'cv2.imshow() is disabled in Docker environments'
  131. assert not is_colab(), 'cv2.imshow() is disabled in Google Colab environments'
  132. cv2.imshow('test', np.zeros((1, 1, 3)))
  133. cv2.waitKey(1)
  134. cv2.destroyAllWindows()
  135. cv2.waitKey(1)
  136. return True
  137. except Exception as e:
  138. print(f'WARNING: Environment does not support cv2.imshow() or PIL Image.show() image displays\n{e}')
  139. return False
  140. def check_file(file):
  141. # Search for file if not found
  142. if Path(file).is_file() or file == '':
  143. return file
  144. else:
  145. files = glob.glob('./**/' + file, recursive=True) # find file
  146. assert len(files), f'File Not Found: {file}' # assert file was found
  147. assert len(files) == 1, f"Multiple files match '{file}', specify exact path: {files}" # assert unique
  148. return files[0] # return file
  149. def check_dataset(dict):
  150. # Download dataset if not found locally
  151. val, s = dict.get('val'), dict.get('download')
  152. if val and len(val):
  153. val = [Path(x).resolve() for x in (val if isinstance(val, list) else [val])] # val path
  154. if not all(x.exists() for x in val):
  155. print('\nWARNING: Dataset not found, nonexistent paths: %s' % [str(x) for x in val if not x.exists()])
  156. if s and len(s): # download script
  157. if s.startswith('http') and s.endswith('.zip'): # URL
  158. f = Path(s).name # filename
  159. print(f'Downloading {s} ...')
  160. torch.hub.download_url_to_file(s, f)
  161. r = os.system(f'unzip -q {f} -d ../ && rm {f}') # unzip
  162. elif s.startswith('bash '): # bash script
  163. print(f'Running {s} ...')
  164. r = os.system(s)
  165. else: # python script
  166. r = exec(s) # return None
  167. print('Dataset autodownload %s\n' % ('success' if r in (0, None) else 'failure')) # print result
  168. else:
  169. raise Exception('Dataset not found.')
  170. def download(url, dir='.', unzip=True, delete=True, curl=False, threads=1):
  171. # Multi-threaded file download and unzip function
  172. def download_one(url, dir):
  173. # Download 1 file
  174. f = dir / Path(url).name # filename
  175. if not f.exists():
  176. print(f'Downloading {url} to {f}...')
  177. if curl:
  178. os.system(f"curl -L '{url}' -o '{f}' --retry 9 -C -") # curl download, retry and resume on fail
  179. else:
  180. torch.hub.download_url_to_file(url, f, progress=True) # torch download
  181. if unzip and f.suffix in ('.zip', '.gz'):
  182. print(f'Unzipping {f}...')
  183. if f.suffix == '.zip':
  184. s = f'unzip -qo {f} -d {dir} && rm {f}' # unzip -quiet -overwrite
  185. elif f.suffix == '.gz':
  186. s = f'tar xfz {f} --directory {f.parent}' # unzip
  187. if delete: # delete zip file after unzip
  188. s += f' && rm {f}'
  189. os.system(s)
  190. dir = Path(dir)
  191. dir.mkdir(parents=True, exist_ok=True) # make directory
  192. if threads > 1:
  193. pool = ThreadPool(threads)
  194. pool.imap(lambda x: download_one(*x), zip(url, repeat(dir))) # multi-threaded
  195. pool.close()
  196. pool.join()
  197. else:
  198. for u in tuple(url) if isinstance(url, str) else url:
  199. download_one(u, dir)
  200. def make_divisible(x, divisor):
  201. # Returns x evenly divisible by divisor
  202. return math.ceil(x / divisor) * divisor
  203. def clean_str(s):
  204. # Cleans a string by replacing special characters with underscore _
  205. return re.sub(pattern="[|@#!¡·$€%&()=?¿^*;:,¨´><+]", repl="_", string=s)
  206. def one_cycle(y1=0.0, y2=1.0, steps=100):
  207. # lambda function for sinusoidal ramp from y1 to y2
  208. return lambda x: ((1 - math.cos(x * math.pi / steps)) / 2) * (y2 - y1) + y1
  209. def colorstr(*input):
  210. # Colors a string https://en.wikipedia.org/wiki/ANSI_escape_code, i.e. colorstr('blue', 'hello world')
  211. *args, string = input if len(input) > 1 else ('blue', 'bold', input[0]) # color arguments, string
  212. colors = {'black': '\033[30m', # basic colors
  213. 'red': '\033[31m',
  214. 'green': '\033[32m',
  215. 'yellow': '\033[33m',
  216. 'blue': '\033[34m',
  217. 'magenta': '\033[35m',
  218. 'cyan': '\033[36m',
  219. 'white': '\033[37m',
  220. 'bright_black': '\033[90m', # bright colors
  221. 'bright_red': '\033[91m',
  222. 'bright_green': '\033[92m',
  223. 'bright_yellow': '\033[93m',
  224. 'bright_blue': '\033[94m',
  225. 'bright_magenta': '\033[95m',
  226. 'bright_cyan': '\033[96m',
  227. 'bright_white': '\033[97m',
  228. 'end': '\033[0m', # misc
  229. 'bold': '\033[1m',
  230. 'underline': '\033[4m'}
  231. return ''.join(colors[x] for x in args) + f'{string}' + colors['end']
  232. def labels_to_class_weights(labels, nc=80):
  233. # Get class weights (inverse frequency) from training labels
  234. if labels[0] is None: # no labels loaded
  235. return torch.Tensor()
  236. labels = np.concatenate(labels, 0) # labels.shape = (866643, 5) for COCO
  237. classes = labels[:, 0].astype(np.int) # labels = [class xywh]
  238. weights = np.bincount(classes, minlength=nc) # occurrences per class
  239. # Prepend gridpoint count (for uCE training)
  240. # gpi = ((320 / 32 * np.array([1, 2, 4])) ** 2 * 3).sum() # gridpoints per image
  241. # weights = np.hstack([gpi * len(labels) - weights.sum() * 9, weights * 9]) ** 0.5 # prepend gridpoints to start
  242. weights[weights == 0] = 1 # replace empty bins with 1
  243. weights = 1 / weights # number of targets per class
  244. weights /= weights.sum() # normalize
  245. return torch.from_numpy(weights)
  246. def labels_to_image_weights(labels, nc=80, class_weights=np.ones(80)):
  247. # Produces image weights based on class_weights and image contents
  248. class_counts = np.array([np.bincount(x[:, 0].astype(np.int), minlength=nc) for x in labels])
  249. image_weights = (class_weights.reshape(1, nc) * class_counts).sum(1)
  250. # index = random.choices(range(n), weights=image_weights, k=1) # weight image sample
  251. return image_weights
  252. def coco80_to_coco91_class(): # converts 80-index (val2014) to 91-index (paper)
  253. # https://tech.amikelive.com/node-718/what-object-categories-labels-are-in-coco-dataset/
  254. # a = np.loadtxt('data/coco.names', dtype='str', delimiter='\n')
  255. # b = np.loadtxt('data/coco_paper.names', dtype='str', delimiter='\n')
  256. # x1 = [list(a[i] == b).index(True) + 1 for i in range(80)] # darknet to coco
  257. # x2 = [list(b[i] == a).index(True) if any(b[i] == a) else None for i in range(91)] # coco to darknet
  258. x = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34,
  259. 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63,
  260. 64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90]
  261. return x
  262. def xyxy2xywh(x):
  263. # Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] where xy1=top-left, xy2=bottom-right
  264. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  265. y[:, 0] = (x[:, 0] + x[:, 2]) / 2 # x center
  266. y[:, 1] = (x[:, 1] + x[:, 3]) / 2 # y center
  267. y[:, 2] = x[:, 2] - x[:, 0] # width
  268. y[:, 3] = x[:, 3] - x[:, 1] # height
  269. return y
  270. def xywh2xyxy(x):
  271. # Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
  272. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  273. y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x
  274. y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y
  275. y[:, 2] = x[:, 0] + x[:, 2] / 2 # bottom right x
  276. y[:, 3] = x[:, 1] + x[:, 3] / 2 # bottom right y
  277. return y
  278. def xywhn2xyxy(x, w=640, h=640, padw=0, padh=0):
  279. # Convert nx4 boxes from [x, y, w, h] normalized to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
  280. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  281. y[:, 0] = w * (x[:, 0] - x[:, 2] / 2) + padw # top left x
  282. y[:, 1] = h * (x[:, 1] - x[:, 3] / 2) + padh # top left y
  283. y[:, 2] = w * (x[:, 0] + x[:, 2] / 2) + padw # bottom right x
  284. y[:, 3] = h * (x[:, 1] + x[:, 3] / 2) + padh # bottom right y
  285. return y
  286. def xyn2xy(x, w=640, h=640, padw=0, padh=0):
  287. # Convert normalized segments into pixel segments, shape (n,2)
  288. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  289. y[:, 0] = w * x[:, 0] + padw # top left x
  290. y[:, 1] = h * x[:, 1] + padh # top left y
  291. return y
  292. def segment2box(segment, width=640, height=640):
  293. # Convert 1 segment label to 1 box label, applying inside-image constraint, i.e. (xy1, xy2, ...) to (xyxy)
  294. x, y = segment.T # segment xy
  295. inside = (x >= 0) & (y >= 0) & (x <= width) & (y <= height)
  296. x, y, = x[inside], y[inside]
  297. return np.array([x.min(), y.min(), x.max(), y.max()]) if any(x) else np.zeros((1, 4)) # xyxy
  298. def segments2boxes(segments):
  299. # Convert segment labels to box labels, i.e. (cls, xy1, xy2, ...) to (cls, xywh)
  300. boxes = []
  301. for s in segments:
  302. x, y = s.T # segment xy
  303. boxes.append([x.min(), y.min(), x.max(), y.max()]) # cls, xyxy
  304. return xyxy2xywh(np.array(boxes)) # cls, xywh
  305. def resample_segments(segments, n=1000):
  306. # Up-sample an (n,2) segment
  307. for i, s in enumerate(segments):
  308. x = np.linspace(0, len(s) - 1, n)
  309. xp = np.arange(len(s))
  310. segments[i] = np.concatenate([np.interp(x, xp, s[:, i]) for i in range(2)]).reshape(2, -1).T # segment xy
  311. return segments
  312. def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None):
  313. # Rescale coords (xyxy) from img1_shape to img0_shape
  314. if ratio_pad is None: # calculate from img0_shape
  315. gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
  316. pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding
  317. else:
  318. gain = ratio_pad[0][0]
  319. pad = ratio_pad[1]
  320. coords[:, [0, 2]] -= pad[0] # x padding
  321. coords[:, [1, 3]] -= pad[1] # y padding
  322. coords[:, :4] /= gain
  323. clip_coords(coords, img0_shape)
  324. return coords
  325. def clip_coords(boxes, img_shape):
  326. # Clip bounding xyxy bounding boxes to image shape (height, width)
  327. boxes[:, 0].clamp_(0, img_shape[1]) # x1
  328. boxes[:, 1].clamp_(0, img_shape[0]) # y1
  329. boxes[:, 2].clamp_(0, img_shape[1]) # x2
  330. boxes[:, 3].clamp_(0, img_shape[0]) # y2
  331. def bbox_iou(box1, box2, x1y1x2y2=True, GIoU=False, DIoU=False, CIoU=False, eps=1e-7):
  332. # Returns the IoU of box1 to box2. box1 is 4, box2 is nx4
  333. box2 = box2.T
  334. # Get the coordinates of bounding boxes
  335. if x1y1x2y2: # x1, y1, x2, y2 = box1
  336. b1_x1, b1_y1, b1_x2, b1_y2 = box1[0], box1[1], box1[2], box1[3]
  337. b2_x1, b2_y1, b2_x2, b2_y2 = box2[0], box2[1], box2[2], box2[3]
  338. else: # transform from xywh to xyxy
  339. b1_x1, b1_x2 = box1[0] - box1[2] / 2, box1[0] + box1[2] / 2
  340. b1_y1, b1_y2 = box1[1] - box1[3] / 2, box1[1] + box1[3] / 2
  341. b2_x1, b2_x2 = box2[0] - box2[2] / 2, box2[0] + box2[2] / 2
  342. b2_y1, b2_y2 = box2[1] - box2[3] / 2, box2[1] + box2[3] / 2
  343. # Intersection area
  344. inter = (torch.min(b1_x2, b2_x2) - torch.max(b1_x1, b2_x1)).clamp(0) * \
  345. (torch.min(b1_y2, b2_y2) - torch.max(b1_y1, b2_y1)).clamp(0)
  346. # Union Area
  347. w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1 + eps
  348. w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1 + eps
  349. union = w1 * h1 + w2 * h2 - inter + eps
  350. iou = inter / union
  351. if GIoU or DIoU or CIoU:
  352. cw = torch.max(b1_x2, b2_x2) - torch.min(b1_x1, b2_x1) # convex (smallest enclosing box) width
  353. ch = torch.max(b1_y2, b2_y2) - torch.min(b1_y1, b2_y1) # convex height
  354. if CIoU or DIoU: # Distance or Complete IoU https://arxiv.org/abs/1911.08287v1
  355. c2 = cw ** 2 + ch ** 2 + eps # convex diagonal squared
  356. rho2 = ((b2_x1 + b2_x2 - b1_x1 - b1_x2) ** 2 +
  357. (b2_y1 + b2_y2 - b1_y1 - b1_y2) ** 2) / 4 # center distance squared
  358. if DIoU:
  359. return iou - rho2 / c2 # DIoU
  360. elif CIoU: # https://github.com/Zzh-tju/DIoU-SSD-pytorch/blob/master/utils/box/box_utils.py#L47
  361. v = (4 / math.pi ** 2) * torch.pow(torch.atan(w2 / h2) - torch.atan(w1 / h1), 2)
  362. with torch.no_grad():
  363. alpha = v / (v - iou + (1 + eps))
  364. return iou - (rho2 / c2 + v * alpha) # CIoU
  365. else: # GIoU https://arxiv.org/pdf/1902.09630.pdf
  366. c_area = cw * ch + eps # convex area
  367. return iou - (c_area - union) / c_area # GIoU
  368. else:
  369. return iou # IoU
  370. def box_iou(box1, box2):
  371. # https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.py
  372. """
  373. Return intersection-over-union (Jaccard index) of boxes.
  374. Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
  375. Arguments:
  376. box1 (Tensor[N, 4])
  377. box2 (Tensor[M, 4])
  378. Returns:
  379. iou (Tensor[N, M]): the NxM matrix containing the pairwise
  380. IoU values for every element in boxes1 and boxes2
  381. """
  382. def box_area(box):
  383. # box = 4xn
  384. return (box[2] - box[0]) * (box[3] - box[1])
  385. area1 = box_area(box1.T)
  386. area2 = box_area(box2.T)
  387. # inter(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2)
  388. inter = (torch.min(box1[:, None, 2:], box2[:, 2:]) - torch.max(box1[:, None, :2], box2[:, :2])).clamp(0).prod(2)
  389. return inter / (area1[:, None] + area2 - inter) # iou = inter / (area1 + area2 - inter)
  390. def wh_iou(wh1, wh2):
  391. # Returns the nxm IoU matrix. wh1 is nx2, wh2 is mx2
  392. wh1 = wh1[:, None] # [N,1,2]
  393. wh2 = wh2[None] # [1,M,2]
  394. inter = torch.min(wh1, wh2).prod(2) # [N,M]
  395. return inter / (wh1.prod(2) + wh2.prod(2) - inter) # iou = inter / (area1 + area2 - inter)
  396. def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False, multi_label=False,
  397. labels=(), max_det=300):
  398. """Runs Non-Maximum Suppression (NMS) on inference results
  399. Returns:
  400. list of detections, on (n,6) tensor per image [xyxy, conf, cls]
  401. """
  402. nc = prediction.shape[2] - 5 # number of classes
  403. xc = prediction[..., 4] > conf_thres # candidates
  404. # Checks
  405. assert 0 <= conf_thres <= 1, f'Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0'
  406. assert 0 <= iou_thres <= 1, f'Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0'
  407. # Settings
  408. min_wh, max_wh = 2, 4096 # (pixels) minimum and maximum box width and height
  409. max_nms = 30000 # maximum number of boxes into torchvision.ops.nms()
  410. time_limit = 10.0 # seconds to quit after
  411. redundant = True # require redundant detections
  412. multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img)
  413. merge = False # use merge-NMS
  414. t = time.time()
  415. output = [torch.zeros((0, 6), device=prediction.device)] * prediction.shape[0]
  416. for xi, x in enumerate(prediction): # image index, image inference
  417. # Apply constraints
  418. # x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0 # width-height
  419. x = x[xc[xi]] # confidence
  420. # Cat apriori labels if autolabelling
  421. if labels and len(labels[xi]):
  422. l = labels[xi]
  423. v = torch.zeros((len(l), nc + 5), device=x.device)
  424. v[:, :4] = l[:, 1:5] # box
  425. v[:, 4] = 1.0 # conf
  426. v[range(len(l)), l[:, 0].long() + 5] = 1.0 # cls
  427. x = torch.cat((x, v), 0)
  428. # If none remain process next image
  429. if not x.shape[0]:
  430. continue
  431. # Compute conf
  432. x[:, 5:] *= x[:, 4:5] # conf = obj_conf * cls_conf
  433. # Box (center x, center y, width, height) to (x1, y1, x2, y2)
  434. box = xywh2xyxy(x[:, :4])
  435. # Detections matrix nx6 (xyxy, conf, cls)
  436. if multi_label:
  437. i, j = (x[:, 5:] > conf_thres).nonzero(as_tuple=False).T
  438. x = torch.cat((box[i], x[i, j + 5, None], j[:, None].float()), 1)
  439. else: # best class only
  440. conf, j = x[:, 5:].max(1, keepdim=True)
  441. x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres]
  442. # Filter by class
  443. if classes is not None:
  444. x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]
  445. # Apply finite constraint
  446. # if not torch.isfinite(x).all():
  447. # x = x[torch.isfinite(x).all(1)]
  448. # Check shape
  449. n = x.shape[0] # number of boxes
  450. if not n: # no boxes
  451. continue
  452. elif n > max_nms: # excess boxes
  453. x = x[x[:, 4].argsort(descending=True)[:max_nms]] # sort by confidence
  454. # Batched NMS
  455. c = x[:, 5:6] * (0 if agnostic else max_wh) # classes
  456. boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores
  457. i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
  458. if i.shape[0] > max_det: # limit detections
  459. i = i[:max_det]
  460. if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean)
  461. # update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
  462. iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix
  463. weights = iou * scores[None] # box weights
  464. x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes
  465. if redundant:
  466. i = i[iou.sum(1) > 1] # require redundancy
  467. output[xi] = x[i]
  468. if (time.time() - t) > time_limit:
  469. print(f'WARNING: NMS time limit {time_limit}s exceeded')
  470. break # time limit exceeded
  471. return output
  472. def strip_optimizer(f='best.pt', s=''): # from utils.general import *; strip_optimizer()
  473. # Strip optimizer from 'f' to finalize training, optionally save as 's'
  474. x = torch.load(f, map_location=torch.device('cpu'))
  475. if x.get('ema'):
  476. x['model'] = x['ema'] # replace model with ema
  477. for k in 'optimizer', 'training_results', 'wandb_id', 'ema', 'updates': # keys
  478. x[k] = None
  479. x['epoch'] = -1
  480. x['model'].half() # to FP16
  481. for p in x['model'].parameters():
  482. p.requires_grad = False
  483. torch.save(x, s or f)
  484. mb = os.path.getsize(s or f) / 1E6 # filesize
  485. print(f"Optimizer stripped from {f},{(' saved as %s,' % s) if s else ''} {mb:.1f}MB")
  486. def print_mutation(hyp, results, yaml_file='hyp_evolved.yaml', bucket=''):
  487. # Print mutation results to evolve.txt (for use with train.py --evolve)
  488. a = '%10s' * len(hyp) % tuple(hyp.keys()) # hyperparam keys
  489. b = '%10.3g' * len(hyp) % tuple(hyp.values()) # hyperparam values
  490. c = '%10.4g' * len(results) % results # results (P, R, mAP@0.5, mAP@0.5:0.95, val_losses x 3)
  491. print('\n%s\n%s\nEvolved fitness: %s\n' % (a, b, c))
  492. if bucket:
  493. url = 'gs://%s/evolve.txt' % bucket
  494. if gsutil_getsize(url) > (os.path.getsize('evolve.txt') if os.path.exists('evolve.txt') else 0):
  495. os.system('gsutil cp %s .' % url) # download evolve.txt if larger than local
  496. with open('evolve.txt', 'a') as f: # append result
  497. f.write(c + b + '\n')
  498. x = np.unique(np.loadtxt('evolve.txt', ndmin=2), axis=0) # load unique rows
  499. x = x[np.argsort(-fitness(x))] # sort
  500. np.savetxt('evolve.txt', x, '%10.3g') # save sort by fitness
  501. # Save yaml
  502. for i, k in enumerate(hyp.keys()):
  503. hyp[k] = float(x[0, i + 7])
  504. with open(yaml_file, 'w') as f:
  505. results = tuple(x[0, :7])
  506. c = '%10.4g' * len(results) % results # results (P, R, mAP@0.5, mAP@0.5:0.95, val_losses x 3)
  507. f.write('# Hyperparameter Evolution Results\n# Generations: %g\n# Metrics: ' % len(x) + c + '\n\n')
  508. yaml.safe_dump(hyp, f, sort_keys=False)
  509. if bucket:
  510. os.system('gsutil cp evolve.txt %s gs://%s' % (yaml_file, bucket)) # upload
  511. def apply_classifier(x, model, img, im0):
  512. # Apply a second stage classifier to yolo outputs
  513. im0 = [im0] if isinstance(im0, np.ndarray) else im0
  514. for i, d in enumerate(x): # per image
  515. if d is not None and len(d):
  516. d = d.clone()
  517. # Reshape and pad cutouts
  518. b = xyxy2xywh(d[:, :4]) # boxes
  519. b[:, 2:] = b[:, 2:].max(1)[0].unsqueeze(1) # rectangle to square
  520. b[:, 2:] = b[:, 2:] * 1.3 + 30 # pad
  521. d[:, :4] = xywh2xyxy(b).long()
  522. # Rescale boxes from img_size to im0 size
  523. scale_coords(img.shape[2:], d[:, :4], im0[i].shape)
  524. # Classes
  525. pred_cls1 = d[:, 5].long()
  526. ims = []
  527. for j, a in enumerate(d): # per item
  528. cutout = im0[i][int(a[1]):int(a[3]), int(a[0]):int(a[2])]
  529. im = cv2.resize(cutout, (224, 224)) # BGR
  530. # cv2.imwrite('test%i.jpg' % j, cutout)
  531. im = im[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
  532. im = np.ascontiguousarray(im, dtype=np.float32) # uint8 to float32
  533. im /= 255.0 # 0 - 255 to 0.0 - 1.0
  534. ims.append(im)
  535. pred_cls2 = model(torch.Tensor(ims).to(d.device)).argmax(1) # classifier prediction
  536. x[i] = x[i][pred_cls1 == pred_cls2] # retain matching class detections
  537. return x
  538. def save_one_box(xyxy, im, file='image.jpg', gain=1.02, pad=10, square=False, BGR=False, save=True):
  539. # Save image crop as {file} with crop size multiple {gain} and {pad} pixels. Save and/or return crop
  540. xyxy = torch.tensor(xyxy).view(-1, 4)
  541. b = xyxy2xywh(xyxy) # boxes
  542. if square:
  543. b[:, 2:] = b[:, 2:].max(1)[0].unsqueeze(1) # attempt rectangle to square
  544. b[:, 2:] = b[:, 2:] * gain + pad # box wh * gain + pad
  545. xyxy = xywh2xyxy(b).long()
  546. clip_coords(xyxy, im.shape)
  547. crop = im[int(xyxy[0, 1]):int(xyxy[0, 3]), int(xyxy[0, 0]):int(xyxy[0, 2]), ::(1 if BGR else -1)]
  548. if save:
  549. cv2.imwrite(str(increment_path(file, mkdir=True).with_suffix('.jpg')), crop)
  550. return crop
  551. def increment_path(path, exist_ok=False, sep='', mkdir=False):
  552. # Increment file or directory path, i.e. runs/exp --> runs/exp{sep}2, runs/exp{sep}3, ... etc.
  553. path = Path(path) # os-agnostic
  554. if path.exists() and not exist_ok:
  555. suffix = path.suffix
  556. path = path.with_suffix('')
  557. dirs = glob.glob(f"{path}{sep}*") # similar paths
  558. matches = [re.search(rf"%s{sep}(\d+)" % path.stem, d) for d in dirs]
  559. i = [int(m.groups()[0]) for m in matches if m] # indices
  560. n = max(i) + 1 if i else 2 # increment number
  561. path = Path(f"{path}{sep}{n}{suffix}") # update path
  562. dir = path if path.suffix == '' else path.parent # directory
  563. if not dir.exists() and mkdir:
  564. dir.mkdir(parents=True, exist_ok=True) # make directory
  565. return path