wandb_utils.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318
  1. """Utilities and tools for tracking runs with Weights & Biases."""
  2. import json
  3. import sys
  4. from pathlib import Path
  5. import torch
  6. import yaml
  7. from tqdm import tqdm
  8. sys.path.append(str(Path(__file__).parent.parent.parent)) # add utils/ to path
  9. from utils.datasets import LoadImagesAndLabels
  10. from utils.datasets import img2label_paths
  11. from utils.general import colorstr, xywh2xyxy, check_dataset, check_file
  12. try:
  13. import wandb
  14. from wandb import init, finish
  15. except ImportError:
  16. wandb = None
  17. WANDB_ARTIFACT_PREFIX = 'wandb-artifact://'
  18. def remove_prefix(from_string, prefix=WANDB_ARTIFACT_PREFIX):
  19. return from_string[len(prefix):]
  20. def check_wandb_config_file(data_config_file):
  21. wandb_config = '_wandb.'.join(data_config_file.rsplit('.', 1)) # updated data.yaml path
  22. if Path(wandb_config).is_file():
  23. return wandb_config
  24. return data_config_file
  25. def get_run_info(run_path):
  26. run_path = Path(remove_prefix(run_path, WANDB_ARTIFACT_PREFIX))
  27. run_id = run_path.stem
  28. project = run_path.parent.stem
  29. entity = run_path.parent.parent.stem
  30. model_artifact_name = 'run_' + run_id + '_model'
  31. return entity, project, run_id, model_artifact_name
  32. def check_wandb_resume(opt):
  33. process_wandb_config_ddp_mode(opt) if opt.global_rank not in [-1, 0] else None
  34. if isinstance(opt.resume, str):
  35. if opt.resume.startswith(WANDB_ARTIFACT_PREFIX):
  36. if opt.global_rank not in [-1, 0]: # For resuming DDP runs
  37. entity, project, run_id, model_artifact_name = get_run_info(opt.resume)
  38. api = wandb.Api()
  39. artifact = api.artifact(entity + '/' + project + '/' + model_artifact_name + ':latest')
  40. modeldir = artifact.download()
  41. opt.weights = str(Path(modeldir) / "last.pt")
  42. return True
  43. return None
  44. def process_wandb_config_ddp_mode(opt):
  45. with open(check_file(opt.data)) as f:
  46. data_dict = yaml.safe_load(f) # data dict
  47. train_dir, val_dir = None, None
  48. if isinstance(data_dict['train'], str) and data_dict['train'].startswith(WANDB_ARTIFACT_PREFIX):
  49. api = wandb.Api()
  50. train_artifact = api.artifact(remove_prefix(data_dict['train']) + ':' + opt.artifact_alias)
  51. train_dir = train_artifact.download()
  52. train_path = Path(train_dir) / 'data/images/'
  53. data_dict['train'] = str(train_path)
  54. if isinstance(data_dict['val'], str) and data_dict['val'].startswith(WANDB_ARTIFACT_PREFIX):
  55. api = wandb.Api()
  56. val_artifact = api.artifact(remove_prefix(data_dict['val']) + ':' + opt.artifact_alias)
  57. val_dir = val_artifact.download()
  58. val_path = Path(val_dir) / 'data/images/'
  59. data_dict['val'] = str(val_path)
  60. if train_dir or val_dir:
  61. ddp_data_path = str(Path(val_dir) / 'wandb_local_data.yaml')
  62. with open(ddp_data_path, 'w') as f:
  63. yaml.safe_dump(data_dict, f)
  64. opt.data = ddp_data_path
  65. class WandbLogger():
  66. """Log training runs, datasets, models, and predictions to Weights & Biases.
  67. This logger sends information to W&B at wandb.ai. By default, this information
  68. includes hyperparameters, system configuration and metrics, model metrics,
  69. and basic data metrics and analyses.
  70. By providing additional command line arguments to train.py, datasets,
  71. models and predictions can also be logged.
  72. For more on how this logger is used, see the Weights & Biases documentation:
  73. https://docs.wandb.com/guides/integrations/yolov5
  74. """
  75. def __init__(self, opt, name, run_id, data_dict, job_type='Training'):
  76. # Pre-training routine --
  77. self.job_type = job_type
  78. self.wandb, self.wandb_run, self.data_dict = wandb, None if not wandb else wandb.run, data_dict
  79. # It's more elegant to stick to 1 wandb.init call, but useful config data is overwritten in the WandbLogger's wandb.init call
  80. if isinstance(opt.resume, str): # checks resume from artifact
  81. if opt.resume.startswith(WANDB_ARTIFACT_PREFIX):
  82. entity, project, run_id, model_artifact_name = get_run_info(opt.resume)
  83. model_artifact_name = WANDB_ARTIFACT_PREFIX + model_artifact_name
  84. assert wandb, 'install wandb to resume wandb runs'
  85. # Resume wandb-artifact:// runs here| workaround for not overwriting wandb.config
  86. self.wandb_run = wandb.init(id=run_id, project=project, entity=entity, resume='allow')
  87. opt.resume = model_artifact_name
  88. elif self.wandb:
  89. self.wandb_run = wandb.init(config=opt,
  90. resume="allow",
  91. project='YOLOv5' if opt.project == 'runs/train' else Path(opt.project).stem,
  92. entity=opt.entity,
  93. name=name,
  94. job_type=job_type,
  95. id=run_id) if not wandb.run else wandb.run
  96. if self.wandb_run:
  97. if self.job_type == 'Training':
  98. if not opt.resume:
  99. wandb_data_dict = self.check_and_upload_dataset(opt) if opt.upload_dataset else data_dict
  100. # Info useful for resuming from artifacts
  101. self.wandb_run.config.opt = vars(opt)
  102. self.wandb_run.config.data_dict = wandb_data_dict
  103. self.data_dict = self.setup_training(opt, data_dict)
  104. if self.job_type == 'Dataset Creation':
  105. self.data_dict = self.check_and_upload_dataset(opt)
  106. else:
  107. prefix = colorstr('wandb: ')
  108. print(f"{prefix}Install Weights & Biases for YOLOv5 logging with 'pip install wandb' (recommended)")
  109. def check_and_upload_dataset(self, opt):
  110. assert wandb, 'Install wandb to upload dataset'
  111. check_dataset(self.data_dict)
  112. config_path = self.log_dataset_artifact(check_file(opt.data),
  113. opt.single_cls,
  114. 'YOLOv5' if opt.project == 'runs/train' else Path(opt.project).stem)
  115. print("Created dataset config file ", config_path)
  116. with open(config_path) as f:
  117. wandb_data_dict = yaml.safe_load(f)
  118. return wandb_data_dict
  119. def setup_training(self, opt, data_dict):
  120. self.log_dict, self.current_epoch, self.log_imgs = {}, 0, 16 # Logging Constants
  121. self.bbox_interval = opt.bbox_interval
  122. if isinstance(opt.resume, str):
  123. modeldir, _ = self.download_model_artifact(opt)
  124. if modeldir:
  125. self.weights = Path(modeldir) / "last.pt"
  126. config = self.wandb_run.config
  127. opt.weights, opt.save_period, opt.batch_size, opt.bbox_interval, opt.epochs, opt.hyp = str(
  128. self.weights), config.save_period, config.total_batch_size, config.bbox_interval, config.epochs, \
  129. config.opt['hyp']
  130. data_dict = dict(self.wandb_run.config.data_dict) # eliminates the need for config file to resume
  131. if 'val_artifact' not in self.__dict__: # If --upload_dataset is set, use the existing artifact, don't download
  132. self.train_artifact_path, self.train_artifact = self.download_dataset_artifact(data_dict.get('train'),
  133. opt.artifact_alias)
  134. self.val_artifact_path, self.val_artifact = self.download_dataset_artifact(data_dict.get('val'),
  135. opt.artifact_alias)
  136. self.result_artifact, self.result_table, self.val_table, self.weights = None, None, None, None
  137. if self.train_artifact_path is not None:
  138. train_path = Path(self.train_artifact_path) / 'data/images/'
  139. data_dict['train'] = str(train_path)
  140. if self.val_artifact_path is not None:
  141. val_path = Path(self.val_artifact_path) / 'data/images/'
  142. data_dict['val'] = str(val_path)
  143. self.val_table = self.val_artifact.get("val")
  144. self.map_val_table_path()
  145. if self.val_artifact is not None:
  146. self.result_artifact = wandb.Artifact("run_" + wandb.run.id + "_progress", "evaluation")
  147. self.result_table = wandb.Table(["epoch", "id", "prediction", "avg_confidence"])
  148. if opt.bbox_interval == -1:
  149. self.bbox_interval = opt.bbox_interval = (opt.epochs // 10) if opt.epochs > 10 else 1
  150. return data_dict
  151. def download_dataset_artifact(self, path, alias):
  152. if isinstance(path, str) and path.startswith(WANDB_ARTIFACT_PREFIX):
  153. artifact_path = Path(remove_prefix(path, WANDB_ARTIFACT_PREFIX) + ":" + alias)
  154. dataset_artifact = wandb.use_artifact(artifact_path.as_posix())
  155. assert dataset_artifact is not None, "'Error: W&B dataset artifact doesn\'t exist'"
  156. datadir = dataset_artifact.download()
  157. return datadir, dataset_artifact
  158. return None, None
  159. def download_model_artifact(self, opt):
  160. if opt.resume.startswith(WANDB_ARTIFACT_PREFIX):
  161. model_artifact = wandb.use_artifact(remove_prefix(opt.resume, WANDB_ARTIFACT_PREFIX) + ":latest")
  162. assert model_artifact is not None, 'Error: W&B model artifact doesn\'t exist'
  163. modeldir = model_artifact.download()
  164. epochs_trained = model_artifact.metadata.get('epochs_trained')
  165. total_epochs = model_artifact.metadata.get('total_epochs')
  166. is_finished = total_epochs is None
  167. assert not is_finished, 'training is finished, can only resume incomplete runs.'
  168. return modeldir, model_artifact
  169. return None, None
  170. def log_model(self, path, opt, epoch, fitness_score, best_model=False):
  171. model_artifact = wandb.Artifact('run_' + wandb.run.id + '_model', type='model', metadata={
  172. 'original_url': str(path),
  173. 'epochs_trained': epoch + 1,
  174. 'save period': opt.save_period,
  175. 'project': opt.project,
  176. 'total_epochs': opt.epochs,
  177. 'fitness_score': fitness_score
  178. })
  179. model_artifact.add_file(str(path / 'last.pt'), name='last.pt')
  180. wandb.log_artifact(model_artifact,
  181. aliases=['latest', 'last', 'epoch ' + str(self.current_epoch), 'best' if best_model else ''])
  182. print("Saving model artifact on epoch ", epoch + 1)
  183. def log_dataset_artifact(self, data_file, single_cls, project, overwrite_config=False):
  184. with open(data_file) as f:
  185. data = yaml.safe_load(f) # data dict
  186. nc, names = (1, ['item']) if single_cls else (int(data['nc']), data['names'])
  187. names = {k: v for k, v in enumerate(names)} # to index dictionary
  188. self.train_artifact = self.create_dataset_table(LoadImagesAndLabels(
  189. data['train'], rect=True, batch_size=1), names, name='train') if data.get('train') else None
  190. self.val_artifact = self.create_dataset_table(LoadImagesAndLabels(
  191. data['val'], rect=True, batch_size=1), names, name='val') if data.get('val') else None
  192. if data.get('train'):
  193. data['train'] = WANDB_ARTIFACT_PREFIX + str(Path(project) / 'train')
  194. if data.get('val'):
  195. data['val'] = WANDB_ARTIFACT_PREFIX + str(Path(project) / 'val')
  196. path = data_file if overwrite_config else '_wandb.'.join(data_file.rsplit('.', 1)) # updated data.yaml path
  197. data.pop('download', None)
  198. with open(path, 'w') as f:
  199. yaml.safe_dump(data, f)
  200. if self.job_type == 'Training': # builds correct artifact pipeline graph
  201. self.wandb_run.use_artifact(self.val_artifact)
  202. self.wandb_run.use_artifact(self.train_artifact)
  203. self.val_artifact.wait()
  204. self.val_table = self.val_artifact.get('val')
  205. self.map_val_table_path()
  206. else:
  207. self.wandb_run.log_artifact(self.train_artifact)
  208. self.wandb_run.log_artifact(self.val_artifact)
  209. return path
  210. def map_val_table_path(self):
  211. self.val_table_map = {}
  212. print("Mapping dataset")
  213. for i, data in enumerate(tqdm(self.val_table.data)):
  214. self.val_table_map[data[3]] = data[0]
  215. def create_dataset_table(self, dataset, class_to_id, name='dataset'):
  216. # TODO: Explore multiprocessing to slpit this loop parallely| This is essential for speeding up the the logging
  217. artifact = wandb.Artifact(name=name, type="dataset")
  218. img_files = tqdm([dataset.path]) if isinstance(dataset.path, str) and Path(dataset.path).is_dir() else None
  219. img_files = tqdm(dataset.img_files) if not img_files else img_files
  220. for img_file in img_files:
  221. if Path(img_file).is_dir():
  222. artifact.add_dir(img_file, name='data/images')
  223. labels_path = 'labels'.join(dataset.path.rsplit('images', 1))
  224. artifact.add_dir(labels_path, name='data/labels')
  225. else:
  226. artifact.add_file(img_file, name='data/images/' + Path(img_file).name)
  227. label_file = Path(img2label_paths([img_file])[0])
  228. artifact.add_file(str(label_file),
  229. name='data/labels/' + label_file.name) if label_file.exists() else None
  230. table = wandb.Table(columns=["id", "train_image", "Classes", "name"])
  231. class_set = wandb.Classes([{'id': id, 'name': name} for id, name in class_to_id.items()])
  232. for si, (img, labels, paths, shapes) in enumerate(tqdm(dataset)):
  233. box_data, img_classes = [], {}
  234. for cls, *xywh in labels[:, 1:].tolist():
  235. cls = int(cls)
  236. box_data.append({"position": {"middle": [xywh[0], xywh[1]], "width": xywh[2], "height": xywh[3]},
  237. "class_id": cls,
  238. "box_caption": "%s" % (class_to_id[cls])})
  239. img_classes[cls] = class_to_id[cls]
  240. boxes = {"ground_truth": {"box_data": box_data, "class_labels": class_to_id}} # inference-space
  241. table.add_data(si, wandb.Image(paths, classes=class_set, boxes=boxes), json.dumps(img_classes),
  242. Path(paths).name)
  243. artifact.add(table, name)
  244. return artifact
  245. def log_training_progress(self, predn, path, names):
  246. if self.val_table and self.result_table:
  247. class_set = wandb.Classes([{'id': id, 'name': name} for id, name in names.items()])
  248. box_data = []
  249. total_conf = 0
  250. for *xyxy, conf, cls in predn.tolist():
  251. if conf >= 0.25:
  252. box_data.append(
  253. {"position": {"minX": xyxy[0], "minY": xyxy[1], "maxX": xyxy[2], "maxY": xyxy[3]},
  254. "class_id": int(cls),
  255. "box_caption": "%s %.3f" % (names[cls], conf),
  256. "scores": {"class_score": conf},
  257. "domain": "pixel"})
  258. total_conf = total_conf + conf
  259. boxes = {"predictions": {"box_data": box_data, "class_labels": names}} # inference-space
  260. id = self.val_table_map[Path(path).name]
  261. self.result_table.add_data(self.current_epoch,
  262. id,
  263. wandb.Image(self.val_table.data[id][1], boxes=boxes, classes=class_set),
  264. total_conf / max(1, len(box_data))
  265. )
  266. def log(self, log_dict):
  267. if self.wandb_run:
  268. for key, value in log_dict.items():
  269. self.log_dict[key] = value
  270. def end_epoch(self, best_result=False):
  271. if self.wandb_run:
  272. wandb.log(self.log_dict)
  273. self.log_dict = {}
  274. if self.result_artifact:
  275. train_results = wandb.JoinedTable(self.val_table, self.result_table, "id")
  276. self.result_artifact.add(train_results, 'result')
  277. wandb.log_artifact(self.result_artifact, aliases=['latest', 'last', 'epoch ' + str(self.current_epoch),
  278. ('best' if best_result else '')])
  279. self.result_table = wandb.Table(["epoch", "id", "prediction", "avg_confidence"])
  280. self.result_artifact = wandb.Artifact("run_" + wandb.run.id + "_progress", "evaluation")
  281. def finish_run(self):
  282. if self.wandb_run:
  283. if self.log_dict:
  284. wandb.log(self.log_dict)
  285. wandb.run.finish()