123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813 |
- # Ultralytics YOLO 🚀, AGPL-3.0 license
- import contextlib
- from copy import deepcopy
- from pathlib import Path
- import torch
- import torch.nn as nn
- from ultralytics.nn.modules import (AIFI, C1, C2, C3, C3TR, SPP, SPPF, Bottleneck, BottleneckCSP, C2f, C3Ghost, C3x,
- Classify, Concat, Conv, Conv2, ConvTranspose, Detect, DWConv, DWConvTranspose2d,
- Focus, GhostBottleneck, GhostConv, HGBlock, HGStem, Pose, RepC3, RepConv,
- RTDETRDecoder, Segment)
- from ultralytics.utils import DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, colorstr, emojis, yaml_load
- from ultralytics.utils.checks import check_requirements, check_suffix, check_yaml
- from ultralytics.utils.loss import v8ClassificationLoss, v8DetectionLoss, v8PoseLoss, v8SegmentationLoss
- from ultralytics.utils.plotting import feature_visualization
- from ultralytics.utils.torch_utils import (fuse_conv_and_bn, fuse_deconv_and_bn, initialize_weights, intersect_dicts,
- make_divisible, model_info, scale_img, time_sync)
- try:
- import thop
- except ImportError:
- thop = None
- class BaseModel(nn.Module):
- """
- The BaseModel class serves as a base class for all the models in the Ultralytics YOLO family.
- """
- def forward(self, x, *args, **kwargs):
- """
- Forward pass of the model on a single scale.
- Wrapper for `_forward_once` method.
- Args:
- x (torch.Tensor | dict): The input image tensor or a dict including image tensor and gt labels.
- Returns:
- (torch.Tensor): The output of the network.
- """
- if isinstance(x, dict): # for cases of training and validating while training.
- return self.loss(x, *args, **kwargs)
- return self.predict(x, *args, **kwargs)
- def predict(self, x, profile=False, visualize=False, augment=False):
- """
- Perform a forward pass through the network.
- Args:
- x (torch.Tensor): The input tensor to the model.
- profile (bool): Print the computation time of each layer if True, defaults to False.
- visualize (bool): Save the feature maps of the model if True, defaults to False.
- augment (bool): Augment image during prediction, defaults to False.
- Returns:
- (torch.Tensor): The last output of the model.
- """
- if augment:
- return self._predict_augment(x)
- return self._predict_once(x, profile, visualize)
- def _predict_once(self, x, profile=False, visualize=False):
- """
- Perform a forward pass through the network.
- Args:
- x (torch.Tensor): The input tensor to the model.
- profile (bool): Print the computation time of each layer if True, defaults to False.
- visualize (bool): Save the feature maps of the model if True, defaults to False.
- Returns:
- (torch.Tensor): The last output of the model.
- """
- y, dt = [], [] # outputs
- for m in self.model:
- if m.f != -1: # if not from previous layer
- x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers
- if profile:
- self._profile_one_layer(m, x, dt)
- x = m(x) # run
- y.append(x if m.i in self.save else None) # save output
- if visualize:
- feature_visualization(x, m.type, m.i, save_dir=visualize)
- return x
- def _predict_augment(self, x):
- """Perform augmentations on input image x and return augmented inference."""
- LOGGER.warning(f'WARNING ⚠️ {self.__class__.__name__} does not support augmented inference yet. '
- f'Reverting to single-scale inference instead.')
- return self._predict_once(x)
- def _profile_one_layer(self, m, x, dt):
- """
- Profile the computation time and FLOPs of a single layer of the model on a given input.
- Appends the results to the provided list.
- Args:
- m (nn.Module): The layer to be profiled.
- x (torch.Tensor): The input data to the layer.
- dt (list): A list to store the computation time of the layer.
- Returns:
- None
- """
- c = m == self.model[-1] and isinstance(x, list) # is final layer list, copy input as inplace fix
- flops = thop.profile(m, inputs=[x.copy() if c else x], verbose=False)[0] / 1E9 * 2 if thop else 0 # FLOPs
- t = time_sync()
- for _ in range(10):
- m(x.copy() if c else x)
- dt.append((time_sync() - t) * 100)
- if m == self.model[0]:
- LOGGER.info(f"{'time (ms)':>10s} {'GFLOPs':>10s} {'params':>10s} module")
- LOGGER.info(f'{dt[-1]:10.2f} {flops:10.2f} {m.np:10.0f} {m.type}')
- if c:
- LOGGER.info(f"{sum(dt):10.2f} {'-':>10s} {'-':>10s} Total")
- def fuse(self, verbose=True):
- """
- Fuse the `Conv2d()` and `BatchNorm2d()` layers of the model into a single layer, in order to improve the
- computation efficiency.
- Returns:
- (nn.Module): The fused model is returned.
- """
- if not self.is_fused():
- for m in self.model.modules():
- if isinstance(m, (Conv, Conv2, DWConv)) and hasattr(m, 'bn'):
- if isinstance(m, Conv2):
- m.fuse_convs()
- m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv
- delattr(m, 'bn') # remove batchnorm
- m.forward = m.forward_fuse # update forward
- if isinstance(m, ConvTranspose) and hasattr(m, 'bn'):
- m.conv_transpose = fuse_deconv_and_bn(m.conv_transpose, m.bn)
- delattr(m, 'bn') # remove batchnorm
- m.forward = m.forward_fuse # update forward
- if isinstance(m, RepConv):
- m.fuse_convs()
- m.forward = m.forward_fuse # update forward
- self.info(verbose=verbose)
- return self
- def is_fused(self, thresh=10):
- """
- Check if the model has less than a certain threshold of BatchNorm layers.
- Args:
- thresh (int, optional): The threshold number of BatchNorm layers. Default is 10.
- Returns:
- (bool): True if the number of BatchNorm layers in the model is less than the threshold, False otherwise.
- """
- bn = tuple(v for k, v in nn.__dict__.items() if 'Norm' in k) # normalization layers, i.e. BatchNorm2d()
- return sum(isinstance(v, bn) for v in self.modules()) < thresh # True if < 'thresh' BatchNorm layers in model
- def info(self, detailed=False, verbose=True, imgsz=640):
- """
- Prints model information
- Args:
- detailed (bool): if True, prints out detailed information about the model. Defaults to False
- verbose (bool): if True, prints out the model information. Defaults to False
- imgsz (int): the size of the image that the model will be trained on. Defaults to 640
- """
- return model_info(self, detailed=detailed, verbose=verbose, imgsz=imgsz)
- def _apply(self, fn):
- """
- Applies a function to all the tensors in the model that are not parameters or registered buffers.
- Args:
- fn (function): the function to apply to the model
- Returns:
- A model that is a Detect() object.
- """
- self = super()._apply(fn)
- m = self.model[-1] # Detect()
- if isinstance(m, (Detect, Segment)):
- m.stride = fn(m.stride)
- m.anchors = fn(m.anchors)
- m.strides = fn(m.strides)
- return self
- def load(self, weights, verbose=True):
- """
- Load the weights into the model.
- Args:
- weights (dict | torch.nn.Module): The pre-trained weights to be loaded.
- verbose (bool, optional): Whether to log the transfer progress. Defaults to True.
- """
- model = weights['model'] if isinstance(weights, dict) else weights # torchvision models are not dicts
- csd = model.float().state_dict() # checkpoint state_dict as FP32
- csd = intersect_dicts(csd, self.state_dict()) # intersect
- self.load_state_dict(csd, strict=False) # load
- if verbose:
- LOGGER.info(f'Transferred {len(csd)}/{len(self.model.state_dict())} items from pretrained weights')
- def loss(self, batch, preds=None):
- """
- Compute loss
- Args:
- batch (dict): Batch to compute loss on
- preds (torch.Tensor | List[torch.Tensor]): Predictions.
- """
- if not hasattr(self, 'criterion'):
- self.criterion = self.init_criterion()
- preds = self.forward(batch['img']) if preds is None else preds
- return self.criterion(preds, batch)
- def init_criterion(self):
- raise NotImplementedError('compute_loss() needs to be implemented by task heads')
- class DetectionModel(BaseModel):
- """YOLOv8 detection model."""
- def __init__(self, cfg='yolov8n.yaml', ch=3, nc=None, verbose=True): # model, input channels, number of classes
- super().__init__()
- self.yaml = cfg if isinstance(cfg, dict) else yaml_model_load(cfg) # cfg dict
- # Define model
- ch = self.yaml['ch'] = self.yaml.get('ch', ch) # input channels
- if nc and nc != self.yaml['nc']:
- LOGGER.info(f"Overriding model.yaml nc={self.yaml['nc']} with nc={nc}")
- self.yaml['nc'] = nc # override YAML value
- self.model, self.save = parse_model(deepcopy(self.yaml), ch=ch, verbose=verbose) # model, savelist
- self.names = {i: f'{i}' for i in range(self.yaml['nc'])} # default names dict
- self.inplace = self.yaml.get('inplace', True)
- # Build strides
- m = self.model[-1] # Detect()
- if isinstance(m, (Detect, Segment, Pose)):
- s = 256 # 2x min stride
- m.inplace = self.inplace
- forward = lambda x: self.forward(x)[0] if isinstance(m, (Segment, Pose)) else self.forward(x)
- m.stride = torch.tensor([s / x.shape[-2] for x in forward(torch.zeros(1, ch, s, s))]) # forward
- self.stride = m.stride
- m.bias_init() # only run once
- else:
- self.stride = torch.Tensor([32]) # default stride for i.e. RTDETR
- # Init weights, biases
- initialize_weights(self)
- if verbose:
- self.info()
- LOGGER.info('')
- def _predict_augment(self, x):
- """Perform augmentations on input image x and return augmented inference and train outputs."""
- img_size = x.shape[-2:] # height, width
- s = [1, 0.83, 0.67] # scales
- f = [None, 3, None] # flips (2-ud, 3-lr)
- y = [] # outputs
- for si, fi in zip(s, f):
- xi = scale_img(x.flip(fi) if fi else x, si, gs=int(self.stride.max()))
- yi = super().predict(xi)[0] # forward
- yi = self._descale_pred(yi, fi, si, img_size)
- y.append(yi)
- y = self._clip_augmented(y) # clip augmented tails
- return torch.cat(y, -1), None # augmented inference, train
- @staticmethod
- def _descale_pred(p, flips, scale, img_size, dim=1):
- """De-scale predictions following augmented inference (inverse operation)."""
- p[:, :4] /= scale # de-scale
- x, y, wh, cls = p.split((1, 1, 2, p.shape[dim] - 4), dim)
- if flips == 2:
- y = img_size[0] - y # de-flip ud
- elif flips == 3:
- x = img_size[1] - x # de-flip lr
- return torch.cat((x, y, wh, cls), dim)
- def _clip_augmented(self, y):
- """Clip YOLOv5 augmented inference tails."""
- nl = self.model[-1].nl # number of detection layers (P3-P5)
- g = sum(4 ** x for x in range(nl)) # grid points
- e = 1 # exclude layer count
- i = (y[0].shape[-1] // g) * sum(4 ** x for x in range(e)) # indices
- y[0] = y[0][..., :-i] # large
- i = (y[-1].shape[-1] // g) * sum(4 ** (nl - 1 - x) for x in range(e)) # indices
- y[-1] = y[-1][..., i:] # small
- return y
- def init_criterion(self):
- return v8DetectionLoss(self)
- class SegmentationModel(DetectionModel):
- """YOLOv8 segmentation model."""
- def __init__(self, cfg='yolov8n-seg.yaml', ch=3, nc=None, verbose=True):
- """Initialize YOLOv8 segmentation model with given config and parameters."""
- super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)
- def init_criterion(self):
- return v8SegmentationLoss(self)
- class PoseModel(DetectionModel):
- """YOLOv8 pose model."""
- def __init__(self, cfg='yolov8n-pose.yaml', ch=3, nc=None, data_kpt_shape=(None, None), verbose=True):
- """Initialize YOLOv8 Pose model."""
- if not isinstance(cfg, dict):
- cfg = yaml_model_load(cfg) # load model YAML
- if any(data_kpt_shape) and list(data_kpt_shape) != list(cfg['kpt_shape']):
- LOGGER.info(f"Overriding model.yaml kpt_shape={cfg['kpt_shape']} with kpt_shape={data_kpt_shape}")
- cfg['kpt_shape'] = data_kpt_shape
- super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)
- def init_criterion(self):
- return v8PoseLoss(self)
- class ClassificationModel(BaseModel):
- """YOLOv8 classification model."""
- def __init__(self,
- cfg='yolov8n-cls.yaml',
- model=None,
- ch=3,
- nc=None,
- cutoff=10,
- verbose=True): # YAML, model, channels, number of classes, cutoff index, verbose flag
- super().__init__()
- self._from_detection_model(model, nc, cutoff) if model is not None else self._from_yaml(cfg, ch, nc, verbose)
- def _from_detection_model(self, model, nc=1000, cutoff=10):
- """Create a YOLOv5 classification model from a YOLOv5 detection model."""
- from ultralytics.nn.autobackend import AutoBackend
- if isinstance(model, AutoBackend):
- model = model.model # unwrap DetectMultiBackend
- model.model = model.model[:cutoff] # backbone
- m = model.model[-1] # last layer
- ch = m.conv.in_channels if hasattr(m, 'conv') else m.cv1.conv.in_channels # ch into module
- c = Classify(ch, nc) # Classify()
- c.i, c.f, c.type = m.i, m.f, 'models.common.Classify' # index, from, type
- model.model[-1] = c # replace
- self.model = model.model
- self.stride = model.stride
- self.save = []
- self.nc = nc
- def _from_yaml(self, cfg, ch, nc, verbose):
- """Set YOLOv8 model configurations and define the model architecture."""
- self.yaml = cfg if isinstance(cfg, dict) else yaml_model_load(cfg) # cfg dict
- # Define model
- ch = self.yaml['ch'] = self.yaml.get('ch', ch) # input channels
- if nc and nc != self.yaml['nc']:
- LOGGER.info(f"Overriding model.yaml nc={self.yaml['nc']} with nc={nc}")
- self.yaml['nc'] = nc # override YAML value
- elif not nc and not self.yaml.get('nc', None):
- raise ValueError('nc not specified. Must specify nc in model.yaml or function arguments.')
- self.model, self.save = parse_model(deepcopy(self.yaml), ch=ch, verbose=verbose) # model, savelist
- self.stride = torch.Tensor([1]) # no stride constraints
- self.names = {i: f'{i}' for i in range(self.yaml['nc'])} # default names dict
- self.info()
- @staticmethod
- def reshape_outputs(model, nc):
- """Update a TorchVision classification model to class count 'n' if required."""
- name, m = list((model.model if hasattr(model, 'model') else model).named_children())[-1] # last module
- if isinstance(m, Classify): # YOLO Classify() head
- if m.linear.out_features != nc:
- m.linear = nn.Linear(m.linear.in_features, nc)
- elif isinstance(m, nn.Linear): # ResNet, EfficientNet
- if m.out_features != nc:
- setattr(model, name, nn.Linear(m.in_features, nc))
- elif isinstance(m, nn.Sequential):
- types = [type(x) for x in m]
- if nn.Linear in types:
- i = types.index(nn.Linear) # nn.Linear index
- if m[i].out_features != nc:
- m[i] = nn.Linear(m[i].in_features, nc)
- elif nn.Conv2d in types:
- i = types.index(nn.Conv2d) # nn.Conv2d index
- if m[i].out_channels != nc:
- m[i] = nn.Conv2d(m[i].in_channels, nc, m[i].kernel_size, m[i].stride, bias=m[i].bias is not None)
- def init_criterion(self):
- """Compute the classification loss between predictions and true labels."""
- return v8ClassificationLoss()
- class RTDETRDetectionModel(DetectionModel):
- def __init__(self, cfg='rtdetr-l.yaml', ch=3, nc=None, verbose=True):
- super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)
- def init_criterion(self):
- """Compute the classification loss between predictions and true labels."""
- from ultralytics.models.utils.loss import RTDETRDetectionLoss
- return RTDETRDetectionLoss(nc=self.nc, use_vfl=True)
- def loss(self, batch, preds=None):
- if not hasattr(self, 'criterion'):
- self.criterion = self.init_criterion()
- img = batch['img']
- # NOTE: preprocess gt_bbox and gt_labels to list.
- bs = len(img)
- batch_idx = batch['batch_idx']
- gt_groups = [(batch_idx == i).sum().item() for i in range(bs)]
- targets = {
- 'cls': batch['cls'].to(img.device, dtype=torch.long).view(-1),
- 'bboxes': batch['bboxes'].to(device=img.device),
- 'batch_idx': batch_idx.to(img.device, dtype=torch.long).view(-1),
- 'gt_groups': gt_groups}
- preds = self.predict(img, batch=targets) if preds is None else preds
- dec_bboxes, dec_scores, enc_bboxes, enc_scores, dn_meta = preds if self.training else preds[1]
- if dn_meta is None:
- dn_bboxes, dn_scores = None, None
- else:
- dn_bboxes, dec_bboxes = torch.split(dec_bboxes, dn_meta['dn_num_split'], dim=2)
- dn_scores, dec_scores = torch.split(dec_scores, dn_meta['dn_num_split'], dim=2)
- dec_bboxes = torch.cat([enc_bboxes.unsqueeze(0), dec_bboxes]) # (7, bs, 300, 4)
- dec_scores = torch.cat([enc_scores.unsqueeze(0), dec_scores])
- loss = self.criterion((dec_bboxes, dec_scores),
- targets,
- dn_bboxes=dn_bboxes,
- dn_scores=dn_scores,
- dn_meta=dn_meta)
- # NOTE: There are like 12 losses in RTDETR, backward with all losses but only show the main three losses.
- return sum(loss.values()), torch.as_tensor([loss[k].detach() for k in ['loss_giou', 'loss_class', 'loss_bbox']],
- device=img.device)
- def predict(self, x, profile=False, visualize=False, batch=None, augment=False):
- """
- Perform a forward pass through the network.
- Args:
- x (torch.Tensor): The input tensor to the model
- profile (bool): Print the computation time of each layer if True, defaults to False.
- visualize (bool): Save the feature maps of the model if True, defaults to False
- batch (dict): A dict including gt boxes and labels from dataloader.
- Returns:
- (torch.Tensor): The last output of the model.
- """
- y, dt = [], [] # outputs
- for m in self.model[:-1]: # except the head part
- if m.f != -1: # if not from previous layer
- x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers
- if profile:
- self._profile_one_layer(m, x, dt)
- x = m(x) # run
- y.append(x if m.i in self.save else None) # save output
- if visualize:
- feature_visualization(x, m.type, m.i, save_dir=visualize)
- head = self.model[-1]
- x = head([y[j] for j in head.f], batch) # head inference
- return x
- class Ensemble(nn.ModuleList):
- """Ensemble of models."""
- def __init__(self):
- """Initialize an ensemble of models."""
- super().__init__()
- def forward(self, x, augment=False, profile=False, visualize=False):
- """Function generates the YOLOv5 network's final layer."""
- y = [module(x, augment, profile, visualize)[0] for module in self]
- # y = torch.stack(y).max(0)[0] # max ensemble
- # y = torch.stack(y).mean(0) # mean ensemble
- y = torch.cat(y, 2) # nms ensemble, y shape(B, HW, C)
- return y, None # inference, train output
- # Functions ------------------------------------------------------------------------------------------------------------
- @contextlib.contextmanager
- def temporary_modules(modules=None):
- """
- Context manager for temporarily adding or modifying modules in Python's module cache (`sys.modules`).
- This function can be used to change the module paths during runtime. It's useful when refactoring code,
- where you've moved a module from one location to another, but you still want to support the old import
- paths for backwards compatibility.
- Args:
- modules (dict, optional): A dictionary mapping old module paths to new module paths.
- Example:
- ```python
- with temporary_modules({'old.module.path': 'new.module.path'}):
- import old.module.path # this will now import new.module.path
- ```
- Note:
- The changes are only in effect inside the context manager and are undone once the context manager exits.
- Be aware that directly manipulating `sys.modules` can lead to unpredictable results, especially in larger
- applications or libraries. Use this function with caution.
- """
- if not modules:
- modules = {}
- import importlib
- import sys
- try:
- # Set modules in sys.modules under their old name
- for old, new in modules.items():
- sys.modules[old] = importlib.import_module(new)
- yield
- finally:
- # Remove the temporary module paths
- for old in modules:
- if old in sys.modules:
- del sys.modules[old]
- def torch_safe_load(weight):
- """
- This function attempts to load a PyTorch model with the torch.load() function. If a ModuleNotFoundError is raised,
- it catches the error, logs a warning message, and attempts to install the missing module via the
- check_requirements() function. After installation, the function again attempts to load the model using torch.load().
- Args:
- weight (str): The file path of the PyTorch model.
- Returns:
- (dict): The loaded PyTorch model.
- """
- from ultralytics.utils.downloads import attempt_download_asset
- check_suffix(file=weight, suffix='.pt')
- file = attempt_download_asset(weight) # search online if missing locally
- try:
- with temporary_modules({
- 'ultralytics.yolo.utils': 'ultralytics.utils',
- 'ultralytics.yolo.v8': 'ultralytics.models.yolo',
- 'ultralytics.yolo.data': 'ultralytics.data'}): # for legacy 8.0 Classify and Pose models
- return torch.load(file, map_location='cpu'), file # load
- except ModuleNotFoundError as e: # e.name is missing module name
- if e.name == 'models':
- raise TypeError(
- emojis(f'ERROR ❌️ {weight} appears to be an Ultralytics YOLOv5 model originally trained '
- f'with https://github.com/ultralytics/yolov5.\nThis model is NOT forwards compatible with '
- f'YOLOv8 at https://github.com/ultralytics/ultralytics.'
- f"\nRecommend fixes are to train a new model using the latest 'ultralytics' package or to "
- f"run a command with an official YOLOv8 model, i.e. 'yolo predict model=yolov8n.pt'")) from e
- LOGGER.warning(f"WARNING ⚠️ {weight} appears to require '{e.name}', which is not in ultralytics requirements."
- f"\nAutoInstall will run now for '{e.name}' but this feature will be removed in the future."
- f"\nRecommend fixes are to train a new model using the latest 'ultralytics' package or to "
- f"run a command with an official YOLOv8 model, i.e. 'yolo predict model=yolov8n.pt'")
- check_requirements(e.name) # install missing module
- return torch.load(file, map_location='cpu'), file # load
- def attempt_load_weights(weights, device=None, inplace=True, fuse=False):
- """Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a."""
- ensemble = Ensemble()
- for w in weights if isinstance(weights, list) else [weights]:
- ckpt, w = torch_safe_load(w) # load ckpt
- args = {**DEFAULT_CFG_DICT, **ckpt['train_args']} if 'train_args' in ckpt else None # combined args
- model = (ckpt.get('ema') or ckpt['model']).to(device).float() # FP32 model
- # Model compatibility updates
- model.args = args # attach args to model
- model.pt_path = w # attach *.pt file path to model
- model.task = guess_model_task(model)
- if not hasattr(model, 'stride'):
- model.stride = torch.tensor([32.])
- # Append
- ensemble.append(model.fuse().eval() if fuse and hasattr(model, 'fuse') else model.eval()) # model in eval mode
- # Module updates
- for m in ensemble.modules():
- t = type(m)
- if t in (nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU, Detect, Segment):
- m.inplace = inplace
- elif t is nn.Upsample and not hasattr(m, 'recompute_scale_factor'):
- m.recompute_scale_factor = None # torch 1.11.0 compatibility
- # Return model
- if len(ensemble) == 1:
- return ensemble[-1]
- # Return ensemble
- LOGGER.info(f'Ensemble created with {weights}\n')
- for k in 'names', 'nc', 'yaml':
- setattr(ensemble, k, getattr(ensemble[0], k))
- ensemble.stride = ensemble[torch.argmax(torch.tensor([m.stride.max() for m in ensemble])).int()].stride
- assert all(ensemble[0].nc == m.nc for m in ensemble), f'Models differ in class counts {[m.nc for m in ensemble]}'
- return ensemble
- def attempt_load_one_weight(weight, device=None, inplace=True, fuse=False):
- """Loads a single model weights."""
- ckpt, weight = torch_safe_load(weight) # load ckpt
- args = {**DEFAULT_CFG_DICT, **(ckpt.get('train_args', {}))} # combine model and default args, preferring model args
- model = (ckpt.get('ema') or ckpt['model']).to(device).float() # FP32 model
- # Model compatibility updates
- model.args = {k: v for k, v in args.items() if k in DEFAULT_CFG_KEYS} # attach args to model
- model.pt_path = weight # attach *.pt file path to model
- model.task = guess_model_task(model)
- if not hasattr(model, 'stride'):
- model.stride = torch.tensor([32.])
- model = model.fuse().eval() if fuse and hasattr(model, 'fuse') else model.eval() # model in eval mode
- # Module updates
- for m in model.modules():
- t = type(m)
- if t in (nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU, Detect, Segment):
- m.inplace = inplace
- elif t is nn.Upsample and not hasattr(m, 'recompute_scale_factor'):
- m.recompute_scale_factor = None # torch 1.11.0 compatibility
- # Return model and ckpt
- return model, ckpt
- def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
- """Parse a YOLO model.yaml dictionary into a PyTorch model."""
- import ast
- # Args
- max_channels = float('inf')
- nc, act, scales = (d.get(x) for x in ('nc', 'activation', 'scales'))
- depth, width, kpt_shape = (d.get(x, 1.0) for x in ('depth_multiple', 'width_multiple', 'kpt_shape'))
- if scales:
- scale = d.get('scale')
- if not scale:
- scale = tuple(scales.keys())[0]
- LOGGER.warning(f"WARNING ⚠️ no model scale passed. Assuming scale='{scale}'.")
- depth, width, max_channels = scales[scale]
- if act:
- Conv.default_act = eval(act) # redefine default activation, i.e. Conv.default_act = nn.SiLU()
- if verbose:
- LOGGER.info(f"{colorstr('activation:')} {act}") # print
- if verbose:
- LOGGER.info(f"\n{'':>3}{'from':>20}{'n':>3}{'params':>10} {'module':<45}{'arguments':<30}")
- ch = [ch]
- layers, save, c2 = [], [], ch[-1] # layers, savelist, ch out
- for i, (f, n, m, args) in enumerate(d['backbone'] + d['head']): # from, number, module, args
- m = getattr(torch.nn, m[3:]) if 'nn.' in m else globals()[m] # get module
- for j, a in enumerate(args):
- if isinstance(a, str):
- with contextlib.suppress(ValueError):
- args[j] = locals()[a] if a in locals() else ast.literal_eval(a)
- n = n_ = max(round(n * depth), 1) if n > 1 else n # depth gain
- if m in (Classify, Conv, ConvTranspose, GhostConv, Bottleneck, GhostBottleneck, SPP, SPPF, DWConv, Focus,
- BottleneckCSP, C1, C2, C2f, C3, C3TR, C3Ghost, nn.ConvTranspose2d, DWConvTranspose2d, C3x, RepC3):
- c1, c2 = ch[f], args[0]
- if c2 != nc: # if c2 not equal to number of classes (i.e. for Classify() output)
- c2 = make_divisible(min(c2, max_channels) * width, 8)
- args = [c1, c2, *args[1:]]
- if m in (BottleneckCSP, C1, C2, C2f, C3, C3TR, C3Ghost, C3x, RepC3):
- args.insert(2, n) # number of repeats
- n = 1
- elif m is AIFI:
- args = [ch[f], *args]
- elif m in (HGStem, HGBlock):
- c1, cm, c2 = ch[f], args[0], args[1]
- args = [c1, cm, c2, *args[2:]]
- if m is HGBlock:
- args.insert(4, n) # number of repeats
- n = 1
- elif m is nn.BatchNorm2d:
- args = [ch[f]]
- elif m is Concat:
- c2 = sum(ch[x] for x in f)
- elif m in (Detect, Segment, Pose):
- args.append([ch[x] for x in f])
- if m is Segment:
- args[2] = make_divisible(min(args[2], max_channels) * width, 8)
- elif m is RTDETRDecoder: # special case, channels arg must be passed in index 1
- args.insert(1, [ch[x] for x in f])
- else:
- c2 = ch[f]
- m_ = nn.Sequential(*(m(*args) for _ in range(n))) if n > 1 else m(*args) # module
- t = str(m)[8:-2].replace('__main__.', '') # module type
- m.np = sum(x.numel() for x in m_.parameters()) # number params
- m_.i, m_.f, m_.type = i, f, t # attach index, 'from' index, type
- if verbose:
- LOGGER.info(f'{i:>3}{str(f):>20}{n_:>3}{m.np:10.0f} {t:<45}{str(args):<30}') # print
- save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1) # append to savelist
- layers.append(m_)
- if i == 0:
- ch = []
- ch.append(c2)
- return nn.Sequential(*layers), sorted(save)
- def yaml_model_load(path):
- """Load a YOLOv8 model from a YAML file."""
- import re
- path = Path(path)
- if path.stem in (f'yolov{d}{x}6' for x in 'nsmlx' for d in (5, 8)):
- new_stem = re.sub(r'(\d+)([nslmx])6(.+)?$', r'\1\2-p6\3', path.stem)
- LOGGER.warning(f'WARNING ⚠️ Ultralytics YOLO P6 models now use -p6 suffix. Renaming {path.stem} to {new_stem}.')
- path = path.with_name(new_stem + path.suffix)
- unified_path = re.sub(r'(\d+)([nslmx])(.+)?$', r'\1\3', str(path)) # i.e. yolov8x.yaml -> yolov8.yaml
- yaml_file = check_yaml(unified_path, hard=False) or check_yaml(path)
- d = yaml_load(yaml_file) # model dict
- d['scale'] = guess_model_scale(path)
- d['yaml_file'] = str(path)
- return d
- def guess_model_scale(model_path):
- """
- Takes a path to a YOLO model's YAML file as input and extracts the size character of the model's scale.
- The function uses regular expression matching to find the pattern of the model scale in the YAML file name,
- which is denoted by n, s, m, l, or x. The function returns the size character of the model scale as a string.
- Args:
- model_path (str | Path): The path to the YOLO model's YAML file.
- Returns:
- (str): The size character of the model's scale, which can be n, s, m, l, or x.
- """
- with contextlib.suppress(AttributeError):
- import re
- return re.search(r'yolov\d+([nslmx])', Path(model_path).stem).group(1) # n, s, m, l, or x
- return ''
- def guess_model_task(model):
- """
- Guess the task of a PyTorch model from its architecture or configuration.
- Args:
- model (nn.Module | dict): PyTorch model or model configuration in YAML format.
- Returns:
- (str): Task of the model ('detect', 'segment', 'classify', 'pose').
- Raises:
- SyntaxError: If the task of the model could not be determined.
- """
- def cfg2task(cfg):
- """Guess from YAML dictionary."""
- m = cfg['head'][-1][-2].lower() # output module name
- if m in ('classify', 'classifier', 'cls', 'fc'):
- return 'classify'
- if m == 'detect':
- return 'detect'
- if m == 'segment':
- return 'segment'
- if m == 'pose':
- return 'pose'
- # Guess from model cfg
- if isinstance(model, dict):
- with contextlib.suppress(Exception):
- return cfg2task(model)
- # Guess from PyTorch model
- if isinstance(model, nn.Module): # PyTorch model
- for x in 'model.args', 'model.model.args', 'model.model.model.args':
- with contextlib.suppress(Exception):
- return eval(x)['task']
- for x in 'model.yaml', 'model.model.yaml', 'model.model.model.yaml':
- with contextlib.suppress(Exception):
- return cfg2task(eval(x))
- for m in model.modules():
- if isinstance(m, Detect):
- return 'detect'
- elif isinstance(m, Segment):
- return 'segment'
- elif isinstance(m, Classify):
- return 'classify'
- elif isinstance(m, Pose):
- return 'pose'
- # Guess from model filename
- if isinstance(model, (str, Path)):
- model = Path(model)
- if '-seg' in model.stem or 'segment' in model.parts:
- return 'segment'
- elif '-cls' in model.stem or 'classify' in model.parts:
- return 'classify'
- elif '-pose' in model.stem or 'pose' in model.parts:
- return 'pose'
- elif 'detect' in model.parts:
- return 'detect'
- # Unable to determine task from model
- LOGGER.warning("WARNING ⚠️ Unable to automatically guess model task, assuming 'task=detect'. "
- "Explicitly define task for your model, i.e. 'task=detect', 'segment', 'classify', or 'pose'.")
- return 'detect' # assume detect