model.py 1.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061
  1. # Ultralytics YOLO 🚀, AGPL-3.0 license
  2. """
  3. YOLO-NAS model interface.
  4. Example:
  5. ```python
  6. from ultralytics import NAS
  7. model = NAS('yolo_nas_s')
  8. results = model.predict('ultralytics/assets/bus.jpg')
  9. ```
  10. """
  11. from pathlib import Path
  12. import torch
  13. from ultralytics.engine.model import Model
  14. from ultralytics.utils.torch_utils import model_info, smart_inference_mode
  15. from .predict import NASPredictor
  16. from .val import NASValidator
  17. class NAS(Model):
  18. def __init__(self, model='yolo_nas_s.pt') -> None:
  19. assert Path(model).suffix not in ('.yaml', '.yml'), 'YOLO-NAS models only support pre-trained models.'
  20. super().__init__(model, task='detect')
  21. @smart_inference_mode()
  22. def _load(self, weights: str, task: str):
  23. # Load or create new NAS model
  24. import super_gradients
  25. suffix = Path(weights).suffix
  26. if suffix == '.pt':
  27. self.model = torch.load(weights)
  28. elif suffix == '':
  29. self.model = super_gradients.training.models.get(weights, pretrained_weights='coco')
  30. # Standardize model
  31. self.model.fuse = lambda verbose=True: self.model
  32. self.model.stride = torch.tensor([32])
  33. self.model.names = dict(enumerate(self.model._class_names))
  34. self.model.is_fused = lambda: False # for info()
  35. self.model.yaml = {} # for info()
  36. self.model.pt_path = weights # for export()
  37. self.model.task = 'detect' # for export()
  38. def info(self, detailed=False, verbose=True):
  39. """
  40. Logs model info.
  41. Args:
  42. detailed (bool): Show detailed information about model.
  43. verbose (bool): Controls verbosity.
  44. """
  45. return model_info(self.model, detailed=detailed, verbose=verbose, imgsz=640)
  46. @property
  47. def task_map(self):
  48. return {'detect': {'predictor': NASPredictor, 'validator': NASValidator}}