val.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215
  1. # Ultralytics YOLO 🚀, AGPL-3.0 license
  2. from pathlib import Path
  3. import numpy as np
  4. import torch
  5. from ultralytics.models.yolo.detect import DetectionValidator
  6. from ultralytics.utils import LOGGER, ops
  7. from ultralytics.utils.checks import check_requirements
  8. from ultralytics.utils.metrics import OKS_SIGMA, PoseMetrics, box_iou, kpt_iou
  9. from ultralytics.utils.plotting import output_to_target, plot_images
  10. class PoseValidator(DetectionValidator):
  11. """
  12. A class extending the DetectionValidator class for validation based on a pose model.
  13. Example:
  14. ```python
  15. from ultralytics.models.yolo.pose import PoseValidator
  16. args = dict(model='yolov8n-pose.pt', data='coco8-pose.yaml')
  17. validator = PoseValidator(args=args)
  18. validator()
  19. ```
  20. """
  21. def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
  22. """Initialize a 'PoseValidator' object with custom parameters and assigned attributes."""
  23. super().__init__(dataloader, save_dir, pbar, args, _callbacks)
  24. self.sigma = None
  25. self.kpt_shape = None
  26. self.args.task = 'pose'
  27. self.metrics = PoseMetrics(save_dir=self.save_dir, on_plot=self.on_plot)
  28. if isinstance(self.args.device, str) and self.args.device.lower() == 'mps':
  29. LOGGER.warning("WARNING ⚠️ Apple MPS known Pose bug. Recommend 'device=cpu' for Pose models. "
  30. 'See https://github.com/ultralytics/ultralytics/issues/4031.')
  31. def preprocess(self, batch):
  32. """Preprocesses the batch by converting the 'keypoints' data into a float and moving it to the device."""
  33. batch = super().preprocess(batch)
  34. batch['keypoints'] = batch['keypoints'].to(self.device).float()
  35. return batch
  36. def get_desc(self):
  37. """Returns description of evaluation metrics in string format."""
  38. return ('%22s' + '%11s' * 10) % ('Class', 'Images', 'Instances', 'Box(P', 'R', 'mAP50', 'mAP50-95)', 'Pose(P',
  39. 'R', 'mAP50', 'mAP50-95)')
  40. def postprocess(self, preds):
  41. """Apply non-maximum suppression and return detections with high confidence scores."""
  42. return ops.non_max_suppression(preds,
  43. self.args.conf,
  44. self.args.iou,
  45. labels=self.lb,
  46. multi_label=True,
  47. agnostic=self.args.single_cls,
  48. max_det=self.args.max_det,
  49. nc=self.nc)
  50. def init_metrics(self, model):
  51. """Initiate pose estimation metrics for YOLO model."""
  52. super().init_metrics(model)
  53. self.kpt_shape = self.data['kpt_shape']
  54. is_pose = self.kpt_shape == [17, 3]
  55. nkpt = self.kpt_shape[0]
  56. self.sigma = OKS_SIGMA if is_pose else np.ones(nkpt) / nkpt
  57. def update_metrics(self, preds, batch):
  58. """Metrics."""
  59. for si, pred in enumerate(preds):
  60. idx = batch['batch_idx'] == si
  61. cls = batch['cls'][idx]
  62. bbox = batch['bboxes'][idx]
  63. kpts = batch['keypoints'][idx]
  64. nl, npr = cls.shape[0], pred.shape[0] # number of labels, predictions
  65. nk = kpts.shape[1] # number of keypoints
  66. shape = batch['ori_shape'][si]
  67. correct_kpts = torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device) # init
  68. correct_bboxes = torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device) # init
  69. self.seen += 1
  70. if npr == 0:
  71. if nl:
  72. self.stats.append((correct_bboxes, correct_kpts, *torch.zeros(
  73. (2, 0), device=self.device), cls.squeeze(-1)))
  74. if self.args.plots:
  75. self.confusion_matrix.process_batch(detections=None, labels=cls.squeeze(-1))
  76. continue
  77. # Predictions
  78. if self.args.single_cls:
  79. pred[:, 5] = 0
  80. predn = pred.clone()
  81. ops.scale_boxes(batch['img'][si].shape[1:], predn[:, :4], shape,
  82. ratio_pad=batch['ratio_pad'][si]) # native-space pred
  83. pred_kpts = predn[:, 6:].view(npr, nk, -1)
  84. ops.scale_coords(batch['img'][si].shape[1:], pred_kpts, shape, ratio_pad=batch['ratio_pad'][si])
  85. # Evaluate
  86. if nl:
  87. height, width = batch['img'].shape[2:]
  88. tbox = ops.xywh2xyxy(bbox) * torch.tensor(
  89. (width, height, width, height), device=self.device) # target boxes
  90. ops.scale_boxes(batch['img'][si].shape[1:], tbox, shape,
  91. ratio_pad=batch['ratio_pad'][si]) # native-space labels
  92. tkpts = kpts.clone()
  93. tkpts[..., 0] *= width
  94. tkpts[..., 1] *= height
  95. tkpts = ops.scale_coords(batch['img'][si].shape[1:], tkpts, shape, ratio_pad=batch['ratio_pad'][si])
  96. labelsn = torch.cat((cls, tbox), 1) # native-space labels
  97. correct_bboxes = self._process_batch(predn[:, :6], labelsn)
  98. correct_kpts = self._process_batch(predn[:, :6], labelsn, pred_kpts, tkpts)
  99. if self.args.plots:
  100. self.confusion_matrix.process_batch(predn, labelsn)
  101. # Append correct_masks, correct_boxes, pconf, pcls, tcls
  102. self.stats.append((correct_bboxes, correct_kpts, pred[:, 4], pred[:, 5], cls.squeeze(-1)))
  103. # Save
  104. if self.args.save_json:
  105. self.pred_to_json(predn, batch['im_file'][si])
  106. # if self.args.save_txt:
  107. # save_one_txt(predn, save_conf, shape, file=save_dir / 'labels' / f'{path.stem}.txt')
  108. def _process_batch(self, detections, labels, pred_kpts=None, gt_kpts=None):
  109. """
  110. Return correct prediction matrix.
  111. Args:
  112. detections (torch.Tensor): Tensor of shape [N, 6] representing detections.
  113. Each detection is of the format: x1, y1, x2, y2, conf, class.
  114. labels (torch.Tensor): Tensor of shape [M, 5] representing labels.
  115. Each label is of the format: class, x1, y1, x2, y2.
  116. pred_kpts (torch.Tensor, optional): Tensor of shape [N, 51] representing predicted keypoints.
  117. 51 corresponds to 17 keypoints each with 3 values.
  118. gt_kpts (torch.Tensor, optional): Tensor of shape [N, 51] representing ground truth keypoints.
  119. Returns:
  120. torch.Tensor: Correct prediction matrix of shape [N, 10] for 10 IoU levels.
  121. """
  122. if pred_kpts is not None and gt_kpts is not None:
  123. # `0.53` is from https://github.com/jin-s13/xtcocoapi/blob/master/xtcocotools/cocoeval.py#L384
  124. area = ops.xyxy2xywh(labels[:, 1:])[:, 2:].prod(1) * 0.53
  125. iou = kpt_iou(gt_kpts, pred_kpts, sigma=self.sigma, area=area)
  126. else: # boxes
  127. iou = box_iou(labels[:, 1:], detections[:, :4])
  128. return self.match_predictions(detections[:, 5], labels[:, 0], iou)
  129. def plot_val_samples(self, batch, ni):
  130. """Plots and saves validation set samples with predicted bounding boxes and keypoints."""
  131. plot_images(batch['img'],
  132. batch['batch_idx'],
  133. batch['cls'].squeeze(-1),
  134. batch['bboxes'],
  135. kpts=batch['keypoints'],
  136. paths=batch['im_file'],
  137. fname=self.save_dir / f'val_batch{ni}_labels.jpg',
  138. names=self.names,
  139. on_plot=self.on_plot)
  140. def plot_predictions(self, batch, preds, ni):
  141. """Plots predictions for YOLO model."""
  142. pred_kpts = torch.cat([p[:, 6:].view(-1, *self.kpt_shape) for p in preds], 0)
  143. plot_images(batch['img'],
  144. *output_to_target(preds, max_det=self.args.max_det),
  145. kpts=pred_kpts,
  146. paths=batch['im_file'],
  147. fname=self.save_dir / f'val_batch{ni}_pred.jpg',
  148. names=self.names,
  149. on_plot=self.on_plot) # pred
  150. def pred_to_json(self, predn, filename):
  151. """Converts YOLO predictions to COCO JSON format."""
  152. stem = Path(filename).stem
  153. image_id = int(stem) if stem.isnumeric() else stem
  154. box = ops.xyxy2xywh(predn[:, :4]) # xywh
  155. box[:, :2] -= box[:, 2:] / 2 # xy center to top-left corner
  156. for p, b in zip(predn.tolist(), box.tolist()):
  157. self.jdict.append({
  158. 'image_id': image_id,
  159. 'category_id': self.class_map[int(p[5])],
  160. 'bbox': [round(x, 3) for x in b],
  161. 'keypoints': p[6:],
  162. 'score': round(p[4], 5)})
  163. def eval_json(self, stats):
  164. """Evaluates object detection model using COCO JSON format."""
  165. if self.args.save_json and self.is_coco and len(self.jdict):
  166. anno_json = self.data['path'] / 'annotations/person_keypoints_val2017.json' # annotations
  167. pred_json = self.save_dir / 'predictions.json' # predictions
  168. LOGGER.info(f'\nEvaluating pycocotools mAP using {pred_json} and {anno_json}...')
  169. try: # https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb
  170. check_requirements('pycocotools>=2.0.6')
  171. from pycocotools.coco import COCO # noqa
  172. from pycocotools.cocoeval import COCOeval # noqa
  173. for x in anno_json, pred_json:
  174. assert x.is_file(), f'{x} file not found'
  175. anno = COCO(str(anno_json)) # init annotations api
  176. pred = anno.loadRes(str(pred_json)) # init predictions api (must pass string, not Path)
  177. for i, eval in enumerate([COCOeval(anno, pred, 'bbox'), COCOeval(anno, pred, 'keypoints')]):
  178. if self.is_coco:
  179. eval.params.imgIds = [int(Path(x).stem) for x in self.dataloader.dataset.im_files] # im to eval
  180. eval.evaluate()
  181. eval.accumulate()
  182. eval.summarize()
  183. idx = i * 4 + 2
  184. stats[self.metrics.keys[idx + 1]], stats[
  185. self.metrics.keys[idx]] = eval.stats[:2] # update mAP50-95 and mAP50
  186. except Exception as e:
  187. LOGGER.warning(f'pycocotools unable to run: {e}')
  188. return stats