comet.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369
  1. # Ultralytics YOLO 🚀, AGPL-3.0 license
  2. import os
  3. from pathlib import Path
  4. from ultralytics.utils import LOGGER, RANK, SETTINGS, TESTS_RUNNING, ops
  5. from ultralytics.utils.torch_utils import model_info_for_loggers
  6. try:
  7. import comet_ml
  8. assert not TESTS_RUNNING # do not log pytest
  9. assert hasattr(comet_ml, '__version__') # verify package is not directory
  10. assert SETTINGS['comet'] is True # verify integration is enabled
  11. except (ImportError, AssertionError):
  12. comet_ml = None
  13. # Ensures certain logging functions only run for supported tasks
  14. COMET_SUPPORTED_TASKS = ['detect']
  15. # Names of plots created by YOLOv8 that are logged to Comet
  16. EVALUATION_PLOT_NAMES = 'F1_curve', 'P_curve', 'R_curve', 'PR_curve', 'confusion_matrix'
  17. LABEL_PLOT_NAMES = 'labels', 'labels_correlogram'
  18. _comet_image_prediction_count = 0
  19. def _get_comet_mode():
  20. return os.getenv('COMET_MODE', 'online')
  21. def _get_comet_model_name():
  22. return os.getenv('COMET_MODEL_NAME', 'YOLOv8')
  23. def _get_eval_batch_logging_interval():
  24. return int(os.getenv('COMET_EVAL_BATCH_LOGGING_INTERVAL', 1))
  25. def _get_max_image_predictions_to_log():
  26. return int(os.getenv('COMET_MAX_IMAGE_PREDICTIONS', 100))
  27. def _scale_confidence_score(score):
  28. scale = float(os.getenv('COMET_MAX_CONFIDENCE_SCORE', 100.0))
  29. return score * scale
  30. def _should_log_confusion_matrix():
  31. return os.getenv('COMET_EVAL_LOG_CONFUSION_MATRIX', 'false').lower() == 'true'
  32. def _should_log_image_predictions():
  33. return os.getenv('COMET_EVAL_LOG_IMAGE_PREDICTIONS', 'true').lower() == 'true'
  34. def _get_experiment_type(mode, project_name):
  35. """Return an experiment based on mode and project name."""
  36. if mode == 'offline':
  37. return comet_ml.OfflineExperiment(project_name=project_name)
  38. return comet_ml.Experiment(project_name=project_name)
  39. def _create_experiment(args):
  40. """Ensures that the experiment object is only created in a single process during distributed training."""
  41. if RANK not in (-1, 0):
  42. return
  43. try:
  44. comet_mode = _get_comet_mode()
  45. _project_name = os.getenv('COMET_PROJECT_NAME', args.project)
  46. experiment = _get_experiment_type(comet_mode, _project_name)
  47. experiment.log_parameters(vars(args))
  48. experiment.log_others({
  49. 'eval_batch_logging_interval': _get_eval_batch_logging_interval(),
  50. 'log_confusion_matrix_on_eval': _should_log_confusion_matrix(),
  51. 'log_image_predictions': _should_log_image_predictions(),
  52. 'max_image_predictions': _get_max_image_predictions_to_log(), })
  53. experiment.log_other('Created from', 'yolov8')
  54. except Exception as e:
  55. LOGGER.warning(f'WARNING ⚠️ Comet installed but not initialized correctly, not logging this run. {e}')
  56. def _fetch_trainer_metadata(trainer):
  57. """Returns metadata for YOLO training including epoch and asset saving status."""
  58. curr_epoch = trainer.epoch + 1
  59. train_num_steps_per_epoch = len(trainer.train_loader.dataset) // trainer.batch_size
  60. curr_step = curr_epoch * train_num_steps_per_epoch
  61. final_epoch = curr_epoch == trainer.epochs
  62. save = trainer.args.save
  63. save_period = trainer.args.save_period
  64. save_interval = curr_epoch % save_period == 0
  65. save_assets = save and save_period > 0 and save_interval and not final_epoch
  66. return dict(
  67. curr_epoch=curr_epoch,
  68. curr_step=curr_step,
  69. save_assets=save_assets,
  70. final_epoch=final_epoch,
  71. )
  72. def _scale_bounding_box_to_original_image_shape(box, resized_image_shape, original_image_shape, ratio_pad):
  73. """YOLOv8 resizes images during training and the label values
  74. are normalized based on this resized shape. This function rescales the
  75. bounding box labels to the original image shape.
  76. """
  77. resized_image_height, resized_image_width = resized_image_shape
  78. # Convert normalized xywh format predictions to xyxy in resized scale format
  79. box = ops.xywhn2xyxy(box, h=resized_image_height, w=resized_image_width)
  80. # Scale box predictions from resized image scale back to original image scale
  81. box = ops.scale_boxes(resized_image_shape, box, original_image_shape, ratio_pad)
  82. # Convert bounding box format from xyxy to xywh for Comet logging
  83. box = ops.xyxy2xywh(box)
  84. # Adjust xy center to correspond top-left corner
  85. box[:2] -= box[2:] / 2
  86. box = box.tolist()
  87. return box
  88. def _format_ground_truth_annotations_for_detection(img_idx, image_path, batch, class_name_map=None):
  89. """Format ground truth annotations for detection."""
  90. indices = batch['batch_idx'] == img_idx
  91. bboxes = batch['bboxes'][indices]
  92. if len(bboxes) == 0:
  93. LOGGER.debug(f'COMET WARNING: Image: {image_path} has no bounding boxes labels')
  94. return None
  95. cls_labels = batch['cls'][indices].squeeze(1).tolist()
  96. if class_name_map:
  97. cls_labels = [str(class_name_map[label]) for label in cls_labels]
  98. original_image_shape = batch['ori_shape'][img_idx]
  99. resized_image_shape = batch['resized_shape'][img_idx]
  100. ratio_pad = batch['ratio_pad'][img_idx]
  101. data = []
  102. for box, label in zip(bboxes, cls_labels):
  103. box = _scale_bounding_box_to_original_image_shape(box, resized_image_shape, original_image_shape, ratio_pad)
  104. data.append({
  105. 'boxes': [box],
  106. 'label': f'gt_{label}',
  107. 'score': _scale_confidence_score(1.0), })
  108. return {'name': 'ground_truth', 'data': data}
  109. def _format_prediction_annotations_for_detection(image_path, metadata, class_label_map=None):
  110. """Format YOLO predictions for object detection visualization."""
  111. stem = image_path.stem
  112. image_id = int(stem) if stem.isnumeric() else stem
  113. predictions = metadata.get(image_id)
  114. if not predictions:
  115. LOGGER.debug(f'COMET WARNING: Image: {image_path} has no bounding boxes predictions')
  116. return None
  117. data = []
  118. for prediction in predictions:
  119. boxes = prediction['bbox']
  120. score = _scale_confidence_score(prediction['score'])
  121. cls_label = prediction['category_id']
  122. if class_label_map:
  123. cls_label = str(class_label_map[cls_label])
  124. data.append({'boxes': [boxes], 'label': cls_label, 'score': score})
  125. return {'name': 'prediction', 'data': data}
  126. def _fetch_annotations(img_idx, image_path, batch, prediction_metadata_map, class_label_map):
  127. """Join the ground truth and prediction annotations if they exist."""
  128. ground_truth_annotations = _format_ground_truth_annotations_for_detection(img_idx, image_path, batch,
  129. class_label_map)
  130. prediction_annotations = _format_prediction_annotations_for_detection(image_path, prediction_metadata_map,
  131. class_label_map)
  132. annotations = [
  133. annotation for annotation in [ground_truth_annotations, prediction_annotations] if annotation is not None]
  134. return [annotations] if annotations else None
  135. def _create_prediction_metadata_map(model_predictions):
  136. """Create metadata map for model predictions by groupings them based on image ID."""
  137. pred_metadata_map = {}
  138. for prediction in model_predictions:
  139. pred_metadata_map.setdefault(prediction['image_id'], [])
  140. pred_metadata_map[prediction['image_id']].append(prediction)
  141. return pred_metadata_map
  142. def _log_confusion_matrix(experiment, trainer, curr_step, curr_epoch):
  143. """Log the confusion matrix to Comet experiment."""
  144. conf_mat = trainer.validator.confusion_matrix.matrix
  145. names = list(trainer.data['names'].values()) + ['background']
  146. experiment.log_confusion_matrix(
  147. matrix=conf_mat,
  148. labels=names,
  149. max_categories=len(names),
  150. epoch=curr_epoch,
  151. step=curr_step,
  152. )
  153. def _log_images(experiment, image_paths, curr_step, annotations=None):
  154. """Logs images to the experiment with optional annotations."""
  155. if annotations:
  156. for image_path, annotation in zip(image_paths, annotations):
  157. experiment.log_image(image_path, name=image_path.stem, step=curr_step, annotations=annotation)
  158. else:
  159. for image_path in image_paths:
  160. experiment.log_image(image_path, name=image_path.stem, step=curr_step)
  161. def _log_image_predictions(experiment, validator, curr_step):
  162. """Logs predicted boxes for a single image during training."""
  163. global _comet_image_prediction_count
  164. task = validator.args.task
  165. if task not in COMET_SUPPORTED_TASKS:
  166. return
  167. jdict = validator.jdict
  168. if not jdict:
  169. return
  170. predictions_metadata_map = _create_prediction_metadata_map(jdict)
  171. dataloader = validator.dataloader
  172. class_label_map = validator.names
  173. batch_logging_interval = _get_eval_batch_logging_interval()
  174. max_image_predictions = _get_max_image_predictions_to_log()
  175. for batch_idx, batch in enumerate(dataloader):
  176. if (batch_idx + 1) % batch_logging_interval != 0:
  177. continue
  178. image_paths = batch['im_file']
  179. for img_idx, image_path in enumerate(image_paths):
  180. if _comet_image_prediction_count >= max_image_predictions:
  181. return
  182. image_path = Path(image_path)
  183. annotations = _fetch_annotations(
  184. img_idx,
  185. image_path,
  186. batch,
  187. predictions_metadata_map,
  188. class_label_map,
  189. )
  190. _log_images(
  191. experiment,
  192. [image_path],
  193. curr_step,
  194. annotations=annotations,
  195. )
  196. _comet_image_prediction_count += 1
  197. def _log_plots(experiment, trainer):
  198. """Logs evaluation plots and label plots for the experiment."""
  199. plot_filenames = [trainer.save_dir / f'{plots}.png' for plots in EVALUATION_PLOT_NAMES]
  200. _log_images(experiment, plot_filenames, None)
  201. label_plot_filenames = [trainer.save_dir / f'{labels}.jpg' for labels in LABEL_PLOT_NAMES]
  202. _log_images(experiment, label_plot_filenames, None)
  203. def _log_model(experiment, trainer):
  204. """Log the best-trained model to Comet.ml."""
  205. model_name = _get_comet_model_name()
  206. experiment.log_model(
  207. model_name,
  208. file_or_folder=str(trainer.best),
  209. file_name='best.pt',
  210. overwrite=True,
  211. )
  212. def on_pretrain_routine_start(trainer):
  213. """Creates or resumes a CometML experiment at the start of a YOLO pre-training routine."""
  214. experiment = comet_ml.get_global_experiment()
  215. is_alive = getattr(experiment, 'alive', False)
  216. if not experiment or not is_alive:
  217. _create_experiment(trainer.args)
  218. def on_train_epoch_end(trainer):
  219. """Log metrics and save batch images at the end of training epochs."""
  220. experiment = comet_ml.get_global_experiment()
  221. if not experiment:
  222. return
  223. metadata = _fetch_trainer_metadata(trainer)
  224. curr_epoch = metadata['curr_epoch']
  225. curr_step = metadata['curr_step']
  226. experiment.log_metrics(
  227. trainer.label_loss_items(trainer.tloss, prefix='train'),
  228. step=curr_step,
  229. epoch=curr_epoch,
  230. )
  231. if curr_epoch == 1:
  232. _log_images(experiment, trainer.save_dir.glob('train_batch*.jpg'), curr_step)
  233. def on_fit_epoch_end(trainer):
  234. """Logs model assets at the end of each epoch."""
  235. experiment = comet_ml.get_global_experiment()
  236. if not experiment:
  237. return
  238. metadata = _fetch_trainer_metadata(trainer)
  239. curr_epoch = metadata['curr_epoch']
  240. curr_step = metadata['curr_step']
  241. save_assets = metadata['save_assets']
  242. experiment.log_metrics(trainer.metrics, step=curr_step, epoch=curr_epoch)
  243. experiment.log_metrics(trainer.lr, step=curr_step, epoch=curr_epoch)
  244. if curr_epoch == 1:
  245. experiment.log_metrics(model_info_for_loggers(trainer), step=curr_step, epoch=curr_epoch)
  246. if not save_assets:
  247. return
  248. _log_model(experiment, trainer)
  249. if _should_log_confusion_matrix():
  250. _log_confusion_matrix(experiment, trainer, curr_step, curr_epoch)
  251. if _should_log_image_predictions():
  252. _log_image_predictions(experiment, trainer.validator, curr_step)
  253. def on_train_end(trainer):
  254. """Perform operations at the end of training."""
  255. experiment = comet_ml.get_global_experiment()
  256. if not experiment:
  257. return
  258. metadata = _fetch_trainer_metadata(trainer)
  259. curr_epoch = metadata['curr_epoch']
  260. curr_step = metadata['curr_step']
  261. plots = trainer.args.plots
  262. _log_model(experiment, trainer)
  263. if plots:
  264. _log_plots(experiment, trainer)
  265. _log_confusion_matrix(experiment, trainer, curr_step, curr_epoch)
  266. _log_image_predictions(experiment, trainer.validator, curr_step)
  267. experiment.end()
  268. global _comet_image_prediction_count
  269. _comet_image_prediction_count = 0
  270. callbacks = {
  271. 'on_pretrain_routine_start': on_pretrain_routine_start,
  272. 'on_train_epoch_end': on_train_epoch_end,
  273. 'on_fit_epoch_end': on_fit_epoch_end,
  274. 'on_train_end': on_train_end} if comet_ml else {}