123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329 |
- # Ultralytics YOLO 🚀, AGPL-3.0 license
- import contextlib
- from itertools import repeat
- from multiprocessing.pool import ThreadPool
- from pathlib import Path
- import cv2
- import numpy as np
- import torch
- import torchvision
- from tqdm import tqdm
- from ultralytics.utils import LOCAL_RANK, NUM_THREADS, TQDM_BAR_FORMAT, colorstr, is_dir_writeable
- from .augment import Compose, Format, Instances, LetterBox, classify_albumentations, classify_transforms, v8_transforms
- from .base import BaseDataset
- from .utils import HELP_URL, LOGGER, get_hash, img2label_paths, verify_image, verify_image_label
- # Ultralytics dataset *.cache version, >= 1.0.0 for YOLOv8
- DATASET_CACHE_VERSION = '1.0.3'
- class YOLODataset(BaseDataset):
- """
- Dataset class for loading object detection and/or segmentation labels in YOLO format.
- Args:
- data (dict, optional): A dataset YAML dictionary. Defaults to None.
- use_segments (bool, optional): If True, segmentation masks are used as labels. Defaults to False.
- use_keypoints (bool, optional): If True, keypoints are used as labels. Defaults to False.
- Returns:
- (torch.utils.data.Dataset): A PyTorch dataset object that can be used for training an object detection model.
- """
- def __init__(self, *args, data=None, use_segments=False, use_keypoints=False, **kwargs):
- self.use_segments = use_segments
- self.use_keypoints = use_keypoints
- self.data = data
- assert not (self.use_segments and self.use_keypoints), 'Can not use both segments and keypoints.'
- super().__init__(*args, **kwargs)
- def cache_labels(self, path=Path('./labels.cache')):
- """Cache dataset labels, check images and read shapes.
- Args:
- path (Path): path where to save the cache file (default: Path('./labels.cache')).
- Returns:
- (dict): labels.
- """
- x = {'labels': []}
- nm, nf, ne, nc, msgs = 0, 0, 0, 0, [] # number missing, found, empty, corrupt, messages
- desc = f'{self.prefix}Scanning {path.parent / path.stem}...'
- total = len(self.im_files)
- nkpt, ndim = self.data.get('kpt_shape', (0, 0))
- if self.use_keypoints and (nkpt <= 0 or ndim not in (2, 3)):
- raise ValueError("'kpt_shape' in data.yaml missing or incorrect. Should be a list with [number of "
- "keypoints, number of dims (2 for x,y or 3 for x,y,visible)], i.e. 'kpt_shape: [17, 3]'")
- with ThreadPool(NUM_THREADS) as pool:
- results = pool.imap(func=verify_image_label,
- iterable=zip(self.im_files, self.label_files, repeat(self.prefix),
- repeat(self.use_keypoints), repeat(len(self.data['names'])), repeat(nkpt),
- repeat(ndim)))
- pbar = tqdm(results, desc=desc, total=total, bar_format=TQDM_BAR_FORMAT)
- for im_file, lb, shape, segments, keypoint, nm_f, nf_f, ne_f, nc_f, msg in pbar:
- nm += nm_f
- nf += nf_f
- ne += ne_f
- nc += nc_f
- if im_file:
- x['labels'].append(
- dict(
- im_file=im_file,
- shape=shape,
- cls=lb[:, 0:1], # n, 1
- bboxes=lb[:, 1:], # n, 4
- segments=segments,
- keypoints=keypoint,
- normalized=True,
- bbox_format='xywh'))
- if msg:
- msgs.append(msg)
- pbar.desc = f'{desc} {nf} images, {nm + ne} backgrounds, {nc} corrupt'
- pbar.close()
- if msgs:
- LOGGER.info('\n'.join(msgs))
- if nf == 0:
- LOGGER.warning(f'{self.prefix}WARNING ⚠️ No labels found in {path}. {HELP_URL}')
- x['hash'] = get_hash(self.label_files + self.im_files)
- x['results'] = nf, nm, ne, nc, len(self.im_files)
- x['msgs'] = msgs # warnings
- save_dataset_cache_file(self.prefix, path, x)
- return x
- def get_labels(self):
- """Returns dictionary of labels for YOLO training."""
- self.label_files = img2label_paths(self.im_files)
- cache_path = Path(self.label_files[0]).parent.with_suffix('.cache')
- try:
- cache, exists = load_dataset_cache_file(cache_path), True # attempt to load a *.cache file
- assert cache['version'] == DATASET_CACHE_VERSION # matches current version
- assert cache['hash'] == get_hash(self.label_files + self.im_files) # identical hash
- except (FileNotFoundError, AssertionError, AttributeError):
- cache, exists = self.cache_labels(cache_path), False # run cache ops
- # Display cache
- nf, nm, ne, nc, n = cache.pop('results') # found, missing, empty, corrupt, total
- if exists and LOCAL_RANK in (-1, 0):
- d = f'Scanning {cache_path}... {nf} images, {nm + ne} backgrounds, {nc} corrupt'
- tqdm(None, desc=self.prefix + d, total=n, initial=n, bar_format=TQDM_BAR_FORMAT) # display results
- if cache['msgs']:
- LOGGER.info('\n'.join(cache['msgs'])) # display warnings
- if nf == 0: # number of labels found
- raise FileNotFoundError(f'{self.prefix}No labels found in {cache_path}, can not start training. {HELP_URL}')
- # Read cache
- [cache.pop(k) for k in ('hash', 'version', 'msgs')] # remove items
- labels = cache['labels']
- assert len(labels), f'No valid labels found, please check your dataset. {HELP_URL}'
- self.im_files = [lb['im_file'] for lb in labels] # update im_files
- # Check if the dataset is all boxes or all segments
- lengths = ((len(lb['cls']), len(lb['bboxes']), len(lb['segments'])) for lb in labels)
- len_cls, len_boxes, len_segments = (sum(x) for x in zip(*lengths))
- if len_segments and len_boxes != len_segments:
- LOGGER.warning(
- f'WARNING ⚠️ Box and segment counts should be equal, but got len(segments) = {len_segments}, '
- f'len(boxes) = {len_boxes}. To resolve this only boxes will be used and all segments will be removed. '
- 'To avoid this please supply either a detect or segment dataset, not a detect-segment mixed dataset.')
- for lb in labels:
- lb['segments'] = []
- if len_cls == 0:
- raise ValueError(f'All labels empty in {cache_path}, can not start training without labels. {HELP_URL}')
- return labels
- # TODO: use hyp config to set all these augmentations
- def build_transforms(self, hyp=None):
- """Builds and appends transforms to the list."""
- if self.augment:
- hyp.mosaic = hyp.mosaic if self.augment and not self.rect else 0.0
- hyp.mixup = hyp.mixup if self.augment and not self.rect else 0.0
- transforms = v8_transforms(self, self.imgsz, hyp)
- else:
- transforms = Compose([LetterBox(new_shape=(self.imgsz, self.imgsz), scaleup=False)])
- transforms.append(
- Format(bbox_format='xywh',
- normalize=True,
- return_mask=self.use_segments,
- return_keypoint=self.use_keypoints,
- batch_idx=True,
- mask_ratio=hyp.mask_ratio,
- mask_overlap=hyp.overlap_mask))
- return transforms
- def close_mosaic(self, hyp):
- """Sets mosaic, copy_paste and mixup options to 0.0 and builds transformations."""
- hyp.mosaic = 0.0 # set mosaic ratio=0.0
- hyp.copy_paste = 0.0 # keep the same behavior as previous v8 close-mosaic
- hyp.mixup = 0.0 # keep the same behavior as previous v8 close-mosaic
- self.transforms = self.build_transforms(hyp)
- def update_labels_info(self, label):
- """custom your label format here."""
- # NOTE: cls is not with bboxes now, classification and semantic segmentation need an independent cls label
- # we can make it also support classification and semantic segmentation by add or remove some dict keys there.
- bboxes = label.pop('bboxes')
- segments = label.pop('segments')
- keypoints = label.pop('keypoints', None)
- bbox_format = label.pop('bbox_format')
- normalized = label.pop('normalized')
- label['instances'] = Instances(bboxes, segments, keypoints, bbox_format=bbox_format, normalized=normalized)
- return label
- @staticmethod
- def collate_fn(batch):
- """Collates data samples into batches."""
- new_batch = {}
- keys = batch[0].keys()
- values = list(zip(*[list(b.values()) for b in batch]))
- for i, k in enumerate(keys):
- value = values[i]
- if k == 'img':
- value = torch.stack(value, 0)
- if k in ['masks', 'keypoints', 'bboxes', 'cls']:
- value = torch.cat(value, 0)
- new_batch[k] = value
- new_batch['batch_idx'] = list(new_batch['batch_idx'])
- for i in range(len(new_batch['batch_idx'])):
- new_batch['batch_idx'][i] += i # add target image index for build_targets()
- new_batch['batch_idx'] = torch.cat(new_batch['batch_idx'], 0)
- return new_batch
- # Classification dataloaders -------------------------------------------------------------------------------------------
- class ClassificationDataset(torchvision.datasets.ImageFolder):
- """
- YOLO Classification Dataset.
- Args:
- root (str): Dataset path.
- Attributes:
- cache_ram (bool): True if images should be cached in RAM, False otherwise.
- cache_disk (bool): True if images should be cached on disk, False otherwise.
- samples (list): List of samples containing file, index, npy, and im.
- torch_transforms (callable): torchvision transforms applied to the dataset.
- album_transforms (callable, optional): Albumentations transforms applied to the dataset if augment is True.
- """
- def __init__(self, root, args, augment=False, cache=False, prefix=''):
- """
- Initialize YOLO object with root, image size, augmentations, and cache settings.
- Args:
- root (str): Dataset path.
- args (Namespace): Argument parser containing dataset related settings.
- augment (bool, optional): True if dataset should be augmented, False otherwise. Defaults to False.
- cache (bool | str | optional): Cache setting, can be True, False, 'ram' or 'disk'. Defaults to False.
- """
- super().__init__(root=root)
- if augment and args.fraction < 1.0: # reduce training fraction
- self.samples = self.samples[:round(len(self.samples) * args.fraction)]
- self.prefix = colorstr(f'{prefix}: ') if prefix else ''
- self.cache_ram = cache is True or cache == 'ram'
- self.cache_disk = cache == 'disk'
- self.samples = self.verify_images() # filter out bad images
- self.samples = [list(x) + [Path(x[0]).with_suffix('.npy'), None] for x in self.samples] # file, index, npy, im
- self.torch_transforms = classify_transforms(args.imgsz)
- self.album_transforms = classify_albumentations(
- augment=augment,
- size=args.imgsz,
- scale=(1.0 - args.scale, 1.0), # (0.08, 1.0)
- hflip=args.fliplr,
- vflip=args.flipud,
- hsv_h=args.hsv_h, # HSV-Hue augmentation (fraction)
- hsv_s=args.hsv_s, # HSV-Saturation augmentation (fraction)
- hsv_v=args.hsv_v, # HSV-Value augmentation (fraction)
- mean=(0.0, 0.0, 0.0), # IMAGENET_MEAN
- std=(1.0, 1.0, 1.0), # IMAGENET_STD
- auto_aug=False) if augment else None
- def __getitem__(self, i):
- """Returns subset of data and targets corresponding to given indices."""
- f, j, fn, im = self.samples[i] # filename, index, filename.with_suffix('.npy'), image
- if self.cache_ram and im is None:
- im = self.samples[i][3] = cv2.imread(f)
- elif self.cache_disk:
- if not fn.exists(): # load npy
- np.save(fn.as_posix(), cv2.imread(f))
- im = np.load(fn)
- else: # read image
- im = cv2.imread(f) # BGR
- if self.album_transforms:
- sample = self.album_transforms(image=cv2.cvtColor(im, cv2.COLOR_BGR2RGB))['image']
- else:
- sample = self.torch_transforms(im)
- return {'img': sample, 'cls': j}
- def __len__(self) -> int:
- return len(self.samples)
- def verify_images(self):
- """Verify all images in dataset."""
- desc = f'{self.prefix}Scanning {self.root}...'
- path = Path(self.root).with_suffix('.cache') # *.cache file path
- with contextlib.suppress(FileNotFoundError, AssertionError, AttributeError):
- cache = load_dataset_cache_file(path) # attempt to load a *.cache file
- assert cache['version'] == DATASET_CACHE_VERSION # matches current version
- assert cache['hash'] == get_hash([x[0] for x in self.samples]) # identical hash
- nf, nc, n, samples = cache.pop('results') # found, missing, empty, corrupt, total
- if LOCAL_RANK in (-1, 0):
- d = f'{desc} {nf} images, {nc} corrupt'
- tqdm(None, desc=d, total=n, initial=n, bar_format=TQDM_BAR_FORMAT)
- if cache['msgs']:
- LOGGER.info('\n'.join(cache['msgs'])) # display warnings
- return samples
- # Run scan if *.cache retrieval failed
- nf, nc, msgs, samples, x = 0, 0, [], [], {}
- with ThreadPool(NUM_THREADS) as pool:
- results = pool.imap(func=verify_image, iterable=zip(self.samples, repeat(self.prefix)))
- pbar = tqdm(results, desc=desc, total=len(self.samples), bar_format=TQDM_BAR_FORMAT)
- for sample, nf_f, nc_f, msg in pbar:
- if nf_f:
- samples.append(sample)
- if msg:
- msgs.append(msg)
- nf += nf_f
- nc += nc_f
- pbar.desc = f'{desc} {nf} images, {nc} corrupt'
- pbar.close()
- if msgs:
- LOGGER.info('\n'.join(msgs))
- x['hash'] = get_hash([x[0] for x in self.samples])
- x['results'] = nf, nc, len(samples), samples
- x['msgs'] = msgs # warnings
- save_dataset_cache_file(self.prefix, path, x)
- return samples
- def load_dataset_cache_file(path):
- """Load an Ultralytics *.cache dictionary from path."""
- import gc
- gc.disable() # reduce pickle load time https://github.com/ultralytics/ultralytics/pull/1585
- cache = np.load(str(path), allow_pickle=True).item() # load dict
- gc.enable()
- return cache
- def save_dataset_cache_file(prefix, path, x):
- """Save an Ultralytics dataset *.cache dictionary x to path."""
- x['version'] = DATASET_CACHE_VERSION # add cache version
- if is_dir_writeable(path.parent):
- if path.exists():
- path.unlink() # remove *.cache file if exists
- np.save(str(path), x) # save cache for next time
- path.with_suffix('.cache.npy').rename(path) # remove .npy suffix
- LOGGER.info(f'{prefix}New cache created: {path}')
- else:
- LOGGER.warning(f'{prefix}WARNING ⚠️ Cache directory {path.parent} is not writeable, cache not saved.')
- # TODO: support semantic segmentation
- class SemanticDataset(BaseDataset):
- def __init__(self):
- """Initialize a SemanticDataset object."""
- super().__init__()
|