coco_eval.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192
  1. import copy
  2. import io
  3. from contextlib import redirect_stdout
  4. import numpy as np
  5. import pycocotools.mask as mask_util
  6. import torch
  7. import utils
  8. from pycocotools.coco import COCO
  9. from pycocotools.cocoeval import COCOeval
  10. class CocoEvaluator:
  11. def __init__(self, coco_gt, iou_types):
  12. if not isinstance(iou_types, (list, tuple)):
  13. raise TypeError(f"This constructor expects iou_types of type list or tuple, instead got {type(iou_types)}")
  14. coco_gt = copy.deepcopy(coco_gt)
  15. self.coco_gt = coco_gt
  16. self.iou_types = iou_types
  17. self.coco_eval = {}
  18. for iou_type in iou_types:
  19. self.coco_eval[iou_type] = COCOeval(coco_gt, iouType=iou_type)
  20. self.img_ids = []
  21. self.eval_imgs = {k: [] for k in iou_types}
  22. def update(self, predictions):
  23. img_ids = list(np.unique(list(predictions.keys())))
  24. self.img_ids.extend(img_ids)
  25. for iou_type in self.iou_types:
  26. results = self.prepare(predictions, iou_type)
  27. with redirect_stdout(io.StringIO()):
  28. coco_dt = COCO.loadRes(self.coco_gt, results) if results else COCO()
  29. coco_eval = self.coco_eval[iou_type]
  30. coco_eval.cocoDt = coco_dt
  31. coco_eval.params.imgIds = list(img_ids)
  32. img_ids, eval_imgs = evaluate(coco_eval)
  33. self.eval_imgs[iou_type].append(eval_imgs)
  34. def synchronize_between_processes(self):
  35. for iou_type in self.iou_types:
  36. self.eval_imgs[iou_type] = np.concatenate(self.eval_imgs[iou_type], 2)
  37. create_common_coco_eval(self.coco_eval[iou_type], self.img_ids, self.eval_imgs[iou_type])
  38. def accumulate(self):
  39. for coco_eval in self.coco_eval.values():
  40. coco_eval.accumulate()
  41. def summarize(self):
  42. for iou_type, coco_eval in self.coco_eval.items():
  43. print(f"IoU metric: {iou_type}")
  44. coco_eval.summarize()
  45. def prepare(self, predictions, iou_type):
  46. if iou_type == "bbox":
  47. return self.prepare_for_coco_detection(predictions)
  48. if iou_type == "segm":
  49. return self.prepare_for_coco_segmentation(predictions)
  50. if iou_type == "keypoints":
  51. return self.prepare_for_coco_keypoint(predictions)
  52. raise ValueError(f"Unknown iou type {iou_type}")
  53. def prepare_for_coco_detection(self, predictions):
  54. coco_results = []
  55. for original_id, prediction in predictions.items():
  56. if len(prediction) == 0:
  57. continue
  58. boxes = prediction["boxes"]
  59. boxes = convert_to_xywh(boxes).tolist()
  60. scores = prediction["scores"].tolist()
  61. labels = prediction["labels"].tolist()
  62. coco_results.extend(
  63. [
  64. {
  65. "image_id": original_id,
  66. "category_id": labels[k],
  67. "bbox": box,
  68. "score": scores[k],
  69. }
  70. for k, box in enumerate(boxes)
  71. ]
  72. )
  73. return coco_results
  74. def prepare_for_coco_segmentation(self, predictions):
  75. coco_results = []
  76. for original_id, prediction in predictions.items():
  77. if len(prediction) == 0:
  78. continue
  79. scores = prediction["scores"]
  80. labels = prediction["labels"]
  81. masks = prediction["masks"]
  82. masks = masks > 0.5
  83. scores = prediction["scores"].tolist()
  84. labels = prediction["labels"].tolist()
  85. rles = [
  86. mask_util.encode(np.array(mask[0, :, :, np.newaxis], dtype=np.uint8, order="F"))[0] for mask in masks
  87. ]
  88. for rle in rles:
  89. rle["counts"] = rle["counts"].decode("utf-8")
  90. coco_results.extend(
  91. [
  92. {
  93. "image_id": original_id,
  94. "category_id": labels[k],
  95. "segmentation": rle,
  96. "score": scores[k],
  97. }
  98. for k, rle in enumerate(rles)
  99. ]
  100. )
  101. return coco_results
  102. def prepare_for_coco_keypoint(self, predictions):
  103. coco_results = []
  104. for original_id, prediction in predictions.items():
  105. if len(prediction) == 0:
  106. continue
  107. boxes = prediction["boxes"]
  108. boxes = convert_to_xywh(boxes).tolist()
  109. scores = prediction["scores"].tolist()
  110. labels = prediction["labels"].tolist()
  111. keypoints = prediction["keypoints"]
  112. keypoints = keypoints.flatten(start_dim=1).tolist()
  113. coco_results.extend(
  114. [
  115. {
  116. "image_id": original_id,
  117. "category_id": labels[k],
  118. "keypoints": keypoint,
  119. "score": scores[k],
  120. }
  121. for k, keypoint in enumerate(keypoints)
  122. ]
  123. )
  124. return coco_results
  125. def convert_to_xywh(boxes):
  126. xmin, ymin, xmax, ymax = boxes.unbind(1)
  127. return torch.stack((xmin, ymin, xmax - xmin, ymax - ymin), dim=1)
  128. def merge(img_ids, eval_imgs):
  129. all_img_ids = utils.all_gather(img_ids)
  130. all_eval_imgs = utils.all_gather(eval_imgs)
  131. merged_img_ids = []
  132. for p in all_img_ids:
  133. merged_img_ids.extend(p)
  134. merged_eval_imgs = []
  135. for p in all_eval_imgs:
  136. merged_eval_imgs.append(p)
  137. merged_img_ids = np.array(merged_img_ids)
  138. merged_eval_imgs = np.concatenate(merged_eval_imgs, 2)
  139. # keep only unique (and in sorted order) images
  140. merged_img_ids, idx = np.unique(merged_img_ids, return_index=True)
  141. merged_eval_imgs = merged_eval_imgs[..., idx]
  142. return merged_img_ids, merged_eval_imgs
  143. def create_common_coco_eval(coco_eval, img_ids, eval_imgs):
  144. img_ids, eval_imgs = merge(img_ids, eval_imgs)
  145. img_ids = list(img_ids)
  146. eval_imgs = list(eval_imgs.flatten())
  147. coco_eval.evalImgs = eval_imgs
  148. coco_eval.params.imgIds = img_ids
  149. coco_eval._paramsEval = copy.deepcopy(coco_eval.params)
  150. def evaluate(imgs):
  151. with redirect_stdout(io.StringIO()):
  152. imgs.evaluate()
  153. return imgs.params.imgIds, np.asarray(imgs.evalImgs).reshape(-1, len(imgs.params.areaRng), len(imgs.params.imgIds))