model.py 886 B

123456789101112131415161718192021222324252627282930
  1. # Ultralytics YOLO 🚀, AGPL-3.0 license
  2. """
  3. RT-DETR model interface
  4. """
  5. from ultralytics.engine.model import Model
  6. from ultralytics.nn.tasks import RTDETRDetectionModel
  7. from .predict import RTDETRPredictor
  8. from .train import RTDETRTrainer
  9. from .val import RTDETRValidator
  10. class RTDETR(Model):
  11. """
  12. RTDETR model interface.
  13. """
  14. def __init__(self, model='rtdetr-l.pt') -> None:
  15. if model and model.split('.')[-1] not in ('pt', 'yaml', 'yml'):
  16. raise NotImplementedError('RT-DETR only supports creating from *.pt file or *.yaml file.')
  17. super().__init__(model=model, task='detect')
  18. @property
  19. def task_map(self):
  20. return {
  21. 'detect': {
  22. 'predictor': RTDETRPredictor,
  23. 'validator': RTDETRValidator,
  24. 'trainer': RTDETRTrainer,
  25. 'model': RTDETRDetectionModel}}