tasks.py 34 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813
  1. # Ultralytics YOLO 🚀, AGPL-3.0 license
  2. import contextlib
  3. from copy import deepcopy
  4. from pathlib import Path
  5. import torch
  6. import torch.nn as nn
  7. from ultralytics.nn.modules import (AIFI, C1, C2, C3, C3TR, SPP, SPPF, Bottleneck, BottleneckCSP, C2f, C3Ghost, C3x,
  8. Classify, Concat, Conv, Conv2, ConvTranspose, Detect, DWConv, DWConvTranspose2d,
  9. Focus, GhostBottleneck, GhostConv, HGBlock, HGStem, Pose, RepC3, RepConv,
  10. RTDETRDecoder, Segment)
  11. from ultralytics.utils import DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, colorstr, emojis, yaml_load
  12. from ultralytics.utils.checks import check_requirements, check_suffix, check_yaml
  13. from ultralytics.utils.loss import v8ClassificationLoss, v8DetectionLoss, v8PoseLoss, v8SegmentationLoss
  14. from ultralytics.utils.plotting import feature_visualization
  15. from ultralytics.utils.torch_utils import (fuse_conv_and_bn, fuse_deconv_and_bn, initialize_weights, intersect_dicts,
  16. make_divisible, model_info, scale_img, time_sync)
  17. try:
  18. import thop
  19. except ImportError:
  20. thop = None
  21. class BaseModel(nn.Module):
  22. """
  23. The BaseModel class serves as a base class for all the models in the Ultralytics YOLO family.
  24. """
  25. def forward(self, x, *args, **kwargs):
  26. """
  27. Forward pass of the model on a single scale.
  28. Wrapper for `_forward_once` method.
  29. Args:
  30. x (torch.Tensor | dict): The input image tensor or a dict including image tensor and gt labels.
  31. Returns:
  32. (torch.Tensor): The output of the network.
  33. """
  34. if isinstance(x, dict): # for cases of training and validating while training.
  35. return self.loss(x, *args, **kwargs)
  36. return self.predict(x, *args, **kwargs)
  37. def predict(self, x, profile=False, visualize=False, augment=False):
  38. """
  39. Perform a forward pass through the network.
  40. Args:
  41. x (torch.Tensor): The input tensor to the model.
  42. profile (bool): Print the computation time of each layer if True, defaults to False.
  43. visualize (bool): Save the feature maps of the model if True, defaults to False.
  44. augment (bool): Augment image during prediction, defaults to False.
  45. Returns:
  46. (torch.Tensor): The last output of the model.
  47. """
  48. if augment:
  49. return self._predict_augment(x)
  50. return self._predict_once(x, profile, visualize)
  51. def _predict_once(self, x, profile=False, visualize=False):
  52. """
  53. Perform a forward pass through the network.
  54. Args:
  55. x (torch.Tensor): The input tensor to the model.
  56. profile (bool): Print the computation time of each layer if True, defaults to False.
  57. visualize (bool): Save the feature maps of the model if True, defaults to False.
  58. Returns:
  59. (torch.Tensor): The last output of the model.
  60. """
  61. y, dt = [], [] # outputs
  62. for m in self.model:
  63. if m.f != -1: # if not from previous layer
  64. x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers
  65. if profile:
  66. self._profile_one_layer(m, x, dt)
  67. x = m(x) # run
  68. y.append(x if m.i in self.save else None) # save output
  69. if visualize:
  70. feature_visualization(x, m.type, m.i, save_dir=visualize)
  71. return x
  72. def _predict_augment(self, x):
  73. """Perform augmentations on input image x and return augmented inference."""
  74. LOGGER.warning(f'WARNING ⚠️ {self.__class__.__name__} does not support augmented inference yet. '
  75. f'Reverting to single-scale inference instead.')
  76. return self._predict_once(x)
  77. def _profile_one_layer(self, m, x, dt):
  78. """
  79. Profile the computation time and FLOPs of a single layer of the model on a given input.
  80. Appends the results to the provided list.
  81. Args:
  82. m (nn.Module): The layer to be profiled.
  83. x (torch.Tensor): The input data to the layer.
  84. dt (list): A list to store the computation time of the layer.
  85. Returns:
  86. None
  87. """
  88. c = m == self.model[-1] and isinstance(x, list) # is final layer list, copy input as inplace fix
  89. flops = thop.profile(m, inputs=[x.copy() if c else x], verbose=False)[0] / 1E9 * 2 if thop else 0 # FLOPs
  90. t = time_sync()
  91. for _ in range(10):
  92. m(x.copy() if c else x)
  93. dt.append((time_sync() - t) * 100)
  94. if m == self.model[0]:
  95. LOGGER.info(f"{'time (ms)':>10s} {'GFLOPs':>10s} {'params':>10s} module")
  96. LOGGER.info(f'{dt[-1]:10.2f} {flops:10.2f} {m.np:10.0f} {m.type}')
  97. if c:
  98. LOGGER.info(f"{sum(dt):10.2f} {'-':>10s} {'-':>10s} Total")
  99. def fuse(self, verbose=True):
  100. """
  101. Fuse the `Conv2d()` and `BatchNorm2d()` layers of the model into a single layer, in order to improve the
  102. computation efficiency.
  103. Returns:
  104. (nn.Module): The fused model is returned.
  105. """
  106. if not self.is_fused():
  107. for m in self.model.modules():
  108. if isinstance(m, (Conv, Conv2, DWConv)) and hasattr(m, 'bn'):
  109. if isinstance(m, Conv2):
  110. m.fuse_convs()
  111. m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv
  112. delattr(m, 'bn') # remove batchnorm
  113. m.forward = m.forward_fuse # update forward
  114. if isinstance(m, ConvTranspose) and hasattr(m, 'bn'):
  115. m.conv_transpose = fuse_deconv_and_bn(m.conv_transpose, m.bn)
  116. delattr(m, 'bn') # remove batchnorm
  117. m.forward = m.forward_fuse # update forward
  118. if isinstance(m, RepConv):
  119. m.fuse_convs()
  120. m.forward = m.forward_fuse # update forward
  121. self.info(verbose=verbose)
  122. return self
  123. def is_fused(self, thresh=10):
  124. """
  125. Check if the model has less than a certain threshold of BatchNorm layers.
  126. Args:
  127. thresh (int, optional): The threshold number of BatchNorm layers. Default is 10.
  128. Returns:
  129. (bool): True if the number of BatchNorm layers in the model is less than the threshold, False otherwise.
  130. """
  131. bn = tuple(v for k, v in nn.__dict__.items() if 'Norm' in k) # normalization layers, i.e. BatchNorm2d()
  132. return sum(isinstance(v, bn) for v in self.modules()) < thresh # True if < 'thresh' BatchNorm layers in model
  133. def info(self, detailed=False, verbose=True, imgsz=640):
  134. """
  135. Prints model information
  136. Args:
  137. detailed (bool): if True, prints out detailed information about the model. Defaults to False
  138. verbose (bool): if True, prints out the model information. Defaults to False
  139. imgsz (int): the size of the image that the model will be trained on. Defaults to 640
  140. """
  141. return model_info(self, detailed=detailed, verbose=verbose, imgsz=imgsz)
  142. def _apply(self, fn):
  143. """
  144. Applies a function to all the tensors in the model that are not parameters or registered buffers.
  145. Args:
  146. fn (function): the function to apply to the model
  147. Returns:
  148. A model that is a Detect() object.
  149. """
  150. self = super()._apply(fn)
  151. m = self.model[-1] # Detect()
  152. if isinstance(m, (Detect, Segment)):
  153. m.stride = fn(m.stride)
  154. m.anchors = fn(m.anchors)
  155. m.strides = fn(m.strides)
  156. return self
  157. def load(self, weights, verbose=True):
  158. """
  159. Load the weights into the model.
  160. Args:
  161. weights (dict | torch.nn.Module): The pre-trained weights to be loaded.
  162. verbose (bool, optional): Whether to log the transfer progress. Defaults to True.
  163. """
  164. model = weights['model'] if isinstance(weights, dict) else weights # torchvision models are not dicts
  165. csd = model.float().state_dict() # checkpoint state_dict as FP32
  166. csd = intersect_dicts(csd, self.state_dict()) # intersect
  167. self.load_state_dict(csd, strict=False) # load
  168. if verbose:
  169. LOGGER.info(f'Transferred {len(csd)}/{len(self.model.state_dict())} items from pretrained weights')
  170. def loss(self, batch, preds=None):
  171. """
  172. Compute loss
  173. Args:
  174. batch (dict): Batch to compute loss on
  175. preds (torch.Tensor | List[torch.Tensor]): Predictions.
  176. """
  177. if not hasattr(self, 'criterion'):
  178. self.criterion = self.init_criterion()
  179. preds = self.forward(batch['img']) if preds is None else preds
  180. return self.criterion(preds, batch)
  181. def init_criterion(self):
  182. raise NotImplementedError('compute_loss() needs to be implemented by task heads')
  183. class DetectionModel(BaseModel):
  184. """YOLOv8 detection model."""
  185. def __init__(self, cfg='yolov8n.yaml', ch=3, nc=None, verbose=True): # model, input channels, number of classes
  186. super().__init__()
  187. self.yaml = cfg if isinstance(cfg, dict) else yaml_model_load(cfg) # cfg dict
  188. # Define model
  189. ch = self.yaml['ch'] = self.yaml.get('ch', ch) # input channels
  190. if nc and nc != self.yaml['nc']:
  191. LOGGER.info(f"Overriding model.yaml nc={self.yaml['nc']} with nc={nc}")
  192. self.yaml['nc'] = nc # override YAML value
  193. self.model, self.save = parse_model(deepcopy(self.yaml), ch=ch, verbose=verbose) # model, savelist
  194. self.names = {i: f'{i}' for i in range(self.yaml['nc'])} # default names dict
  195. self.inplace = self.yaml.get('inplace', True)
  196. # Build strides
  197. m = self.model[-1] # Detect()
  198. if isinstance(m, (Detect, Segment, Pose)):
  199. s = 256 # 2x min stride
  200. m.inplace = self.inplace
  201. forward = lambda x: self.forward(x)[0] if isinstance(m, (Segment, Pose)) else self.forward(x)
  202. m.stride = torch.tensor([s / x.shape[-2] for x in forward(torch.zeros(1, ch, s, s))]) # forward
  203. self.stride = m.stride
  204. m.bias_init() # only run once
  205. else:
  206. self.stride = torch.Tensor([32]) # default stride for i.e. RTDETR
  207. # Init weights, biases
  208. initialize_weights(self)
  209. if verbose:
  210. self.info()
  211. LOGGER.info('')
  212. def _predict_augment(self, x):
  213. """Perform augmentations on input image x and return augmented inference and train outputs."""
  214. img_size = x.shape[-2:] # height, width
  215. s = [1, 0.83, 0.67] # scales
  216. f = [None, 3, None] # flips (2-ud, 3-lr)
  217. y = [] # outputs
  218. for si, fi in zip(s, f):
  219. xi = scale_img(x.flip(fi) if fi else x, si, gs=int(self.stride.max()))
  220. yi = super().predict(xi)[0] # forward
  221. yi = self._descale_pred(yi, fi, si, img_size)
  222. y.append(yi)
  223. y = self._clip_augmented(y) # clip augmented tails
  224. return torch.cat(y, -1), None # augmented inference, train
  225. @staticmethod
  226. def _descale_pred(p, flips, scale, img_size, dim=1):
  227. """De-scale predictions following augmented inference (inverse operation)."""
  228. p[:, :4] /= scale # de-scale
  229. x, y, wh, cls = p.split((1, 1, 2, p.shape[dim] - 4), dim)
  230. if flips == 2:
  231. y = img_size[0] - y # de-flip ud
  232. elif flips == 3:
  233. x = img_size[1] - x # de-flip lr
  234. return torch.cat((x, y, wh, cls), dim)
  235. def _clip_augmented(self, y):
  236. """Clip YOLOv5 augmented inference tails."""
  237. nl = self.model[-1].nl # number of detection layers (P3-P5)
  238. g = sum(4 ** x for x in range(nl)) # grid points
  239. e = 1 # exclude layer count
  240. i = (y[0].shape[-1] // g) * sum(4 ** x for x in range(e)) # indices
  241. y[0] = y[0][..., :-i] # large
  242. i = (y[-1].shape[-1] // g) * sum(4 ** (nl - 1 - x) for x in range(e)) # indices
  243. y[-1] = y[-1][..., i:] # small
  244. return y
  245. def init_criterion(self):
  246. return v8DetectionLoss(self)
  247. class SegmentationModel(DetectionModel):
  248. """YOLOv8 segmentation model."""
  249. def __init__(self, cfg='yolov8n-seg.yaml', ch=3, nc=None, verbose=True):
  250. """Initialize YOLOv8 segmentation model with given config and parameters."""
  251. super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)
  252. def init_criterion(self):
  253. return v8SegmentationLoss(self)
  254. class PoseModel(DetectionModel):
  255. """YOLOv8 pose model."""
  256. def __init__(self, cfg='yolov8n-pose.yaml', ch=3, nc=None, data_kpt_shape=(None, None), verbose=True):
  257. """Initialize YOLOv8 Pose model."""
  258. if not isinstance(cfg, dict):
  259. cfg = yaml_model_load(cfg) # load model YAML
  260. if any(data_kpt_shape) and list(data_kpt_shape) != list(cfg['kpt_shape']):
  261. LOGGER.info(f"Overriding model.yaml kpt_shape={cfg['kpt_shape']} with kpt_shape={data_kpt_shape}")
  262. cfg['kpt_shape'] = data_kpt_shape
  263. super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)
  264. def init_criterion(self):
  265. return v8PoseLoss(self)
  266. class ClassificationModel(BaseModel):
  267. """YOLOv8 classification model."""
  268. def __init__(self,
  269. cfg='yolov8n-cls.yaml',
  270. model=None,
  271. ch=3,
  272. nc=None,
  273. cutoff=10,
  274. verbose=True): # YAML, model, channels, number of classes, cutoff index, verbose flag
  275. super().__init__()
  276. self._from_detection_model(model, nc, cutoff) if model is not None else self._from_yaml(cfg, ch, nc, verbose)
  277. def _from_detection_model(self, model, nc=1000, cutoff=10):
  278. """Create a YOLOv5 classification model from a YOLOv5 detection model."""
  279. from ultralytics.nn.autobackend import AutoBackend
  280. if isinstance(model, AutoBackend):
  281. model = model.model # unwrap DetectMultiBackend
  282. model.model = model.model[:cutoff] # backbone
  283. m = model.model[-1] # last layer
  284. ch = m.conv.in_channels if hasattr(m, 'conv') else m.cv1.conv.in_channels # ch into module
  285. c = Classify(ch, nc) # Classify()
  286. c.i, c.f, c.type = m.i, m.f, 'models.common.Classify' # index, from, type
  287. model.model[-1] = c # replace
  288. self.model = model.model
  289. self.stride = model.stride
  290. self.save = []
  291. self.nc = nc
  292. def _from_yaml(self, cfg, ch, nc, verbose):
  293. """Set YOLOv8 model configurations and define the model architecture."""
  294. self.yaml = cfg if isinstance(cfg, dict) else yaml_model_load(cfg) # cfg dict
  295. # Define model
  296. ch = self.yaml['ch'] = self.yaml.get('ch', ch) # input channels
  297. if nc and nc != self.yaml['nc']:
  298. LOGGER.info(f"Overriding model.yaml nc={self.yaml['nc']} with nc={nc}")
  299. self.yaml['nc'] = nc # override YAML value
  300. elif not nc and not self.yaml.get('nc', None):
  301. raise ValueError('nc not specified. Must specify nc in model.yaml or function arguments.')
  302. self.model, self.save = parse_model(deepcopy(self.yaml), ch=ch, verbose=verbose) # model, savelist
  303. self.stride = torch.Tensor([1]) # no stride constraints
  304. self.names = {i: f'{i}' for i in range(self.yaml['nc'])} # default names dict
  305. self.info()
  306. @staticmethod
  307. def reshape_outputs(model, nc):
  308. """Update a TorchVision classification model to class count 'n' if required."""
  309. name, m = list((model.model if hasattr(model, 'model') else model).named_children())[-1] # last module
  310. if isinstance(m, Classify): # YOLO Classify() head
  311. if m.linear.out_features != nc:
  312. m.linear = nn.Linear(m.linear.in_features, nc)
  313. elif isinstance(m, nn.Linear): # ResNet, EfficientNet
  314. if m.out_features != nc:
  315. setattr(model, name, nn.Linear(m.in_features, nc))
  316. elif isinstance(m, nn.Sequential):
  317. types = [type(x) for x in m]
  318. if nn.Linear in types:
  319. i = types.index(nn.Linear) # nn.Linear index
  320. if m[i].out_features != nc:
  321. m[i] = nn.Linear(m[i].in_features, nc)
  322. elif nn.Conv2d in types:
  323. i = types.index(nn.Conv2d) # nn.Conv2d index
  324. if m[i].out_channels != nc:
  325. m[i] = nn.Conv2d(m[i].in_channels, nc, m[i].kernel_size, m[i].stride, bias=m[i].bias is not None)
  326. def init_criterion(self):
  327. """Compute the classification loss between predictions and true labels."""
  328. return v8ClassificationLoss()
  329. class RTDETRDetectionModel(DetectionModel):
  330. def __init__(self, cfg='rtdetr-l.yaml', ch=3, nc=None, verbose=True):
  331. super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)
  332. def init_criterion(self):
  333. """Compute the classification loss between predictions and true labels."""
  334. from ultralytics.models.utils.loss import RTDETRDetectionLoss
  335. return RTDETRDetectionLoss(nc=self.nc, use_vfl=True)
  336. def loss(self, batch, preds=None):
  337. if not hasattr(self, 'criterion'):
  338. self.criterion = self.init_criterion()
  339. img = batch['img']
  340. # NOTE: preprocess gt_bbox and gt_labels to list.
  341. bs = len(img)
  342. batch_idx = batch['batch_idx']
  343. gt_groups = [(batch_idx == i).sum().item() for i in range(bs)]
  344. targets = {
  345. 'cls': batch['cls'].to(img.device, dtype=torch.long).view(-1),
  346. 'bboxes': batch['bboxes'].to(device=img.device),
  347. 'batch_idx': batch_idx.to(img.device, dtype=torch.long).view(-1),
  348. 'gt_groups': gt_groups}
  349. preds = self.predict(img, batch=targets) if preds is None else preds
  350. dec_bboxes, dec_scores, enc_bboxes, enc_scores, dn_meta = preds if self.training else preds[1]
  351. if dn_meta is None:
  352. dn_bboxes, dn_scores = None, None
  353. else:
  354. dn_bboxes, dec_bboxes = torch.split(dec_bboxes, dn_meta['dn_num_split'], dim=2)
  355. dn_scores, dec_scores = torch.split(dec_scores, dn_meta['dn_num_split'], dim=2)
  356. dec_bboxes = torch.cat([enc_bboxes.unsqueeze(0), dec_bboxes]) # (7, bs, 300, 4)
  357. dec_scores = torch.cat([enc_scores.unsqueeze(0), dec_scores])
  358. loss = self.criterion((dec_bboxes, dec_scores),
  359. targets,
  360. dn_bboxes=dn_bboxes,
  361. dn_scores=dn_scores,
  362. dn_meta=dn_meta)
  363. # NOTE: There are like 12 losses in RTDETR, backward with all losses but only show the main three losses.
  364. return sum(loss.values()), torch.as_tensor([loss[k].detach() for k in ['loss_giou', 'loss_class', 'loss_bbox']],
  365. device=img.device)
  366. def predict(self, x, profile=False, visualize=False, batch=None, augment=False):
  367. """
  368. Perform a forward pass through the network.
  369. Args:
  370. x (torch.Tensor): The input tensor to the model
  371. profile (bool): Print the computation time of each layer if True, defaults to False.
  372. visualize (bool): Save the feature maps of the model if True, defaults to False
  373. batch (dict): A dict including gt boxes and labels from dataloader.
  374. Returns:
  375. (torch.Tensor): The last output of the model.
  376. """
  377. y, dt = [], [] # outputs
  378. for m in self.model[:-1]: # except the head part
  379. if m.f != -1: # if not from previous layer
  380. x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers
  381. if profile:
  382. self._profile_one_layer(m, x, dt)
  383. x = m(x) # run
  384. y.append(x if m.i in self.save else None) # save output
  385. if visualize:
  386. feature_visualization(x, m.type, m.i, save_dir=visualize)
  387. head = self.model[-1]
  388. x = head([y[j] for j in head.f], batch) # head inference
  389. return x
  390. class Ensemble(nn.ModuleList):
  391. """Ensemble of models."""
  392. def __init__(self):
  393. """Initialize an ensemble of models."""
  394. super().__init__()
  395. def forward(self, x, augment=False, profile=False, visualize=False):
  396. """Function generates the YOLOv5 network's final layer."""
  397. y = [module(x, augment, profile, visualize)[0] for module in self]
  398. # y = torch.stack(y).max(0)[0] # max ensemble
  399. # y = torch.stack(y).mean(0) # mean ensemble
  400. y = torch.cat(y, 2) # nms ensemble, y shape(B, HW, C)
  401. return y, None # inference, train output
  402. # Functions ------------------------------------------------------------------------------------------------------------
  403. @contextlib.contextmanager
  404. def temporary_modules(modules=None):
  405. """
  406. Context manager for temporarily adding or modifying modules in Python's module cache (`sys.modules`).
  407. This function can be used to change the module paths during runtime. It's useful when refactoring code,
  408. where you've moved a module from one location to another, but you still want to support the old import
  409. paths for backwards compatibility.
  410. Args:
  411. modules (dict, optional): A dictionary mapping old module paths to new module paths.
  412. Example:
  413. ```python
  414. with temporary_modules({'old.module.path': 'new.module.path'}):
  415. import old.module.path # this will now import new.module.path
  416. ```
  417. Note:
  418. The changes are only in effect inside the context manager and are undone once the context manager exits.
  419. Be aware that directly manipulating `sys.modules` can lead to unpredictable results, especially in larger
  420. applications or libraries. Use this function with caution.
  421. """
  422. if not modules:
  423. modules = {}
  424. import importlib
  425. import sys
  426. try:
  427. # Set modules in sys.modules under their old name
  428. for old, new in modules.items():
  429. sys.modules[old] = importlib.import_module(new)
  430. yield
  431. finally:
  432. # Remove the temporary module paths
  433. for old in modules:
  434. if old in sys.modules:
  435. del sys.modules[old]
  436. def torch_safe_load(weight):
  437. """
  438. This function attempts to load a PyTorch model with the torch.load() function. If a ModuleNotFoundError is raised,
  439. it catches the error, logs a warning message, and attempts to install the missing module via the
  440. check_requirements() function. After installation, the function again attempts to load the model using torch.load().
  441. Args:
  442. weight (str): The file path of the PyTorch model.
  443. Returns:
  444. (dict): The loaded PyTorch model.
  445. """
  446. from ultralytics.utils.downloads import attempt_download_asset
  447. check_suffix(file=weight, suffix='.pt')
  448. file = attempt_download_asset(weight) # search online if missing locally
  449. try:
  450. with temporary_modules({
  451. 'ultralytics.yolo.utils': 'ultralytics.utils',
  452. 'ultralytics.yolo.v8': 'ultralytics.models.yolo',
  453. 'ultralytics.yolo.data': 'ultralytics.data'}): # for legacy 8.0 Classify and Pose models
  454. return torch.load(file, map_location='cpu'), file # load
  455. except ModuleNotFoundError as e: # e.name is missing module name
  456. if e.name == 'models':
  457. raise TypeError(
  458. emojis(f'ERROR ❌️ {weight} appears to be an Ultralytics YOLOv5 model originally trained '
  459. f'with https://github.com/ultralytics/yolov5.\nThis model is NOT forwards compatible with '
  460. f'YOLOv8 at https://github.com/ultralytics/ultralytics.'
  461. f"\nRecommend fixes are to train a new model using the latest 'ultralytics' package or to "
  462. f"run a command with an official YOLOv8 model, i.e. 'yolo predict model=yolov8n.pt'")) from e
  463. LOGGER.warning(f"WARNING ⚠️ {weight} appears to require '{e.name}', which is not in ultralytics requirements."
  464. f"\nAutoInstall will run now for '{e.name}' but this feature will be removed in the future."
  465. f"\nRecommend fixes are to train a new model using the latest 'ultralytics' package or to "
  466. f"run a command with an official YOLOv8 model, i.e. 'yolo predict model=yolov8n.pt'")
  467. check_requirements(e.name) # install missing module
  468. return torch.load(file, map_location='cpu'), file # load
  469. def attempt_load_weights(weights, device=None, inplace=True, fuse=False):
  470. """Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a."""
  471. ensemble = Ensemble()
  472. for w in weights if isinstance(weights, list) else [weights]:
  473. ckpt, w = torch_safe_load(w) # load ckpt
  474. args = {**DEFAULT_CFG_DICT, **ckpt['train_args']} if 'train_args' in ckpt else None # combined args
  475. model = (ckpt.get('ema') or ckpt['model']).to(device).float() # FP32 model
  476. # Model compatibility updates
  477. model.args = args # attach args to model
  478. model.pt_path = w # attach *.pt file path to model
  479. model.task = guess_model_task(model)
  480. if not hasattr(model, 'stride'):
  481. model.stride = torch.tensor([32.])
  482. # Append
  483. ensemble.append(model.fuse().eval() if fuse and hasattr(model, 'fuse') else model.eval()) # model in eval mode
  484. # Module updates
  485. for m in ensemble.modules():
  486. t = type(m)
  487. if t in (nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU, Detect, Segment):
  488. m.inplace = inplace
  489. elif t is nn.Upsample and not hasattr(m, 'recompute_scale_factor'):
  490. m.recompute_scale_factor = None # torch 1.11.0 compatibility
  491. # Return model
  492. if len(ensemble) == 1:
  493. return ensemble[-1]
  494. # Return ensemble
  495. LOGGER.info(f'Ensemble created with {weights}\n')
  496. for k in 'names', 'nc', 'yaml':
  497. setattr(ensemble, k, getattr(ensemble[0], k))
  498. ensemble.stride = ensemble[torch.argmax(torch.tensor([m.stride.max() for m in ensemble])).int()].stride
  499. assert all(ensemble[0].nc == m.nc for m in ensemble), f'Models differ in class counts {[m.nc for m in ensemble]}'
  500. return ensemble
  501. def attempt_load_one_weight(weight, device=None, inplace=True, fuse=False):
  502. """Loads a single model weights."""
  503. ckpt, weight = torch_safe_load(weight) # load ckpt
  504. args = {**DEFAULT_CFG_DICT, **(ckpt.get('train_args', {}))} # combine model and default args, preferring model args
  505. model = (ckpt.get('ema') or ckpt['model']).to(device).float() # FP32 model
  506. # Model compatibility updates
  507. model.args = {k: v for k, v in args.items() if k in DEFAULT_CFG_KEYS} # attach args to model
  508. model.pt_path = weight # attach *.pt file path to model
  509. model.task = guess_model_task(model)
  510. if not hasattr(model, 'stride'):
  511. model.stride = torch.tensor([32.])
  512. model = model.fuse().eval() if fuse and hasattr(model, 'fuse') else model.eval() # model in eval mode
  513. # Module updates
  514. for m in model.modules():
  515. t = type(m)
  516. if t in (nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU, Detect, Segment):
  517. m.inplace = inplace
  518. elif t is nn.Upsample and not hasattr(m, 'recompute_scale_factor'):
  519. m.recompute_scale_factor = None # torch 1.11.0 compatibility
  520. # Return model and ckpt
  521. return model, ckpt
  522. def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
  523. """Parse a YOLO model.yaml dictionary into a PyTorch model."""
  524. import ast
  525. # Args
  526. max_channels = float('inf')
  527. nc, act, scales = (d.get(x) for x in ('nc', 'activation', 'scales'))
  528. depth, width, kpt_shape = (d.get(x, 1.0) for x in ('depth_multiple', 'width_multiple', 'kpt_shape'))
  529. if scales:
  530. scale = d.get('scale')
  531. if not scale:
  532. scale = tuple(scales.keys())[0]
  533. LOGGER.warning(f"WARNING ⚠️ no model scale passed. Assuming scale='{scale}'.")
  534. depth, width, max_channels = scales[scale]
  535. if act:
  536. Conv.default_act = eval(act) # redefine default activation, i.e. Conv.default_act = nn.SiLU()
  537. if verbose:
  538. LOGGER.info(f"{colorstr('activation:')} {act}") # print
  539. if verbose:
  540. LOGGER.info(f"\n{'':>3}{'from':>20}{'n':>3}{'params':>10} {'module':<45}{'arguments':<30}")
  541. ch = [ch]
  542. layers, save, c2 = [], [], ch[-1] # layers, savelist, ch out
  543. for i, (f, n, m, args) in enumerate(d['backbone'] + d['head']): # from, number, module, args
  544. m = getattr(torch.nn, m[3:]) if 'nn.' in m else globals()[m] # get module
  545. for j, a in enumerate(args):
  546. if isinstance(a, str):
  547. with contextlib.suppress(ValueError):
  548. args[j] = locals()[a] if a in locals() else ast.literal_eval(a)
  549. n = n_ = max(round(n * depth), 1) if n > 1 else n # depth gain
  550. if m in (Classify, Conv, ConvTranspose, GhostConv, Bottleneck, GhostBottleneck, SPP, SPPF, DWConv, Focus,
  551. BottleneckCSP, C1, C2, C2f, C3, C3TR, C3Ghost, nn.ConvTranspose2d, DWConvTranspose2d, C3x, RepC3):
  552. c1, c2 = ch[f], args[0]
  553. if c2 != nc: # if c2 not equal to number of classes (i.e. for Classify() output)
  554. c2 = make_divisible(min(c2, max_channels) * width, 8)
  555. args = [c1, c2, *args[1:]]
  556. if m in (BottleneckCSP, C1, C2, C2f, C3, C3TR, C3Ghost, C3x, RepC3):
  557. args.insert(2, n) # number of repeats
  558. n = 1
  559. elif m is AIFI:
  560. args = [ch[f], *args]
  561. elif m in (HGStem, HGBlock):
  562. c1, cm, c2 = ch[f], args[0], args[1]
  563. args = [c1, cm, c2, *args[2:]]
  564. if m is HGBlock:
  565. args.insert(4, n) # number of repeats
  566. n = 1
  567. elif m is nn.BatchNorm2d:
  568. args = [ch[f]]
  569. elif m is Concat:
  570. c2 = sum(ch[x] for x in f)
  571. elif m in (Detect, Segment, Pose):
  572. args.append([ch[x] for x in f])
  573. if m is Segment:
  574. args[2] = make_divisible(min(args[2], max_channels) * width, 8)
  575. elif m is RTDETRDecoder: # special case, channels arg must be passed in index 1
  576. args.insert(1, [ch[x] for x in f])
  577. else:
  578. c2 = ch[f]
  579. m_ = nn.Sequential(*(m(*args) for _ in range(n))) if n > 1 else m(*args) # module
  580. t = str(m)[8:-2].replace('__main__.', '') # module type
  581. m.np = sum(x.numel() for x in m_.parameters()) # number params
  582. m_.i, m_.f, m_.type = i, f, t # attach index, 'from' index, type
  583. if verbose:
  584. LOGGER.info(f'{i:>3}{str(f):>20}{n_:>3}{m.np:10.0f} {t:<45}{str(args):<30}') # print
  585. save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1) # append to savelist
  586. layers.append(m_)
  587. if i == 0:
  588. ch = []
  589. ch.append(c2)
  590. return nn.Sequential(*layers), sorted(save)
  591. def yaml_model_load(path):
  592. """Load a YOLOv8 model from a YAML file."""
  593. import re
  594. path = Path(path)
  595. if path.stem in (f'yolov{d}{x}6' for x in 'nsmlx' for d in (5, 8)):
  596. new_stem = re.sub(r'(\d+)([nslmx])6(.+)?$', r'\1\2-p6\3', path.stem)
  597. LOGGER.warning(f'WARNING ⚠️ Ultralytics YOLO P6 models now use -p6 suffix. Renaming {path.stem} to {new_stem}.')
  598. path = path.with_name(new_stem + path.suffix)
  599. unified_path = re.sub(r'(\d+)([nslmx])(.+)?$', r'\1\3', str(path)) # i.e. yolov8x.yaml -> yolov8.yaml
  600. yaml_file = check_yaml(unified_path, hard=False) or check_yaml(path)
  601. d = yaml_load(yaml_file) # model dict
  602. d['scale'] = guess_model_scale(path)
  603. d['yaml_file'] = str(path)
  604. return d
  605. def guess_model_scale(model_path):
  606. """
  607. Takes a path to a YOLO model's YAML file as input and extracts the size character of the model's scale.
  608. The function uses regular expression matching to find the pattern of the model scale in the YAML file name,
  609. which is denoted by n, s, m, l, or x. The function returns the size character of the model scale as a string.
  610. Args:
  611. model_path (str | Path): The path to the YOLO model's YAML file.
  612. Returns:
  613. (str): The size character of the model's scale, which can be n, s, m, l, or x.
  614. """
  615. with contextlib.suppress(AttributeError):
  616. import re
  617. return re.search(r'yolov\d+([nslmx])', Path(model_path).stem).group(1) # n, s, m, l, or x
  618. return ''
  619. def guess_model_task(model):
  620. """
  621. Guess the task of a PyTorch model from its architecture or configuration.
  622. Args:
  623. model (nn.Module | dict): PyTorch model or model configuration in YAML format.
  624. Returns:
  625. (str): Task of the model ('detect', 'segment', 'classify', 'pose').
  626. Raises:
  627. SyntaxError: If the task of the model could not be determined.
  628. """
  629. def cfg2task(cfg):
  630. """Guess from YAML dictionary."""
  631. m = cfg['head'][-1][-2].lower() # output module name
  632. if m in ('classify', 'classifier', 'cls', 'fc'):
  633. return 'classify'
  634. if m == 'detect':
  635. return 'detect'
  636. if m == 'segment':
  637. return 'segment'
  638. if m == 'pose':
  639. return 'pose'
  640. # Guess from model cfg
  641. if isinstance(model, dict):
  642. with contextlib.suppress(Exception):
  643. return cfg2task(model)
  644. # Guess from PyTorch model
  645. if isinstance(model, nn.Module): # PyTorch model
  646. for x in 'model.args', 'model.model.args', 'model.model.model.args':
  647. with contextlib.suppress(Exception):
  648. return eval(x)['task']
  649. for x in 'model.yaml', 'model.model.yaml', 'model.model.model.yaml':
  650. with contextlib.suppress(Exception):
  651. return cfg2task(eval(x))
  652. for m in model.modules():
  653. if isinstance(m, Detect):
  654. return 'detect'
  655. elif isinstance(m, Segment):
  656. return 'segment'
  657. elif isinstance(m, Classify):
  658. return 'classify'
  659. elif isinstance(m, Pose):
  660. return 'pose'
  661. # Guess from model filename
  662. if isinstance(model, (str, Path)):
  663. model = Path(model)
  664. if '-seg' in model.stem or 'segment' in model.parts:
  665. return 'segment'
  666. elif '-cls' in model.stem or 'classify' in model.parts:
  667. return 'classify'
  668. elif '-pose' in model.stem or 'pose' in model.parts:
  669. return 'pose'
  670. elif 'detect' in model.parts:
  671. return 'detect'
  672. # Unable to determine task from model
  673. LOGGER.warning("WARNING ⚠️ Unable to automatically guess model task, assuming 'task=detect'. "
  674. "Explicitly define task for your model, i.e. 'task=detect', 'segment', 'classify', or 'pose'.")
  675. return 'detect' # assume detect