utils.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567
  1. # Ultralytics YOLO 🚀, AGPL-3.0 license
  2. import contextlib
  3. import hashlib
  4. import json
  5. import os
  6. import random
  7. import subprocess
  8. import time
  9. import zipfile
  10. from multiprocessing.pool import ThreadPool
  11. from pathlib import Path
  12. from tarfile import is_tarfile
  13. import cv2
  14. import numpy as np
  15. from PIL import Image, ImageOps
  16. from tqdm import tqdm
  17. from ultralytics.nn.autobackend import check_class_names
  18. from ultralytics.utils import (DATASETS_DIR, LOGGER, NUM_THREADS, ROOT, SETTINGS_YAML, clean_url, colorstr, emojis,
  19. yaml_load)
  20. from ultralytics.utils.checks import check_file, check_font, is_ascii
  21. from ultralytics.utils.downloads import download, safe_download, unzip_file
  22. from ultralytics.utils.ops import segments2boxes
  23. HELP_URL = 'See https://docs.ultralytics.com/datasets/detect for dataset formatting guidance.'
  24. IMG_FORMATS = 'bmp', 'dng', 'jpeg', 'jpg', 'mpo', 'png', 'tif', 'tiff', 'webp', 'pfm' # image suffixes
  25. VID_FORMATS = 'asf', 'avi', 'gif', 'm4v', 'mkv', 'mov', 'mp4', 'mpeg', 'mpg', 'ts', 'wmv', 'webm' # video suffixes
  26. PIN_MEMORY = str(os.getenv('PIN_MEMORY', True)).lower() == 'true' # global pin_memory for dataloaders
  27. def img2label_paths(img_paths):
  28. """Define label paths as a function of image paths."""
  29. sa, sb = f'{os.sep}images{os.sep}', f'{os.sep}labels{os.sep}' # /images/, /labels/ substrings
  30. return [sb.join(x.rsplit(sa, 1)).rsplit('.', 1)[0] + '.txt' for x in img_paths]
  31. def get_hash(paths):
  32. """Returns a single hash value of a list of paths (files or dirs)."""
  33. size = sum(os.path.getsize(p) for p in paths if os.path.exists(p)) # sizes
  34. h = hashlib.sha256(str(size).encode()) # hash sizes
  35. h.update(''.join(paths).encode()) # hash paths
  36. return h.hexdigest() # return hash
  37. def exif_size(img: Image.Image):
  38. """Returns exif-corrected PIL size."""
  39. s = img.size # (width, height)
  40. if img.format == 'JPEG': # only support JPEG images
  41. with contextlib.suppress(Exception):
  42. exif = img.getexif()
  43. if exif:
  44. rotation = exif.get(274, None) # the EXIF key for the orientation tag is 274
  45. if rotation in [6, 8]: # rotation 270 or 90
  46. s = s[1], s[0]
  47. return s
  48. def verify_image(args):
  49. """Verify one image."""
  50. (im_file, cls), prefix = args
  51. # Number (found, corrupt), message
  52. nf, nc, msg = 0, 0, ''
  53. try:
  54. im = Image.open(im_file)
  55. im.verify() # PIL verify
  56. shape = exif_size(im) # image size
  57. shape = (shape[1], shape[0]) # hw
  58. assert (shape[0] > 9) & (shape[1] > 9), f'image size {shape} <10 pixels'
  59. assert im.format.lower() in IMG_FORMATS, f'invalid image format {im.format}'
  60. if im.format.lower() in ('jpg', 'jpeg'):
  61. with open(im_file, 'rb') as f:
  62. f.seek(-2, 2)
  63. if f.read() != b'\xff\xd9': # corrupt JPEG
  64. ImageOps.exif_transpose(Image.open(im_file)).save(im_file, 'JPEG', subsampling=0, quality=100)
  65. msg = f'{prefix}WARNING ⚠️ {im_file}: corrupt JPEG restored and saved'
  66. nf = 1
  67. except Exception as e:
  68. nc = 1
  69. msg = f'{prefix}WARNING ⚠️ {im_file}: ignoring corrupt image/label: {e}'
  70. return (im_file, cls), nf, nc, msg
  71. def verify_image_label(args):
  72. """Verify one image-label pair."""
  73. im_file, lb_file, prefix, keypoint, num_cls, nkpt, ndim = args
  74. # Number (missing, found, empty, corrupt), message, segments, keypoints
  75. nm, nf, ne, nc, msg, segments, keypoints = 0, 0, 0, 0, '', [], None
  76. try:
  77. # Verify images
  78. im = Image.open(im_file)
  79. im.verify() # PIL verify
  80. shape = exif_size(im) # image size
  81. shape = (shape[1], shape[0]) # hw
  82. assert (shape[0] > 9) & (shape[1] > 9), f'image size {shape} <10 pixels'
  83. assert im.format.lower() in IMG_FORMATS, f'invalid image format {im.format}'
  84. if im.format.lower() in ('jpg', 'jpeg'):
  85. with open(im_file, 'rb') as f:
  86. f.seek(-2, 2)
  87. if f.read() != b'\xff\xd9': # corrupt JPEG
  88. ImageOps.exif_transpose(Image.open(im_file)).save(im_file, 'JPEG', subsampling=0, quality=100)
  89. msg = f'{prefix}WARNING ⚠️ {im_file}: corrupt JPEG restored and saved'
  90. # Verify labels
  91. if os.path.isfile(lb_file):
  92. nf = 1 # label found
  93. with open(lb_file) as f:
  94. lb = [x.split() for x in f.read().strip().splitlines() if len(x)]
  95. if any(len(x) > 6 for x in lb) and (not keypoint): # is segment
  96. classes = np.array([x[0] for x in lb], dtype=np.float32)
  97. segments = [np.array(x[1:], dtype=np.float32).reshape(-1, 2) for x in lb] # (cls, xy1...)
  98. lb = np.concatenate((classes.reshape(-1, 1), segments2boxes(segments)), 1) # (cls, xywh)
  99. lb = np.array(lb, dtype=np.float32)
  100. nl = len(lb)
  101. if nl:
  102. if keypoint:
  103. assert lb.shape[1] == (5 + nkpt * ndim), f'labels require {(5 + nkpt * ndim)} columns each'
  104. assert (lb[:, 5::ndim] <= 1).all(), 'non-normalized or out of bounds coordinate labels'
  105. assert (lb[:, 6::ndim] <= 1).all(), 'non-normalized or out of bounds coordinate labels'
  106. else:
  107. assert lb.shape[1] == 5, f'labels require 5 columns, {lb.shape[1]} columns detected'
  108. assert (lb[:, 1:] <= 1).all(), \
  109. f'non-normalized or out of bounds coordinates {lb[:, 1:][lb[:, 1:] > 1]}'
  110. assert (lb >= 0).all(), f'negative label values {lb[lb < 0]}'
  111. # All labels
  112. max_cls = int(lb[:, 0].max()) # max label count
  113. assert max_cls <= num_cls, \
  114. f'Label class {max_cls} exceeds dataset class count {num_cls}. ' \
  115. f'Possible class labels are 0-{num_cls - 1}'
  116. _, i = np.unique(lb, axis=0, return_index=True)
  117. if len(i) < nl: # duplicate row check
  118. lb = lb[i] # remove duplicates
  119. if segments:
  120. segments = [segments[x] for x in i]
  121. msg = f'{prefix}WARNING ⚠️ {im_file}: {nl - len(i)} duplicate labels removed'
  122. else:
  123. ne = 1 # label empty
  124. lb = np.zeros((0, (5 + nkpt * ndim)), dtype=np.float32) if keypoint else np.zeros(
  125. (0, 5), dtype=np.float32)
  126. else:
  127. nm = 1 # label missing
  128. lb = np.zeros((0, (5 + nkpt * ndim)), dtype=np.float32) if keypoint else np.zeros((0, 5), dtype=np.float32)
  129. if keypoint:
  130. keypoints = lb[:, 5:].reshape(-1, nkpt, ndim)
  131. if ndim == 2:
  132. kpt_mask = np.where((keypoints[..., 0] < 0) | (keypoints[..., 1] < 0), 0.0, 1.0).astype(np.float32)
  133. keypoints = np.concatenate([keypoints, kpt_mask[..., None]], axis=-1) # (nl, nkpt, 3)
  134. lb = lb[:, :5]
  135. return im_file, lb, shape, segments, keypoints, nm, nf, ne, nc, msg
  136. except Exception as e:
  137. nc = 1
  138. msg = f'{prefix}WARNING ⚠️ {im_file}: ignoring corrupt image/label: {e}'
  139. return [None, None, None, None, None, nm, nf, ne, nc, msg]
  140. def polygon2mask(imgsz, polygons, color=1, downsample_ratio=1):
  141. """
  142. Args:
  143. imgsz (tuple): The image size.
  144. polygons (list[np.ndarray]): [N, M], N is the number of polygons, M is the number of points(Be divided by 2).
  145. color (int): color
  146. downsample_ratio (int): downsample ratio
  147. """
  148. mask = np.zeros(imgsz, dtype=np.uint8)
  149. polygons = np.asarray(polygons, dtype=np.int32)
  150. polygons = polygons.reshape((polygons.shape[0], -1, 2))
  151. cv2.fillPoly(mask, polygons, color=color)
  152. nh, nw = (imgsz[0] // downsample_ratio, imgsz[1] // downsample_ratio)
  153. # NOTE: fillPoly first then resize is trying to keep the same way of loss calculation when mask-ratio=1.
  154. return cv2.resize(mask, (nw, nh))
  155. def polygons2masks(imgsz, polygons, color, downsample_ratio=1):
  156. """
  157. Args:
  158. imgsz (tuple): The image size.
  159. polygons (list[np.ndarray]): each polygon is [N, M], N is number of polygons, M is number of points (M % 2 = 0)
  160. color (int): color
  161. downsample_ratio (int): downsample ratio
  162. """
  163. return np.array([polygon2mask(imgsz, [x.reshape(-1)], color, downsample_ratio) for x in polygons])
  164. def polygons2masks_overlap(imgsz, segments, downsample_ratio=1):
  165. """Return a (640, 640) overlap mask."""
  166. masks = np.zeros((imgsz[0] // downsample_ratio, imgsz[1] // downsample_ratio),
  167. dtype=np.int32 if len(segments) > 255 else np.uint8)
  168. areas = []
  169. ms = []
  170. for si in range(len(segments)):
  171. mask = polygon2mask(imgsz, [segments[si].reshape(-1)], downsample_ratio=downsample_ratio, color=1)
  172. ms.append(mask)
  173. areas.append(mask.sum())
  174. areas = np.asarray(areas)
  175. index = np.argsort(-areas)
  176. ms = np.array(ms)[index]
  177. for i in range(len(segments)):
  178. mask = ms[i] * (i + 1)
  179. masks = masks + mask
  180. masks = np.clip(masks, a_min=0, a_max=i + 1)
  181. return masks, index
  182. def check_det_dataset(dataset, autodownload=True):
  183. """
  184. Download, verify, and/or unzip a dataset if not found locally.
  185. This function checks the availability of a specified dataset, and if not found, it has the option to download and
  186. unzip the dataset. It then reads and parses the accompanying YAML data, ensuring key requirements are met and also
  187. resolves paths related to the dataset.
  188. Args:
  189. dataset (str): Path to the dataset or dataset descriptor (like a YAML file).
  190. autodownload (bool, optional): Whether to automatically download the dataset if not found. Defaults to True.
  191. Returns:
  192. (dict): Parsed dataset information and paths.
  193. """
  194. data = check_file(dataset)
  195. # Download (optional)
  196. extract_dir = ''
  197. if isinstance(data, (str, Path)) and (zipfile.is_zipfile(data) or is_tarfile(data)):
  198. new_dir = safe_download(data, dir=DATASETS_DIR, unzip=True, delete=False, curl=False)
  199. data = next((DATASETS_DIR / new_dir).rglob('*.yaml'))
  200. extract_dir, autodownload = data.parent, False
  201. # Read YAML (optional)
  202. if isinstance(data, (str, Path)):
  203. data = yaml_load(data, append_filename=True) # dictionary
  204. # Checks
  205. for k in 'train', 'val':
  206. if k not in data:
  207. if k == 'val' and 'validation' in data:
  208. LOGGER.info("WARNING ⚠️ renaming data YAML 'validation' key to 'val' to match YOLO format.")
  209. data['val'] = data.pop('validation') # replace 'validation' key with 'val' key
  210. else:
  211. raise SyntaxError(
  212. emojis(f"{dataset} '{k}:' key missing ❌.\n'train' and 'val' are required in all data YAMLs."))
  213. if 'names' not in data and 'nc' not in data:
  214. raise SyntaxError(emojis(f"{dataset} key missing ❌.\n either 'names' or 'nc' are required in all data YAMLs."))
  215. if 'names' in data and 'nc' in data and len(data['names']) != data['nc']:
  216. raise SyntaxError(emojis(f"{dataset} 'names' length {len(data['names'])} and 'nc: {data['nc']}' must match."))
  217. if 'names' not in data:
  218. data['names'] = [f'class_{i}' for i in range(data['nc'])]
  219. else:
  220. data['nc'] = len(data['names'])
  221. data['names'] = check_class_names(data['names'])
  222. # Resolve paths
  223. path = Path(extract_dir or data.get('path') or Path(data.get('yaml_file', '')).parent) # dataset root
  224. if not path.is_absolute():
  225. path = (DATASETS_DIR / path).resolve()
  226. data['path'] = path # download scripts
  227. for k in 'train', 'val', 'test':
  228. if data.get(k): # prepend path
  229. if isinstance(data[k], str):
  230. x = (path / data[k]).resolve()
  231. if not x.exists() and data[k].startswith('../'):
  232. x = (path / data[k][3:]).resolve()
  233. data[k] = str(x)
  234. else:
  235. data[k] = [str((path / x).resolve()) for x in data[k]]
  236. # Parse YAML
  237. train, val, test, s = (data.get(x) for x in ('train', 'val', 'test', 'download'))
  238. if val:
  239. val = [Path(x).resolve() for x in (val if isinstance(val, list) else [val])] # val path
  240. if not all(x.exists() for x in val):
  241. name = clean_url(dataset) # dataset name with URL auth stripped
  242. m = f"\nDataset '{name}' images not found ⚠️, missing path '{[x for x in val if not x.exists()][0]}'"
  243. if s and autodownload:
  244. LOGGER.warning(m)
  245. else:
  246. m += f"\nNote dataset download directory is '{DATASETS_DIR}'. You can update this in '{SETTINGS_YAML}'"
  247. raise FileNotFoundError(m)
  248. t = time.time()
  249. r = None # success
  250. if s.startswith('http') and s.endswith('.zip'): # URL
  251. safe_download(url=s, dir=DATASETS_DIR, delete=True)
  252. elif s.startswith('bash '): # bash script
  253. LOGGER.info(f'Running {s} ...')
  254. r = os.system(s)
  255. else: # python script
  256. exec(s, {'yaml': data})
  257. dt = f'({round(time.time() - t, 1)}s)'
  258. s = f"success ✅ {dt}, saved to {colorstr('bold', DATASETS_DIR)}" if r in (0, None) else f'failure {dt} ❌'
  259. LOGGER.info(f'Dataset download {s}\n')
  260. check_font('Arial.ttf' if is_ascii(data['names']) else 'Arial.Unicode.ttf') # download fonts
  261. return data # dictionary
  262. def check_cls_dataset(dataset: str, split=''):
  263. """
  264. Checks a classification dataset such as Imagenet.
  265. This function accepts a `dataset` name and attempts to retrieve the corresponding dataset information.
  266. If the dataset is not found locally, it attempts to download the dataset from the internet and save it locally.
  267. Args:
  268. dataset (str): The name of the dataset.
  269. split (str, optional): The split of the dataset. Either 'val', 'test', or ''. Defaults to ''.
  270. Returns:
  271. (dict): A dictionary containing the following keys:
  272. - 'train' (Path): The directory path containing the training set of the dataset.
  273. - 'val' (Path): The directory path containing the validation set of the dataset.
  274. - 'test' (Path): The directory path containing the test set of the dataset.
  275. - 'nc' (int): The number of classes in the dataset.
  276. - 'names' (dict): A dictionary of class names in the dataset.
  277. """
  278. dataset = Path(dataset)
  279. data_dir = (dataset if dataset.is_dir() else (DATASETS_DIR / dataset)).resolve()
  280. if not data_dir.is_dir():
  281. LOGGER.warning(f'\nDataset not found ⚠️, missing path {data_dir}, attempting download...')
  282. t = time.time()
  283. if str(dataset) == 'imagenet':
  284. subprocess.run(f"bash {ROOT / 'data/scripts/get_imagenet.sh'}", shell=True, check=True)
  285. else:
  286. url = f'https://github.com/ultralytics/yolov5/releases/download/v1.0/{dataset}.zip'
  287. download(url, dir=data_dir.parent)
  288. s = f"Dataset download success ✅ ({time.time() - t:.1f}s), saved to {colorstr('bold', data_dir)}\n"
  289. LOGGER.info(s)
  290. train_set = data_dir / 'train'
  291. val_set = data_dir / 'val' if (data_dir / 'val').exists() else data_dir / 'validation' if (
  292. data_dir / 'validation').exists() else None # data/test or data/val
  293. test_set = data_dir / 'test' if (data_dir / 'test').exists() else None # data/val or data/test
  294. if split == 'val' and not val_set:
  295. LOGGER.warning("WARNING ⚠️ Dataset 'split=val' not found, using 'split=test' instead.")
  296. elif split == 'test' and not test_set:
  297. LOGGER.warning("WARNING ⚠️ Dataset 'split=test' not found, using 'split=val' instead.")
  298. nc = len([x for x in (data_dir / 'train').glob('*') if x.is_dir()]) # number of classes
  299. names = [x.name for x in (data_dir / 'train').iterdir() if x.is_dir()] # class names list
  300. names = dict(enumerate(sorted(names)))
  301. # Print to console
  302. for k, v in {'train': train_set, 'val': val_set, 'test': test_set}.items():
  303. prefix = f'{colorstr(f"{k}:")} {v}...'
  304. if v is None:
  305. LOGGER.info(prefix)
  306. else:
  307. files = [path for path in v.rglob('*.*') if path.suffix[1:].lower() in IMG_FORMATS]
  308. nf = len(files) # number of files
  309. nd = len({file.parent for file in files}) # number of directories
  310. if nf == 0:
  311. if k == 'train':
  312. raise FileNotFoundError(emojis(f"{dataset} '{k}:' no training images found ❌ "))
  313. else:
  314. LOGGER.warning(f'{prefix} found {nf} images in {nd} classes: WARNING ⚠️ no images found')
  315. elif nd != nc:
  316. LOGGER.warning(f'{prefix} found {nf} images in {nd} classes: ERROR ❌️ requires {nc} classes, not {nd}')
  317. else:
  318. LOGGER.info(f'{prefix} found {nf} images in {nd} classes ✅ ')
  319. return {'train': train_set, 'val': val_set or test_set, 'test': test_set or val_set, 'nc': nc, 'names': names}
  320. class HUBDatasetStats:
  321. """
  322. A class for generating HUB dataset JSON and `-hub` dataset directory.
  323. Args:
  324. path (str): Path to data.yaml or data.zip (with data.yaml inside data.zip). Default is 'coco128.yaml'.
  325. task (str): Dataset task. Options are 'detect', 'segment', 'pose', 'classify'. Default is 'detect'.
  326. autodownload (bool): Attempt to download dataset if not found locally. Default is False.
  327. Example:
  328. Download *.zip files from i.e. https://github.com/ultralytics/hub/raw/main/example_datasets/coco8.zip.
  329. ```python
  330. from ultralytics.data.utils import HUBDatasetStats
  331. stats = HUBDatasetStats('path/to/coco8.zip', task='detect') # detect dataset
  332. stats = HUBDatasetStats('path/to/coco8-seg.zip', task='segment') # segment dataset
  333. stats = HUBDatasetStats('path/to/coco8-pose.zip', task='pose') # pose dataset
  334. stats.get_json(save=False)
  335. stats.process_images()
  336. ```
  337. """
  338. def __init__(self, path='coco128.yaml', task='detect', autodownload=False):
  339. """Initialize class."""
  340. path = Path(path).resolve()
  341. LOGGER.info(f'Starting HUB dataset checks for {path}....')
  342. zipped, data_dir, yaml_path = self._unzip(path)
  343. try:
  344. # data = yaml_load(check_yaml(yaml_path)) # data dict
  345. data = check_det_dataset(yaml_path, autodownload) # data dict
  346. if zipped:
  347. data['path'] = data_dir
  348. except Exception as e:
  349. raise Exception('error/HUB/dataset_stats/yaml_load') from e
  350. self.hub_dir = Path(str(data['path']) + '-hub')
  351. self.im_dir = self.hub_dir / 'images'
  352. self.im_dir.mkdir(parents=True, exist_ok=True) # makes /images
  353. self.stats = {'nc': len(data['names']), 'names': list(data['names'].values())} # statistics dictionary
  354. self.data = data
  355. self.task = task # detect, segment, pose, classify
  356. @staticmethod
  357. def _find_yaml(dir):
  358. """Return data.yaml file."""
  359. files = list(dir.glob('*.yaml')) or list(dir.rglob('*.yaml')) # try root level first and then recursive
  360. assert files, f"No *.yaml file found in '{dir.resolve()}'"
  361. if len(files) > 1:
  362. files = [f for f in files if f.stem == dir.stem] # prefer *.yaml files that match dir name
  363. assert len(files) == 1, f"Expected 1 *.yaml file in '{dir.resolve()}', but found {len(files)}.\n{files}"
  364. return files[0]
  365. def _unzip(self, path):
  366. """Unzip data.zip."""
  367. if not str(path).endswith('.zip'): # path is data.yaml
  368. return False, None, path
  369. unzip_dir = unzip_file(path, path=path.parent)
  370. assert unzip_dir.is_dir(), f'Error unzipping {path}, {unzip_dir} not found. ' \
  371. f'path/to/abc.zip MUST unzip to path/to/abc/'
  372. return True, str(unzip_dir), self._find_yaml(unzip_dir) # zipped, data_dir, yaml_path
  373. def _hub_ops(self, f):
  374. """Saves a compressed image for HUB previews."""
  375. compress_one_image(f, self.im_dir / Path(f).name) # save to dataset-hub
  376. def get_json(self, save=False, verbose=False):
  377. """Return dataset JSON for Ultralytics HUB."""
  378. from ultralytics.data import YOLODataset # ClassificationDataset
  379. def _round(labels):
  380. """Update labels to integer class and 4 decimal place floats."""
  381. if self.task == 'detect':
  382. coordinates = labels['bboxes']
  383. elif self.task == 'segment':
  384. coordinates = [x.flatten() for x in labels['segments']]
  385. elif self.task == 'pose':
  386. n = labels['keypoints'].shape[0]
  387. coordinates = np.concatenate((labels['bboxes'], labels['keypoints'].reshape(n, -1)), 1)
  388. else:
  389. raise ValueError('Undefined dataset task.')
  390. zipped = zip(labels['cls'], coordinates)
  391. return [[int(c[0]), *(round(float(x), 4) for x in points)] for c, points in zipped]
  392. for split in 'train', 'val', 'test':
  393. if self.data.get(split) is None:
  394. self.stats[split] = None # i.e. no test set
  395. continue
  396. dataset = YOLODataset(img_path=self.data[split],
  397. data=self.data,
  398. use_segments=self.task == 'segment',
  399. use_keypoints=self.task == 'pose')
  400. x = np.array([
  401. np.bincount(label['cls'].astype(int).flatten(), minlength=self.data['nc'])
  402. for label in tqdm(dataset.labels, total=len(dataset), desc='Statistics')]) # shape(128x80)
  403. self.stats[split] = {
  404. 'instance_stats': {
  405. 'total': int(x.sum()),
  406. 'per_class': x.sum(0).tolist()},
  407. 'image_stats': {
  408. 'total': len(dataset),
  409. 'unlabelled': int(np.all(x == 0, 1).sum()),
  410. 'per_class': (x > 0).sum(0).tolist()},
  411. 'labels': [{
  412. Path(k).name: _round(v)} for k, v in zip(dataset.im_files, dataset.labels)]}
  413. # Save, print and return
  414. if save:
  415. stats_path = self.hub_dir / 'stats.json'
  416. LOGGER.info(f'Saving {stats_path.resolve()}...')
  417. with open(stats_path, 'w') as f:
  418. json.dump(self.stats, f) # save stats.json
  419. if verbose:
  420. LOGGER.info(json.dumps(self.stats, indent=2, sort_keys=False))
  421. return self.stats
  422. def process_images(self):
  423. """Compress images for Ultralytics HUB."""
  424. from ultralytics.data import YOLODataset # ClassificationDataset
  425. for split in 'train', 'val', 'test':
  426. if self.data.get(split) is None:
  427. continue
  428. dataset = YOLODataset(img_path=self.data[split], data=self.data)
  429. with ThreadPool(NUM_THREADS) as pool:
  430. for _ in tqdm(pool.imap(self._hub_ops, dataset.im_files), total=len(dataset), desc=f'{split} images'):
  431. pass
  432. LOGGER.info(f'Done. All images saved to {self.im_dir}')
  433. return self.im_dir
  434. def compress_one_image(f, f_new=None, max_dim=1920, quality=50):
  435. """
  436. Compresses a single image file to reduced size while preserving its aspect ratio and quality using either the
  437. Python Imaging Library (PIL) or OpenCV library. If the input image is smaller than the maximum dimension, it will
  438. not be resized.
  439. Args:
  440. f (str): The path to the input image file.
  441. f_new (str, optional): The path to the output image file. If not specified, the input file will be overwritten.
  442. max_dim (int, optional): The maximum dimension (width or height) of the output image. Default is 1920 pixels.
  443. quality (int, optional): The image compression quality as a percentage. Default is 50%.
  444. Example:
  445. ```python
  446. from pathlib import Path
  447. from ultralytics.data.utils import compress_one_image
  448. for f in Path('path/to/dataset').rglob('*.jpg'):
  449. compress_one_image(f)
  450. ```
  451. """
  452. try: # use PIL
  453. im = Image.open(f)
  454. r = max_dim / max(im.height, im.width) # ratio
  455. if r < 1.0: # image too large
  456. im = im.resize((int(im.width * r), int(im.height * r)))
  457. im.save(f_new or f, 'JPEG', quality=quality, optimize=True) # save
  458. except Exception as e: # use OpenCV
  459. LOGGER.info(f'WARNING ⚠️ HUB ops PIL failure {f}: {e}')
  460. im = cv2.imread(f)
  461. im_height, im_width = im.shape[:2]
  462. r = max_dim / max(im_height, im_width) # ratio
  463. if r < 1.0: # image too large
  464. im = cv2.resize(im, (int(im_width * r), int(im_height * r)), interpolation=cv2.INTER_AREA)
  465. cv2.imwrite(str(f_new or f), im)
  466. def autosplit(path=DATASETS_DIR / 'coco8/images', weights=(0.9, 0.1, 0.0), annotated_only=False):
  467. """
  468. Automatically split a dataset into train/val/test splits and save the resulting splits into autosplit_*.txt files.
  469. Args:
  470. path (Path, optional): Path to images directory. Defaults to DATASETS_DIR / 'coco8/images'.
  471. weights (list | tuple, optional): Train, validation, and test split fractions. Defaults to (0.9, 0.1, 0.0).
  472. annotated_only (bool, optional): If True, only images with an associated txt file are used. Defaults to False.
  473. Example:
  474. ```python
  475. from ultralytics.data.utils import autosplit
  476. autosplit()
  477. ```
  478. """
  479. path = Path(path) # images dir
  480. files = sorted(x for x in path.rglob('*.*') if x.suffix[1:].lower() in IMG_FORMATS) # image files only
  481. n = len(files) # number of files
  482. random.seed(0) # for reproducibility
  483. indices = random.choices([0, 1, 2], weights=weights, k=n) # assign each image to a split
  484. txt = ['autosplit_train.txt', 'autosplit_val.txt', 'autosplit_test.txt'] # 3 txt files
  485. for x in txt:
  486. if (path.parent / x).exists():
  487. (path.parent / x).unlink() # remove existing
  488. LOGGER.info(f'Autosplitting images from {path}' + ', using *.txt labeled images only' * annotated_only)
  489. for i, img in tqdm(zip(indices, files), total=n):
  490. if not annotated_only or Path(img2label_paths([str(img)])[0]).exists(): # check label
  491. with open(path.parent / txt[i], 'a') as f:
  492. f.write(f'./{img.relative_to(path.parent).as_posix()}' + '\n') # add image to txt file