voc.py 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224
  1. import collections
  2. import os
  3. from xml.etree.ElementTree import Element as ET_Element
  4. from .vision import VisionDataset
  5. try:
  6. from defusedxml.ElementTree import parse as ET_parse
  7. except ImportError:
  8. from xml.etree.ElementTree import parse as ET_parse
  9. from typing import Any, Callable, Dict, List, Optional, Tuple
  10. from PIL import Image
  11. from .utils import download_and_extract_archive, verify_str_arg
  12. DATASET_YEAR_DICT = {
  13. "2012": {
  14. "url": "http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar",
  15. "filename": "VOCtrainval_11-May-2012.tar",
  16. "md5": "6cd6e144f989b92b3379bac3b3de84fd",
  17. "base_dir": os.path.join("VOCdevkit", "VOC2012"),
  18. },
  19. "2011": {
  20. "url": "http://host.robots.ox.ac.uk/pascal/VOC/voc2011/VOCtrainval_25-May-2011.tar",
  21. "filename": "VOCtrainval_25-May-2011.tar",
  22. "md5": "6c3384ef61512963050cb5d687e5bf1e",
  23. "base_dir": os.path.join("TrainVal", "VOCdevkit", "VOC2011"),
  24. },
  25. "2010": {
  26. "url": "http://host.robots.ox.ac.uk/pascal/VOC/voc2010/VOCtrainval_03-May-2010.tar",
  27. "filename": "VOCtrainval_03-May-2010.tar",
  28. "md5": "da459979d0c395079b5c75ee67908abb",
  29. "base_dir": os.path.join("VOCdevkit", "VOC2010"),
  30. },
  31. "2009": {
  32. "url": "http://host.robots.ox.ac.uk/pascal/VOC/voc2009/VOCtrainval_11-May-2009.tar",
  33. "filename": "VOCtrainval_11-May-2009.tar",
  34. "md5": "a3e00b113cfcfebf17e343f59da3caa1",
  35. "base_dir": os.path.join("VOCdevkit", "VOC2009"),
  36. },
  37. "2008": {
  38. "url": "http://host.robots.ox.ac.uk/pascal/VOC/voc2008/VOCtrainval_14-Jul-2008.tar",
  39. "filename": "VOCtrainval_11-May-2012.tar",
  40. "md5": "2629fa636546599198acfcfbfcf1904a",
  41. "base_dir": os.path.join("VOCdevkit", "VOC2008"),
  42. },
  43. "2007": {
  44. "url": "http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtrainval_06-Nov-2007.tar",
  45. "filename": "VOCtrainval_06-Nov-2007.tar",
  46. "md5": "c52e279531787c972589f7e41ab4ae64",
  47. "base_dir": os.path.join("VOCdevkit", "VOC2007"),
  48. },
  49. "2007-test": {
  50. "url": "http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtest_06-Nov-2007.tar",
  51. "filename": "VOCtest_06-Nov-2007.tar",
  52. "md5": "b6e924de25625d8de591ea690078ad9f",
  53. "base_dir": os.path.join("VOCdevkit", "VOC2007"),
  54. },
  55. }
  56. class _VOCBase(VisionDataset):
  57. _SPLITS_DIR: str
  58. _TARGET_DIR: str
  59. _TARGET_FILE_EXT: str
  60. def __init__(
  61. self,
  62. root: str,
  63. year: str = "2012",
  64. image_set: str = "train",
  65. download: bool = False,
  66. transform: Optional[Callable] = None,
  67. target_transform: Optional[Callable] = None,
  68. transforms: Optional[Callable] = None,
  69. ):
  70. super().__init__(root, transforms, transform, target_transform)
  71. self.year = verify_str_arg(year, "year", valid_values=[str(yr) for yr in range(2007, 2013)])
  72. valid_image_sets = ["train", "trainval", "val"]
  73. if year == "2007":
  74. valid_image_sets.append("test")
  75. self.image_set = verify_str_arg(image_set, "image_set", valid_image_sets)
  76. key = "2007-test" if year == "2007" and image_set == "test" else year
  77. dataset_year_dict = DATASET_YEAR_DICT[key]
  78. self.url = dataset_year_dict["url"]
  79. self.filename = dataset_year_dict["filename"]
  80. self.md5 = dataset_year_dict["md5"]
  81. base_dir = dataset_year_dict["base_dir"]
  82. voc_root = os.path.join(self.root, base_dir)
  83. if download:
  84. download_and_extract_archive(self.url, self.root, filename=self.filename, md5=self.md5)
  85. if not os.path.isdir(voc_root):
  86. raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
  87. splits_dir = os.path.join(voc_root, "ImageSets", self._SPLITS_DIR)
  88. split_f = os.path.join(splits_dir, image_set.rstrip("\n") + ".txt")
  89. with open(os.path.join(split_f)) as f:
  90. file_names = [x.strip() for x in f.readlines()]
  91. image_dir = os.path.join(voc_root, "JPEGImages")
  92. self.images = [os.path.join(image_dir, x + ".jpg") for x in file_names]
  93. target_dir = os.path.join(voc_root, self._TARGET_DIR)
  94. self.targets = [os.path.join(target_dir, x + self._TARGET_FILE_EXT) for x in file_names]
  95. assert len(self.images) == len(self.targets)
  96. def __len__(self) -> int:
  97. return len(self.images)
  98. class VOCSegmentation(_VOCBase):
  99. """`Pascal VOC <http://host.robots.ox.ac.uk/pascal/VOC/>`_ Segmentation Dataset.
  100. Args:
  101. root (string): Root directory of the VOC Dataset.
  102. year (string, optional): The dataset year, supports years ``"2007"`` to ``"2012"``.
  103. image_set (string, optional): Select the image_set to use, ``"train"``, ``"trainval"`` or ``"val"``. If
  104. ``year=="2007"``, can also be ``"test"``.
  105. download (bool, optional): If true, downloads the dataset from the internet and
  106. puts it in root directory. If dataset is already downloaded, it is not
  107. downloaded again.
  108. transform (callable, optional): A function/transform that takes in an PIL image
  109. and returns a transformed version. E.g, ``transforms.RandomCrop``
  110. target_transform (callable, optional): A function/transform that takes in the
  111. target and transforms it.
  112. transforms (callable, optional): A function/transform that takes input sample and its target as entry
  113. and returns a transformed version.
  114. """
  115. _SPLITS_DIR = "Segmentation"
  116. _TARGET_DIR = "SegmentationClass"
  117. _TARGET_FILE_EXT = ".png"
  118. @property
  119. def masks(self) -> List[str]:
  120. return self.targets
  121. def __getitem__(self, index: int) -> Tuple[Any, Any]:
  122. """
  123. Args:
  124. index (int): Index
  125. Returns:
  126. tuple: (image, target) where target is the image segmentation.
  127. """
  128. img = Image.open(self.images[index]).convert("RGB")
  129. target = Image.open(self.masks[index])
  130. if self.transforms is not None:
  131. img, target = self.transforms(img, target)
  132. return img, target
  133. class VOCDetection(_VOCBase):
  134. """`Pascal VOC <http://host.robots.ox.ac.uk/pascal/VOC/>`_ Detection Dataset.
  135. Args:
  136. root (string): Root directory of the VOC Dataset.
  137. year (string, optional): The dataset year, supports years ``"2007"`` to ``"2012"``.
  138. image_set (string, optional): Select the image_set to use, ``"train"``, ``"trainval"`` or ``"val"``. If
  139. ``year=="2007"``, can also be ``"test"``.
  140. download (bool, optional): If true, downloads the dataset from the internet and
  141. puts it in root directory. If dataset is already downloaded, it is not
  142. downloaded again.
  143. (default: alphabetic indexing of VOC's 20 classes).
  144. transform (callable, optional): A function/transform that takes in an PIL image
  145. and returns a transformed version. E.g, ``transforms.RandomCrop``
  146. target_transform (callable, required): A function/transform that takes in the
  147. target and transforms it.
  148. transforms (callable, optional): A function/transform that takes input sample and its target as entry
  149. and returns a transformed version.
  150. """
  151. _SPLITS_DIR = "Main"
  152. _TARGET_DIR = "Annotations"
  153. _TARGET_FILE_EXT = ".xml"
  154. @property
  155. def annotations(self) -> List[str]:
  156. return self.targets
  157. def __getitem__(self, index: int) -> Tuple[Any, Any]:
  158. """
  159. Args:
  160. index (int): Index
  161. Returns:
  162. tuple: (image, target) where target is a dictionary of the XML tree.
  163. """
  164. img = Image.open(self.images[index]).convert("RGB")
  165. target = self.parse_voc_xml(ET_parse(self.annotations[index]).getroot())
  166. if self.transforms is not None:
  167. img, target = self.transforms(img, target)
  168. return img, target
  169. @staticmethod
  170. def parse_voc_xml(node: ET_Element) -> Dict[str, Any]:
  171. voc_dict: Dict[str, Any] = {}
  172. children = list(node)
  173. if children:
  174. def_dic: Dict[str, Any] = collections.defaultdict(list)
  175. for dc in map(VOCDetection.parse_voc_xml, children):
  176. for ind, v in dc.items():
  177. def_dic[ind].append(v)
  178. if node.tag == "annotation":
  179. def_dic["object"] = [def_dic["object"]]
  180. voc_dict = {node.tag: {ind: v[0] if len(v) == 1 else v for ind, v in def_dic.items()}}
  181. if node.text:
  182. text = node.text.strip()
  183. if not children:
  184. voc_dict[node.tag] = text
  185. return voc_dict