123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172 |
- # Ultralytics YOLO 🚀, AGPL-3.0 license
- from copy import copy
- import torch
- from ultralytics.models.yolo.detect import DetectionTrainer
- from ultralytics.nn.tasks import RTDETRDetectionModel
- from ultralytics.utils import RANK, colorstr
- from .val import RTDETRDataset, RTDETRValidator
- class RTDETRTrainer(DetectionTrainer):
- """
- A class extending the DetectionTrainer class for training based on an RT-DETR detection model.
- Notes:
- - F.grid_sample used in rt-detr does not support the `deterministic=True` argument.
- - AMP training can lead to NaN outputs and may produce errors during bipartite graph matching.
- Example:
- ```python
- from ultralytics.models.rtdetr.train import RTDETRTrainer
- args = dict(model='rtdetr-l.yaml', data='coco8.yaml', imgsz=640, epochs=3)
- trainer = RTDETRTrainer(overrides=args)
- trainer.train()
- ```
- """
- def get_model(self, cfg=None, weights=None, verbose=True):
- """Return a YOLO detection model."""
- model = RTDETRDetectionModel(cfg, nc=self.data['nc'], verbose=verbose and RANK == -1)
- if weights:
- model.load(weights)
- return model
- def build_dataset(self, img_path, mode='val', batch=None):
- """Build RTDETR Dataset
- Args:
- img_path (str): Path to the folder containing images.
- mode (str): `train` mode or `val` mode, users are able to customize different augmentations for each mode.
- batch (int, optional): Size of batches, this is for `rect`. Defaults to None.
- """
- return RTDETRDataset(
- img_path=img_path,
- imgsz=self.args.imgsz,
- batch_size=batch,
- augment=mode == 'train', # no augmentation
- hyp=self.args,
- rect=False, # no rect
- cache=self.args.cache or None,
- prefix=colorstr(f'{mode}: '),
- data=self.data)
- def get_validator(self):
- """Returns a DetectionValidator for RTDETR model validation."""
- self.loss_names = 'giou_loss', 'cls_loss', 'l1_loss'
- return RTDETRValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args))
- def preprocess_batch(self, batch):
- """Preprocesses a batch of images by scaling and converting to float."""
- batch = super().preprocess_batch(batch)
- bs = len(batch['img'])
- batch_idx = batch['batch_idx']
- gt_bbox, gt_class = [], []
- for i in range(bs):
- gt_bbox.append(batch['bboxes'][batch_idx == i].to(batch_idx.device))
- gt_class.append(batch['cls'][batch_idx == i].to(device=batch_idx.device, dtype=torch.long))
- return batch
|