model.py 1.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051
  1. # Ultralytics YOLO 🚀, AGPL-3.0 license
  2. """
  3. SAM model interface
  4. """
  5. from pathlib import Path
  6. from ultralytics.engine.model import Model
  7. from ultralytics.utils.torch_utils import model_info
  8. from .build import build_sam
  9. from .predict import Predictor
  10. class SAM(Model):
  11. """
  12. SAM model interface.
  13. """
  14. def __init__(self, model='sam_b.pt') -> None:
  15. if model and Path(model).suffix not in ('.pt', '.pth'):
  16. raise NotImplementedError('SAM prediction requires pre-trained *.pt or *.pth model.')
  17. super().__init__(model=model, task='segment')
  18. def _load(self, weights: str, task=None):
  19. self.model = build_sam(weights)
  20. def predict(self, source, stream=False, bboxes=None, points=None, labels=None, **kwargs):
  21. """Predicts and returns segmentation masks for given image or video source."""
  22. overrides = dict(conf=0.25, task='segment', mode='predict', imgsz=1024)
  23. kwargs.update(overrides)
  24. prompts = dict(bboxes=bboxes, points=points, labels=labels)
  25. return super().predict(source, stream, prompts=prompts, **kwargs)
  26. def __call__(self, source=None, stream=False, bboxes=None, points=None, labels=None, **kwargs):
  27. """Calls the 'predict' function with given arguments to perform object detection."""
  28. return self.predict(source, stream, bboxes, points, labels, **kwargs)
  29. def info(self, detailed=False, verbose=True):
  30. """
  31. Logs model info.
  32. Args:
  33. detailed (bool): Show detailed information about model.
  34. verbose (bool): Controls verbosity.
  35. """
  36. return model_info(self.model, detailed=detailed, verbose=verbose)
  37. @property
  38. def task_map(self):
  39. return {'segment': {'predictor': Predictor}}