bot_sort.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147
  1. # Ultralytics YOLO 🚀, AGPL-3.0 license
  2. from collections import deque
  3. import numpy as np
  4. from .basetrack import TrackState
  5. from .byte_tracker import BYTETracker, STrack
  6. from .utils import matching
  7. from .utils.gmc import GMC
  8. from .utils.kalman_filter import KalmanFilterXYWH
  9. class BOTrack(STrack):
  10. shared_kalman = KalmanFilterXYWH()
  11. def __init__(self, tlwh, score, cls, feat=None, feat_history=50):
  12. """Initialize YOLOv8 object with temporal parameters, such as feature history, alpha and current features."""
  13. super().__init__(tlwh, score, cls)
  14. self.smooth_feat = None
  15. self.curr_feat = None
  16. if feat is not None:
  17. self.update_features(feat)
  18. self.features = deque([], maxlen=feat_history)
  19. self.alpha = 0.9
  20. def update_features(self, feat):
  21. """Update features vector and smooth it using exponential moving average."""
  22. feat /= np.linalg.norm(feat)
  23. self.curr_feat = feat
  24. if self.smooth_feat is None:
  25. self.smooth_feat = feat
  26. else:
  27. self.smooth_feat = self.alpha * self.smooth_feat + (1 - self.alpha) * feat
  28. self.features.append(feat)
  29. self.smooth_feat /= np.linalg.norm(self.smooth_feat)
  30. def predict(self):
  31. """Predicts the mean and covariance using Kalman filter."""
  32. mean_state = self.mean.copy()
  33. if self.state != TrackState.Tracked:
  34. mean_state[6] = 0
  35. mean_state[7] = 0
  36. self.mean, self.covariance = self.kalman_filter.predict(mean_state, self.covariance)
  37. def re_activate(self, new_track, frame_id, new_id=False):
  38. """Reactivates a track with updated features and optionally assigns a new ID."""
  39. if new_track.curr_feat is not None:
  40. self.update_features(new_track.curr_feat)
  41. super().re_activate(new_track, frame_id, new_id)
  42. def update(self, new_track, frame_id):
  43. """Update the YOLOv8 instance with new track and frame ID."""
  44. if new_track.curr_feat is not None:
  45. self.update_features(new_track.curr_feat)
  46. super().update(new_track, frame_id)
  47. @property
  48. def tlwh(self):
  49. """Get current position in bounding box format `(top left x, top left y,
  50. width, height)`.
  51. """
  52. if self.mean is None:
  53. return self._tlwh.copy()
  54. ret = self.mean[:4].copy()
  55. ret[:2] -= ret[2:] / 2
  56. return ret
  57. @staticmethod
  58. def multi_predict(stracks):
  59. """Predicts the mean and covariance of multiple object tracks using shared Kalman filter."""
  60. if len(stracks) <= 0:
  61. return
  62. multi_mean = np.asarray([st.mean.copy() for st in stracks])
  63. multi_covariance = np.asarray([st.covariance for st in stracks])
  64. for i, st in enumerate(stracks):
  65. if st.state != TrackState.Tracked:
  66. multi_mean[i][6] = 0
  67. multi_mean[i][7] = 0
  68. multi_mean, multi_covariance = BOTrack.shared_kalman.multi_predict(multi_mean, multi_covariance)
  69. for i, (mean, cov) in enumerate(zip(multi_mean, multi_covariance)):
  70. stracks[i].mean = mean
  71. stracks[i].covariance = cov
  72. def convert_coords(self, tlwh):
  73. """Converts Top-Left-Width-Height bounding box coordinates to X-Y-Width-Height format."""
  74. return self.tlwh_to_xywh(tlwh)
  75. @staticmethod
  76. def tlwh_to_xywh(tlwh):
  77. """Convert bounding box to format `(center x, center y, width,
  78. height)`.
  79. """
  80. ret = np.asarray(tlwh).copy()
  81. ret[:2] += ret[2:] / 2
  82. return ret
  83. class BOTSORT(BYTETracker):
  84. def __init__(self, args, frame_rate=30):
  85. """Initialize YOLOv8 object with ReID module and GMC algorithm."""
  86. super().__init__(args, frame_rate)
  87. # ReID module
  88. self.proximity_thresh = args.proximity_thresh
  89. self.appearance_thresh = args.appearance_thresh
  90. if args.with_reid:
  91. # Haven't supported BoT-SORT(reid) yet
  92. self.encoder = None
  93. self.gmc = GMC(method=args.gmc_method)
  94. def get_kalmanfilter(self):
  95. """Returns an instance of KalmanFilterXYWH for object tracking."""
  96. return KalmanFilterXYWH()
  97. def init_track(self, dets, scores, cls, img=None):
  98. """Initialize track with detections, scores, and classes."""
  99. if len(dets) == 0:
  100. return []
  101. if self.args.with_reid and self.encoder is not None:
  102. features_keep = self.encoder.inference(img, dets)
  103. return [BOTrack(xyxy, s, c, f) for (xyxy, s, c, f) in zip(dets, scores, cls, features_keep)] # detections
  104. else:
  105. return [BOTrack(xyxy, s, c) for (xyxy, s, c) in zip(dets, scores, cls)] # detections
  106. def get_dists(self, tracks, detections):
  107. """Get distances between tracks and detections using IoU and (optionally) ReID embeddings."""
  108. dists = matching.iou_distance(tracks, detections)
  109. dists_mask = (dists > self.proximity_thresh)
  110. # TODO: mot20
  111. # if not self.args.mot20:
  112. dists = matching.fuse_score(dists, detections)
  113. if self.args.with_reid and self.encoder is not None:
  114. emb_dists = matching.embedding_distance(tracks, detections) / 2.0
  115. emb_dists[emb_dists > self.appearance_thresh] = 1.0
  116. emb_dists[dists_mask] = 1.0
  117. dists = np.minimum(dists, emb_dists)
  118. return dists
  119. def multi_predict(self, tracks):
  120. """Predict and track multiple objects with YOLOv8 model."""
  121. BOTrack.multi_predict(tracks)