| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172 | # Ultralytics YOLO 🚀, AGPL-3.0 licensefrom copy import copyimport torchfrom ultralytics.models.yolo.detect import DetectionTrainerfrom ultralytics.nn.tasks import RTDETRDetectionModelfrom ultralytics.utils import RANK, colorstrfrom .val import RTDETRDataset, RTDETRValidatorclass RTDETRTrainer(DetectionTrainer):    """    A class extending the DetectionTrainer class for training based on an RT-DETR detection model.    Notes:        - F.grid_sample used in rt-detr does not support the `deterministic=True` argument.        - AMP training can lead to NaN outputs and may produce errors during bipartite graph matching.    Example:        ```python        from ultralytics.models.rtdetr.train import RTDETRTrainer        args = dict(model='rtdetr-l.yaml', data='coco8.yaml', imgsz=640, epochs=3)        trainer = RTDETRTrainer(overrides=args)        trainer.train()        ```    """    def get_model(self, cfg=None, weights=None, verbose=True):        """Return a YOLO detection model."""        model = RTDETRDetectionModel(cfg, nc=self.data['nc'], verbose=verbose and RANK == -1)        if weights:            model.load(weights)        return model    def build_dataset(self, img_path, mode='val', batch=None):        """Build RTDETR Dataset        Args:            img_path (str): Path to the folder containing images.            mode (str): `train` mode or `val` mode, users are able to customize different augmentations for each mode.            batch (int, optional): Size of batches, this is for `rect`. Defaults to None.        """        return RTDETRDataset(            img_path=img_path,            imgsz=self.args.imgsz,            batch_size=batch,            augment=mode == 'train',  # no augmentation            hyp=self.args,            rect=False,  # no rect            cache=self.args.cache or None,            prefix=colorstr(f'{mode}: '),            data=self.data)    def get_validator(self):        """Returns a DetectionValidator for RTDETR model validation."""        self.loss_names = 'giou_loss', 'cls_loss', 'l1_loss'        return RTDETRValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args))    def preprocess_batch(self, batch):        """Preprocesses a batch of images by scaling and converting to float."""        batch = super().preprocess_batch(batch)        bs = len(batch['img'])        batch_idx = batch['batch_idx']        gt_bbox, gt_class = [], []        for i in range(bs):            gt_bbox.append(batch['bboxes'][batch_idx == i].to(batch_idx.device))            gt_class.append(batch['cls'][batch_idx == i].to(device=batch_idx.device, dtype=torch.long))        return batch
 |