torch_utils.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526
  1. # Ultralytics YOLO 🚀, AGPL-3.0 license
  2. import math
  3. import os
  4. import platform
  5. import random
  6. import time
  7. from contextlib import contextmanager
  8. from copy import deepcopy
  9. from pathlib import Path
  10. from typing import Union
  11. import numpy as np
  12. import torch
  13. import torch.distributed as dist
  14. import torch.nn as nn
  15. import torch.nn.functional as F
  16. from ultralytics.utils import DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, RANK, __version__
  17. from ultralytics.utils.checks import check_version
  18. try:
  19. import thop
  20. except ImportError:
  21. thop = None
  22. TORCH_1_9 = check_version(torch.__version__, '1.9.0')
  23. TORCH_2_0 = check_version(torch.__version__, '2.0.0')
  24. @contextmanager
  25. def torch_distributed_zero_first(local_rank: int):
  26. """Decorator to make all processes in distributed training wait for each local_master to do something."""
  27. initialized = torch.distributed.is_available() and torch.distributed.is_initialized()
  28. if initialized and local_rank not in (-1, 0):
  29. dist.barrier(device_ids=[local_rank])
  30. yield
  31. if initialized and local_rank == 0:
  32. dist.barrier(device_ids=[0])
  33. def smart_inference_mode():
  34. """Applies torch.inference_mode() decorator if torch>=1.9.0 else torch.no_grad() decorator."""
  35. def decorate(fn):
  36. """Applies appropriate torch decorator for inference mode based on torch version."""
  37. return (torch.inference_mode if TORCH_1_9 else torch.no_grad)()(fn)
  38. return decorate
  39. def get_cpu_info():
  40. """Return a string with system CPU information, i.e. 'Apple M2'."""
  41. import cpuinfo # pip install py-cpuinfo
  42. k = 'brand_raw', 'hardware_raw', 'arch_string_raw' # info keys sorted by preference (not all keys always available)
  43. info = cpuinfo.get_cpu_info() # info dict
  44. string = info.get(k[0] if k[0] in info else k[1] if k[1] in info else k[2], 'unknown')
  45. return string.replace('(R)', '').replace('CPU ', '').replace('@ ', '')
  46. def select_device(device='', batch=0, newline=False, verbose=True):
  47. """Selects PyTorch Device. Options are device = None or 'cpu' or 0 or '0' or '0,1,2,3'."""
  48. s = f'Ultralytics YOLOv{__version__} 🚀 Python-{platform.python_version()} torch-{torch.__version__} '
  49. device = str(device).lower()
  50. for remove in 'cuda:', 'none', '(', ')', '[', ']', "'", ' ':
  51. device = device.replace(remove, '') # to string, 'cuda:0' -> '0' and '(0, 1)' -> '0,1'
  52. cpu = device == 'cpu'
  53. mps = device == 'mps' # Apple Metal Performance Shaders (MPS)
  54. if cpu or mps:
  55. os.environ['CUDA_VISIBLE_DEVICES'] = '-1' # force torch.cuda.is_available() = False
  56. elif device: # non-cpu device requested
  57. if device == 'cuda':
  58. device = '0'
  59. visible = os.environ.get('CUDA_VISIBLE_DEVICES', None)
  60. os.environ['CUDA_VISIBLE_DEVICES'] = device # set environment variable - must be before assert is_available()
  61. if not (torch.cuda.is_available() and torch.cuda.device_count() >= len(device.replace(',', ''))):
  62. LOGGER.info(s)
  63. install = 'See https://pytorch.org/get-started/locally/ for up-to-date torch install instructions if no ' \
  64. 'CUDA devices are seen by torch.\n' if torch.cuda.device_count() == 0 else ''
  65. raise ValueError(f"Invalid CUDA 'device={device}' requested."
  66. f" Use 'device=cpu' or pass valid CUDA device(s) if available,"
  67. f" i.e. 'device=0' or 'device=0,1,2,3' for Multi-GPU.\n"
  68. f'\ntorch.cuda.is_available(): {torch.cuda.is_available()}'
  69. f'\ntorch.cuda.device_count(): {torch.cuda.device_count()}'
  70. f"\nos.environ['CUDA_VISIBLE_DEVICES']: {visible}\n"
  71. f'{install}')
  72. if not cpu and not mps and torch.cuda.is_available(): # prefer GPU if available
  73. devices = device.split(',') if device else '0' # range(torch.cuda.device_count()) # i.e. 0,1,6,7
  74. n = len(devices) # device count
  75. if n > 1 and batch > 0 and batch % n != 0: # check batch_size is divisible by device_count
  76. raise ValueError(f"'batch={batch}' must be a multiple of GPU count {n}. Try 'batch={batch // n * n}' or "
  77. f"'batch={batch // n * n + n}', the nearest batch sizes evenly divisible by {n}.")
  78. space = ' ' * (len(s) + 1)
  79. for i, d in enumerate(devices):
  80. p = torch.cuda.get_device_properties(i)
  81. s += f"{'' if i == 0 else space}CUDA:{d} ({p.name}, {p.total_memory / (1 << 20):.0f}MiB)\n" # bytes to MB
  82. arg = 'cuda:0'
  83. elif mps and getattr(torch, 'has_mps', False) and torch.backends.mps.is_available() and TORCH_2_0:
  84. # Prefer MPS if available
  85. s += f'MPS ({get_cpu_info()})\n'
  86. arg = 'mps'
  87. else: # revert to CPU
  88. s += f'CPU ({get_cpu_info()})\n'
  89. arg = 'cpu'
  90. if verbose and RANK == -1:
  91. LOGGER.info(s if newline else s.rstrip())
  92. return torch.device(arg)
  93. def time_sync():
  94. """PyTorch-accurate time."""
  95. if torch.cuda.is_available():
  96. torch.cuda.synchronize()
  97. return time.time()
  98. def fuse_conv_and_bn(conv, bn):
  99. """Fuse Conv2d() and BatchNorm2d() layers https://tehnokv.com/posts/fusing-batchnorm-and-conv/."""
  100. fusedconv = nn.Conv2d(conv.in_channels,
  101. conv.out_channels,
  102. kernel_size=conv.kernel_size,
  103. stride=conv.stride,
  104. padding=conv.padding,
  105. dilation=conv.dilation,
  106. groups=conv.groups,
  107. bias=True).requires_grad_(False).to(conv.weight.device)
  108. # Prepare filters
  109. w_conv = conv.weight.clone().view(conv.out_channels, -1)
  110. w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
  111. fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.shape))
  112. # Prepare spatial bias
  113. b_conv = torch.zeros(conv.weight.size(0), device=conv.weight.device) if conv.bias is None else conv.bias
  114. b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))
  115. fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn)
  116. return fusedconv
  117. def fuse_deconv_and_bn(deconv, bn):
  118. """Fuse ConvTranspose2d() and BatchNorm2d() layers."""
  119. fuseddconv = nn.ConvTranspose2d(deconv.in_channels,
  120. deconv.out_channels,
  121. kernel_size=deconv.kernel_size,
  122. stride=deconv.stride,
  123. padding=deconv.padding,
  124. output_padding=deconv.output_padding,
  125. dilation=deconv.dilation,
  126. groups=deconv.groups,
  127. bias=True).requires_grad_(False).to(deconv.weight.device)
  128. # Prepare filters
  129. w_deconv = deconv.weight.clone().view(deconv.out_channels, -1)
  130. w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
  131. fuseddconv.weight.copy_(torch.mm(w_bn, w_deconv).view(fuseddconv.weight.shape))
  132. # Prepare spatial bias
  133. b_conv = torch.zeros(deconv.weight.size(1), device=deconv.weight.device) if deconv.bias is None else deconv.bias
  134. b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))
  135. fuseddconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn)
  136. return fuseddconv
  137. def model_info(model, detailed=False, verbose=True, imgsz=640):
  138. """Model information. imgsz may be int or list, i.e. imgsz=640 or imgsz=[640, 320]."""
  139. if not verbose:
  140. return
  141. n_p = get_num_params(model) # number of parameters
  142. n_g = get_num_gradients(model) # number of gradients
  143. n_l = len(list(model.modules())) # number of layers
  144. if detailed:
  145. LOGGER.info(
  146. f"{'layer':>5} {'name':>40} {'gradient':>9} {'parameters':>12} {'shape':>20} {'mu':>10} {'sigma':>10}")
  147. for i, (name, p) in enumerate(model.named_parameters()):
  148. name = name.replace('module_list.', '')
  149. LOGGER.info('%5g %40s %9s %12g %20s %10.3g %10.3g %10s' %
  150. (i, name, p.requires_grad, p.numel(), list(p.shape), p.mean(), p.std(), p.dtype))
  151. flops = get_flops(model, imgsz)
  152. fused = ' (fused)' if getattr(model, 'is_fused', lambda: False)() else ''
  153. fs = f', {flops:.1f} GFLOPs' if flops else ''
  154. yaml_file = getattr(model, 'yaml_file', '') or getattr(model, 'yaml', {}).get('yaml_file', '')
  155. model_name = Path(yaml_file).stem.replace('yolo', 'YOLO') or 'Model'
  156. LOGGER.info(f'{model_name} summary{fused}: {n_l} layers, {n_p} parameters, {n_g} gradients{fs}')
  157. return n_l, n_p, n_g, flops
  158. def get_num_params(model):
  159. """Return the total number of parameters in a YOLO model."""
  160. return sum(x.numel() for x in model.parameters())
  161. def get_num_gradients(model):
  162. """Return the total number of parameters with gradients in a YOLO model."""
  163. return sum(x.numel() for x in model.parameters() if x.requires_grad)
  164. def model_info_for_loggers(trainer):
  165. """
  166. Return model info dict with useful model information.
  167. Example for YOLOv8n:
  168. {'model/parameters': 3151904,
  169. 'model/GFLOPs': 8.746,
  170. 'model/speed_ONNX(ms)': 41.244,
  171. 'model/speed_TensorRT(ms)': 3.211,
  172. 'model/speed_PyTorch(ms)': 18.755}
  173. """
  174. if trainer.args.profile: # profile ONNX and TensorRT times
  175. from ultralytics.utils.benchmarks import ProfileModels
  176. results = ProfileModels([trainer.last], device=trainer.device).profile()[0]
  177. results.pop('model/name')
  178. else: # only return PyTorch times from most recent validation
  179. results = {
  180. 'model/parameters': get_num_params(trainer.model),
  181. 'model/GFLOPs': round(get_flops(trainer.model), 3)}
  182. results['model/speed_PyTorch(ms)'] = round(trainer.validator.speed['inference'], 3)
  183. return results
  184. def get_flops(model, imgsz=640):
  185. """Return a YOLO model's FLOPs."""
  186. try:
  187. model = de_parallel(model)
  188. p = next(model.parameters())
  189. stride = max(int(model.stride.max()), 32) if hasattr(model, 'stride') else 32 # max stride
  190. im = torch.empty((1, p.shape[1], stride, stride), device=p.device) # input image in BCHW format
  191. flops = thop.profile(deepcopy(model), inputs=[im], verbose=False)[0] / 1E9 * 2 if thop else 0 # stride GFLOPs
  192. imgsz = imgsz if isinstance(imgsz, list) else [imgsz, imgsz] # expand if int/float
  193. return flops * imgsz[0] / stride * imgsz[1] / stride # 640x640 GFLOPs
  194. except Exception:
  195. return 0
  196. def get_flops_with_torch_profiler(model, imgsz=640):
  197. """Compute model FLOPs (thop alternative)."""
  198. if TORCH_2_0:
  199. model = de_parallel(model)
  200. p = next(model.parameters())
  201. stride = (max(int(model.stride.max()), 32) if hasattr(model, 'stride') else 32) * 2 # max stride
  202. im = torch.zeros((1, p.shape[1], stride, stride), device=p.device) # input image in BCHW format
  203. with torch.profiler.profile(with_flops=True) as prof:
  204. model(im)
  205. flops = sum(x.flops for x in prof.key_averages()) / 1E9
  206. imgsz = imgsz if isinstance(imgsz, list) else [imgsz, imgsz] # expand if int/float
  207. flops = flops * imgsz[0] / stride * imgsz[1] / stride # 640x640 GFLOPs
  208. return flops
  209. return 0
  210. def initialize_weights(model):
  211. """Initialize model weights to random values."""
  212. for m in model.modules():
  213. t = type(m)
  214. if t is nn.Conv2d:
  215. pass # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
  216. elif t is nn.BatchNorm2d:
  217. m.eps = 1e-3
  218. m.momentum = 0.03
  219. elif t in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU]:
  220. m.inplace = True
  221. def scale_img(img, ratio=1.0, same_shape=False, gs=32): # img(16,3,256,416)
  222. # Scales img(bs,3,y,x) by ratio constrained to gs-multiple
  223. if ratio == 1.0:
  224. return img
  225. h, w = img.shape[2:]
  226. s = (int(h * ratio), int(w * ratio)) # new size
  227. img = F.interpolate(img, size=s, mode='bilinear', align_corners=False) # resize
  228. if not same_shape: # pad/crop img
  229. h, w = (math.ceil(x * ratio / gs) * gs for x in (h, w))
  230. return F.pad(img, [0, w - s[1], 0, h - s[0]], value=0.447) # value = imagenet mean
  231. def make_divisible(x, divisor):
  232. """Returns nearest x divisible by divisor."""
  233. if isinstance(divisor, torch.Tensor):
  234. divisor = int(divisor.max()) # to int
  235. return math.ceil(x / divisor) * divisor
  236. def copy_attr(a, b, include=(), exclude=()):
  237. """Copies attributes from object 'b' to object 'a', with options to include/exclude certain attributes."""
  238. for k, v in b.__dict__.items():
  239. if (len(include) and k not in include) or k.startswith('_') or k in exclude:
  240. continue
  241. else:
  242. setattr(a, k, v)
  243. def get_latest_opset():
  244. """Return second-most (for maturity) recently supported ONNX opset by this version of torch."""
  245. return max(int(k[14:]) for k in vars(torch.onnx) if 'symbolic_opset' in k) - 1 # opset
  246. def intersect_dicts(da, db, exclude=()):
  247. """Returns a dictionary of intersecting keys with matching shapes, excluding 'exclude' keys, using da values."""
  248. return {k: v for k, v in da.items() if k in db and all(x not in k for x in exclude) and v.shape == db[k].shape}
  249. def is_parallel(model):
  250. """Returns True if model is of type DP or DDP."""
  251. return isinstance(model, (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel))
  252. def de_parallel(model):
  253. """De-parallelize a model: returns single-GPU model if model is of type DP or DDP."""
  254. return model.module if is_parallel(model) else model
  255. def one_cycle(y1=0.0, y2=1.0, steps=100):
  256. """Returns a lambda function for sinusoidal ramp from y1 to y2 https://arxiv.org/pdf/1812.01187.pdf."""
  257. return lambda x: ((1 - math.cos(x * math.pi / steps)) / 2) * (y2 - y1) + y1
  258. def init_seeds(seed=0, deterministic=False):
  259. """Initialize random number generator (RNG) seeds https://pytorch.org/docs/stable/notes/randomness.html."""
  260. random.seed(seed)
  261. np.random.seed(seed)
  262. torch.manual_seed(seed)
  263. torch.cuda.manual_seed(seed)
  264. torch.cuda.manual_seed_all(seed) # for Multi-GPU, exception safe
  265. # torch.backends.cudnn.benchmark = True # AutoBatch problem https://github.com/ultralytics/yolov5/issues/9287
  266. if deterministic:
  267. if TORCH_2_0:
  268. torch.use_deterministic_algorithms(True, warn_only=True) # warn if deterministic is not possible
  269. torch.backends.cudnn.deterministic = True
  270. os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
  271. os.environ['PYTHONHASHSEED'] = str(seed)
  272. else:
  273. LOGGER.warning('WARNING ⚠️ Upgrade to torch>=2.0.0 for deterministic training.')
  274. else:
  275. torch.use_deterministic_algorithms(False)
  276. torch.backends.cudnn.deterministic = False
  277. class ModelEMA:
  278. """Updated Exponential Moving Average (EMA) from https://github.com/rwightman/pytorch-image-models
  279. Keeps a moving average of everything in the model state_dict (parameters and buffers)
  280. For EMA details see https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
  281. To disable EMA set the `enabled` attribute to `False`.
  282. """
  283. def __init__(self, model, decay=0.9999, tau=2000, updates=0):
  284. """Create EMA."""
  285. self.ema = deepcopy(de_parallel(model)).eval() # FP32 EMA
  286. self.updates = updates # number of EMA updates
  287. self.decay = lambda x: decay * (1 - math.exp(-x / tau)) # decay exponential ramp (to help early epochs)
  288. for p in self.ema.parameters():
  289. p.requires_grad_(False)
  290. self.enabled = True
  291. def update(self, model):
  292. """Update EMA parameters."""
  293. if self.enabled:
  294. self.updates += 1
  295. d = self.decay(self.updates)
  296. msd = de_parallel(model).state_dict() # model state_dict
  297. for k, v in self.ema.state_dict().items():
  298. if v.dtype.is_floating_point: # true for FP16 and FP32
  299. v *= d
  300. v += (1 - d) * msd[k].detach()
  301. # assert v.dtype == msd[k].dtype == torch.float32, f'{k}: EMA {v.dtype}, model {msd[k].dtype}'
  302. def update_attr(self, model, include=(), exclude=('process_group', 'reducer')):
  303. """Updates attributes and saves stripped model with optimizer removed."""
  304. if self.enabled:
  305. copy_attr(self.ema, model, include, exclude)
  306. def strip_optimizer(f: Union[str, Path] = 'best.pt', s: str = '') -> None:
  307. """
  308. Strip optimizer from 'f' to finalize training, optionally save as 's'.
  309. Args:
  310. f (str): file path to model to strip the optimizer from. Default is 'best.pt'.
  311. s (str): file path to save the model with stripped optimizer to. If not provided, 'f' will be overwritten.
  312. Returns:
  313. None
  314. Example:
  315. ```python
  316. from pathlib import Path
  317. from ultralytics.utils.torch_utils import strip_optimizer
  318. for f in Path('path/to/weights').rglob('*.pt'):
  319. strip_optimizer(f)
  320. ```
  321. """
  322. # Use dill (if exists) to serialize the lambda functions where pickle does not do this
  323. try:
  324. import dill as pickle
  325. except ImportError:
  326. import pickle
  327. x = torch.load(f, map_location=torch.device('cpu'))
  328. if 'model' not in x:
  329. LOGGER.info(f'Skipping {f}, not a valid Ultralytics model.')
  330. return
  331. if hasattr(x['model'], 'args'):
  332. x['model'].args = dict(x['model'].args) # convert from IterableSimpleNamespace to dict
  333. args = {**DEFAULT_CFG_DICT, **x['train_args']} if 'train_args' in x else None # combine args
  334. if x.get('ema'):
  335. x['model'] = x['ema'] # replace model with ema
  336. for k in 'optimizer', 'best_fitness', 'ema', 'updates': # keys
  337. x[k] = None
  338. x['epoch'] = -1
  339. x['model'].half() # to FP16
  340. for p in x['model'].parameters():
  341. p.requires_grad = False
  342. x['train_args'] = {k: v for k, v in args.items() if k in DEFAULT_CFG_KEYS} # strip non-default keys
  343. # x['model'].args = x['train_args']
  344. torch.save(x, s or f, pickle_module=pickle)
  345. mb = os.path.getsize(s or f) / 1E6 # filesize
  346. LOGGER.info(f"Optimizer stripped from {f},{f' saved as {s},' if s else ''} {mb:.1f}MB")
  347. def profile(input, ops, n=10, device=None):
  348. """
  349. Ultralytics speed, memory and FLOPs profiler.
  350. Example:
  351. ```python
  352. from ultralytics.utils.torch_utils import profile
  353. input = torch.randn(16, 3, 640, 640)
  354. m1 = lambda x: x * torch.sigmoid(x)
  355. m2 = nn.SiLU()
  356. profile(input, [m1, m2], n=100) # profile over 100 iterations
  357. ```
  358. """
  359. results = []
  360. if not isinstance(device, torch.device):
  361. device = select_device(device)
  362. LOGGER.info(f"{'Params':>12s}{'GFLOPs':>12s}{'GPU_mem (GB)':>14s}{'forward (ms)':>14s}{'backward (ms)':>14s}"
  363. f"{'input':>24s}{'output':>24s}")
  364. for x in input if isinstance(input, list) else [input]:
  365. x = x.to(device)
  366. x.requires_grad = True
  367. for m in ops if isinstance(ops, list) else [ops]:
  368. m = m.to(device) if hasattr(m, 'to') else m # device
  369. m = m.half() if hasattr(m, 'half') and isinstance(x, torch.Tensor) and x.dtype is torch.float16 else m
  370. tf, tb, t = 0, 0, [0, 0, 0] # dt forward, backward
  371. try:
  372. flops = thop.profile(m, inputs=[x], verbose=False)[0] / 1E9 * 2 if thop else 0 # GFLOPs
  373. except Exception:
  374. flops = 0
  375. try:
  376. for _ in range(n):
  377. t[0] = time_sync()
  378. y = m(x)
  379. t[1] = time_sync()
  380. try:
  381. (sum(yi.sum() for yi in y) if isinstance(y, list) else y).sum().backward()
  382. t[2] = time_sync()
  383. except Exception: # no backward method
  384. # print(e) # for debug
  385. t[2] = float('nan')
  386. tf += (t[1] - t[0]) * 1000 / n # ms per op forward
  387. tb += (t[2] - t[1]) * 1000 / n # ms per op backward
  388. mem = torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0 # (GB)
  389. s_in, s_out = (tuple(x.shape) if isinstance(x, torch.Tensor) else 'list' for x in (x, y)) # shapes
  390. p = sum(x.numel() for x in m.parameters()) if isinstance(m, nn.Module) else 0 # parameters
  391. LOGGER.info(f'{p:12}{flops:12.4g}{mem:>14.3f}{tf:14.4g}{tb:14.4g}{str(s_in):>24s}{str(s_out):>24s}')
  392. results.append([p, flops, mem, tf, tb, s_in, s_out])
  393. except Exception as e:
  394. LOGGER.info(e)
  395. results.append(None)
  396. torch.cuda.empty_cache()
  397. return results
  398. class EarlyStopping:
  399. """
  400. Early stopping class that stops training when a specified number of epochs have passed without improvement.
  401. """
  402. def __init__(self, patience=50):
  403. """
  404. Initialize early stopping object
  405. Args:
  406. patience (int, optional): Number of epochs to wait after fitness stops improving before stopping.
  407. """
  408. self.best_fitness = 0.0 # i.e. mAP
  409. self.best_epoch = 0
  410. self.patience = patience or float('inf') # epochs to wait after fitness stops improving to stop
  411. self.possible_stop = False # possible stop may occur next epoch
  412. def __call__(self, epoch, fitness):
  413. """
  414. Check whether to stop training
  415. Args:
  416. epoch (int): Current epoch of training
  417. fitness (float): Fitness value of current epoch
  418. Returns:
  419. (bool): True if training should stop, False otherwise
  420. """
  421. if fitness is None: # check if fitness=None (happens when val=False)
  422. return False
  423. if fitness >= self.best_fitness: # >= 0 to allow for early zero-fitness stage of training
  424. self.best_epoch = epoch
  425. self.best_fitness = fitness
  426. delta = epoch - self.best_epoch # epochs without improvement
  427. self.possible_stop = delta >= (self.patience - 1) # possible stop may occur next epoch
  428. stop = delta >= self.patience # stop training if patience exceeded
  429. if stop:
  430. LOGGER.info(f'Stopping training early as no improvement observed in last {self.patience} epochs. '
  431. f'Best results observed at epoch {self.best_epoch}, best model saved as best.pt.\n'
  432. f'To update EarlyStopping(patience={self.patience}) pass a new patience value, '
  433. f'i.e. `patience=300` or use `patience=0` to disable EarlyStopping.')
  434. return stop