datasets.py 44 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067
  1. # Dataset utils and dataloaders
  2. import glob
  3. import logging
  4. import math
  5. import os
  6. import random
  7. import shutil
  8. import time
  9. from itertools import repeat
  10. from multiprocessing.pool import ThreadPool
  11. from pathlib import Path
  12. from threading import Thread
  13. import cv2
  14. import numpy as np
  15. import torch
  16. import torch.nn.functional as F
  17. from PIL import Image, ExifTags
  18. from torch.utils.data import Dataset
  19. from tqdm import tqdm
  20. from utils.general import check_requirements, xyxy2xywh, xywh2xyxy, xywhn2xyxy, xyn2xy, segment2box, segments2boxes, \
  21. resample_segments, clean_str
  22. from utils.torch_utils import torch_distributed_zero_first
  23. # Parameters
  24. help_url = 'https://github.com/ultralytics/yolov5/wiki/Train-Custom-Data'
  25. img_formats = ['bmp', 'jpg', 'jpeg', 'png', 'tif', 'tiff', 'dng', 'webp', 'mpo'] # acceptable image suffixes
  26. vid_formats = ['mov', 'avi', 'mp4', 'mpg', 'mpeg', 'm4v', 'wmv', 'mkv'] # acceptable video suffixes
  27. logger = logging.getLogger(__name__)
  28. # Get orientation exif tag
  29. for orientation in ExifTags.TAGS.keys():
  30. if ExifTags.TAGS[orientation] == 'Orientation':
  31. break
  32. def get_hash(files):
  33. # Returns a single hash value of a list of files
  34. return sum(os.path.getsize(f) for f in files if os.path.isfile(f))
  35. def exif_size(img):
  36. # Returns exif-corrected PIL size
  37. s = img.size # (width, height)
  38. try:
  39. rotation = dict(img._getexif().items())[orientation]
  40. if rotation == 6: # rotation 270
  41. s = (s[1], s[0])
  42. elif rotation == 8: # rotation 90
  43. s = (s[1], s[0])
  44. except:
  45. pass
  46. return s
  47. def create_dataloader(path, imgsz, batch_size, stride, opt, hyp=None, augment=False, cache=False, pad=0.0, rect=False,
  48. rank=-1, world_size=1, workers=8, image_weights=False, quad=False, prefix=''):
  49. # Make sure only the first process in DDP process the dataset first, and the following others can use the cache
  50. with torch_distributed_zero_first(rank):
  51. dataset = LoadImagesAndLabels(path, imgsz, batch_size,
  52. augment=augment, # augment images
  53. hyp=hyp, # augmentation hyperparameters
  54. rect=rect, # rectangular training
  55. cache_images=cache,
  56. single_cls=opt.single_cls,
  57. stride=int(stride),
  58. pad=pad,
  59. image_weights=image_weights,
  60. prefix=prefix)
  61. batch_size = min(batch_size, len(dataset))
  62. nw = min([os.cpu_count() // world_size, batch_size if batch_size > 1 else 0, workers]) # number of workers
  63. sampler = torch.utils.data.distributed.DistributedSampler(dataset) if rank != -1 else None
  64. loader = torch.utils.data.DataLoader if image_weights else InfiniteDataLoader
  65. # Use torch.utils.data.DataLoader() if dataset.properties will update during training else InfiniteDataLoader()
  66. dataloader = loader(dataset,
  67. batch_size=batch_size,
  68. num_workers=nw,
  69. sampler=sampler,
  70. pin_memory=True,
  71. collate_fn=LoadImagesAndLabels.collate_fn4 if quad else LoadImagesAndLabels.collate_fn)
  72. return dataloader, dataset
  73. class InfiniteDataLoader(torch.utils.data.dataloader.DataLoader):
  74. """ Dataloader that reuses workers
  75. Uses same syntax as vanilla DataLoader
  76. """
  77. def __init__(self, *args, **kwargs):
  78. super().__init__(*args, **kwargs)
  79. object.__setattr__(self, 'batch_sampler', _RepeatSampler(self.batch_sampler))
  80. self.iterator = super().__iter__()
  81. def __len__(self):
  82. return len(self.batch_sampler.sampler)
  83. def __iter__(self):
  84. for i in range(len(self)):
  85. yield next(self.iterator)
  86. class _RepeatSampler(object):
  87. """ Sampler that repeats forever
  88. Args:
  89. sampler (Sampler)
  90. """
  91. def __init__(self, sampler):
  92. self.sampler = sampler
  93. def __iter__(self):
  94. while True:
  95. yield from iter(self.sampler)
  96. class LoadImages: # for inference
  97. def __init__(self, path, img_size=640, stride=32):
  98. p = str(Path(path).absolute()) # os-agnostic absolute path
  99. if '*' in p:
  100. files = sorted(glob.glob(p, recursive=True)) # glob
  101. elif os.path.isdir(p):
  102. files = sorted(glob.glob(os.path.join(p, '*.*'))) # dir
  103. elif os.path.isfile(p):
  104. files = [p] # files
  105. else:
  106. raise Exception(f'ERROR: {p} does not exist')
  107. images = [x for x in files if x.split('.')[-1].lower() in img_formats]
  108. videos = [x for x in files if x.split('.')[-1].lower() in vid_formats]
  109. ni, nv = len(images), len(videos)
  110. self.img_size = img_size
  111. self.stride = stride
  112. self.files = images + videos
  113. self.nf = ni + nv # number of files
  114. self.video_flag = [False] * ni + [True] * nv
  115. self.mode = 'image'
  116. if any(videos):
  117. self.new_video(videos[0]) # new video
  118. else:
  119. self.cap = None
  120. assert self.nf > 0, f'No images or videos found in {p}. ' \
  121. f'Supported formats are:\nimages: {img_formats}\nvideos: {vid_formats}'
  122. def __iter__(self):
  123. self.count = 0
  124. return self
  125. def __next__(self):
  126. if self.count == self.nf:
  127. raise StopIteration
  128. path = self.files[self.count]
  129. if self.video_flag[self.count]:
  130. # Read video
  131. self.mode = 'video'
  132. ret_val, img0 = self.cap.read()
  133. if not ret_val:
  134. self.count += 1
  135. self.cap.release()
  136. if self.count == self.nf: # last video
  137. raise StopIteration
  138. else:
  139. path = self.files[self.count]
  140. self.new_video(path)
  141. ret_val, img0 = self.cap.read()
  142. self.frame += 1
  143. print(f'video {self.count + 1}/{self.nf} ({self.frame}/{self.frames}) {path}: ', end='')
  144. else:
  145. # Read image
  146. self.count += 1
  147. img0 = cv2.imread(path) # BGR
  148. assert img0 is not None, 'Image Not Found ' + path
  149. print(f'image {self.count}/{self.nf} {path}: ', end='')
  150. # Padded resize
  151. img = letterbox(img0, self.img_size, stride=self.stride)[0]
  152. # Convert
  153. img = img[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
  154. img = np.ascontiguousarray(img)
  155. return path, img, img0, self.cap
  156. def new_video(self, path):
  157. self.frame = 0
  158. self.cap = cv2.VideoCapture(path)
  159. self.frames = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT))
  160. def __len__(self):
  161. return self.nf # number of files
  162. class LoadWebcam: # for inference
  163. def __init__(self, pipe='0', img_size=640, stride=32):
  164. self.img_size = img_size
  165. self.stride = stride
  166. if pipe.isnumeric():
  167. pipe = eval(pipe) # local camera
  168. # pipe = 'rtsp://192.168.1.64/1' # IP camera
  169. # pipe = 'rtsp://username:password@192.168.1.64/1' # IP camera with login
  170. # pipe = 'http://wmccpinetop.axiscam.net/mjpg/video.mjpg' # IP golf camera
  171. self.pipe = pipe
  172. self.cap = cv2.VideoCapture(pipe) # video capture object
  173. self.cap.set(cv2.CAP_PROP_BUFFERSIZE, 3) # set buffer size
  174. def __iter__(self):
  175. self.count = -1
  176. return self
  177. def __next__(self):
  178. self.count += 1
  179. if cv2.waitKey(1) == ord('q'): # q to quit
  180. self.cap.release()
  181. cv2.destroyAllWindows()
  182. raise StopIteration
  183. # Read frame
  184. if self.pipe == 0: # local camera
  185. ret_val, img0 = self.cap.read()
  186. img0 = cv2.flip(img0, 1) # flip left-right
  187. else: # IP camera
  188. n = 0
  189. while True:
  190. n += 1
  191. self.cap.grab()
  192. if n % 30 == 0: # skip frames
  193. ret_val, img0 = self.cap.retrieve()
  194. if ret_val:
  195. break
  196. # Print
  197. assert ret_val, f'Camera Error {self.pipe}'
  198. img_path = 'webcam.jpg'
  199. print(f'webcam {self.count}: ', end='')
  200. # Padded resize
  201. img = letterbox(img0, self.img_size, stride=self.stride)[0]
  202. # Convert
  203. img = img[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
  204. img = np.ascontiguousarray(img)
  205. return img_path, img, img0, None
  206. def __len__(self):
  207. return 0
  208. class LoadStreams: # multiple IP or RTSP cameras
  209. def __init__(self, sources='streams.txt', img_size=640, stride=32):
  210. self.mode = 'stream'
  211. self.img_size = img_size
  212. self.stride = stride
  213. if os.path.isfile(sources):
  214. with open(sources, 'r') as f:
  215. sources = [x.strip() for x in f.read().strip().splitlines() if len(x.strip())]
  216. else:
  217. sources = [sources]
  218. n = len(sources)
  219. self.imgs, self.fps, self.frames, self.threads = [None] * n, [0] * n, [0] * n, [None] * n
  220. self.sources = [clean_str(x) for x in sources] # clean source names for later
  221. for i, s in enumerate(sources): # index, source
  222. # Start thread to read frames from video stream
  223. print(f'{i + 1}/{n}: {s}... ', end='')
  224. if 'youtube.com/' in s or 'youtu.be/' in s: # if source is YouTube video
  225. check_requirements(('pafy', 'youtube_dl'))
  226. import pafy
  227. s = pafy.new(s).getbest(preftype="mp4").url # YouTube URL
  228. s = eval(s) if s.isnumeric() else s # i.e. s = '0' local webcam
  229. cap = cv2.VideoCapture(s)
  230. assert cap.isOpened(), f'Failed to open {s}'
  231. w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
  232. h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
  233. self.fps[i] = max(cap.get(cv2.CAP_PROP_FPS) % 100, 0) or 30.0 # 30 FPS fallback
  234. self.frames[i] = max(int(cap.get(cv2.CAP_PROP_FRAME_COUNT)), 0) or float('inf') # infinite stream fallback
  235. _, self.imgs[i] = cap.read() # guarantee first frame
  236. self.threads[i] = Thread(target=self.update, args=([i, cap]), daemon=True)
  237. print(f" success ({self.frames[i]} frames {w}x{h} at {self.fps[i]:.2f} FPS)")
  238. self.threads[i].start()
  239. print('') # newline
  240. # check for common shapes
  241. s = np.stack([letterbox(x, self.img_size, stride=self.stride)[0].shape for x in self.imgs], 0) # shapes
  242. self.rect = np.unique(s, axis=0).shape[0] == 1 # rect inference if all shapes equal
  243. if not self.rect:
  244. print('WARNING: Different stream shapes detected. For optimal performance supply similarly-shaped streams.')
  245. def update(self, i, cap):
  246. # Read stream `i` frames in daemon thread
  247. n, f = 0, self.frames[i]
  248. while cap.isOpened() and n < f:
  249. n += 1
  250. # _, self.imgs[index] = cap.read()
  251. cap.grab()
  252. if n % 4: # read every 4th frame
  253. success, im = cap.retrieve()
  254. self.imgs[i] = im if success else self.imgs[i] * 0
  255. time.sleep(1 / self.fps[i]) # wait time
  256. def __iter__(self):
  257. self.count = -1
  258. return self
  259. def __next__(self):
  260. self.count += 1
  261. if not all(x.is_alive() for x in self.threads) or cv2.waitKey(1) == ord('q'): # q to quit
  262. cv2.destroyAllWindows()
  263. raise StopIteration
  264. # Letterbox
  265. img0 = self.imgs.copy()
  266. img = [letterbox(x, self.img_size, auto=self.rect, stride=self.stride)[0] for x in img0]
  267. # Stack
  268. img = np.stack(img, 0)
  269. # Convert
  270. img = img[:, :, :, ::-1].transpose(0, 3, 1, 2) # BGR to RGB, to bsx3x416x416
  271. img = np.ascontiguousarray(img)
  272. return self.sources, img, img0, None
  273. def __len__(self):
  274. return 0 # 1E12 frames = 32 streams at 30 FPS for 30 years
  275. def img2label_paths(img_paths):
  276. # Define label paths as a function of image paths
  277. sa, sb = os.sep + 'images' + os.sep, os.sep + 'labels' + os.sep # /images/, /labels/ substrings
  278. return ['txt'.join(x.replace(sa, sb, 1).rsplit(x.split('.')[-1], 1)) for x in img_paths]
  279. class LoadImagesAndLabels(Dataset): # for training/testing
  280. def __init__(self, path, img_size=640, batch_size=16, augment=False, hyp=None, rect=False, image_weights=False,
  281. cache_images=False, single_cls=False, stride=32, pad=0.0, prefix=''):
  282. self.img_size = img_size
  283. self.augment = augment
  284. self.hyp = hyp
  285. self.image_weights = image_weights
  286. self.rect = False if image_weights else rect
  287. self.mosaic = self.augment and not self.rect # load 4 images at a time into a mosaic (only during training)
  288. self.mosaic_border = [-img_size // 2, -img_size // 2]
  289. self.stride = stride
  290. self.path = path
  291. try:
  292. f = [] # image files
  293. for p in path if isinstance(path, list) else [path]:
  294. p = Path(p) # os-agnostic
  295. if p.is_dir(): # dir
  296. f += glob.glob(str(p / '**' / '*.*'), recursive=True)
  297. # f = list(p.rglob('**/*.*')) # pathlib
  298. elif p.is_file(): # file
  299. with open(p, 'r') as t:
  300. t = t.read().strip().splitlines()
  301. parent = str(p.parent) + os.sep
  302. f += [x.replace('./', parent) if x.startswith('./') else x for x in t] # local to global path
  303. # f += [p.parent / x.lstrip(os.sep) for x in t] # local to global path (pathlib)
  304. else:
  305. raise Exception(f'{prefix}{p} does not exist')
  306. self.img_files = sorted([x.replace('/', os.sep) for x in f if x.split('.')[-1].lower() in img_formats])
  307. # self.img_files = sorted([x for x in f if x.suffix[1:].lower() in img_formats]) # pathlib
  308. assert self.img_files, f'{prefix}No images found'
  309. except Exception as e:
  310. raise Exception(f'{prefix}Error loading data from {path}: {e}\nSee {help_url}')
  311. # Check cache
  312. self.label_files = img2label_paths(self.img_files) # labels
  313. cache_path = (p if p.is_file() else Path(self.label_files[0]).parent).with_suffix('.cache') # cached labels
  314. if cache_path.is_file():
  315. cache, exists = torch.load(cache_path), True # load
  316. if cache['hash'] != get_hash(self.label_files + self.img_files) or 'version' not in cache: # changed
  317. cache, exists = self.cache_labels(cache_path, prefix), False # re-cache
  318. else:
  319. cache, exists = self.cache_labels(cache_path, prefix), False # cache
  320. # Display cache
  321. nf, nm, ne, nc, n = cache.pop('results') # found, missing, empty, corrupted, total
  322. if exists:
  323. d = f"Scanning '{cache_path}' images and labels... {nf} found, {nm} missing, {ne} empty, {nc} corrupted"
  324. tqdm(None, desc=prefix + d, total=n, initial=n) # display cache results
  325. assert nf > 0 or not augment, f'{prefix}No labels in {cache_path}. Can not train without labels. See {help_url}'
  326. # Read cache
  327. cache.pop('hash') # remove hash
  328. cache.pop('version') # remove version
  329. labels, shapes, self.segments = zip(*cache.values())
  330. self.labels = list(labels)
  331. self.shapes = np.array(shapes, dtype=np.float64)
  332. self.img_files = list(cache.keys()) # update
  333. self.label_files = img2label_paths(cache.keys()) # update
  334. if single_cls:
  335. for x in self.labels:
  336. x[:, 0] = 0
  337. n = len(shapes) # number of images
  338. bi = np.floor(np.arange(n) / batch_size).astype(np.int) # batch index
  339. nb = bi[-1] + 1 # number of batches
  340. self.batch = bi # batch index of image
  341. self.n = n
  342. self.indices = range(n)
  343. # Rectangular Training
  344. if self.rect:
  345. # Sort by aspect ratio
  346. s = self.shapes # wh
  347. ar = s[:, 1] / s[:, 0] # aspect ratio
  348. irect = ar.argsort()
  349. self.img_files = [self.img_files[i] for i in irect]
  350. self.label_files = [self.label_files[i] for i in irect]
  351. self.labels = [self.labels[i] for i in irect]
  352. self.shapes = s[irect] # wh
  353. ar = ar[irect]
  354. # Set training image shapes
  355. shapes = [[1, 1]] * nb
  356. for i in range(nb):
  357. ari = ar[bi == i]
  358. mini, maxi = ari.min(), ari.max()
  359. if maxi < 1:
  360. shapes[i] = [maxi, 1]
  361. elif mini > 1:
  362. shapes[i] = [1, 1 / mini]
  363. self.batch_shapes = np.ceil(np.array(shapes) * img_size / stride + pad).astype(np.int) * stride
  364. # Cache images into memory for faster training (WARNING: large datasets may exceed system RAM)
  365. self.imgs = [None] * n
  366. if cache_images:
  367. gb = 0 # Gigabytes of cached images
  368. self.img_hw0, self.img_hw = [None] * n, [None] * n
  369. results = ThreadPool(8).imap(lambda x: load_image(*x), zip(repeat(self), range(n))) # 8 threads
  370. pbar = tqdm(enumerate(results), total=n)
  371. for i, x in pbar:
  372. self.imgs[i], self.img_hw0[i], self.img_hw[i] = x # img, hw_original, hw_resized = load_image(self, i)
  373. gb += self.imgs[i].nbytes
  374. pbar.desc = f'{prefix}Caching images ({gb / 1E9:.1f}GB)'
  375. pbar.close()
  376. def cache_labels(self, path=Path('./labels.cache'), prefix=''):
  377. # Cache dataset labels, check images and read shapes
  378. x = {} # dict
  379. nm, nf, ne, nc = 0, 0, 0, 0 # number missing, found, empty, duplicate
  380. pbar = tqdm(zip(self.img_files, self.label_files), desc='Scanning images', total=len(self.img_files))
  381. for i, (im_file, lb_file) in enumerate(pbar):
  382. try:
  383. # verify images
  384. im = Image.open(im_file)
  385. im.verify() # PIL verify
  386. shape = exif_size(im) # image size
  387. segments = [] # instance segments
  388. assert (shape[0] > 9) & (shape[1] > 9), f'image size {shape} <10 pixels'
  389. assert im.format.lower() in img_formats, f'invalid image format {im.format}'
  390. # verify labels
  391. if os.path.isfile(lb_file):
  392. nf += 1 # label found
  393. with open(lb_file, 'r') as f:
  394. l = [x.split() for x in f.read().strip().splitlines()]
  395. if any([len(x) > 8 for x in l]): # is segment
  396. classes = np.array([x[0] for x in l], dtype=np.float32)
  397. segments = [np.array(x[1:], dtype=np.float32).reshape(-1, 2) for x in l] # (cls, xy1...)
  398. l = np.concatenate((classes.reshape(-1, 1), segments2boxes(segments)), 1) # (cls, xywh)
  399. l = np.array(l, dtype=np.float32)
  400. if len(l):
  401. assert l.shape[1] == 5, 'labels require 5 columns each'
  402. assert (l >= 0).all(), 'negative labels'
  403. assert (l[:, 1:] <= 1).all(), 'non-normalized or out of bounds coordinate labels'
  404. assert np.unique(l, axis=0).shape[0] == l.shape[0], 'duplicate labels'
  405. else:
  406. ne += 1 # label empty
  407. l = np.zeros((0, 5), dtype=np.float32)
  408. else:
  409. nm += 1 # label missing
  410. l = np.zeros((0, 5), dtype=np.float32)
  411. x[im_file] = [l, shape, segments]
  412. except Exception as e:
  413. nc += 1
  414. logging.info(f'{prefix}WARNING: Ignoring corrupted image and/or label {im_file}: {e}')
  415. pbar.desc = f"{prefix}Scanning '{path.parent / path.stem}' images and labels... " \
  416. f"{nf} found, {nm} missing, {ne} empty, {nc} corrupted"
  417. pbar.close()
  418. if nf == 0:
  419. logging.info(f'{prefix}WARNING: No labels found in {path}. See {help_url}')
  420. x['hash'] = get_hash(self.label_files + self.img_files)
  421. x['results'] = nf, nm, ne, nc, i + 1
  422. x['version'] = 0.1 # cache version
  423. try:
  424. torch.save(x, path) # save for next time
  425. logging.info(f'{prefix}New cache created: {path}')
  426. except Exception as e:
  427. logging.info(f'{prefix}WARNING: Cache directory {path.parent} is not writeable: {e}') # path not writeable
  428. return x
  429. def __len__(self):
  430. return len(self.img_files)
  431. # def __iter__(self):
  432. # self.count = -1
  433. # print('ran dataset iter')
  434. # #self.shuffled_vector = np.random.permutation(self.nF) if self.augment else np.arange(self.nF)
  435. # return self
  436. def __getitem__(self, index):
  437. index = self.indices[index] # linear, shuffled, or image_weights
  438. hyp = self.hyp
  439. mosaic = self.mosaic and random.random() < hyp['mosaic']
  440. if mosaic:
  441. # Load mosaic
  442. img, labels = load_mosaic(self, index)
  443. shapes = None
  444. # MixUp https://arxiv.org/pdf/1710.09412.pdf
  445. if random.random() < hyp['mixup']:
  446. img2, labels2 = load_mosaic(self, random.randint(0, self.n - 1))
  447. r = np.random.beta(8.0, 8.0) # mixup ratio, alpha=beta=8.0
  448. img = (img * r + img2 * (1 - r)).astype(np.uint8)
  449. labels = np.concatenate((labels, labels2), 0)
  450. else:
  451. # Load image
  452. img, (h0, w0), (h, w) = load_image(self, index)
  453. # Letterbox
  454. shape = self.batch_shapes[self.batch[index]] if self.rect else self.img_size # final letterboxed shape
  455. img, ratio, pad = letterbox(img, shape, auto=False, scaleup=self.augment)
  456. shapes = (h0, w0), ((h / h0, w / w0), pad) # for COCO mAP rescaling
  457. labels = self.labels[index].copy()
  458. if labels.size: # normalized xywh to pixel xyxy format
  459. labels[:, 1:] = xywhn2xyxy(labels[:, 1:], ratio[0] * w, ratio[1] * h, padw=pad[0], padh=pad[1])
  460. if self.augment:
  461. # Augment imagespace
  462. if not mosaic:
  463. img, labels = random_perspective(img, labels,
  464. degrees=hyp['degrees'],
  465. translate=hyp['translate'],
  466. scale=hyp['scale'],
  467. shear=hyp['shear'],
  468. perspective=hyp['perspective'])
  469. # Augment colorspace
  470. augment_hsv(img, hgain=hyp['hsv_h'], sgain=hyp['hsv_s'], vgain=hyp['hsv_v'])
  471. # Apply cutouts
  472. # if random.random() < 0.9:
  473. # labels = cutout(img, labels)
  474. nL = len(labels) # number of labels
  475. if nL:
  476. labels[:, 1:5] = xyxy2xywh(labels[:, 1:5]) # convert xyxy to xywh
  477. labels[:, [2, 4]] /= img.shape[0] # normalized height 0-1
  478. labels[:, [1, 3]] /= img.shape[1] # normalized width 0-1
  479. if self.augment:
  480. # flip up-down
  481. if random.random() < hyp['flipud']:
  482. img = np.flipud(img)
  483. if nL:
  484. labels[:, 2] = 1 - labels[:, 2]
  485. # flip left-right
  486. if random.random() < hyp['fliplr']:
  487. img = np.fliplr(img)
  488. if nL:
  489. labels[:, 1] = 1 - labels[:, 1]
  490. labels_out = torch.zeros((nL, 6))
  491. if nL:
  492. labels_out[:, 1:] = torch.from_numpy(labels)
  493. # Convert
  494. img = img[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
  495. img = np.ascontiguousarray(img)
  496. return torch.from_numpy(img), labels_out, self.img_files[index], shapes
  497. @staticmethod
  498. def collate_fn(batch):
  499. img, label, path, shapes = zip(*batch) # transposed
  500. for i, l in enumerate(label):
  501. l[:, 0] = i # add target image index for build_targets()
  502. return torch.stack(img, 0), torch.cat(label, 0), path, shapes
  503. @staticmethod
  504. def collate_fn4(batch):
  505. img, label, path, shapes = zip(*batch) # transposed
  506. n = len(shapes) // 4
  507. img4, label4, path4, shapes4 = [], [], path[:n], shapes[:n]
  508. ho = torch.tensor([[0., 0, 0, 1, 0, 0]])
  509. wo = torch.tensor([[0., 0, 1, 0, 0, 0]])
  510. s = torch.tensor([[1, 1, .5, .5, .5, .5]]) # scale
  511. for i in range(n): # zidane torch.zeros(16,3,720,1280) # BCHW
  512. i *= 4
  513. if random.random() < 0.5:
  514. im = F.interpolate(img[i].unsqueeze(0).float(), scale_factor=2., mode='bilinear', align_corners=False)[
  515. 0].type(img[i].type())
  516. l = label[i]
  517. else:
  518. im = torch.cat((torch.cat((img[i], img[i + 1]), 1), torch.cat((img[i + 2], img[i + 3]), 1)), 2)
  519. l = torch.cat((label[i], label[i + 1] + ho, label[i + 2] + wo, label[i + 3] + ho + wo), 0) * s
  520. img4.append(im)
  521. label4.append(l)
  522. for i, l in enumerate(label4):
  523. l[:, 0] = i # add target image index for build_targets()
  524. return torch.stack(img4, 0), torch.cat(label4, 0), path4, shapes4
  525. # Ancillary functions --------------------------------------------------------------------------------------------------
  526. def load_image(self, index):
  527. # loads 1 image from dataset, returns img, original hw, resized hw
  528. img = self.imgs[index]
  529. if img is None: # not cached
  530. path = self.img_files[index]
  531. img = cv2.imread(path) # BGR
  532. assert img is not None, 'Image Not Found ' + path
  533. h0, w0 = img.shape[:2] # orig hw
  534. r = self.img_size / max(h0, w0) # ratio
  535. if r != 1: # if sizes are not equal
  536. img = cv2.resize(img, (int(w0 * r), int(h0 * r)),
  537. interpolation=cv2.INTER_AREA if r < 1 and not self.augment else cv2.INTER_LINEAR)
  538. return img, (h0, w0), img.shape[:2] # img, hw_original, hw_resized
  539. else:
  540. return self.imgs[index], self.img_hw0[index], self.img_hw[index] # img, hw_original, hw_resized
  541. def augment_hsv(img, hgain=0.5, sgain=0.5, vgain=0.5):
  542. r = np.random.uniform(-1, 1, 3) * [hgain, sgain, vgain] + 1 # random gains
  543. hue, sat, val = cv2.split(cv2.cvtColor(img, cv2.COLOR_BGR2HSV))
  544. dtype = img.dtype # uint8
  545. x = np.arange(0, 256, dtype=np.int16)
  546. lut_hue = ((x * r[0]) % 180).astype(dtype)
  547. lut_sat = np.clip(x * r[1], 0, 255).astype(dtype)
  548. lut_val = np.clip(x * r[2], 0, 255).astype(dtype)
  549. img_hsv = cv2.merge((cv2.LUT(hue, lut_hue), cv2.LUT(sat, lut_sat), cv2.LUT(val, lut_val))).astype(dtype)
  550. cv2.cvtColor(img_hsv, cv2.COLOR_HSV2BGR, dst=img) # no return needed
  551. def hist_equalize(img, clahe=True, bgr=False):
  552. # Equalize histogram on BGR image 'img' with img.shape(n,m,3) and range 0-255
  553. yuv = cv2.cvtColor(img, cv2.COLOR_BGR2YUV if bgr else cv2.COLOR_RGB2YUV)
  554. if clahe:
  555. c = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
  556. yuv[:, :, 0] = c.apply(yuv[:, :, 0])
  557. else:
  558. yuv[:, :, 0] = cv2.equalizeHist(yuv[:, :, 0]) # equalize Y channel histogram
  559. return cv2.cvtColor(yuv, cv2.COLOR_YUV2BGR if bgr else cv2.COLOR_YUV2RGB) # convert YUV image to RGB
  560. def load_mosaic(self, index):
  561. # loads images in a 4-mosaic
  562. labels4, segments4 = [], []
  563. s = self.img_size
  564. yc, xc = [int(random.uniform(-x, 2 * s + x)) for x in self.mosaic_border] # mosaic center x, y
  565. indices = [index] + random.choices(self.indices, k=3) # 3 additional image indices
  566. for i, index in enumerate(indices):
  567. # Load image
  568. img, _, (h, w) = load_image(self, index)
  569. # place img in img4
  570. if i == 0: # top left
  571. img4 = np.full((s * 2, s * 2, img.shape[2]), 114, dtype=np.uint8) # base image with 4 tiles
  572. x1a, y1a, x2a, y2a = max(xc - w, 0), max(yc - h, 0), xc, yc # xmin, ymin, xmax, ymax (large image)
  573. x1b, y1b, x2b, y2b = w - (x2a - x1a), h - (y2a - y1a), w, h # xmin, ymin, xmax, ymax (small image)
  574. elif i == 1: # top right
  575. x1a, y1a, x2a, y2a = xc, max(yc - h, 0), min(xc + w, s * 2), yc
  576. x1b, y1b, x2b, y2b = 0, h - (y2a - y1a), min(w, x2a - x1a), h
  577. elif i == 2: # bottom left
  578. x1a, y1a, x2a, y2a = max(xc - w, 0), yc, xc, min(s * 2, yc + h)
  579. x1b, y1b, x2b, y2b = w - (x2a - x1a), 0, w, min(y2a - y1a, h)
  580. elif i == 3: # bottom right
  581. x1a, y1a, x2a, y2a = xc, yc, min(xc + w, s * 2), min(s * 2, yc + h)
  582. x1b, y1b, x2b, y2b = 0, 0, min(w, x2a - x1a), min(y2a - y1a, h)
  583. img4[y1a:y2a, x1a:x2a] = img[y1b:y2b, x1b:x2b] # img4[ymin:ymax, xmin:xmax]
  584. padw = x1a - x1b
  585. padh = y1a - y1b
  586. # Labels
  587. labels, segments = self.labels[index].copy(), self.segments[index].copy()
  588. if labels.size:
  589. labels[:, 1:] = xywhn2xyxy(labels[:, 1:], w, h, padw, padh) # normalized xywh to pixel xyxy format
  590. segments = [xyn2xy(x, w, h, padw, padh) for x in segments]
  591. labels4.append(labels)
  592. segments4.extend(segments)
  593. # Concat/clip labels
  594. labels4 = np.concatenate(labels4, 0)
  595. for x in (labels4[:, 1:], *segments4):
  596. np.clip(x, 0, 2 * s, out=x) # clip when using random_perspective()
  597. # img4, labels4 = replicate(img4, labels4) # replicate
  598. # Augment
  599. img4, labels4 = random_perspective(img4, labels4, segments4,
  600. degrees=self.hyp['degrees'],
  601. translate=self.hyp['translate'],
  602. scale=self.hyp['scale'],
  603. shear=self.hyp['shear'],
  604. perspective=self.hyp['perspective'],
  605. border=self.mosaic_border) # border to remove
  606. return img4, labels4
  607. def load_mosaic9(self, index):
  608. # loads images in a 9-mosaic
  609. labels9, segments9 = [], []
  610. s = self.img_size
  611. indices = [index] + random.choices(self.indices, k=8) # 8 additional image indices
  612. for i, index in enumerate(indices):
  613. # Load image
  614. img, _, (h, w) = load_image(self, index)
  615. # place img in img9
  616. if i == 0: # center
  617. img9 = np.full((s * 3, s * 3, img.shape[2]), 114, dtype=np.uint8) # base image with 4 tiles
  618. h0, w0 = h, w
  619. c = s, s, s + w, s + h # xmin, ymin, xmax, ymax (base) coordinates
  620. elif i == 1: # top
  621. c = s, s - h, s + w, s
  622. elif i == 2: # top right
  623. c = s + wp, s - h, s + wp + w, s
  624. elif i == 3: # right
  625. c = s + w0, s, s + w0 + w, s + h
  626. elif i == 4: # bottom right
  627. c = s + w0, s + hp, s + w0 + w, s + hp + h
  628. elif i == 5: # bottom
  629. c = s + w0 - w, s + h0, s + w0, s + h0 + h
  630. elif i == 6: # bottom left
  631. c = s + w0 - wp - w, s + h0, s + w0 - wp, s + h0 + h
  632. elif i == 7: # left
  633. c = s - w, s + h0 - h, s, s + h0
  634. elif i == 8: # top left
  635. c = s - w, s + h0 - hp - h, s, s + h0 - hp
  636. padx, pady = c[:2]
  637. x1, y1, x2, y2 = [max(x, 0) for x in c] # allocate coords
  638. # Labels
  639. labels, segments = self.labels[index].copy(), self.segments[index].copy()
  640. if labels.size:
  641. labels[:, 1:] = xywhn2xyxy(labels[:, 1:], w, h, padx, pady) # normalized xywh to pixel xyxy format
  642. segments = [xyn2xy(x, w, h, padx, pady) for x in segments]
  643. labels9.append(labels)
  644. segments9.extend(segments)
  645. # Image
  646. img9[y1:y2, x1:x2] = img[y1 - pady:, x1 - padx:] # img9[ymin:ymax, xmin:xmax]
  647. hp, wp = h, w # height, width previous
  648. # Offset
  649. yc, xc = [int(random.uniform(0, s)) for _ in self.mosaic_border] # mosaic center x, y
  650. img9 = img9[yc:yc + 2 * s, xc:xc + 2 * s]
  651. # Concat/clip labels
  652. labels9 = np.concatenate(labels9, 0)
  653. labels9[:, [1, 3]] -= xc
  654. labels9[:, [2, 4]] -= yc
  655. c = np.array([xc, yc]) # centers
  656. segments9 = [x - c for x in segments9]
  657. for x in (labels9[:, 1:], *segments9):
  658. np.clip(x, 0, 2 * s, out=x) # clip when using random_perspective()
  659. # img9, labels9 = replicate(img9, labels9) # replicate
  660. # Augment
  661. img9, labels9 = random_perspective(img9, labels9, segments9,
  662. degrees=self.hyp['degrees'],
  663. translate=self.hyp['translate'],
  664. scale=self.hyp['scale'],
  665. shear=self.hyp['shear'],
  666. perspective=self.hyp['perspective'],
  667. border=self.mosaic_border) # border to remove
  668. return img9, labels9
  669. def replicate(img, labels):
  670. # Replicate labels
  671. h, w = img.shape[:2]
  672. boxes = labels[:, 1:].astype(int)
  673. x1, y1, x2, y2 = boxes.T
  674. s = ((x2 - x1) + (y2 - y1)) / 2 # side length (pixels)
  675. for i in s.argsort()[:round(s.size * 0.5)]: # smallest indices
  676. x1b, y1b, x2b, y2b = boxes[i]
  677. bh, bw = y2b - y1b, x2b - x1b
  678. yc, xc = int(random.uniform(0, h - bh)), int(random.uniform(0, w - bw)) # offset x, y
  679. x1a, y1a, x2a, y2a = [xc, yc, xc + bw, yc + bh]
  680. img[y1a:y2a, x1a:x2a] = img[y1b:y2b, x1b:x2b] # img4[ymin:ymax, xmin:xmax]
  681. labels = np.append(labels, [[labels[i, 0], x1a, y1a, x2a, y2a]], axis=0)
  682. return img, labels
  683. def letterbox(img, new_shape=(640, 640), color=(114, 114, 114), auto=True, scaleFill=False, scaleup=True, stride=32):
  684. # Resize and pad image while meeting stride-multiple constraints
  685. shape = img.shape[:2] # current shape [height, width]
  686. if isinstance(new_shape, int):
  687. new_shape = (new_shape, new_shape)
  688. # Scale ratio (new / old)
  689. r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
  690. if not scaleup: # only scale down, do not scale up (for better test mAP)
  691. r = min(r, 1.0)
  692. # Compute padding
  693. ratio = r, r # width, height ratios
  694. new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
  695. dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding
  696. if auto: # minimum rectangle
  697. dw, dh = np.mod(dw, stride), np.mod(dh, stride) # wh padding
  698. elif scaleFill: # stretch
  699. dw, dh = 0.0, 0.0
  700. new_unpad = (new_shape[1], new_shape[0])
  701. ratio = new_shape[1] / shape[1], new_shape[0] / shape[0] # width, height ratios
  702. dw /= 2 # divide padding into 2 sides
  703. dh /= 2
  704. if shape[::-1] != new_unpad: # resize
  705. img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR)
  706. top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
  707. left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
  708. img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) # add border
  709. return img, ratio, (dw, dh)
  710. def random_perspective(img, targets=(), segments=(), degrees=10, translate=.1, scale=.1, shear=10, perspective=0.0,
  711. border=(0, 0)):
  712. # torchvision.transforms.RandomAffine(degrees=(-10, 10), translate=(.1, .1), scale=(.9, 1.1), shear=(-10, 10))
  713. # targets = [cls, xyxy]
  714. height = img.shape[0] + border[0] * 2 # shape(h,w,c)
  715. width = img.shape[1] + border[1] * 2
  716. # Center
  717. C = np.eye(3)
  718. C[0, 2] = -img.shape[1] / 2 # x translation (pixels)
  719. C[1, 2] = -img.shape[0] / 2 # y translation (pixels)
  720. # Perspective
  721. P = np.eye(3)
  722. P[2, 0] = random.uniform(-perspective, perspective) # x perspective (about y)
  723. P[2, 1] = random.uniform(-perspective, perspective) # y perspective (about x)
  724. # Rotation and Scale
  725. R = np.eye(3)
  726. a = random.uniform(-degrees, degrees)
  727. # a += random.choice([-180, -90, 0, 90]) # add 90deg rotations to small rotations
  728. s = random.uniform(1 - scale, 1 + scale)
  729. # s = 2 ** random.uniform(-scale, scale)
  730. R[:2] = cv2.getRotationMatrix2D(angle=a, center=(0, 0), scale=s)
  731. # Shear
  732. S = np.eye(3)
  733. S[0, 1] = math.tan(random.uniform(-shear, shear) * math.pi / 180) # x shear (deg)
  734. S[1, 0] = math.tan(random.uniform(-shear, shear) * math.pi / 180) # y shear (deg)
  735. # Translation
  736. T = np.eye(3)
  737. T[0, 2] = random.uniform(0.5 - translate, 0.5 + translate) * width # x translation (pixels)
  738. T[1, 2] = random.uniform(0.5 - translate, 0.5 + translate) * height # y translation (pixels)
  739. # Combined rotation matrix
  740. M = T @ S @ R @ P @ C # order of operations (right to left) is IMPORTANT
  741. if (border[0] != 0) or (border[1] != 0) or (M != np.eye(3)).any(): # image changed
  742. if perspective:
  743. img = cv2.warpPerspective(img, M, dsize=(width, height), borderValue=(114, 114, 114))
  744. else: # affine
  745. img = cv2.warpAffine(img, M[:2], dsize=(width, height), borderValue=(114, 114, 114))
  746. # Visualize
  747. # import matplotlib.pyplot as plt
  748. # ax = plt.subplots(1, 2, figsize=(12, 6))[1].ravel()
  749. # ax[0].imshow(img[:, :, ::-1]) # base
  750. # ax[1].imshow(img2[:, :, ::-1]) # warped
  751. # Transform label coordinates
  752. n = len(targets)
  753. if n:
  754. use_segments = any(x.any() for x in segments)
  755. new = np.zeros((n, 4))
  756. if use_segments: # warp segments
  757. segments = resample_segments(segments) # upsample
  758. for i, segment in enumerate(segments):
  759. xy = np.ones((len(segment), 3))
  760. xy[:, :2] = segment
  761. xy = xy @ M.T # transform
  762. xy = xy[:, :2] / xy[:, 2:3] if perspective else xy[:, :2] # perspective rescale or affine
  763. # clip
  764. new[i] = segment2box(xy, width, height)
  765. else: # warp boxes
  766. xy = np.ones((n * 4, 3))
  767. xy[:, :2] = targets[:, [1, 2, 3, 4, 1, 4, 3, 2]].reshape(n * 4, 2) # x1y1, x2y2, x1y2, x2y1
  768. xy = xy @ M.T # transform
  769. xy = (xy[:, :2] / xy[:, 2:3] if perspective else xy[:, :2]).reshape(n, 8) # perspective rescale or affine
  770. # create new boxes
  771. x = xy[:, [0, 2, 4, 6]]
  772. y = xy[:, [1, 3, 5, 7]]
  773. new = np.concatenate((x.min(1), y.min(1), x.max(1), y.max(1))).reshape(4, n).T
  774. # clip
  775. new[:, [0, 2]] = new[:, [0, 2]].clip(0, width)
  776. new[:, [1, 3]] = new[:, [1, 3]].clip(0, height)
  777. # filter candidates
  778. i = box_candidates(box1=targets[:, 1:5].T * s, box2=new.T, area_thr=0.01 if use_segments else 0.10)
  779. targets = targets[i]
  780. targets[:, 1:5] = new[i]
  781. return img, targets
  782. def box_candidates(box1, box2, wh_thr=2, ar_thr=20, area_thr=0.1, eps=1e-16): # box1(4,n), box2(4,n)
  783. # Compute candidate boxes: box1 before augment, box2 after augment, wh_thr (pixels), aspect_ratio_thr, area_ratio
  784. w1, h1 = box1[2] - box1[0], box1[3] - box1[1]
  785. w2, h2 = box2[2] - box2[0], box2[3] - box2[1]
  786. ar = np.maximum(w2 / (h2 + eps), h2 / (w2 + eps)) # aspect ratio
  787. return (w2 > wh_thr) & (h2 > wh_thr) & (w2 * h2 / (w1 * h1 + eps) > area_thr) & (ar < ar_thr) # candidates
  788. def cutout(image, labels):
  789. # Applies image cutout augmentation https://arxiv.org/abs/1708.04552
  790. h, w = image.shape[:2]
  791. def bbox_ioa(box1, box2):
  792. # Returns the intersection over box2 area given box1, box2. box1 is 4, box2 is nx4. boxes are x1y1x2y2
  793. box2 = box2.transpose()
  794. # Get the coordinates of bounding boxes
  795. b1_x1, b1_y1, b1_x2, b1_y2 = box1[0], box1[1], box1[2], box1[3]
  796. b2_x1, b2_y1, b2_x2, b2_y2 = box2[0], box2[1], box2[2], box2[3]
  797. # Intersection area
  798. inter_area = (np.minimum(b1_x2, b2_x2) - np.maximum(b1_x1, b2_x1)).clip(0) * \
  799. (np.minimum(b1_y2, b2_y2) - np.maximum(b1_y1, b2_y1)).clip(0)
  800. # box2 area
  801. box2_area = (b2_x2 - b2_x1) * (b2_y2 - b2_y1) + 1e-16
  802. # Intersection over box2 area
  803. return inter_area / box2_area
  804. # create random masks
  805. scales = [0.5] * 1 + [0.25] * 2 + [0.125] * 4 + [0.0625] * 8 + [0.03125] * 16 # image size fraction
  806. for s in scales:
  807. mask_h = random.randint(1, int(h * s))
  808. mask_w = random.randint(1, int(w * s))
  809. # box
  810. xmin = max(0, random.randint(0, w) - mask_w // 2)
  811. ymin = max(0, random.randint(0, h) - mask_h // 2)
  812. xmax = min(w, xmin + mask_w)
  813. ymax = min(h, ymin + mask_h)
  814. # apply random color mask
  815. image[ymin:ymax, xmin:xmax] = [random.randint(64, 191) for _ in range(3)]
  816. # return unobscured labels
  817. if len(labels) and s > 0.03:
  818. box = np.array([xmin, ymin, xmax, ymax], dtype=np.float32)
  819. ioa = bbox_ioa(box, labels[:, 1:5]) # intersection over area
  820. labels = labels[ioa < 0.60] # remove >60% obscured labels
  821. return labels
  822. def create_folder(path='./new'):
  823. # Create folder
  824. if os.path.exists(path):
  825. shutil.rmtree(path) # delete output folder
  826. os.makedirs(path) # make new output folder
  827. def flatten_recursive(path='../coco128'):
  828. # Flatten a recursive directory by bringing all files to top level
  829. new_path = Path(path + '_flat')
  830. create_folder(new_path)
  831. for file in tqdm(glob.glob(str(Path(path)) + '/**/*.*', recursive=True)):
  832. shutil.copyfile(file, new_path / Path(file).name)
  833. def extract_boxes(path='../coco128/'): # from utils.datasets import *; extract_boxes('../coco128')
  834. # Convert detection dataset into classification dataset, with one directory per class
  835. path = Path(path) # images dir
  836. shutil.rmtree(path / 'classifier') if (path / 'classifier').is_dir() else None # remove existing
  837. files = list(path.rglob('*.*'))
  838. n = len(files) # number of files
  839. for im_file in tqdm(files, total=n):
  840. if im_file.suffix[1:] in img_formats:
  841. # image
  842. im = cv2.imread(str(im_file))[..., ::-1] # BGR to RGB
  843. h, w = im.shape[:2]
  844. # labels
  845. lb_file = Path(img2label_paths([str(im_file)])[0])
  846. if Path(lb_file).exists():
  847. with open(lb_file, 'r') as f:
  848. lb = np.array([x.split() for x in f.read().strip().splitlines()], dtype=np.float32) # labels
  849. for j, x in enumerate(lb):
  850. c = int(x[0]) # class
  851. f = (path / 'classifier') / f'{c}' / f'{path.stem}_{im_file.stem}_{j}.jpg' # new filename
  852. if not f.parent.is_dir():
  853. f.parent.mkdir(parents=True)
  854. b = x[1:] * [w, h, w, h] # box
  855. # b[2:] = b[2:].max() # rectangle to square
  856. b[2:] = b[2:] * 1.2 + 3 # pad
  857. b = xywh2xyxy(b.reshape(-1, 4)).ravel().astype(np.int)
  858. b[[0, 2]] = np.clip(b[[0, 2]], 0, w) # clip boxes outside of image
  859. b[[1, 3]] = np.clip(b[[1, 3]], 0, h)
  860. assert cv2.imwrite(str(f), im[b[1]:b[3], b[0]:b[2]]), f'box failure in {f}'
  861. def autosplit(path='../coco128', weights=(0.9, 0.1, 0.0), annotated_only=False):
  862. """ Autosplit a dataset into train/val/test splits and save path/autosplit_*.txt files
  863. Usage: from utils.datasets import *; autosplit('../coco128')
  864. Arguments
  865. path: Path to images directory
  866. weights: Train, val, test weights (list)
  867. annotated_only: Only use images with an annotated txt file
  868. """
  869. path = Path(path) # images dir
  870. files = sum([list(path.rglob(f"*.{img_ext}")) for img_ext in img_formats], []) # image files only
  871. n = len(files) # number of files
  872. indices = random.choices([0, 1, 2], weights=weights, k=n) # assign each image to a split
  873. txt = ['autosplit_train.txt', 'autosplit_val.txt', 'autosplit_test.txt'] # 3 txt files
  874. [(path / x).unlink() for x in txt if (path / x).exists()] # remove existing
  875. print(f'Autosplitting images from {path}' + ', using *.txt labeled images only' * annotated_only)
  876. for i, img in tqdm(zip(indices, files), total=n):
  877. if not annotated_only or Path(img2label_paths([str(img)])[0]).exists(): # check label
  878. with open(path / txt[i], 'a') as f:
  879. f.write(str(img) + '\n') # add image to txt file