tracker.py 3.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192
  1. from deep_sort.utils.parser import get_config
  2. from deep_sort.deep_sort import DeepSort
  3. import torch
  4. import cv2
  5. palette = (2 ** 11 - 1, 2 ** 15 - 1, 2 ** 20 - 1)
  6. cfg = get_config()
  7. cfg.merge_from_file("deep_sort/configs/deep_sort.yaml")
  8. deepsort = DeepSort(cfg.DEEPSORT.REID_CKPT,
  9. max_dist=cfg.DEEPSORT.MAX_DIST, min_confidence=cfg.DEEPSORT.MIN_CONFIDENCE,
  10. nms_max_overlap=cfg.DEEPSORT.NMS_MAX_OVERLAP, max_iou_distance=cfg.DEEPSORT.MAX_IOU_DISTANCE,
  11. max_age=cfg.DEEPSORT.MAX_AGE, n_init=cfg.DEEPSORT.N_INIT, nn_budget=cfg.DEEPSORT.NN_BUDGET,
  12. use_cuda=True)
  13. def plot_bboxes(image, bboxes, line_thickness=None):
  14. # Plots one bounding box on image img
  15. tl = line_thickness or round(
  16. 0.002 * (image.shape[0] + image.shape[1]) / 2) + 1 # line/font thickness
  17. for (x1, y1, x2, y2, cls_id, pos_id) in bboxes:
  18. if cls_id in ['person']:
  19. color = (0, 0, 255)
  20. else:
  21. color = (0, 255, 0)
  22. c1, c2 = (x1, y1), (x2, y2)
  23. cv2.rectangle(image, c1, c2, color, thickness=tl, lineType=cv2.LINE_AA)
  24. tf = max(tl - 1, 1) # font thickness
  25. t_size = cv2.getTextSize(cls_id, 0, fontScale=tl / 3, thickness=tf)[0]
  26. c2 = c1[0] + t_size[0], c1[1] - t_size[1] - 3
  27. cv2.rectangle(image, c1, c2, color, -1, cv2.LINE_AA) # filled
  28. cv2.putText(image, '{} ID-{}'.format(cls_id, pos_id), (c1[0], c1[1] - 2), 0, tl / 3,
  29. [225, 255, 255], thickness=tf, lineType=cv2.LINE_AA)
  30. return image
  31. def update_tracker(target_detector, image):
  32. new_faces = []
  33. _, bboxes = target_detector.detect(image)
  34. bbox_xywh = []
  35. confs = []
  36. clss = []
  37. for x1, y1, x2, y2, cls_id, conf in bboxes:
  38. obj = [
  39. int((x1+x2)/2), int((y1+y2)/2),
  40. x2-x1, y2-y1
  41. ]
  42. bbox_xywh.append(obj)
  43. confs.append(conf)
  44. clss.append(cls_id)
  45. xywhs = torch.Tensor(bbox_xywh)
  46. confss = torch.Tensor(confs)
  47. outputs = deepsort.update(xywhs, confss, clss, image)
  48. bboxes2draw = []
  49. face_bboxes = []
  50. current_ids = []
  51. for value in list(outputs):
  52. x1, y1, x2, y2, cls_, track_id = value
  53. bboxes2draw.append(
  54. (x1, y1, x2, y2, cls_, track_id)
  55. )
  56. current_ids.append(track_id)
  57. if cls_ == 'face':
  58. if not track_id in target_detector.faceTracker:
  59. target_detector.faceTracker[track_id] = 0
  60. face = image[y1:y2, x1:x2]
  61. new_faces.append((face, track_id))
  62. face_bboxes.append(
  63. (x1, y1, x2, y2)
  64. )
  65. ids2delete = []
  66. for history_id in target_detector.faceTracker:
  67. if not history_id in current_ids:
  68. target_detector.faceTracker[history_id] -= 1
  69. if target_detector.faceTracker[history_id] < -5:
  70. ids2delete.append(history_id)
  71. for ids in ids2delete:
  72. target_detector.faceTracker.pop(ids)
  73. print('-[INFO] Delete track id:', ids)
  74. image = plot_bboxes(image, bboxes2draw)
  75. return image, new_faces, face_bboxes