base.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212
  1. # Ultralytics YOLO 🚀, AGPL-3.0 license
  2. """
  3. Base callbacks
  4. """
  5. from collections import defaultdict
  6. from copy import deepcopy
  7. # Trainer callbacks ----------------------------------------------------------------------------------------------------
  8. def on_pretrain_routine_start(trainer):
  9. """Called before the pretraining routine starts."""
  10. pass
  11. def on_pretrain_routine_end(trainer):
  12. """Called after the pretraining routine ends."""
  13. pass
  14. def on_train_start(trainer):
  15. """Called when the training starts."""
  16. pass
  17. def on_train_epoch_start(trainer):
  18. """Called at the start of each training epoch."""
  19. pass
  20. def on_train_batch_start(trainer):
  21. """Called at the start of each training batch."""
  22. pass
  23. def optimizer_step(trainer):
  24. """Called when the optimizer takes a step."""
  25. pass
  26. def on_before_zero_grad(trainer):
  27. """Called before the gradients are set to zero."""
  28. pass
  29. def on_train_batch_end(trainer):
  30. """Called at the end of each training batch."""
  31. pass
  32. def on_train_epoch_end(trainer):
  33. """Called at the end of each training epoch."""
  34. pass
  35. def on_fit_epoch_end(trainer):
  36. """Called at the end of each fit epoch (train + val)."""
  37. pass
  38. def on_model_save(trainer):
  39. """Called when the model is saved."""
  40. pass
  41. def on_train_end(trainer):
  42. """Called when the training ends."""
  43. pass
  44. def on_params_update(trainer):
  45. """Called when the model parameters are updated."""
  46. pass
  47. def teardown(trainer):
  48. """Called during the teardown of the training process."""
  49. pass
  50. # Validator callbacks --------------------------------------------------------------------------------------------------
  51. def on_val_start(validator):
  52. """Called when the validation starts."""
  53. pass
  54. def on_val_batch_start(validator):
  55. """Called at the start of each validation batch."""
  56. pass
  57. def on_val_batch_end(validator):
  58. """Called at the end of each validation batch."""
  59. pass
  60. def on_val_end(validator):
  61. """Called when the validation ends."""
  62. pass
  63. # Predictor callbacks --------------------------------------------------------------------------------------------------
  64. def on_predict_start(predictor):
  65. """Called when the prediction starts."""
  66. pass
  67. def on_predict_batch_start(predictor):
  68. """Called at the start of each prediction batch."""
  69. pass
  70. def on_predict_batch_end(predictor):
  71. """Called at the end of each prediction batch."""
  72. pass
  73. def on_predict_postprocess_end(predictor):
  74. """Called after the post-processing of the prediction ends."""
  75. pass
  76. def on_predict_end(predictor):
  77. """Called when the prediction ends."""
  78. pass
  79. # Exporter callbacks ---------------------------------------------------------------------------------------------------
  80. def on_export_start(exporter):
  81. """Called when the model export starts."""
  82. pass
  83. def on_export_end(exporter):
  84. """Called when the model export ends."""
  85. pass
  86. default_callbacks = {
  87. # Run in trainer
  88. 'on_pretrain_routine_start': [on_pretrain_routine_start],
  89. 'on_pretrain_routine_end': [on_pretrain_routine_end],
  90. 'on_train_start': [on_train_start],
  91. 'on_train_epoch_start': [on_train_epoch_start],
  92. 'on_train_batch_start': [on_train_batch_start],
  93. 'optimizer_step': [optimizer_step],
  94. 'on_before_zero_grad': [on_before_zero_grad],
  95. 'on_train_batch_end': [on_train_batch_end],
  96. 'on_train_epoch_end': [on_train_epoch_end],
  97. 'on_fit_epoch_end': [on_fit_epoch_end], # fit = train + val
  98. 'on_model_save': [on_model_save],
  99. 'on_train_end': [on_train_end],
  100. 'on_params_update': [on_params_update],
  101. 'teardown': [teardown],
  102. # Run in validator
  103. 'on_val_start': [on_val_start],
  104. 'on_val_batch_start': [on_val_batch_start],
  105. 'on_val_batch_end': [on_val_batch_end],
  106. 'on_val_end': [on_val_end],
  107. # Run in predictor
  108. 'on_predict_start': [on_predict_start],
  109. 'on_predict_batch_start': [on_predict_batch_start],
  110. 'on_predict_postprocess_end': [on_predict_postprocess_end],
  111. 'on_predict_batch_end': [on_predict_batch_end],
  112. 'on_predict_end': [on_predict_end],
  113. # Run in exporter
  114. 'on_export_start': [on_export_start],
  115. 'on_export_end': [on_export_end]}
  116. def get_default_callbacks():
  117. """
  118. Return a copy of the default_callbacks dictionary with lists as default values.
  119. Returns:
  120. (defaultdict): A defaultdict with keys from default_callbacks and empty lists as default values.
  121. """
  122. return defaultdict(list, deepcopy(default_callbacks))
  123. def add_integration_callbacks(instance):
  124. """
  125. Add integration callbacks from various sources to the instance's callbacks.
  126. Args:
  127. instance (Trainer, Predictor, Validator, Exporter): An object with a 'callbacks' attribute that is a dictionary
  128. of callback lists.
  129. """
  130. from .clearml import callbacks as clearml_cb
  131. from .comet import callbacks as comet_cb
  132. from .dvc import callbacks as dvc_cb
  133. from .hub import callbacks as hub_cb
  134. from .mlflow import callbacks as mlflow_cb
  135. from .neptune import callbacks as neptune_cb
  136. from .raytune import callbacks as tune_cb
  137. from .tensorboard import callbacks as tensorboard_cb
  138. from .wb import callbacks as wb_cb
  139. for x in clearml_cb, comet_cb, hub_cb, mlflow_cb, neptune_cb, tune_cb, tensorboard_cb, wb_cb, dvc_cb:
  140. for k, v in x.items():
  141. if v not in instance.callbacks[k]: # prevent duplicate callbacks addition
  142. instance.callbacks[k].append(v) # callback[name].append(func)