validator.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323
  1. # Ultralytics YOLO 🚀, AGPL-3.0 license
  2. """
  3. Check a model's accuracy on a test or val split of a dataset.
  4. Usage:
  5. $ yolo mode=val model=yolov8n.pt data=coco128.yaml imgsz=640
  6. Usage - formats:
  7. $ yolo mode=val model=yolov8n.pt # PyTorch
  8. yolov8n.torchscript # TorchScript
  9. yolov8n.onnx # ONNX Runtime or OpenCV DNN with dnn=True
  10. yolov8n_openvino_model # OpenVINO
  11. yolov8n.engine # TensorRT
  12. yolov8n.mlpackage # CoreML (macOS-only)
  13. yolov8n_saved_model # TensorFlow SavedModel
  14. yolov8n.pb # TensorFlow GraphDef
  15. yolov8n.tflite # TensorFlow Lite
  16. yolov8n_edgetpu.tflite # TensorFlow Edge TPU
  17. yolov8n_paddle_model # PaddlePaddle
  18. """
  19. import json
  20. import time
  21. from pathlib import Path
  22. import numpy as np
  23. import torch
  24. from tqdm import tqdm
  25. from ultralytics.cfg import get_cfg
  26. from ultralytics.data.utils import check_cls_dataset, check_det_dataset
  27. from ultralytics.nn.autobackend import AutoBackend
  28. from ultralytics.utils import LOGGER, RANK, SETTINGS, TQDM_BAR_FORMAT, callbacks, colorstr, emojis
  29. from ultralytics.utils.checks import check_imgsz
  30. from ultralytics.utils.files import increment_path
  31. from ultralytics.utils.ops import Profile
  32. from ultralytics.utils.torch_utils import de_parallel, select_device, smart_inference_mode
  33. class BaseValidator:
  34. """
  35. BaseValidator
  36. A base class for creating validators.
  37. Attributes:
  38. args (SimpleNamespace): Configuration for the validator.
  39. dataloader (DataLoader): Dataloader to use for validation.
  40. pbar (tqdm): Progress bar to update during validation.
  41. model (nn.Module): Model to validate.
  42. data (dict): Data dictionary.
  43. device (torch.device): Device to use for validation.
  44. batch_i (int): Current batch index.
  45. training (bool): Whether the model is in training mode.
  46. names (dict): Class names.
  47. seen: Records the number of images seen so far during validation.
  48. stats: Placeholder for statistics during validation.
  49. confusion_matrix: Placeholder for a confusion matrix.
  50. nc: Number of classes.
  51. iouv: (torch.Tensor): IoU thresholds from 0.50 to 0.95 in spaces of 0.05.
  52. jdict (dict): Dictionary to store JSON validation results.
  53. speed (dict): Dictionary with keys 'preprocess', 'inference', 'loss', 'postprocess' and their respective
  54. batch processing times in milliseconds.
  55. save_dir (Path): Directory to save results.
  56. plots (dict): Dictionary to store plots for visualization.
  57. callbacks (dict): Dictionary to store various callback functions.
  58. """
  59. def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
  60. """
  61. Initializes a BaseValidator instance.
  62. Args:
  63. dataloader (torch.utils.data.DataLoader): Dataloader to be used for validation.
  64. save_dir (Path): Directory to save results.
  65. pbar (tqdm.tqdm): Progress bar for displaying progress.
  66. args (SimpleNamespace): Configuration for the validator.
  67. _callbacks (dict): Dictionary to store various callback functions.
  68. """
  69. self.args = get_cfg(overrides=args)
  70. self.dataloader = dataloader
  71. self.pbar = pbar
  72. self.model = None
  73. self.data = None
  74. self.device = None
  75. self.batch_i = None
  76. self.training = True
  77. self.names = None
  78. self.seen = None
  79. self.stats = None
  80. self.confusion_matrix = None
  81. self.nc = None
  82. self.iouv = None
  83. self.jdict = None
  84. self.speed = {'preprocess': 0.0, 'inference': 0.0, 'loss': 0.0, 'postprocess': 0.0}
  85. project = self.args.project or Path(SETTINGS['runs_dir']) / self.args.task
  86. name = self.args.name or f'{self.args.mode}'
  87. self.save_dir = save_dir or increment_path(Path(project) / name,
  88. exist_ok=self.args.exist_ok if RANK in (-1, 0) else True)
  89. (self.save_dir / 'labels' if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True)
  90. if self.args.conf is None:
  91. self.args.conf = 0.001 # default conf=0.001
  92. self.plots = {}
  93. self.callbacks = _callbacks or callbacks.get_default_callbacks()
  94. @smart_inference_mode()
  95. def __call__(self, trainer=None, model=None):
  96. """
  97. Supports validation of a pre-trained model if passed or a model being trained
  98. if trainer is passed (trainer gets priority).
  99. """
  100. self.training = trainer is not None
  101. augment = self.args.augment and (not self.training)
  102. if self.training:
  103. self.device = trainer.device
  104. self.data = trainer.data
  105. model = trainer.ema.ema or trainer.model
  106. self.args.half = self.device.type != 'cpu' # force FP16 val during training
  107. model = model.half() if self.args.half else model.float()
  108. self.model = model
  109. self.loss = torch.zeros_like(trainer.loss_items, device=trainer.device)
  110. self.args.plots = trainer.stopper.possible_stop or (trainer.epoch == trainer.epochs - 1)
  111. model.eval()
  112. else:
  113. callbacks.add_integration_callbacks(self)
  114. self.run_callbacks('on_val_start')
  115. model = AutoBackend(model or self.args.model,
  116. device=select_device(self.args.device, self.args.batch),
  117. dnn=self.args.dnn,
  118. data=self.args.data,
  119. fp16=self.args.half)
  120. self.model = model
  121. self.device = model.device # update device
  122. self.args.half = model.fp16 # update half
  123. stride, pt, jit, engine = model.stride, model.pt, model.jit, model.engine
  124. imgsz = check_imgsz(self.args.imgsz, stride=stride)
  125. if engine:
  126. self.args.batch = model.batch_size
  127. elif not pt and not jit:
  128. self.args.batch = 1 # export.py models default to batch-size 1
  129. LOGGER.info(f'Forcing batch=1 square inference (1,3,{imgsz},{imgsz}) for non-PyTorch models')
  130. if isinstance(self.args.data, str) and self.args.data.split('.')[-1] in ('yaml', 'yml'):
  131. self.data = check_det_dataset(self.args.data)
  132. elif self.args.task == 'classify':
  133. self.data = check_cls_dataset(self.args.data, split=self.args.split)
  134. else:
  135. raise FileNotFoundError(emojis(f"Dataset '{self.args.data}' for task={self.args.task} not found ❌"))
  136. if self.device.type == 'cpu':
  137. self.args.workers = 0 # faster CPU val as time dominated by inference, not dataloading
  138. if not pt:
  139. self.args.rect = False
  140. self.dataloader = self.dataloader or self.get_dataloader(self.data.get(self.args.split), self.args.batch)
  141. model.eval()
  142. model.warmup(imgsz=(1 if pt else self.args.batch, 3, imgsz, imgsz)) # warmup
  143. dt = Profile(), Profile(), Profile(), Profile()
  144. n_batches = len(self.dataloader)
  145. desc = self.get_desc()
  146. # NOTE: keeping `not self.training` in tqdm will eliminate pbar after segmentation evaluation during training,
  147. # which may affect classification task since this arg is in yolov5/classify/val.py.
  148. # bar = tqdm(self.dataloader, desc, n_batches, not self.training, bar_format=TQDM_BAR_FORMAT)
  149. bar = tqdm(self.dataloader, desc, n_batches, bar_format=TQDM_BAR_FORMAT)
  150. self.init_metrics(de_parallel(model))
  151. self.jdict = [] # empty before each val
  152. for batch_i, batch in enumerate(bar):
  153. self.run_callbacks('on_val_batch_start')
  154. self.batch_i = batch_i
  155. # Preprocess
  156. with dt[0]:
  157. batch = self.preprocess(batch)
  158. # Inference
  159. with dt[1]:
  160. preds = model(batch['img'], augment=augment)
  161. # Loss
  162. with dt[2]:
  163. if self.training:
  164. self.loss += model.loss(batch, preds)[1]
  165. # Postprocess
  166. with dt[3]:
  167. preds = self.postprocess(preds)
  168. self.update_metrics(preds, batch)
  169. if self.args.plots and batch_i < 3:
  170. self.plot_val_samples(batch, batch_i)
  171. self.plot_predictions(batch, preds, batch_i)
  172. self.run_callbacks('on_val_batch_end')
  173. stats = self.get_stats()
  174. self.check_stats(stats)
  175. self.speed = dict(zip(self.speed.keys(), (x.t / len(self.dataloader.dataset) * 1E3 for x in dt)))
  176. self.finalize_metrics()
  177. self.print_results()
  178. self.run_callbacks('on_val_end')
  179. if self.training:
  180. model.float()
  181. results = {**stats, **trainer.label_loss_items(self.loss.cpu() / len(self.dataloader), prefix='val')}
  182. return {k: round(float(v), 5) for k, v in results.items()} # return results as 5 decimal place floats
  183. else:
  184. LOGGER.info('Speed: %.1fms preprocess, %.1fms inference, %.1fms loss, %.1fms postprocess per image' %
  185. tuple(self.speed.values()))
  186. if self.args.save_json and self.jdict:
  187. with open(str(self.save_dir / 'predictions.json'), 'w') as f:
  188. LOGGER.info(f'Saving {f.name}...')
  189. json.dump(self.jdict, f) # flatten and save
  190. stats = self.eval_json(stats) # update stats
  191. if self.args.plots or self.args.save_json:
  192. LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}")
  193. return stats
  194. def match_predictions(self, pred_classes, true_classes, iou):
  195. """
  196. Matches predictions to ground truth objects (pred_classes, true_classes) using IoU.
  197. Args:
  198. pred_classes (torch.Tensor): Predicted class indices of shape(N,).
  199. true_classes (torch.Tensor): Target class indices of shape(M,).
  200. iou (torch.Tensor): An NxM tensor containing the pairwise IoU values for predictions and ground of truth
  201. Returns:
  202. (torch.Tensor): Correct tensor of shape(N,10) for 10 IoU thresholds.
  203. """
  204. correct = np.zeros((pred_classes.shape[0], self.iouv.shape[0])).astype(bool)
  205. correct_class = true_classes[:, None] == pred_classes
  206. for i, iouv in enumerate(self.iouv):
  207. x = torch.nonzero(iou.ge(iouv) & correct_class) # IoU > threshold and classes match
  208. if x.shape[0]:
  209. # Concatenate [label, detect, iou]
  210. matches = torch.cat((x, iou[x[:, 0], x[:, 1]].unsqueeze(1)), 1).cpu().numpy()
  211. if x.shape[0] > 1:
  212. matches = matches[matches[:, 2].argsort()[::-1]]
  213. matches = matches[np.unique(matches[:, 1], return_index=True)[1]]
  214. # matches = matches[matches[:, 2].argsort()[::-1]]
  215. matches = matches[np.unique(matches[:, 0], return_index=True)[1]]
  216. correct[matches[:, 1].astype(int), i] = True
  217. return torch.tensor(correct, dtype=torch.bool, device=pred_classes.device)
  218. def add_callback(self, event: str, callback):
  219. """Appends the given callback."""
  220. self.callbacks[event].append(callback)
  221. def run_callbacks(self, event: str):
  222. """Runs all callbacks associated with a specified event."""
  223. for callback in self.callbacks.get(event, []):
  224. callback(self)
  225. def get_dataloader(self, dataset_path, batch_size):
  226. """Get data loader from dataset path and batch size."""
  227. raise NotImplementedError('get_dataloader function not implemented for this validator')
  228. def build_dataset(self, img_path):
  229. """Build dataset"""
  230. raise NotImplementedError('build_dataset function not implemented in validator')
  231. def preprocess(self, batch):
  232. """Preprocesses an input batch."""
  233. return batch
  234. def postprocess(self, preds):
  235. """Describes and summarizes the purpose of 'postprocess()' but no details mentioned."""
  236. return preds
  237. def init_metrics(self, model):
  238. """Initialize performance metrics for the YOLO model."""
  239. pass
  240. def update_metrics(self, preds, batch):
  241. """Updates metrics based on predictions and batch."""
  242. pass
  243. def finalize_metrics(self, *args, **kwargs):
  244. """Finalizes and returns all metrics."""
  245. pass
  246. def get_stats(self):
  247. """Returns statistics about the model's performance."""
  248. return {}
  249. def check_stats(self, stats):
  250. """Checks statistics."""
  251. pass
  252. def print_results(self):
  253. """Prints the results of the model's predictions."""
  254. pass
  255. def get_desc(self):
  256. """Get description of the YOLO model."""
  257. pass
  258. @property
  259. def metric_keys(self):
  260. """Returns the metric keys used in YOLO training/validation."""
  261. return []
  262. def on_plot(self, name, data=None):
  263. """Registers plots (e.g. to be consumed in callbacks)"""
  264. path = Path(name)
  265. self.plots[path] = {'data': data, 'timestamp': time.time()}
  266. # TODO: may need to put these following functions into callback
  267. def plot_val_samples(self, batch, ni):
  268. """Plots validation samples during training."""
  269. pass
  270. def plot_predictions(self, batch, preds, ni):
  271. """Plots YOLO model predictions on batch images."""
  272. pass
  273. def pred_to_json(self, preds, batch):
  274. """Convert predictions to JSON format."""
  275. pass
  276. def eval_json(self, stats):
  277. """Evaluate and return JSON format of prediction statistics."""
  278. pass