dvc.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
  1. # Ultralytics YOLO 🚀, AGPL-3.0 license
  2. import os
  3. import re
  4. from pathlib import Path
  5. import pkg_resources as pkg
  6. from ultralytics.utils import LOGGER, SETTINGS, TESTS_RUNNING
  7. from ultralytics.utils.torch_utils import model_info_for_loggers
  8. try:
  9. from importlib.metadata import version
  10. import dvclive
  11. assert not TESTS_RUNNING # do not log pytest
  12. assert SETTINGS['dvc'] is True # verify integration is enabled
  13. ver = version('dvclive')
  14. if pkg.parse_version(ver) < pkg.parse_version('2.11.0'):
  15. LOGGER.debug(f'DVCLive is detected but version {ver} is incompatible (>=2.11 required).')
  16. dvclive = None # noqa: F811
  17. except (ImportError, AssertionError, TypeError):
  18. dvclive = None
  19. # DVCLive logger instance
  20. live = None
  21. _processed_plots = {}
  22. # `on_fit_epoch_end` is called on final validation (probably need to be fixed)
  23. # for now this is the way we distinguish final evaluation of the best model vs
  24. # last epoch validation
  25. _training_epoch = False
  26. def _log_images(path, prefix=''):
  27. if live:
  28. name = path.name
  29. # Group images by batch to enable sliders in UI
  30. if m := re.search(r'_batch(\d+)', name):
  31. ni = m[1]
  32. new_stem = re.sub(r'_batch(\d+)', '_batch', path.stem)
  33. name = (Path(new_stem) / ni).with_suffix(path.suffix)
  34. live.log_image(os.path.join(prefix, name), path)
  35. def _log_plots(plots, prefix=''):
  36. for name, params in plots.items():
  37. timestamp = params['timestamp']
  38. if _processed_plots.get(name) != timestamp:
  39. _log_images(name, prefix)
  40. _processed_plots[name] = timestamp
  41. def _log_confusion_matrix(validator):
  42. targets = []
  43. preds = []
  44. matrix = validator.confusion_matrix.matrix
  45. names = list(validator.names.values())
  46. if validator.confusion_matrix.task == 'detect':
  47. names += ['background']
  48. for ti, pred in enumerate(matrix.T.astype(int)):
  49. for pi, num in enumerate(pred):
  50. targets.extend([names[ti]] * num)
  51. preds.extend([names[pi]] * num)
  52. live.log_sklearn_plot('confusion_matrix', targets, preds, name='cf.json', normalized=True)
  53. def on_pretrain_routine_start(trainer):
  54. try:
  55. global live
  56. live = dvclive.Live(save_dvc_exp=True, cache_images=True)
  57. LOGGER.info(
  58. f'DVCLive is detected and auto logging is enabled (can be disabled in the {SETTINGS.file} with `dvc: false`).'
  59. )
  60. except Exception as e:
  61. LOGGER.warning(f'WARNING ⚠️ DVCLive installed but not initialized correctly, not logging this run. {e}')
  62. def on_pretrain_routine_end(trainer):
  63. _log_plots(trainer.plots, 'train')
  64. def on_train_start(trainer):
  65. if live:
  66. live.log_params(trainer.args)
  67. def on_train_epoch_start(trainer):
  68. global _training_epoch
  69. _training_epoch = True
  70. def on_fit_epoch_end(trainer):
  71. global _training_epoch
  72. if live and _training_epoch:
  73. all_metrics = {**trainer.label_loss_items(trainer.tloss, prefix='train'), **trainer.metrics, **trainer.lr}
  74. for metric, value in all_metrics.items():
  75. live.log_metric(metric, value)
  76. if trainer.epoch == 0:
  77. for metric, value in model_info_for_loggers(trainer).items():
  78. live.log_metric(metric, value, plot=False)
  79. _log_plots(trainer.plots, 'train')
  80. _log_plots(trainer.validator.plots, 'val')
  81. live.next_step()
  82. _training_epoch = False
  83. def on_train_end(trainer):
  84. if live:
  85. # At the end log the best metrics. It runs validator on the best model internally.
  86. all_metrics = {**trainer.label_loss_items(trainer.tloss, prefix='train'), **trainer.metrics, **trainer.lr}
  87. for metric, value in all_metrics.items():
  88. live.log_metric(metric, value, plot=False)
  89. _log_plots(trainer.plots, 'val')
  90. _log_plots(trainer.validator.plots, 'val')
  91. _log_confusion_matrix(trainer.validator)
  92. if trainer.best.exists():
  93. live.log_artifact(trainer.best, copy=True, type='model')
  94. live.end()
  95. callbacks = {
  96. 'on_pretrain_routine_start': on_pretrain_routine_start,
  97. 'on_pretrain_routine_end': on_pretrain_routine_end,
  98. 'on_train_start': on_train_start,
  99. 'on_train_epoch_start': on_train_epoch_start,
  100. 'on_fit_epoch_end': on_fit_epoch_end,
  101. 'on_train_end': on_train_end} if dvclive else {}