track.py 2.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566
  1. # Ultralytics YOLO 🚀, AGPL-3.0 license
  2. from functools import partial
  3. import torch
  4. from ultralytics.utils import IterableSimpleNamespace, yaml_load
  5. from ultralytics.utils.checks import check_yaml
  6. from .bot_sort import BOTSORT
  7. from .byte_tracker import BYTETracker
  8. TRACKER_MAP = {'bytetrack': BYTETracker, 'botsort': BOTSORT}
  9. def on_predict_start(predictor, persist=False):
  10. """
  11. Initialize trackers for object tracking during prediction.
  12. Args:
  13. predictor (object): The predictor object to initialize trackers for.
  14. persist (bool, optional): Whether to persist the trackers if they already exist. Defaults to False.
  15. Raises:
  16. AssertionError: If the tracker_type is not 'bytetrack' or 'botsort'.
  17. """
  18. if hasattr(predictor, 'trackers') and persist:
  19. return
  20. tracker = check_yaml(predictor.args.tracker)
  21. cfg = IterableSimpleNamespace(**yaml_load(tracker))
  22. assert cfg.tracker_type in ['bytetrack', 'botsort'], \
  23. f"Only support 'bytetrack' and 'botsort' for now, but got '{cfg.tracker_type}'"
  24. trackers = []
  25. for _ in range(predictor.dataset.bs):
  26. tracker = TRACKER_MAP[cfg.tracker_type](args=cfg, frame_rate=30)
  27. trackers.append(tracker)
  28. predictor.trackers = trackers
  29. def on_predict_postprocess_end(predictor):
  30. """Postprocess detected boxes and update with object tracking."""
  31. bs = predictor.dataset.bs
  32. im0s = predictor.batch[1]
  33. for i in range(bs):
  34. det = predictor.results[i].boxes.cpu().numpy()
  35. if len(det) == 0:
  36. continue
  37. tracks = predictor.trackers[i].update(det, im0s[i])
  38. if len(tracks) == 0:
  39. continue
  40. idx = tracks[:, -1].astype(int)
  41. predictor.results[i] = predictor.results[i][idx]
  42. predictor.results[i].update(boxes=torch.as_tensor(tracks[:, :-1]))
  43. def register_tracker(model, persist):
  44. """
  45. Register tracking callbacks to the model for object tracking during prediction.
  46. Args:
  47. model (object): The model object to register tracking callbacks for.
  48. persist (bool): Whether to persist the trackers if they already exist.
  49. """
  50. model.add_callback('on_predict_start', partial(on_predict_start, persist=persist))
  51. model.add_callback('on_predict_postprocess_end', on_predict_postprocess_end)