base.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287
  1. # Ultralytics YOLO 🚀, AGPL-3.0 license
  2. import glob
  3. import math
  4. import os
  5. import random
  6. from copy import deepcopy
  7. from multiprocessing.pool import ThreadPool
  8. from pathlib import Path
  9. from typing import Optional
  10. import cv2
  11. import numpy as np
  12. import psutil
  13. from torch.utils.data import Dataset
  14. from tqdm import tqdm
  15. from ultralytics.utils import DEFAULT_CFG, LOCAL_RANK, LOGGER, NUM_THREADS, TQDM_BAR_FORMAT
  16. from .utils import HELP_URL, IMG_FORMATS
  17. class BaseDataset(Dataset):
  18. """
  19. Base dataset class for loading and processing image data.
  20. Args:
  21. img_path (str): Path to the folder containing images.
  22. imgsz (int, optional): Image size. Defaults to 640.
  23. cache (bool, optional): Cache images to RAM or disk during training. Defaults to False.
  24. augment (bool, optional): If True, data augmentation is applied. Defaults to True.
  25. hyp (dict, optional): Hyperparameters to apply data augmentation. Defaults to None.
  26. prefix (str, optional): Prefix to print in log messages. Defaults to ''.
  27. rect (bool, optional): If True, rectangular training is used. Defaults to False.
  28. batch_size (int, optional): Size of batches. Defaults to None.
  29. stride (int, optional): Stride. Defaults to 32.
  30. pad (float, optional): Padding. Defaults to 0.0.
  31. single_cls (bool, optional): If True, single class training is used. Defaults to False.
  32. classes (list): List of included classes. Default is None.
  33. fraction (float): Fraction of dataset to utilize. Default is 1.0 (use all data).
  34. Attributes:
  35. im_files (list): List of image file paths.
  36. labels (list): List of label data dictionaries.
  37. ni (int): Number of images in the dataset.
  38. ims (list): List of loaded images.
  39. npy_files (list): List of numpy file paths.
  40. transforms (callable): Image transformation function.
  41. """
  42. def __init__(self,
  43. img_path,
  44. imgsz=640,
  45. cache=False,
  46. augment=True,
  47. hyp=DEFAULT_CFG,
  48. prefix='',
  49. rect=False,
  50. batch_size=16,
  51. stride=32,
  52. pad=0.5,
  53. single_cls=False,
  54. classes=None,
  55. fraction=1.0):
  56. super().__init__()
  57. self.img_path = img_path
  58. self.imgsz = imgsz
  59. self.augment = augment
  60. self.single_cls = single_cls
  61. self.prefix = prefix
  62. self.fraction = fraction
  63. self.im_files = self.get_img_files(self.img_path)
  64. self.labels = self.get_labels()
  65. self.update_labels(include_class=classes) # single_cls and include_class
  66. self.ni = len(self.labels) # number of images
  67. self.rect = rect
  68. self.batch_size = batch_size
  69. self.stride = stride
  70. self.pad = pad
  71. if self.rect:
  72. assert self.batch_size is not None
  73. self.set_rectangle()
  74. # Buffer thread for mosaic images
  75. self.buffer = [] # buffer size = batch size
  76. self.max_buffer_length = min((self.ni, self.batch_size * 8, 1000)) if self.augment else 0
  77. # Cache stuff
  78. if cache == 'ram' and not self.check_cache_ram():
  79. cache = False
  80. self.ims, self.im_hw0, self.im_hw = [None] * self.ni, [None] * self.ni, [None] * self.ni
  81. self.npy_files = [Path(f).with_suffix('.npy') for f in self.im_files]
  82. if cache:
  83. self.cache_images(cache)
  84. # Transforms
  85. self.transforms = self.build_transforms(hyp=hyp)
  86. def get_img_files(self, img_path):
  87. """Read image files."""
  88. try:
  89. f = [] # image files
  90. for p in img_path if isinstance(img_path, list) else [img_path]:
  91. p = Path(p) # os-agnostic
  92. if p.is_dir(): # dir
  93. f += glob.glob(str(p / '**' / '*.*'), recursive=True)
  94. # F = list(p.rglob('*.*')) # pathlib
  95. elif p.is_file(): # file
  96. with open(p) as t:
  97. t = t.read().strip().splitlines()
  98. parent = str(p.parent) + os.sep
  99. f += [x.replace('./', parent) if x.startswith('./') else x for x in t] # local to global path
  100. # F += [p.parent / x.lstrip(os.sep) for x in t] # local to global path (pathlib)
  101. else:
  102. raise FileNotFoundError(f'{self.prefix}{p} does not exist')
  103. im_files = sorted(x.replace('/', os.sep) for x in f if x.split('.')[-1].lower() in IMG_FORMATS)
  104. # self.img_files = sorted([x for x in f if x.suffix[1:].lower() in IMG_FORMATS]) # pathlib
  105. assert im_files, f'{self.prefix}No images found'
  106. except Exception as e:
  107. raise FileNotFoundError(f'{self.prefix}Error loading data from {img_path}\n{HELP_URL}') from e
  108. if self.fraction < 1:
  109. im_files = im_files[:round(len(im_files) * self.fraction)]
  110. return im_files
  111. def update_labels(self, include_class: Optional[list]):
  112. """include_class, filter labels to include only these classes (optional)."""
  113. include_class_array = np.array(include_class).reshape(1, -1)
  114. for i in range(len(self.labels)):
  115. if include_class is not None:
  116. cls = self.labels[i]['cls']
  117. bboxes = self.labels[i]['bboxes']
  118. segments = self.labels[i]['segments']
  119. keypoints = self.labels[i]['keypoints']
  120. j = (cls == include_class_array).any(1)
  121. self.labels[i]['cls'] = cls[j]
  122. self.labels[i]['bboxes'] = bboxes[j]
  123. if segments:
  124. self.labels[i]['segments'] = [segments[si] for si, idx in enumerate(j) if idx]
  125. if keypoints is not None:
  126. self.labels[i]['keypoints'] = keypoints[j]
  127. if self.single_cls:
  128. self.labels[i]['cls'][:, 0] = 0
  129. def load_image(self, i):
  130. """Loads 1 image from dataset index 'i', returns (im, resized hw)."""
  131. im, f, fn = self.ims[i], self.im_files[i], self.npy_files[i]
  132. if im is None: # not cached in RAM
  133. if fn.exists(): # load npy
  134. im = np.load(fn)
  135. else: # read image
  136. im = cv2.imread(f) # BGR
  137. if im is None:
  138. raise FileNotFoundError(f'Image Not Found {f}')
  139. h0, w0 = im.shape[:2] # orig hw
  140. r = self.imgsz / max(h0, w0) # ratio
  141. if r != 1: # if sizes are not equal
  142. interp = cv2.INTER_LINEAR if (self.augment or r > 1) else cv2.INTER_AREA
  143. im = cv2.resize(im, (min(math.ceil(w0 * r), self.imgsz), min(math.ceil(h0 * r), self.imgsz)),
  144. interpolation=interp)
  145. # Add to buffer if training with augmentations
  146. if self.augment:
  147. self.ims[i], self.im_hw0[i], self.im_hw[i] = im, (h0, w0), im.shape[:2] # im, hw_original, hw_resized
  148. self.buffer.append(i)
  149. if len(self.buffer) >= self.max_buffer_length:
  150. j = self.buffer.pop(0)
  151. self.ims[j], self.im_hw0[j], self.im_hw[j] = None, None, None
  152. return im, (h0, w0), im.shape[:2]
  153. return self.ims[i], self.im_hw0[i], self.im_hw[i]
  154. def cache_images(self, cache):
  155. """Cache images to memory or disk."""
  156. b, gb = 0, 1 << 30 # bytes of cached images, bytes per gigabytes
  157. fcn = self.cache_images_to_disk if cache == 'disk' else self.load_image
  158. with ThreadPool(NUM_THREADS) as pool:
  159. results = pool.imap(fcn, range(self.ni))
  160. pbar = tqdm(enumerate(results), total=self.ni, bar_format=TQDM_BAR_FORMAT, disable=LOCAL_RANK > 0)
  161. for i, x in pbar:
  162. if cache == 'disk':
  163. b += self.npy_files[i].stat().st_size
  164. else: # 'ram'
  165. self.ims[i], self.im_hw0[i], self.im_hw[i] = x # im, hw_orig, hw_resized = load_image(self, i)
  166. b += self.ims[i].nbytes
  167. pbar.desc = f'{self.prefix}Caching images ({b / gb:.1f}GB {cache})'
  168. pbar.close()
  169. def cache_images_to_disk(self, i):
  170. """Saves an image as an *.npy file for faster loading."""
  171. f = self.npy_files[i]
  172. if not f.exists():
  173. np.save(f.as_posix(), cv2.imread(self.im_files[i]))
  174. def check_cache_ram(self, safety_margin=0.5):
  175. """Check image caching requirements vs available memory."""
  176. b, gb = 0, 1 << 30 # bytes of cached images, bytes per gigabytes
  177. n = min(self.ni, 30) # extrapolate from 30 random images
  178. for _ in range(n):
  179. im = cv2.imread(random.choice(self.im_files)) # sample image
  180. ratio = self.imgsz / max(im.shape[0], im.shape[1]) # max(h, w) # ratio
  181. b += im.nbytes * ratio ** 2
  182. mem_required = b * self.ni / n * (1 + safety_margin) # GB required to cache dataset into RAM
  183. mem = psutil.virtual_memory()
  184. cache = mem_required < mem.available # to cache or not to cache, that is the question
  185. if not cache:
  186. LOGGER.info(f'{self.prefix}{mem_required / gb:.1f}GB RAM required to cache images '
  187. f'with {int(safety_margin * 100)}% safety margin but only '
  188. f'{mem.available / gb:.1f}/{mem.total / gb:.1f}GB available, '
  189. f"{'caching images ✅' if cache else 'not caching images ⚠️'}")
  190. return cache
  191. def set_rectangle(self):
  192. """Sets the shape of bounding boxes for YOLO detections as rectangles."""
  193. bi = np.floor(np.arange(self.ni) / self.batch_size).astype(int) # batch index
  194. nb = bi[-1] + 1 # number of batches
  195. s = np.array([x.pop('shape') for x in self.labels]) # hw
  196. ar = s[:, 0] / s[:, 1] # aspect ratio
  197. irect = ar.argsort()
  198. self.im_files = [self.im_files[i] for i in irect]
  199. self.labels = [self.labels[i] for i in irect]
  200. ar = ar[irect]
  201. # Set training image shapes
  202. shapes = [[1, 1]] * nb
  203. for i in range(nb):
  204. ari = ar[bi == i]
  205. mini, maxi = ari.min(), ari.max()
  206. if maxi < 1:
  207. shapes[i] = [maxi, 1]
  208. elif mini > 1:
  209. shapes[i] = [1, 1 / mini]
  210. self.batch_shapes = np.ceil(np.array(shapes) * self.imgsz / self.stride + self.pad).astype(int) * self.stride
  211. self.batch = bi # batch index of image
  212. def __getitem__(self, index):
  213. """Returns transformed label information for given index."""
  214. return self.transforms(self.get_image_and_label(index))
  215. def get_image_and_label(self, index):
  216. """Get and return label information from the dataset."""
  217. label = deepcopy(self.labels[index]) # requires deepcopy() https://github.com/ultralytics/ultralytics/pull/1948
  218. label.pop('shape', None) # shape is for rect, remove it
  219. label['img'], label['ori_shape'], label['resized_shape'] = self.load_image(index)
  220. label['ratio_pad'] = (label['resized_shape'][0] / label['ori_shape'][0],
  221. label['resized_shape'][1] / label['ori_shape'][1]) # for evaluation
  222. if self.rect:
  223. label['rect_shape'] = self.batch_shapes[self.batch[index]]
  224. return self.update_labels_info(label)
  225. def __len__(self):
  226. """Returns the length of the labels list for the dataset."""
  227. return len(self.labels)
  228. def update_labels_info(self, label):
  229. """custom your label format here."""
  230. return label
  231. def build_transforms(self, hyp=None):
  232. """Users can custom augmentations here
  233. like:
  234. if self.augment:
  235. # Training transforms
  236. return Compose([])
  237. else:
  238. # Val transforms
  239. return Compose([])
  240. """
  241. raise NotImplementedError
  242. def get_labels(self):
  243. """Users can custom their own format here.
  244. Make sure your output is a list with each element like below:
  245. dict(
  246. im_file=im_file,
  247. shape=shape, # format: (height, width)
  248. cls=cls,
  249. bboxes=bboxes, # xywh
  250. segments=segments, # xy
  251. keypoints=keypoints, # xy
  252. normalized=True, # or False
  253. bbox_format="xyxy", # or xywh, ltwh
  254. )
  255. """
  256. raise NotImplementedError