gmc.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279
  1. # Ultralytics YOLO 🚀, AGPL-3.0 license
  2. import copy
  3. import cv2
  4. import numpy as np
  5. from ultralytics.utils import LOGGER
  6. class GMC:
  7. def __init__(self, method='sparseOptFlow', downscale=2):
  8. """Initialize a video tracker with specified parameters."""
  9. super().__init__()
  10. self.method = method
  11. self.downscale = max(1, int(downscale))
  12. if self.method == 'orb':
  13. self.detector = cv2.FastFeatureDetector_create(20)
  14. self.extractor = cv2.ORB_create()
  15. self.matcher = cv2.BFMatcher(cv2.NORM_HAMMING)
  16. elif self.method == 'sift':
  17. self.detector = cv2.SIFT_create(nOctaveLayers=3, contrastThreshold=0.02, edgeThreshold=20)
  18. self.extractor = cv2.SIFT_create(nOctaveLayers=3, contrastThreshold=0.02, edgeThreshold=20)
  19. self.matcher = cv2.BFMatcher(cv2.NORM_L2)
  20. elif self.method == 'ecc':
  21. number_of_iterations = 5000
  22. termination_eps = 1e-6
  23. self.warp_mode = cv2.MOTION_EUCLIDEAN
  24. self.criteria = (cv2.TERM_CRITERIA_EPS | cv2.TERM_CRITERIA_COUNT, number_of_iterations, termination_eps)
  25. elif self.method == 'sparseOptFlow':
  26. self.feature_params = dict(maxCorners=1000,
  27. qualityLevel=0.01,
  28. minDistance=1,
  29. blockSize=3,
  30. useHarrisDetector=False,
  31. k=0.04)
  32. elif self.method in ['none', 'None', None]:
  33. self.method = None
  34. else:
  35. raise ValueError(f'Error: Unknown GMC method:{method}')
  36. self.prevFrame = None
  37. self.prevKeyPoints = None
  38. self.prevDescriptors = None
  39. self.initializedFirstFrame = False
  40. def apply(self, raw_frame, detections=None):
  41. """Apply object detection on a raw frame using specified method."""
  42. if self.method in ['orb', 'sift']:
  43. return self.applyFeatures(raw_frame, detections)
  44. elif self.method == 'ecc':
  45. return self.applyEcc(raw_frame, detections)
  46. elif self.method == 'sparseOptFlow':
  47. return self.applySparseOptFlow(raw_frame, detections)
  48. else:
  49. return np.eye(2, 3)
  50. def applyEcc(self, raw_frame, detections=None):
  51. """Initialize."""
  52. height, width, _ = raw_frame.shape
  53. frame = cv2.cvtColor(raw_frame, cv2.COLOR_BGR2GRAY)
  54. H = np.eye(2, 3, dtype=np.float32)
  55. # Downscale image (TODO: consider using pyramids)
  56. if self.downscale > 1.0:
  57. frame = cv2.GaussianBlur(frame, (3, 3), 1.5)
  58. frame = cv2.resize(frame, (width // self.downscale, height // self.downscale))
  59. width = width // self.downscale
  60. height = height // self.downscale
  61. # Handle first frame
  62. if not self.initializedFirstFrame:
  63. # Initialize data
  64. self.prevFrame = frame.copy()
  65. # Initialization done
  66. self.initializedFirstFrame = True
  67. return H
  68. # Run the ECC algorithm. The results are stored in warp_matrix.
  69. # (cc, H) = cv2.findTransformECC(self.prevFrame, frame, H, self.warp_mode, self.criteria)
  70. try:
  71. (cc, H) = cv2.findTransformECC(self.prevFrame, frame, H, self.warp_mode, self.criteria, None, 1)
  72. except Exception as e:
  73. LOGGER.warning(f'WARNING: find transform failed. Set warp as identity {e}')
  74. return H
  75. def applyFeatures(self, raw_frame, detections=None):
  76. """Initialize."""
  77. height, width, _ = raw_frame.shape
  78. frame = cv2.cvtColor(raw_frame, cv2.COLOR_BGR2GRAY)
  79. H = np.eye(2, 3)
  80. # Downscale image (TODO: consider using pyramids)
  81. if self.downscale > 1.0:
  82. # frame = cv2.GaussianBlur(frame, (3, 3), 1.5)
  83. frame = cv2.resize(frame, (width // self.downscale, height // self.downscale))
  84. width = width // self.downscale
  85. height = height // self.downscale
  86. # Find the keypoints
  87. mask = np.zeros_like(frame)
  88. # mask[int(0.05 * height): int(0.95 * height), int(0.05 * width): int(0.95 * width)] = 255
  89. mask[int(0.02 * height):int(0.98 * height), int(0.02 * width):int(0.98 * width)] = 255
  90. if detections is not None:
  91. for det in detections:
  92. tlbr = (det[:4] / self.downscale).astype(np.int_)
  93. mask[tlbr[1]:tlbr[3], tlbr[0]:tlbr[2]] = 0
  94. keypoints = self.detector.detect(frame, mask)
  95. # Compute the descriptors
  96. keypoints, descriptors = self.extractor.compute(frame, keypoints)
  97. # Handle first frame
  98. if not self.initializedFirstFrame:
  99. # Initialize data
  100. self.prevFrame = frame.copy()
  101. self.prevKeyPoints = copy.copy(keypoints)
  102. self.prevDescriptors = copy.copy(descriptors)
  103. # Initialization done
  104. self.initializedFirstFrame = True
  105. return H
  106. # Match descriptors.
  107. knnMatches = self.matcher.knnMatch(self.prevDescriptors, descriptors, 2)
  108. # Filtered matches based on smallest spatial distance
  109. matches = []
  110. spatialDistances = []
  111. maxSpatialDistance = 0.25 * np.array([width, height])
  112. # Handle empty matches case
  113. if len(knnMatches) == 0:
  114. # Store to next iteration
  115. self.prevFrame = frame.copy()
  116. self.prevKeyPoints = copy.copy(keypoints)
  117. self.prevDescriptors = copy.copy(descriptors)
  118. return H
  119. for m, n in knnMatches:
  120. if m.distance < 0.9 * n.distance:
  121. prevKeyPointLocation = self.prevKeyPoints[m.queryIdx].pt
  122. currKeyPointLocation = keypoints[m.trainIdx].pt
  123. spatialDistance = (prevKeyPointLocation[0] - currKeyPointLocation[0],
  124. prevKeyPointLocation[1] - currKeyPointLocation[1])
  125. if (np.abs(spatialDistance[0]) < maxSpatialDistance[0]) and \
  126. (np.abs(spatialDistance[1]) < maxSpatialDistance[1]):
  127. spatialDistances.append(spatialDistance)
  128. matches.append(m)
  129. meanSpatialDistances = np.mean(spatialDistances, 0)
  130. stdSpatialDistances = np.std(spatialDistances, 0)
  131. inliers = (spatialDistances - meanSpatialDistances) < 2.5 * stdSpatialDistances
  132. goodMatches = []
  133. prevPoints = []
  134. currPoints = []
  135. for i in range(len(matches)):
  136. if inliers[i, 0] and inliers[i, 1]:
  137. goodMatches.append(matches[i])
  138. prevPoints.append(self.prevKeyPoints[matches[i].queryIdx].pt)
  139. currPoints.append(keypoints[matches[i].trainIdx].pt)
  140. prevPoints = np.array(prevPoints)
  141. currPoints = np.array(currPoints)
  142. # Draw the keypoint matches on the output image
  143. # if False:
  144. # import matplotlib.pyplot as plt
  145. # matches_img = np.hstack((self.prevFrame, frame))
  146. # matches_img = cv2.cvtColor(matches_img, cv2.COLOR_GRAY2BGR)
  147. # W = np.size(self.prevFrame, 1)
  148. # for m in goodMatches:
  149. # prev_pt = np.array(self.prevKeyPoints[m.queryIdx].pt, dtype=np.int_)
  150. # curr_pt = np.array(keypoints[m.trainIdx].pt, dtype=np.int_)
  151. # curr_pt[0] += W
  152. # color = np.random.randint(0, 255, 3)
  153. # color = (int(color[0]), int(color[1]), int(color[2]))
  154. #
  155. # matches_img = cv2.line(matches_img, prev_pt, curr_pt, tuple(color), 1, cv2.LINE_AA)
  156. # matches_img = cv2.circle(matches_img, prev_pt, 2, tuple(color), -1)
  157. # matches_img = cv2.circle(matches_img, curr_pt, 2, tuple(color), -1)
  158. #
  159. # plt.figure()
  160. # plt.imshow(matches_img)
  161. # plt.show()
  162. # Find rigid matrix
  163. if (np.size(prevPoints, 0) > 4) and (np.size(prevPoints, 0) == np.size(prevPoints, 0)):
  164. H, inliers = cv2.estimateAffinePartial2D(prevPoints, currPoints, cv2.RANSAC)
  165. # Handle downscale
  166. if self.downscale > 1.0:
  167. H[0, 2] *= self.downscale
  168. H[1, 2] *= self.downscale
  169. else:
  170. LOGGER.warning('WARNING: not enough matching points')
  171. # Store to next iteration
  172. self.prevFrame = frame.copy()
  173. self.prevKeyPoints = copy.copy(keypoints)
  174. self.prevDescriptors = copy.copy(descriptors)
  175. return H
  176. def applySparseOptFlow(self, raw_frame, detections=None):
  177. """Initialize."""
  178. height, width, _ = raw_frame.shape
  179. frame = cv2.cvtColor(raw_frame, cv2.COLOR_BGR2GRAY)
  180. H = np.eye(2, 3)
  181. # Downscale image
  182. if self.downscale > 1.0:
  183. # frame = cv2.GaussianBlur(frame, (3, 3), 1.5)
  184. frame = cv2.resize(frame, (width // self.downscale, height // self.downscale))
  185. # Find the keypoints
  186. keypoints = cv2.goodFeaturesToTrack(frame, mask=None, **self.feature_params)
  187. # Handle first frame
  188. if not self.initializedFirstFrame:
  189. # Initialize data
  190. self.prevFrame = frame.copy()
  191. self.prevKeyPoints = copy.copy(keypoints)
  192. # Initialization done
  193. self.initializedFirstFrame = True
  194. return H
  195. # Find correspondences
  196. matchedKeypoints, status, err = cv2.calcOpticalFlowPyrLK(self.prevFrame, frame, self.prevKeyPoints, None)
  197. # Leave good correspondences only
  198. prevPoints = []
  199. currPoints = []
  200. for i in range(len(status)):
  201. if status[i]:
  202. prevPoints.append(self.prevKeyPoints[i])
  203. currPoints.append(matchedKeypoints[i])
  204. prevPoints = np.array(prevPoints)
  205. currPoints = np.array(currPoints)
  206. # Find rigid matrix
  207. if (np.size(prevPoints, 0) > 4) and (np.size(prevPoints, 0) == np.size(prevPoints, 0)):
  208. H, inliers = cv2.estimateAffinePartial2D(prevPoints, currPoints, cv2.RANSAC)
  209. # Handle downscale
  210. if self.downscale > 1.0:
  211. H[0, 2] *= self.downscale
  212. H[1, 2] *= self.downscale
  213. else:
  214. LOGGER.warning('WARNING: not enough matching points')
  215. # Store to next iteration
  216. self.prevFrame = frame.copy()
  217. self.prevKeyPoints = copy.copy(keypoints)
  218. return H