mlflow.py 2.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071
  1. # Ultralytics YOLO 🚀, AGPL-3.0 license
  2. import os
  3. import re
  4. from pathlib import Path
  5. from ultralytics.utils import LOGGER, SETTINGS, TESTS_RUNNING, colorstr
  6. try:
  7. import mlflow
  8. assert not TESTS_RUNNING # do not log pytest
  9. assert hasattr(mlflow, '__version__') # verify package is not directory
  10. assert SETTINGS['mlflow'] is True # verify integration is enabled
  11. except (ImportError, AssertionError):
  12. mlflow = None
  13. def on_pretrain_routine_end(trainer):
  14. """Logs training parameters to MLflow."""
  15. global mlflow, run, experiment_name
  16. if os.environ.get('MLFLOW_TRACKING_URI') is None:
  17. mlflow = None
  18. if mlflow:
  19. mlflow_location = os.environ['MLFLOW_TRACKING_URI'] # "http://192.168.xxx.xxx:5000"
  20. mlflow.set_tracking_uri(mlflow_location)
  21. experiment_name = os.environ.get('MLFLOW_EXPERIMENT_NAME') or trainer.args.project or '/Shared/YOLOv8'
  22. run_name = os.environ.get('MLFLOW_RUN') or trainer.args.name
  23. experiment = mlflow.get_experiment_by_name(experiment_name)
  24. if experiment is None:
  25. mlflow.create_experiment(experiment_name)
  26. mlflow.set_experiment(experiment_name)
  27. prefix = colorstr('MLFlow: ')
  28. try:
  29. run, active_run = mlflow, mlflow.active_run()
  30. if not active_run:
  31. active_run = mlflow.start_run(experiment_id=experiment.experiment_id, run_name=run_name)
  32. LOGGER.info(f'{prefix}Using run_id({active_run.info.run_id}) at {mlflow_location}')
  33. run.log_params(vars(trainer.model.args))
  34. except Exception as err:
  35. LOGGER.error(f'{prefix}Failing init - {repr(err)}')
  36. LOGGER.warning(f'{prefix}Continuing without Mlflow')
  37. def on_fit_epoch_end(trainer):
  38. """Logs training metrics to Mlflow."""
  39. if mlflow:
  40. metrics_dict = {f"{re.sub('[()]', '', k)}": float(v) for k, v in trainer.metrics.items()}
  41. run.log_metrics(metrics=metrics_dict, step=trainer.epoch)
  42. def on_train_end(trainer):
  43. """Called at end of train loop to log model artifact info."""
  44. if mlflow:
  45. root_dir = Path(__file__).resolve().parents[3]
  46. run.log_artifact(trainer.last)
  47. run.log_artifact(trainer.best)
  48. run.pyfunc.log_model(artifact_path=experiment_name,
  49. code_path=[str(root_dir)],
  50. artifacts={'model_path': str(trainer.save_dir)},
  51. python_model=run.pyfunc.PythonModel())
  52. callbacks = {
  53. 'on_pretrain_routine_end': on_pretrain_routine_end,
  54. 'on_fit_epoch_end': on_fit_epoch_end,
  55. 'on_train_end': on_train_end} if mlflow else {}