coco_utils.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234
  1. import os
  2. import torch
  3. import torch.utils.data
  4. import torchvision
  5. import transforms as T
  6. from pycocotools import mask as coco_mask
  7. from pycocotools.coco import COCO
  8. def convert_coco_poly_to_mask(segmentations, height, width):
  9. masks = []
  10. for polygons in segmentations:
  11. rles = coco_mask.frPyObjects(polygons, height, width)
  12. mask = coco_mask.decode(rles)
  13. if len(mask.shape) < 3:
  14. mask = mask[..., None]
  15. mask = torch.as_tensor(mask, dtype=torch.uint8)
  16. mask = mask.any(dim=2)
  17. masks.append(mask)
  18. if masks:
  19. masks = torch.stack(masks, dim=0)
  20. else:
  21. masks = torch.zeros((0, height, width), dtype=torch.uint8)
  22. return masks
  23. class ConvertCocoPolysToMask:
  24. def __call__(self, image, target):
  25. w, h = image.size
  26. image_id = target["image_id"]
  27. anno = target["annotations"]
  28. anno = [obj for obj in anno if obj["iscrowd"] == 0]
  29. boxes = [obj["bbox"] for obj in anno]
  30. # guard against no boxes via resizing
  31. boxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1, 4)
  32. boxes[:, 2:] += boxes[:, :2]
  33. boxes[:, 0::2].clamp_(min=0, max=w)
  34. boxes[:, 1::2].clamp_(min=0, max=h)
  35. classes = [obj["category_id"] for obj in anno]
  36. classes = torch.tensor(classes, dtype=torch.int64)
  37. segmentations = [obj["segmentation"] for obj in anno]
  38. masks = convert_coco_poly_to_mask(segmentations, h, w)
  39. keypoints = None
  40. if anno and "keypoints" in anno[0]:
  41. keypoints = [obj["keypoints"] for obj in anno]
  42. keypoints = torch.as_tensor(keypoints, dtype=torch.float32)
  43. num_keypoints = keypoints.shape[0]
  44. if num_keypoints:
  45. keypoints = keypoints.view(num_keypoints, -1, 3)
  46. keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0])
  47. boxes = boxes[keep]
  48. classes = classes[keep]
  49. masks = masks[keep]
  50. if keypoints is not None:
  51. keypoints = keypoints[keep]
  52. target = {}
  53. target["boxes"] = boxes
  54. target["labels"] = classes
  55. target["masks"] = masks
  56. target["image_id"] = image_id
  57. if keypoints is not None:
  58. target["keypoints"] = keypoints
  59. # for conversion to coco api
  60. area = torch.tensor([obj["area"] for obj in anno])
  61. iscrowd = torch.tensor([obj["iscrowd"] for obj in anno])
  62. target["area"] = area
  63. target["iscrowd"] = iscrowd
  64. return image, target
  65. def _coco_remove_images_without_annotations(dataset, cat_list=None):
  66. def _has_only_empty_bbox(anno):
  67. return all(any(o <= 1 for o in obj["bbox"][2:]) for obj in anno)
  68. def _count_visible_keypoints(anno):
  69. return sum(sum(1 for v in ann["keypoints"][2::3] if v > 0) for ann in anno)
  70. min_keypoints_per_image = 10
  71. def _has_valid_annotation(anno):
  72. # if it's empty, there is no annotation
  73. if len(anno) == 0:
  74. return False
  75. # if all boxes have close to zero area, there is no annotation
  76. if _has_only_empty_bbox(anno):
  77. return False
  78. # keypoints task have a slight different criteria for considering
  79. # if an annotation is valid
  80. if "keypoints" not in anno[0]:
  81. return True
  82. # for keypoint detection tasks, only consider valid images those
  83. # containing at least min_keypoints_per_image
  84. if _count_visible_keypoints(anno) >= min_keypoints_per_image:
  85. return True
  86. return False
  87. ids = []
  88. for ds_idx, img_id in enumerate(dataset.ids):
  89. ann_ids = dataset.coco.getAnnIds(imgIds=img_id, iscrowd=None)
  90. anno = dataset.coco.loadAnns(ann_ids)
  91. if cat_list:
  92. anno = [obj for obj in anno if obj["category_id"] in cat_list]
  93. if _has_valid_annotation(anno):
  94. ids.append(ds_idx)
  95. dataset = torch.utils.data.Subset(dataset, ids)
  96. return dataset
  97. def convert_to_coco_api(ds):
  98. coco_ds = COCO()
  99. # annotation IDs need to start at 1, not 0, see torchvision issue #1530
  100. ann_id = 1
  101. dataset = {"images": [], "categories": [], "annotations": []}
  102. categories = set()
  103. for img_idx in range(len(ds)):
  104. # find better way to get target
  105. # targets = ds.get_annotations(img_idx)
  106. img, targets = ds[img_idx]
  107. image_id = targets["image_id"]
  108. img_dict = {}
  109. img_dict["id"] = image_id
  110. img_dict["height"] = img.shape[-2]
  111. img_dict["width"] = img.shape[-1]
  112. dataset["images"].append(img_dict)
  113. bboxes = targets["boxes"].clone()
  114. bboxes[:, 2:] -= bboxes[:, :2]
  115. bboxes = bboxes.tolist()
  116. labels = targets["labels"].tolist()
  117. areas = targets["area"].tolist()
  118. iscrowd = targets["iscrowd"].tolist()
  119. if "masks" in targets:
  120. masks = targets["masks"]
  121. # make masks Fortran contiguous for coco_mask
  122. masks = masks.permute(0, 2, 1).contiguous().permute(0, 2, 1)
  123. if "keypoints" in targets:
  124. keypoints = targets["keypoints"]
  125. keypoints = keypoints.reshape(keypoints.shape[0], -1).tolist()
  126. num_objs = len(bboxes)
  127. for i in range(num_objs):
  128. ann = {}
  129. ann["image_id"] = image_id
  130. ann["bbox"] = bboxes[i]
  131. ann["category_id"] = labels[i]
  132. categories.add(labels[i])
  133. ann["area"] = areas[i]
  134. ann["iscrowd"] = iscrowd[i]
  135. ann["id"] = ann_id
  136. if "masks" in targets:
  137. ann["segmentation"] = coco_mask.encode(masks[i].numpy())
  138. if "keypoints" in targets:
  139. ann["keypoints"] = keypoints[i]
  140. ann["num_keypoints"] = sum(k != 0 for k in keypoints[i][2::3])
  141. dataset["annotations"].append(ann)
  142. ann_id += 1
  143. dataset["categories"] = [{"id": i} for i in sorted(categories)]
  144. coco_ds.dataset = dataset
  145. coco_ds.createIndex()
  146. return coco_ds
  147. def get_coco_api_from_dataset(dataset):
  148. # FIXME: This is... awful?
  149. for _ in range(10):
  150. if isinstance(dataset, torchvision.datasets.CocoDetection):
  151. break
  152. if isinstance(dataset, torch.utils.data.Subset):
  153. dataset = dataset.dataset
  154. if isinstance(dataset, torchvision.datasets.CocoDetection):
  155. return dataset.coco
  156. return convert_to_coco_api(dataset)
  157. class CocoDetection(torchvision.datasets.CocoDetection):
  158. def __init__(self, img_folder, ann_file, transforms):
  159. super().__init__(img_folder, ann_file)
  160. self._transforms = transforms
  161. def __getitem__(self, idx):
  162. img, target = super().__getitem__(idx)
  163. image_id = self.ids[idx]
  164. target = dict(image_id=image_id, annotations=target)
  165. if self._transforms is not None:
  166. img, target = self._transforms(img, target)
  167. return img, target
  168. def get_coco(root, image_set, transforms, mode="instances", use_v2=False, with_masks=False):
  169. anno_file_template = "{}_{}2017.json"
  170. PATHS = {
  171. "train": ("train2017", os.path.join("annotations", anno_file_template.format(mode, "train"))),
  172. "val": ("val2017", os.path.join("annotations", anno_file_template.format(mode, "val"))),
  173. # "train": ("val2017", os.path.join("annotations", anno_file_template.format(mode, "val")))
  174. }
  175. img_folder, ann_file = PATHS[image_set]
  176. img_folder = os.path.join(root, img_folder)
  177. ann_file = os.path.join(root, ann_file)
  178. if use_v2:
  179. from torchvision.datasets import wrap_dataset_for_transforms_v2
  180. dataset = torchvision.datasets.CocoDetection(img_folder, ann_file, transforms=transforms)
  181. target_keys = ["boxes", "labels", "image_id"]
  182. if with_masks:
  183. target_keys += ["masks"]
  184. dataset = wrap_dataset_for_transforms_v2(dataset, target_keys=target_keys)
  185. else:
  186. # TODO: handle with_masks for V1?
  187. t = [ConvertCocoPolysToMask()]
  188. if transforms is not None:
  189. t.append(transforms)
  190. transforms = T.Compose(t)
  191. dataset = CocoDetection(img_folder, ann_file, transforms=transforms)
  192. if image_set == "train":
  193. dataset = _coco_remove_images_without_annotations(dataset)
  194. # dataset = torch.utils.data.Subset(dataset, [i for i in range(500)])
  195. return dataset