io.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133
  1. import os
  2. from typing import Dict
  3. import numpy as np
  4. # from utils.log import get_logger
  5. def write_results(filename, results, data_type):
  6. if data_type == 'mot':
  7. save_format = '{frame},{id},{x1},{y1},{w},{h},-1,-1,-1,-1\n'
  8. elif data_type == 'kitti':
  9. save_format = '{frame} {id} pedestrian 0 0 -10 {x1} {y1} {x2} {y2} -10 -10 -10 -1000 -1000 -1000 -10\n'
  10. else:
  11. raise ValueError(data_type)
  12. with open(filename, 'w') as f:
  13. for frame_id, tlwhs, track_ids in results:
  14. if data_type == 'kitti':
  15. frame_id -= 1
  16. for tlwh, track_id in zip(tlwhs, track_ids):
  17. if track_id < 0:
  18. continue
  19. x1, y1, w, h = tlwh
  20. x2, y2 = x1 + w, y1 + h
  21. line = save_format.format(frame=frame_id, id=track_id, x1=x1, y1=y1, x2=x2, y2=y2, w=w, h=h)
  22. f.write(line)
  23. # def write_results(filename, results_dict: Dict, data_type: str):
  24. # if not filename:
  25. # return
  26. # path = os.path.dirname(filename)
  27. # if not os.path.exists(path):
  28. # os.makedirs(path)
  29. # if data_type in ('mot', 'mcmot', 'lab'):
  30. # save_format = '{frame},{id},{x1},{y1},{w},{h},1,-1,-1,-1\n'
  31. # elif data_type == 'kitti':
  32. # save_format = '{frame} {id} pedestrian -1 -1 -10 {x1} {y1} {x2} {y2} -1 -1 -1 -1000 -1000 -1000 -10 {score}\n'
  33. # else:
  34. # raise ValueError(data_type)
  35. # with open(filename, 'w') as f:
  36. # for frame_id, frame_data in results_dict.items():
  37. # if data_type == 'kitti':
  38. # frame_id -= 1
  39. # for tlwh, track_id in frame_data:
  40. # if track_id < 0:
  41. # continue
  42. # x1, y1, w, h = tlwh
  43. # x2, y2 = x1 + w, y1 + h
  44. # line = save_format.format(frame=frame_id, id=track_id, x1=x1, y1=y1, x2=x2, y2=y2, w=w, h=h, score=1.0)
  45. # f.write(line)
  46. # logger.info('Save results to {}'.format(filename))
  47. def read_results(filename, data_type: str, is_gt=False, is_ignore=False):
  48. if data_type in ('mot', 'lab'):
  49. read_fun = read_mot_results
  50. else:
  51. raise ValueError('Unknown data type: {}'.format(data_type))
  52. return read_fun(filename, is_gt, is_ignore)
  53. """
  54. labels={'ped', ... % 1
  55. 'person_on_vhcl', ... % 2
  56. 'car', ... % 3
  57. 'bicycle', ... % 4
  58. 'mbike', ... % 5
  59. 'non_mot_vhcl', ... % 6
  60. 'static_person', ... % 7
  61. 'distractor', ... % 8
  62. 'occluder', ... % 9
  63. 'occluder_on_grnd', ... %10
  64. 'occluder_full', ... % 11
  65. 'reflection', ... % 12
  66. 'crowd' ... % 13
  67. };
  68. """
  69. def read_mot_results(filename, is_gt, is_ignore):
  70. valid_labels = {1}
  71. ignore_labels = {2, 7, 8, 12}
  72. results_dict = dict()
  73. if os.path.isfile(filename):
  74. with open(filename, 'r') as f:
  75. for line in f.readlines():
  76. linelist = line.split(',')
  77. if len(linelist) < 7:
  78. continue
  79. fid = int(linelist[0])
  80. if fid < 1:
  81. continue
  82. results_dict.setdefault(fid, list())
  83. if is_gt:
  84. if 'MOT16-' in filename or 'MOT17-' in filename:
  85. label = int(float(linelist[7]))
  86. mark = int(float(linelist[6]))
  87. if mark == 0 or label not in valid_labels:
  88. continue
  89. score = 1
  90. elif is_ignore:
  91. if 'MOT16-' in filename or 'MOT17-' in filename:
  92. label = int(float(linelist[7]))
  93. vis_ratio = float(linelist[8])
  94. if label not in ignore_labels and vis_ratio >= 0:
  95. continue
  96. else:
  97. continue
  98. score = 1
  99. else:
  100. score = float(linelist[6])
  101. tlwh = tuple(map(float, linelist[2:6]))
  102. target_id = int(linelist[1])
  103. results_dict[fid].append((tlwh, target_id, score))
  104. return results_dict
  105. def unzip_objs(objs):
  106. if len(objs) > 0:
  107. tlwhs, ids, scores = zip(*objs)
  108. else:
  109. tlwhs, ids, scores = [], [], []
  110. tlwhs = np.asarray(tlwhs, dtype=float).reshape(-1, 4)
  111. return tlwhs, ids, scores