session.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190
  1. # Ultralytics YOLO 🚀, AGPL-3.0 license
  2. import signal
  3. import sys
  4. from pathlib import Path
  5. from time import sleep
  6. import requests
  7. from ultralytics.hub.utils import HUB_API_ROOT, HUB_WEB_ROOT, PREFIX, smart_request
  8. from ultralytics.utils import LOGGER, __version__, checks, emojis, is_colab, threaded
  9. from ultralytics.utils.errors import HUBModelError
  10. AGENT_NAME = f'python-{__version__}-colab' if is_colab() else f'python-{__version__}-local'
  11. class HUBTrainingSession:
  12. """
  13. HUB training session for Ultralytics HUB YOLO models. Handles model initialization, heartbeats, and checkpointing.
  14. Args:
  15. url (str): Model identifier used to initialize the HUB training session.
  16. Attributes:
  17. agent_id (str): Identifier for the instance communicating with the server.
  18. model_id (str): Identifier for the YOLOv5 model being trained.
  19. model_url (str): URL for the model in Ultralytics HUB.
  20. api_url (str): API URL for the model in Ultralytics HUB.
  21. auth_header (dict): Authentication header for the Ultralytics HUB API requests.
  22. rate_limits (dict): Rate limits for different API calls (in seconds).
  23. timers (dict): Timers for rate limiting.
  24. metrics_queue (dict): Queue for the model's metrics.
  25. model (dict): Model data fetched from Ultralytics HUB.
  26. alive (bool): Indicates if the heartbeat loop is active.
  27. """
  28. def __init__(self, url):
  29. """
  30. Initialize the HUBTrainingSession with the provided model identifier.
  31. Args:
  32. url (str): Model identifier used to initialize the HUB training session.
  33. It can be a URL string or a model key with specific format.
  34. Raises:
  35. ValueError: If the provided model identifier is invalid.
  36. ConnectionError: If connecting with global API key is not supported.
  37. """
  38. from ultralytics.hub.auth import Auth
  39. # Parse input
  40. if url.startswith(f'{HUB_WEB_ROOT}/models/'):
  41. url = url.split(f'{HUB_WEB_ROOT}/models/')[-1]
  42. if [len(x) for x in url.split('_')] == [42, 20]:
  43. key, model_id = url.split('_')
  44. elif len(url) == 20:
  45. key, model_id = '', url
  46. else:
  47. raise HUBModelError(f"model='{url}' not found. Check format is correct, i.e. "
  48. f"model='{HUB_WEB_ROOT}/models/MODEL_ID' and try again.")
  49. # Authorize
  50. auth = Auth(key)
  51. self.agent_id = None # identifies which instance is communicating with server
  52. self.model_id = model_id
  53. self.model_url = f'{HUB_WEB_ROOT}/models/{model_id}'
  54. self.api_url = f'{HUB_API_ROOT}/v1/models/{model_id}'
  55. self.auth_header = auth.get_auth_header()
  56. self.rate_limits = {'metrics': 3.0, 'ckpt': 900.0, 'heartbeat': 300.0} # rate limits (seconds)
  57. self.timers = {} # rate limit timers (seconds)
  58. self.metrics_queue = {} # metrics queue
  59. self.model = self._get_model()
  60. self.alive = True
  61. self._start_heartbeat() # start heartbeats
  62. self._register_signal_handlers()
  63. LOGGER.info(f'{PREFIX}View model at {self.model_url} 🚀')
  64. def _register_signal_handlers(self):
  65. """Register signal handlers for SIGTERM and SIGINT signals to gracefully handle termination."""
  66. signal.signal(signal.SIGTERM, self._handle_signal)
  67. signal.signal(signal.SIGINT, self._handle_signal)
  68. def _handle_signal(self, signum, frame):
  69. """
  70. Handle kill signals and prevent heartbeats from being sent on Colab after termination.
  71. This method does not use frame, it is included as it is passed by signal.
  72. """
  73. if self.alive is True:
  74. LOGGER.info(f'{PREFIX}Kill signal received! ❌')
  75. self._stop_heartbeat()
  76. sys.exit(signum)
  77. def _stop_heartbeat(self):
  78. """Terminate the heartbeat loop."""
  79. self.alive = False
  80. def upload_metrics(self):
  81. """Upload model metrics to Ultralytics HUB."""
  82. payload = {'metrics': self.metrics_queue.copy(), 'type': 'metrics'}
  83. smart_request('post', self.api_url, json=payload, headers=self.auth_header, code=2)
  84. def _get_model(self):
  85. """Fetch and return model data from Ultralytics HUB."""
  86. api_url = f'{HUB_API_ROOT}/v1/models/{self.model_id}'
  87. try:
  88. response = smart_request('get', api_url, headers=self.auth_header, thread=False, code=0)
  89. data = response.json().get('data', None)
  90. if data.get('status', None) == 'trained':
  91. raise ValueError(emojis(f'Model is already trained and uploaded to {self.model_url} 🚀'))
  92. if not data.get('data', None):
  93. raise ValueError('Dataset may still be processing. Please wait a minute and try again.') # RF fix
  94. self.model_id = data['id']
  95. if data['status'] == 'new': # new model to start training
  96. self.train_args = {
  97. # TODO: deprecate 'batch_size' key for 'batch' in 3Q23
  98. 'batch': data['batch' if ('batch' in data) else 'batch_size'],
  99. 'epochs': data['epochs'],
  100. 'imgsz': data['imgsz'],
  101. 'patience': data['patience'],
  102. 'device': data['device'],
  103. 'cache': data['cache'],
  104. 'data': data['data']}
  105. self.model_file = data.get('cfg') or data.get('weights') # cfg for pretrained=False
  106. self.model_file = checks.check_yolov5u_filename(self.model_file, verbose=False) # YOLOv5->YOLOv5u
  107. elif data['status'] == 'training': # existing model to resume training
  108. self.train_args = {'data': data['data'], 'resume': True}
  109. self.model_file = data['resume']
  110. return data
  111. except requests.exceptions.ConnectionError as e:
  112. raise ConnectionRefusedError('ERROR: The HUB server is not online. Please try again later.') from e
  113. except Exception:
  114. raise
  115. def upload_model(self, epoch, weights, is_best=False, map=0.0, final=False):
  116. """
  117. Upload a model checkpoint to Ultralytics HUB.
  118. Args:
  119. epoch (int): The current training epoch.
  120. weights (str): Path to the model weights file.
  121. is_best (bool): Indicates if the current model is the best one so far.
  122. map (float): Mean average precision of the model.
  123. final (bool): Indicates if the model is the final model after training.
  124. """
  125. if Path(weights).is_file():
  126. with open(weights, 'rb') as f:
  127. file = f.read()
  128. else:
  129. LOGGER.warning(f'{PREFIX}WARNING ⚠️ Model upload issue. Missing model {weights}.')
  130. file = None
  131. url = f'{self.api_url}/upload'
  132. # url = 'http://httpbin.org/post' # for debug
  133. data = {'epoch': epoch}
  134. if final:
  135. data.update({'type': 'final', 'map': map})
  136. smart_request('post',
  137. url,
  138. data=data,
  139. files={'best.pt': file},
  140. headers=self.auth_header,
  141. retry=10,
  142. timeout=3600,
  143. thread=False,
  144. progress=True,
  145. code=4)
  146. else:
  147. data.update({'type': 'epoch', 'isBest': bool(is_best)})
  148. smart_request('post', url, data=data, files={'last.pt': file}, headers=self.auth_header, code=3)
  149. @threaded
  150. def _start_heartbeat(self):
  151. """Begin a threaded heartbeat loop to report the agent's status to Ultralytics HUB."""
  152. while self.alive:
  153. r = smart_request('post',
  154. f'{HUB_API_ROOT}/v1/agent/heartbeat/models/{self.model_id}',
  155. json={
  156. 'agent': AGENT_NAME,
  157. 'agentId': self.agent_id},
  158. headers=self.auth_header,
  159. retry=0,
  160. code=5,
  161. thread=False) # already in a thread
  162. self.agent_id = r.json().get('data', {}).get('agentId', None)
  163. sleep(self.rate_limits['heartbeat'])