raytune.py 608 B

123456789101112131415161718192021222324
  1. # Ultralytics YOLO 🚀, AGPL-3.0 license
  2. from ultralytics.utils import SETTINGS
  3. try:
  4. import ray
  5. from ray import tune
  6. from ray.air import session
  7. assert SETTINGS['raytune'] is True # verify integration is enabled
  8. except (ImportError, AssertionError):
  9. tune = None
  10. def on_fit_epoch_end(trainer):
  11. """Sends training metrics to Ray Tune at end of each epoch."""
  12. if ray.tune.is_session_enabled():
  13. metrics = trainer.metrics
  14. metrics['epoch'] = trainer.epoch
  15. session.report(metrics)
  16. callbacks = {
  17. 'on_fit_epoch_end': on_fit_epoch_end, } if tune else {}