coco_utils.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114
  1. import copy
  2. import os
  3. import torch
  4. import torch.utils.data
  5. import torchvision
  6. from PIL import Image
  7. from pycocotools import mask as coco_mask
  8. from transforms import Compose
  9. class FilterAndRemapCocoCategories:
  10. def __init__(self, categories, remap=True):
  11. self.categories = categories
  12. self.remap = remap
  13. def __call__(self, image, anno):
  14. anno = [obj for obj in anno if obj["category_id"] in self.categories]
  15. if not self.remap:
  16. return image, anno
  17. anno = copy.deepcopy(anno)
  18. for obj in anno:
  19. obj["category_id"] = self.categories.index(obj["category_id"])
  20. return image, anno
  21. def convert_coco_poly_to_mask(segmentations, height, width):
  22. masks = []
  23. for polygons in segmentations:
  24. rles = coco_mask.frPyObjects(polygons, height, width)
  25. mask = coco_mask.decode(rles)
  26. if len(mask.shape) < 3:
  27. mask = mask[..., None]
  28. mask = torch.as_tensor(mask, dtype=torch.uint8)
  29. mask = mask.any(dim=2)
  30. masks.append(mask)
  31. if masks:
  32. masks = torch.stack(masks, dim=0)
  33. else:
  34. masks = torch.zeros((0, height, width), dtype=torch.uint8)
  35. return masks
  36. class ConvertCocoPolysToMask:
  37. def __call__(self, image, anno):
  38. w, h = image.size
  39. segmentations = [obj["segmentation"] for obj in anno]
  40. cats = [obj["category_id"] for obj in anno]
  41. if segmentations:
  42. masks = convert_coco_poly_to_mask(segmentations, h, w)
  43. cats = torch.as_tensor(cats, dtype=masks.dtype)
  44. # merge all instance masks into a single segmentation map
  45. # with its corresponding categories
  46. target, _ = (masks * cats[:, None, None]).max(dim=0)
  47. # discard overlapping instances
  48. target[masks.sum(0) > 1] = 255
  49. else:
  50. target = torch.zeros((h, w), dtype=torch.uint8)
  51. target = Image.fromarray(target.numpy())
  52. return image, target
  53. def _coco_remove_images_without_annotations(dataset, cat_list=None):
  54. def _has_valid_annotation(anno):
  55. # if it's empty, there is no annotation
  56. if len(anno) == 0:
  57. return False
  58. # if more than 1k pixels occupied in the image
  59. return sum(obj["area"] for obj in anno) > 1000
  60. ids = []
  61. for ds_idx, img_id in enumerate(dataset.ids):
  62. ann_ids = dataset.coco.getAnnIds(imgIds=img_id, iscrowd=None)
  63. anno = dataset.coco.loadAnns(ann_ids)
  64. if cat_list:
  65. anno = [obj for obj in anno if obj["category_id"] in cat_list]
  66. if _has_valid_annotation(anno):
  67. ids.append(ds_idx)
  68. dataset = torch.utils.data.Subset(dataset, ids)
  69. return dataset
  70. def get_coco(root, image_set, transforms, use_v2=False):
  71. PATHS = {
  72. "train": ("train2017", os.path.join("annotations", "instances_train2017.json")),
  73. "val": ("val2017", os.path.join("annotations", "instances_val2017.json")),
  74. # "train": ("val2017", os.path.join("annotations", "instances_val2017.json"))
  75. }
  76. CAT_LIST = [0, 5, 2, 16, 9, 44, 6, 3, 17, 62, 21, 67, 18, 19, 4, 1, 64, 20, 63, 7, 72]
  77. img_folder, ann_file = PATHS[image_set]
  78. img_folder = os.path.join(root, img_folder)
  79. ann_file = os.path.join(root, ann_file)
  80. # The 2 "Compose" below achieve the same thing: converting coco detection
  81. # samples into segmentation-compatible samples. They just do it with
  82. # slightly different implementations. We could refactor and unify, but
  83. # keeping them separate helps keeping the v2 version clean
  84. if use_v2:
  85. import v2_extras
  86. from torchvision.datasets import wrap_dataset_for_transforms_v2
  87. transforms = Compose([v2_extras.CocoDetectionToVOCSegmentation(), transforms])
  88. dataset = torchvision.datasets.CocoDetection(img_folder, ann_file, transforms=transforms)
  89. dataset = wrap_dataset_for_transforms_v2(dataset, target_keys={"masks", "labels"})
  90. else:
  91. transforms = Compose([FilterAndRemapCocoCategories(CAT_LIST, remap=True), ConvertCocoPolysToMask(), transforms])
  92. dataset = torchvision.datasets.CocoDetection(img_folder, ann_file, transforms=transforms)
  93. if image_set == "train":
  94. dataset = _coco_remove_images_without_annotations(dataset, CAT_LIST)
  95. return dataset