trainer.py 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691
  1. # Ultralytics YOLO 🚀, AGPL-3.0 license
  2. """
  3. Train a model on a dataset
  4. Usage:
  5. $ yolo mode=train model=yolov8n.pt data=coco128.yaml imgsz=640 epochs=100 batch=16
  6. """
  7. import math
  8. import os
  9. import subprocess
  10. import time
  11. import warnings
  12. from copy import deepcopy
  13. from datetime import datetime, timedelta
  14. from pathlib import Path
  15. import numpy as np
  16. import torch
  17. from torch import distributed as dist
  18. from torch import nn, optim
  19. from torch.cuda import amp
  20. from torch.nn.parallel import DistributedDataParallel as DDP
  21. from tqdm import tqdm
  22. from ultralytics.cfg import get_cfg
  23. from ultralytics.data.utils import check_cls_dataset, check_det_dataset
  24. from ultralytics.nn.tasks import attempt_load_one_weight, attempt_load_weights
  25. from ultralytics.utils import (DEFAULT_CFG, LOGGER, RANK, SETTINGS, TQDM_BAR_FORMAT, __version__, callbacks, clean_url,
  26. colorstr, emojis, yaml_save)
  27. from ultralytics.utils.autobatch import check_train_batch_size
  28. from ultralytics.utils.checks import check_amp, check_file, check_imgsz, print_args
  29. from ultralytics.utils.dist import ddp_cleanup, generate_ddp_command
  30. from ultralytics.utils.files import get_latest_run, increment_path
  31. from ultralytics.utils.torch_utils import (EarlyStopping, ModelEMA, de_parallel, init_seeds, one_cycle, select_device,
  32. strip_optimizer)
  33. class BaseTrainer:
  34. """
  35. BaseTrainer
  36. A base class for creating trainers.
  37. Attributes:
  38. args (SimpleNamespace): Configuration for the trainer.
  39. check_resume (method): Method to check if training should be resumed from a saved checkpoint.
  40. validator (BaseValidator): Validator instance.
  41. model (nn.Module): Model instance.
  42. callbacks (defaultdict): Dictionary of callbacks.
  43. save_dir (Path): Directory to save results.
  44. wdir (Path): Directory to save weights.
  45. last (Path): Path to the last checkpoint.
  46. best (Path): Path to the best checkpoint.
  47. save_period (int): Save checkpoint every x epochs (disabled if < 1).
  48. batch_size (int): Batch size for training.
  49. epochs (int): Number of epochs to train for.
  50. start_epoch (int): Starting epoch for training.
  51. device (torch.device): Device to use for training.
  52. amp (bool): Flag to enable AMP (Automatic Mixed Precision).
  53. scaler (amp.GradScaler): Gradient scaler for AMP.
  54. data (str): Path to data.
  55. trainset (torch.utils.data.Dataset): Training dataset.
  56. testset (torch.utils.data.Dataset): Testing dataset.
  57. ema (nn.Module): EMA (Exponential Moving Average) of the model.
  58. lf (nn.Module): Loss function.
  59. scheduler (torch.optim.lr_scheduler._LRScheduler): Learning rate scheduler.
  60. best_fitness (float): The best fitness value achieved.
  61. fitness (float): Current fitness value.
  62. loss (float): Current loss value.
  63. tloss (float): Total loss value.
  64. loss_names (list): List of loss names.
  65. csv (Path): Path to results CSV file.
  66. """
  67. def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
  68. """
  69. Initializes the BaseTrainer class.
  70. Args:
  71. cfg (str, optional): Path to a configuration file. Defaults to DEFAULT_CFG.
  72. overrides (dict, optional): Configuration overrides. Defaults to None.
  73. """
  74. self.args = get_cfg(cfg, overrides)
  75. self.check_resume(overrides)
  76. self.device = select_device(self.args.device, self.args.batch)
  77. self.validator = None
  78. self.model = None
  79. self.metrics = None
  80. self.plots = {}
  81. init_seeds(self.args.seed + 1 + RANK, deterministic=self.args.deterministic)
  82. # Dirs
  83. project = self.args.project or Path(SETTINGS['runs_dir']) / self.args.task
  84. name = self.args.name or f'{self.args.mode}'
  85. if hasattr(self.args, 'save_dir'):
  86. self.save_dir = Path(self.args.save_dir)
  87. else:
  88. self.save_dir = Path(
  89. increment_path(Path(project) / name, exist_ok=self.args.exist_ok if RANK in (-1, 0) else True))
  90. self.wdir = self.save_dir / 'weights' # weights dir
  91. if RANK in (-1, 0):
  92. self.wdir.mkdir(parents=True, exist_ok=True) # make dir
  93. self.args.save_dir = str(self.save_dir)
  94. yaml_save(self.save_dir / 'args.yaml', vars(self.args)) # save run args
  95. self.last, self.best = self.wdir / 'last.pt', self.wdir / 'best.pt' # checkpoint paths
  96. self.save_period = self.args.save_period
  97. self.batch_size = self.args.batch
  98. self.epochs = self.args.epochs
  99. self.start_epoch = 0
  100. if RANK == -1:
  101. print_args(vars(self.args))
  102. # Device
  103. if self.device.type == 'cpu':
  104. self.args.workers = 0 # faster CPU training as time dominated by inference, not dataloading
  105. # Model and Dataset
  106. self.model = self.args.model
  107. try:
  108. if self.args.task == 'classify':
  109. self.data = check_cls_dataset(self.args.data)
  110. elif self.args.data.split('.')[-1] in ('yaml', 'yml') or self.args.task in ('detect', 'segment'):
  111. self.data = check_det_dataset(self.args.data)
  112. if 'yaml_file' in self.data:
  113. self.args.data = self.data['yaml_file'] # for validating 'yolo train data=url.zip' usage
  114. except Exception as e:
  115. raise RuntimeError(emojis(f"Dataset '{clean_url(self.args.data)}' error ❌ {e}")) from e
  116. self.trainset, self.testset = self.get_dataset(self.data)
  117. self.ema = None
  118. # Optimization utils init
  119. self.lf = None
  120. self.scheduler = None
  121. # Epoch level metrics
  122. self.best_fitness = None
  123. self.fitness = None
  124. self.loss = None
  125. self.tloss = None
  126. self.loss_names = ['Loss']
  127. self.csv = self.save_dir / 'results.csv'
  128. self.plot_idx = [0, 1, 2]
  129. # Callbacks
  130. self.callbacks = _callbacks or callbacks.get_default_callbacks()
  131. if RANK in (-1, 0):
  132. callbacks.add_integration_callbacks(self)
  133. def add_callback(self, event: str, callback):
  134. """
  135. Appends the given callback.
  136. """
  137. self.callbacks[event].append(callback)
  138. def set_callback(self, event: str, callback):
  139. """
  140. Overrides the existing callbacks with the given callback.
  141. """
  142. self.callbacks[event] = [callback]
  143. def run_callbacks(self, event: str):
  144. """Run all existing callbacks associated with a particular event."""
  145. for callback in self.callbacks.get(event, []):
  146. callback(self)
  147. def train(self):
  148. """Allow device='', device=None on Multi-GPU systems to default to device=0."""
  149. if isinstance(self.args.device, int) or self.args.device: # i.e. device=0 or device=[0,1,2,3]
  150. world_size = torch.cuda.device_count()
  151. elif torch.cuda.is_available(): # i.e. device=None or device=''
  152. world_size = 1 # default to device 0
  153. else: # i.e. device='cpu' or 'mps'
  154. world_size = 0
  155. # Run subprocess if DDP training, else train normally
  156. if world_size > 1 and 'LOCAL_RANK' not in os.environ:
  157. # Argument checks
  158. if self.args.rect:
  159. LOGGER.warning("WARNING ⚠️ 'rect=True' is incompatible with Multi-GPU training, setting rect=False")
  160. self.args.rect = False
  161. # Command
  162. cmd, file = generate_ddp_command(world_size, self)
  163. try:
  164. LOGGER.info(f'{colorstr("DDP:")} debug command {" ".join(cmd)}')
  165. subprocess.run(cmd, check=True)
  166. except Exception as e:
  167. raise e
  168. finally:
  169. ddp_cleanup(self, str(file))
  170. else:
  171. self._do_train(world_size)
  172. def _setup_ddp(self, world_size):
  173. """Initializes and sets the DistributedDataParallel parameters for training."""
  174. torch.cuda.set_device(RANK)
  175. self.device = torch.device('cuda', RANK)
  176. # LOGGER.info(f'DDP info: RANK {RANK}, WORLD_SIZE {world_size}, DEVICE {self.device}')
  177. os.environ['NCCL_BLOCKING_WAIT'] = '1' # set to enforce timeout
  178. dist.init_process_group(
  179. 'nccl' if dist.is_nccl_available() else 'gloo',
  180. timeout=timedelta(seconds=10800), # 3 hours
  181. rank=RANK,
  182. world_size=world_size)
  183. def _setup_train(self, world_size):
  184. """
  185. Builds dataloaders and optimizer on correct rank process.
  186. """
  187. # Model
  188. self.run_callbacks('on_pretrain_routine_start')
  189. ckpt = self.setup_model()
  190. self.model = self.model.to(self.device)
  191. self.set_model_attributes()
  192. # Freeze layers
  193. freeze_list = self.args.freeze if isinstance(
  194. self.args.freeze, list) else range(self.args.freeze) if isinstance(self.args.freeze, int) else []
  195. always_freeze_names = ['.dfl'] # always freeze these layers
  196. freeze_layer_names = [f'model.{x}.' for x in freeze_list] + always_freeze_names
  197. for k, v in self.model.named_parameters():
  198. # v.register_hook(lambda x: torch.nan_to_num(x)) # NaN to 0 (commented for erratic training results)
  199. if any(x in k for x in freeze_layer_names):
  200. LOGGER.info(f"Freezing layer '{k}'")
  201. v.requires_grad = False
  202. elif not v.requires_grad:
  203. LOGGER.info(f"WARNING ⚠️ setting 'requires_grad=True' for frozen layer '{k}'. "
  204. 'See ultralytics.engine.trainer for customization of frozen layers.')
  205. v.requires_grad = True
  206. # Check AMP
  207. self.amp = torch.tensor(self.args.amp).to(self.device) # True or False
  208. if self.amp and RANK in (-1, 0): # Single-GPU and DDP
  209. callbacks_backup = callbacks.default_callbacks.copy() # backup callbacks as check_amp() resets them
  210. self.amp = torch.tensor(check_amp(self.model), device=self.device)
  211. callbacks.default_callbacks = callbacks_backup # restore callbacks
  212. if RANK > -1 and world_size > 1: # DDP
  213. dist.broadcast(self.amp, src=0) # broadcast the tensor from rank 0 to all other ranks (returns None)
  214. self.amp = bool(self.amp) # as boolean
  215. self.scaler = amp.GradScaler(enabled=self.amp)
  216. if world_size > 1:
  217. self.model = DDP(self.model, device_ids=[RANK])
  218. # Check imgsz
  219. gs = max(int(self.model.stride.max() if hasattr(self.model, 'stride') else 32), 32) # grid size (max stride)
  220. self.args.imgsz = check_imgsz(self.args.imgsz, stride=gs, floor=gs, max_dim=1)
  221. # Batch size
  222. if self.batch_size == -1:
  223. if RANK == -1: # single-GPU only, estimate best batch size
  224. self.args.batch = self.batch_size = check_train_batch_size(self.model, self.args.imgsz, self.amp)
  225. else:
  226. SyntaxError('batch=-1 to use AutoBatch is only available in Single-GPU training. '
  227. 'Please pass a valid batch size value for Multi-GPU DDP training, i.e. batch=16')
  228. # Dataloaders
  229. batch_size = self.batch_size // max(world_size, 1)
  230. self.train_loader = self.get_dataloader(self.trainset, batch_size=batch_size, rank=RANK, mode='train')
  231. if RANK in (-1, 0):
  232. self.test_loader = self.get_dataloader(self.testset, batch_size=batch_size * 2, rank=-1, mode='val')
  233. self.validator = self.get_validator()
  234. metric_keys = self.validator.metrics.keys + self.label_loss_items(prefix='val')
  235. self.metrics = dict(zip(metric_keys, [0] * len(metric_keys))) # TODO: init metrics for plot_results()?
  236. self.ema = ModelEMA(self.model)
  237. if self.args.plots:
  238. self.plot_training_labels()
  239. # Optimizer
  240. self.accumulate = max(round(self.args.nbs / self.batch_size), 1) # accumulate loss before optimizing
  241. weight_decay = self.args.weight_decay * self.batch_size * self.accumulate / self.args.nbs # scale weight_decay
  242. iterations = math.ceil(len(self.train_loader.dataset) / max(self.batch_size, self.args.nbs)) * self.epochs
  243. self.optimizer = self.build_optimizer(model=self.model,
  244. name=self.args.optimizer,
  245. lr=self.args.lr0,
  246. momentum=self.args.momentum,
  247. decay=weight_decay,
  248. iterations=iterations)
  249. # Scheduler
  250. if self.args.cos_lr:
  251. self.lf = one_cycle(1, self.args.lrf, self.epochs) # cosine 1->hyp['lrf']
  252. else:
  253. self.lf = lambda x: (1 - x / self.epochs) * (1.0 - self.args.lrf) + self.args.lrf # linear
  254. self.scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=self.lf)
  255. self.stopper, self.stop = EarlyStopping(patience=self.args.patience), False
  256. self.resume_training(ckpt)
  257. self.scheduler.last_epoch = self.start_epoch - 1 # do not move
  258. self.run_callbacks('on_pretrain_routine_end')
  259. def _do_train(self, world_size=1):
  260. """Train completed, evaluate and plot if specified by arguments."""
  261. if world_size > 1:
  262. self._setup_ddp(world_size)
  263. self._setup_train(world_size)
  264. self.epoch_time = None
  265. self.epoch_time_start = time.time()
  266. self.train_time_start = time.time()
  267. nb = len(self.train_loader) # number of batches
  268. nw = max(round(self.args.warmup_epochs * nb), 100) if self.args.warmup_epochs > 0 else -1 # warmup iterations
  269. last_opt_step = -1
  270. self.run_callbacks('on_train_start')
  271. LOGGER.info(f'Image sizes {self.args.imgsz} train, {self.args.imgsz} val\n'
  272. f'Using {self.train_loader.num_workers * (world_size or 1)} dataloader workers\n'
  273. f"Logging results to {colorstr('bold', self.save_dir)}\n"
  274. f'Starting training for {self.epochs} epochs...')
  275. if self.args.close_mosaic:
  276. base_idx = (self.epochs - self.args.close_mosaic) * nb
  277. self.plot_idx.extend([base_idx, base_idx + 1, base_idx + 2])
  278. epoch = self.epochs # predefine for resume fully trained model edge cases
  279. for epoch in range(self.start_epoch, self.epochs):
  280. self.epoch = epoch
  281. self.run_callbacks('on_train_epoch_start')
  282. self.model.train()
  283. if RANK != -1:
  284. self.train_loader.sampler.set_epoch(epoch)
  285. pbar = enumerate(self.train_loader)
  286. # Update dataloader attributes (optional)
  287. if epoch == (self.epochs - self.args.close_mosaic):
  288. LOGGER.info('Closing dataloader mosaic')
  289. if hasattr(self.train_loader.dataset, 'mosaic'):
  290. self.train_loader.dataset.mosaic = False
  291. if hasattr(self.train_loader.dataset, 'close_mosaic'):
  292. self.train_loader.dataset.close_mosaic(hyp=self.args)
  293. self.train_loader.reset()
  294. if RANK in (-1, 0):
  295. LOGGER.info(self.progress_string())
  296. pbar = tqdm(enumerate(self.train_loader), total=nb, bar_format=TQDM_BAR_FORMAT)
  297. self.tloss = None
  298. self.optimizer.zero_grad()
  299. for i, batch in pbar:
  300. self.run_callbacks('on_train_batch_start')
  301. # Warmup
  302. ni = i + nb * epoch
  303. if ni <= nw:
  304. xi = [0, nw] # x interp
  305. self.accumulate = max(1, np.interp(ni, xi, [1, self.args.nbs / self.batch_size]).round())
  306. for j, x in enumerate(self.optimizer.param_groups):
  307. # Bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
  308. x['lr'] = np.interp(
  309. ni, xi, [self.args.warmup_bias_lr if j == 0 else 0.0, x['initial_lr'] * self.lf(epoch)])
  310. if 'momentum' in x:
  311. x['momentum'] = np.interp(ni, xi, [self.args.warmup_momentum, self.args.momentum])
  312. # Forward
  313. with torch.cuda.amp.autocast(self.amp):
  314. batch = self.preprocess_batch(batch)
  315. self.loss, self.loss_items = self.model(batch)
  316. if RANK != -1:
  317. self.loss *= world_size
  318. self.tloss = (self.tloss * i + self.loss_items) / (i + 1) if self.tloss is not None \
  319. else self.loss_items
  320. # Backward
  321. self.scaler.scale(self.loss).backward()
  322. # Optimize - https://pytorch.org/docs/master/notes/amp_examples.html
  323. if ni - last_opt_step >= self.accumulate:
  324. self.optimizer_step()
  325. last_opt_step = ni
  326. # Log
  327. mem = f'{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G' # (GB)
  328. loss_len = self.tloss.shape[0] if len(self.tloss.size()) else 1
  329. losses = self.tloss if loss_len > 1 else torch.unsqueeze(self.tloss, 0)
  330. if RANK in (-1, 0):
  331. pbar.set_description(
  332. ('%11s' * 2 + '%11.4g' * (2 + loss_len)) %
  333. (f'{epoch + 1}/{self.epochs}', mem, *losses, batch['cls'].shape[0], batch['img'].shape[-1]))
  334. self.run_callbacks('on_batch_end')
  335. if self.args.plots and ni in self.plot_idx:
  336. self.plot_training_samples(batch, ni)
  337. self.run_callbacks('on_train_batch_end')
  338. self.lr = {f'lr/pg{ir}': x['lr'] for ir, x in enumerate(self.optimizer.param_groups)} # for loggers
  339. with warnings.catch_warnings():
  340. warnings.simplefilter('ignore') # suppress 'Detected lr_scheduler.step() before optimizer.step()'
  341. self.scheduler.step()
  342. self.run_callbacks('on_train_epoch_end')
  343. if RANK in (-1, 0):
  344. # Validation
  345. self.ema.update_attr(self.model, include=['yaml', 'nc', 'args', 'names', 'stride', 'class_weights'])
  346. final_epoch = (epoch + 1 == self.epochs) or self.stopper.possible_stop
  347. if self.args.val or final_epoch:
  348. self.metrics, self.fitness = self.validate()
  349. self.save_metrics(metrics={**self.label_loss_items(self.tloss), **self.metrics, **self.lr})
  350. self.stop = self.stopper(epoch + 1, self.fitness)
  351. # Save model
  352. if self.args.save or (epoch + 1 == self.epochs):
  353. self.save_model()
  354. self.run_callbacks('on_model_save')
  355. tnow = time.time()
  356. self.epoch_time = tnow - self.epoch_time_start
  357. self.epoch_time_start = tnow
  358. self.run_callbacks('on_fit_epoch_end')
  359. torch.cuda.empty_cache() # clears GPU vRAM at end of epoch, can help with out of memory errors
  360. # Early Stopping
  361. if RANK != -1: # if DDP training
  362. broadcast_list = [self.stop if RANK == 0 else None]
  363. dist.broadcast_object_list(broadcast_list, 0) # broadcast 'stop' to all ranks
  364. if RANK != 0:
  365. self.stop = broadcast_list[0]
  366. if self.stop:
  367. break # must break all DDP ranks
  368. if RANK in (-1, 0):
  369. # Do final val with best.pt
  370. LOGGER.info(f'\n{epoch - self.start_epoch + 1} epochs completed in '
  371. f'{(time.time() - self.train_time_start) / 3600:.3f} hours.')
  372. self.final_eval()
  373. if self.args.plots:
  374. self.plot_metrics()
  375. self.run_callbacks('on_train_end')
  376. torch.cuda.empty_cache()
  377. self.run_callbacks('teardown')
  378. def save_model(self):
  379. """Save model checkpoints based on various conditions."""
  380. ckpt = {
  381. 'epoch': self.epoch,
  382. 'best_fitness': self.best_fitness,
  383. 'model': deepcopy(de_parallel(self.model)).half(),
  384. 'ema': deepcopy(self.ema.ema).half(),
  385. 'updates': self.ema.updates,
  386. 'optimizer': self.optimizer.state_dict(),
  387. 'train_args': vars(self.args), # save as dict
  388. 'date': datetime.now().isoformat(),
  389. 'version': __version__}
  390. # Use dill (if exists) to serialize the lambda functions where pickle does not do this
  391. try:
  392. import dill as pickle
  393. except ImportError:
  394. import pickle
  395. # Save last, best and delete
  396. torch.save(ckpt, self.last, pickle_module=pickle)
  397. if self.best_fitness == self.fitness:
  398. torch.save(ckpt, self.best, pickle_module=pickle)
  399. if (self.epoch > 0) and (self.save_period > 0) and (self.epoch % self.save_period == 0):
  400. torch.save(ckpt, self.wdir / f'epoch{self.epoch}.pt', pickle_module=pickle)
  401. del ckpt
  402. @staticmethod
  403. def get_dataset(data):
  404. """
  405. Get train, val path from data dict if it exists. Returns None if data format is not recognized.
  406. """
  407. return data['train'], data.get('val') or data.get('test')
  408. def setup_model(self):
  409. """
  410. load/create/download model for any task.
  411. """
  412. if isinstance(self.model, torch.nn.Module): # if model is loaded beforehand. No setup needed
  413. return
  414. model, weights = self.model, None
  415. ckpt = None
  416. if str(model).endswith('.pt'):
  417. weights, ckpt = attempt_load_one_weight(model)
  418. cfg = ckpt['model'].yaml
  419. else:
  420. cfg = model
  421. self.model = self.get_model(cfg=cfg, weights=weights, verbose=RANK == -1) # calls Model(cfg, weights)
  422. return ckpt
  423. def optimizer_step(self):
  424. """Perform a single step of the training optimizer with gradient clipping and EMA update."""
  425. self.scaler.unscale_(self.optimizer) # unscale gradients
  426. torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=10.0) # clip gradients
  427. self.scaler.step(self.optimizer)
  428. self.scaler.update()
  429. self.optimizer.zero_grad()
  430. if self.ema:
  431. self.ema.update(self.model)
  432. def preprocess_batch(self, batch):
  433. """
  434. Allows custom preprocessing model inputs and ground truths depending on task type.
  435. """
  436. return batch
  437. def validate(self):
  438. """
  439. Runs validation on test set using self.validator. The returned dict is expected to contain "fitness" key.
  440. """
  441. metrics = self.validator(self)
  442. fitness = metrics.pop('fitness', -self.loss.detach().cpu().numpy()) # use loss as fitness measure if not found
  443. if not self.best_fitness or self.best_fitness < fitness:
  444. self.best_fitness = fitness
  445. return metrics, fitness
  446. def get_model(self, cfg=None, weights=None, verbose=True):
  447. """Get model and raise NotImplementedError for loading cfg files."""
  448. raise NotImplementedError("This task trainer doesn't support loading cfg files")
  449. def get_validator(self):
  450. """Returns a NotImplementedError when the get_validator function is called."""
  451. raise NotImplementedError('get_validator function not implemented in trainer')
  452. def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode='train'):
  453. """
  454. Returns dataloader derived from torch.data.Dataloader.
  455. """
  456. raise NotImplementedError('get_dataloader function not implemented in trainer')
  457. def build_dataset(self, img_path, mode='train', batch=None):
  458. """Build dataset"""
  459. raise NotImplementedError('build_dataset function not implemented in trainer')
  460. def label_loss_items(self, loss_items=None, prefix='train'):
  461. """
  462. Returns a loss dict with labelled training loss items tensor
  463. """
  464. # Not needed for classification but necessary for segmentation & detection
  465. return {'loss': loss_items} if loss_items is not None else ['loss']
  466. def set_model_attributes(self):
  467. """
  468. To set or update model parameters before training.
  469. """
  470. self.model.names = self.data['names']
  471. def build_targets(self, preds, targets):
  472. """Builds target tensors for training YOLO model."""
  473. pass
  474. def progress_string(self):
  475. """Returns a string describing training progress."""
  476. return ''
  477. # TODO: may need to put these following functions into callback
  478. def plot_training_samples(self, batch, ni):
  479. """Plots training samples during YOLOv5 training."""
  480. pass
  481. def plot_training_labels(self):
  482. """Plots training labels for YOLO model."""
  483. pass
  484. def save_metrics(self, metrics):
  485. """Saves training metrics to a CSV file."""
  486. keys, vals = list(metrics.keys()), list(metrics.values())
  487. n = len(metrics) + 1 # number of cols
  488. s = '' if self.csv.exists() else (('%23s,' * n % tuple(['epoch'] + keys)).rstrip(',') + '\n') # header
  489. with open(self.csv, 'a') as f:
  490. f.write(s + ('%23.5g,' * n % tuple([self.epoch + 1] + vals)).rstrip(',') + '\n')
  491. def plot_metrics(self):
  492. """Plot and display metrics visually."""
  493. pass
  494. def on_plot(self, name, data=None):
  495. """Registers plots (e.g. to be consumed in callbacks)"""
  496. path = Path(name)
  497. self.plots[path] = {'data': data, 'timestamp': time.time()}
  498. def final_eval(self):
  499. """Performs final evaluation and validation for object detection YOLO model."""
  500. for f in self.last, self.best:
  501. if f.exists():
  502. strip_optimizer(f) # strip optimizers
  503. if f is self.best:
  504. LOGGER.info(f'\nValidating {f}...')
  505. self.metrics = self.validator(model=f)
  506. self.metrics.pop('fitness', None)
  507. self.run_callbacks('on_fit_epoch_end')
  508. def check_resume(self, overrides):
  509. """Check if resume checkpoint exists and update arguments accordingly."""
  510. resume = self.args.resume
  511. if resume:
  512. try:
  513. exists = isinstance(resume, (str, Path)) and Path(resume).exists()
  514. last = Path(check_file(resume) if exists else get_latest_run())
  515. # Check that resume data YAML exists, otherwise strip to force re-download of dataset
  516. ckpt_args = attempt_load_weights(last).args
  517. if not Path(ckpt_args['data']).exists():
  518. ckpt_args['data'] = self.args.data
  519. resume = True
  520. self.args = get_cfg(ckpt_args)
  521. self.args.model = str(last) # reinstate model
  522. for k in 'imgsz', 'batch': # allow arg updates to reduce memory on resume if crashed due to CUDA OOM
  523. if k in overrides:
  524. setattr(self.args, k, overrides[k])
  525. except Exception as e:
  526. raise FileNotFoundError('Resume checkpoint not found. Please pass a valid checkpoint to resume from, '
  527. "i.e. 'yolo train resume model=path/to/last.pt'") from e
  528. self.resume = resume
  529. def resume_training(self, ckpt):
  530. """Resume YOLO training from given epoch and best fitness."""
  531. if ckpt is None:
  532. return
  533. best_fitness = 0.0
  534. start_epoch = ckpt['epoch'] + 1
  535. if ckpt['optimizer'] is not None:
  536. self.optimizer.load_state_dict(ckpt['optimizer']) # optimizer
  537. best_fitness = ckpt['best_fitness']
  538. if self.ema and ckpt.get('ema'):
  539. self.ema.ema.load_state_dict(ckpt['ema'].float().state_dict()) # EMA
  540. self.ema.updates = ckpt['updates']
  541. if self.resume:
  542. assert start_epoch > 0, \
  543. f'{self.args.model} training to {self.epochs} epochs is finished, nothing to resume.\n' \
  544. f"Start a new training without resuming, i.e. 'yolo train model={self.args.model}'"
  545. LOGGER.info(
  546. f'Resuming training from {self.args.model} from epoch {start_epoch + 1} to {self.epochs} total epochs')
  547. if self.epochs < start_epoch:
  548. LOGGER.info(
  549. f"{self.model} has been trained for {ckpt['epoch']} epochs. Fine-tuning for {self.epochs} more epochs.")
  550. self.epochs += ckpt['epoch'] # finetune additional epochs
  551. self.best_fitness = best_fitness
  552. self.start_epoch = start_epoch
  553. if start_epoch > (self.epochs - self.args.close_mosaic):
  554. LOGGER.info('Closing dataloader mosaic')
  555. if hasattr(self.train_loader.dataset, 'mosaic'):
  556. self.train_loader.dataset.mosaic = False
  557. if hasattr(self.train_loader.dataset, 'close_mosaic'):
  558. self.train_loader.dataset.close_mosaic(hyp=self.args)
  559. def build_optimizer(self, model, name='auto', lr=0.001, momentum=0.9, decay=1e-5, iterations=1e5):
  560. """
  561. Constructs an optimizer for the given model, based on the specified optimizer name, learning rate,
  562. momentum, weight decay, and number of iterations.
  563. Args:
  564. model (torch.nn.Module): The model for which to build an optimizer.
  565. name (str, optional): The name of the optimizer to use. If 'auto', the optimizer is selected
  566. based on the number of iterations. Default: 'auto'.
  567. lr (float, optional): The learning rate for the optimizer. Default: 0.001.
  568. momentum (float, optional): The momentum factor for the optimizer. Default: 0.9.
  569. decay (float, optional): The weight decay for the optimizer. Default: 1e-5.
  570. iterations (float, optional): The number of iterations, which determines the optimizer if
  571. name is 'auto'. Default: 1e5.
  572. Returns:
  573. (torch.optim.Optimizer): The constructed optimizer.
  574. """
  575. g = [], [], [] # optimizer parameter groups
  576. bn = tuple(v for k, v in nn.__dict__.items() if 'Norm' in k) # normalization layers, i.e. BatchNorm2d()
  577. if name == 'auto':
  578. nc = getattr(model, 'nc', 10) # number of classes
  579. lr_fit = round(0.002 * 5 / (4 + nc), 6) # lr0 fit equation to 6 decimal places
  580. name, lr, momentum = ('SGD', 0.01, 0.9) if iterations > 10000 else ('AdamW', lr_fit, 0.9)
  581. self.args.warmup_bias_lr = 0.0 # no higher than 0.01 for Adam
  582. for module_name, module in model.named_modules():
  583. for param_name, param in module.named_parameters(recurse=False):
  584. fullname = f'{module_name}.{param_name}' if module_name else param_name
  585. if 'bias' in fullname: # bias (no decay)
  586. g[2].append(param)
  587. elif isinstance(module, bn): # weight (no decay)
  588. g[1].append(param)
  589. else: # weight (with decay)
  590. g[0].append(param)
  591. if name in ('Adam', 'Adamax', 'AdamW', 'NAdam', 'RAdam'):
  592. optimizer = getattr(optim, name, optim.Adam)(g[2], lr=lr, betas=(momentum, 0.999), weight_decay=0.0)
  593. elif name == 'RMSProp':
  594. optimizer = optim.RMSprop(g[2], lr=lr, momentum=momentum)
  595. elif name == 'SGD':
  596. optimizer = optim.SGD(g[2], lr=lr, momentum=momentum, nesterov=True)
  597. else:
  598. raise NotImplementedError(
  599. f"Optimizer '{name}' not found in list of available optimizers "
  600. f'[Adam, AdamW, NAdam, RAdam, RMSProp, SGD, auto].'
  601. 'To request support for addition optimizers please visit https://github.com/ultralytics/ultralytics.')
  602. optimizer.add_param_group({'params': g[0], 'weight_decay': decay}) # add g0 with weight_decay
  603. optimizer.add_param_group({'params': g[1], 'weight_decay': 0.0}) # add g1 (BatchNorm2d weights)
  604. LOGGER.info(
  605. f"{colorstr('optimizer:')} {type(optimizer).__name__}(lr={lr}, momentum={momentum}) with parameter groups "
  606. f'{len(g[1])} weight(decay=0.0), {len(g[0])} weight(decay={decay}), {len(g[2])} bias(decay=0.0)')
  607. return optimizer