train.py 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172
  1. # Ultralytics YOLO 🚀, AGPL-3.0 license
  2. from copy import copy
  3. import torch
  4. from ultralytics.models.yolo.detect import DetectionTrainer
  5. from ultralytics.nn.tasks import RTDETRDetectionModel
  6. from ultralytics.utils import RANK, colorstr
  7. from .val import RTDETRDataset, RTDETRValidator
  8. class RTDETRTrainer(DetectionTrainer):
  9. """
  10. A class extending the DetectionTrainer class for training based on an RT-DETR detection model.
  11. Notes:
  12. - F.grid_sample used in rt-detr does not support the `deterministic=True` argument.
  13. - AMP training can lead to NaN outputs and may produce errors during bipartite graph matching.
  14. Example:
  15. ```python
  16. from ultralytics.models.rtdetr.train import RTDETRTrainer
  17. args = dict(model='rtdetr-l.yaml', data='coco8.yaml', imgsz=640, epochs=3)
  18. trainer = RTDETRTrainer(overrides=args)
  19. trainer.train()
  20. ```
  21. """
  22. def get_model(self, cfg=None, weights=None, verbose=True):
  23. """Return a YOLO detection model."""
  24. model = RTDETRDetectionModel(cfg, nc=self.data['nc'], verbose=verbose and RANK == -1)
  25. if weights:
  26. model.load(weights)
  27. return model
  28. def build_dataset(self, img_path, mode='val', batch=None):
  29. """Build RTDETR Dataset
  30. Args:
  31. img_path (str): Path to the folder containing images.
  32. mode (str): `train` mode or `val` mode, users are able to customize different augmentations for each mode.
  33. batch (int, optional): Size of batches, this is for `rect`. Defaults to None.
  34. """
  35. return RTDETRDataset(
  36. img_path=img_path,
  37. imgsz=self.args.imgsz,
  38. batch_size=batch,
  39. augment=mode == 'train', # no augmentation
  40. hyp=self.args,
  41. rect=False, # no rect
  42. cache=self.args.cache or None,
  43. prefix=colorstr(f'{mode}: '),
  44. data=self.data)
  45. def get_validator(self):
  46. """Returns a DetectionValidator for RTDETR model validation."""
  47. self.loss_names = 'giou_loss', 'cls_loss', 'l1_loss'
  48. return RTDETRValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args))
  49. def preprocess_batch(self, batch):
  50. """Preprocesses a batch of images by scaling and converting to float."""
  51. batch = super().preprocess_batch(batch)
  52. bs = len(batch['img'])
  53. batch_idx = batch['batch_idx']
  54. gt_bbox, gt_class = [], []
  55. for i in range(bs):
  56. gt_bbox.append(batch['bboxes'][batch_idx == i].to(batch_idx.device))
  57. gt_class.append(batch['cls'][batch_idx == i].to(device=batch_idx.device, dtype=torch.long))
  58. return batch