123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224 |
- import collections
- import os
- from xml.etree.ElementTree import Element as ET_Element
- from .vision import VisionDataset
- try:
- from defusedxml.ElementTree import parse as ET_parse
- except ImportError:
- from xml.etree.ElementTree import parse as ET_parse
- from typing import Any, Callable, Dict, List, Optional, Tuple
- from PIL import Image
- from .utils import download_and_extract_archive, verify_str_arg
- DATASET_YEAR_DICT = {
- "2012": {
- "url": "http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar",
- "filename": "VOCtrainval_11-May-2012.tar",
- "md5": "6cd6e144f989b92b3379bac3b3de84fd",
- "base_dir": os.path.join("VOCdevkit", "VOC2012"),
- },
- "2011": {
- "url": "http://host.robots.ox.ac.uk/pascal/VOC/voc2011/VOCtrainval_25-May-2011.tar",
- "filename": "VOCtrainval_25-May-2011.tar",
- "md5": "6c3384ef61512963050cb5d687e5bf1e",
- "base_dir": os.path.join("TrainVal", "VOCdevkit", "VOC2011"),
- },
- "2010": {
- "url": "http://host.robots.ox.ac.uk/pascal/VOC/voc2010/VOCtrainval_03-May-2010.tar",
- "filename": "VOCtrainval_03-May-2010.tar",
- "md5": "da459979d0c395079b5c75ee67908abb",
- "base_dir": os.path.join("VOCdevkit", "VOC2010"),
- },
- "2009": {
- "url": "http://host.robots.ox.ac.uk/pascal/VOC/voc2009/VOCtrainval_11-May-2009.tar",
- "filename": "VOCtrainval_11-May-2009.tar",
- "md5": "a3e00b113cfcfebf17e343f59da3caa1",
- "base_dir": os.path.join("VOCdevkit", "VOC2009"),
- },
- "2008": {
- "url": "http://host.robots.ox.ac.uk/pascal/VOC/voc2008/VOCtrainval_14-Jul-2008.tar",
- "filename": "VOCtrainval_11-May-2012.tar",
- "md5": "2629fa636546599198acfcfbfcf1904a",
- "base_dir": os.path.join("VOCdevkit", "VOC2008"),
- },
- "2007": {
- "url": "http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtrainval_06-Nov-2007.tar",
- "filename": "VOCtrainval_06-Nov-2007.tar",
- "md5": "c52e279531787c972589f7e41ab4ae64",
- "base_dir": os.path.join("VOCdevkit", "VOC2007"),
- },
- "2007-test": {
- "url": "http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtest_06-Nov-2007.tar",
- "filename": "VOCtest_06-Nov-2007.tar",
- "md5": "b6e924de25625d8de591ea690078ad9f",
- "base_dir": os.path.join("VOCdevkit", "VOC2007"),
- },
- }
- class _VOCBase(VisionDataset):
- _SPLITS_DIR: str
- _TARGET_DIR: str
- _TARGET_FILE_EXT: str
- def __init__(
- self,
- root: str,
- year: str = "2012",
- image_set: str = "train",
- download: bool = False,
- transform: Optional[Callable] = None,
- target_transform: Optional[Callable] = None,
- transforms: Optional[Callable] = None,
- ):
- super().__init__(root, transforms, transform, target_transform)
- self.year = verify_str_arg(year, "year", valid_values=[str(yr) for yr in range(2007, 2013)])
- valid_image_sets = ["train", "trainval", "val"]
- if year == "2007":
- valid_image_sets.append("test")
- self.image_set = verify_str_arg(image_set, "image_set", valid_image_sets)
- key = "2007-test" if year == "2007" and image_set == "test" else year
- dataset_year_dict = DATASET_YEAR_DICT[key]
- self.url = dataset_year_dict["url"]
- self.filename = dataset_year_dict["filename"]
- self.md5 = dataset_year_dict["md5"]
- base_dir = dataset_year_dict["base_dir"]
- voc_root = os.path.join(self.root, base_dir)
- if download:
- download_and_extract_archive(self.url, self.root, filename=self.filename, md5=self.md5)
- if not os.path.isdir(voc_root):
- raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
- splits_dir = os.path.join(voc_root, "ImageSets", self._SPLITS_DIR)
- split_f = os.path.join(splits_dir, image_set.rstrip("\n") + ".txt")
- with open(os.path.join(split_f)) as f:
- file_names = [x.strip() for x in f.readlines()]
- image_dir = os.path.join(voc_root, "JPEGImages")
- self.images = [os.path.join(image_dir, x + ".jpg") for x in file_names]
- target_dir = os.path.join(voc_root, self._TARGET_DIR)
- self.targets = [os.path.join(target_dir, x + self._TARGET_FILE_EXT) for x in file_names]
- assert len(self.images) == len(self.targets)
- def __len__(self) -> int:
- return len(self.images)
- class VOCSegmentation(_VOCBase):
- """`Pascal VOC <http://host.robots.ox.ac.uk/pascal/VOC/>`_ Segmentation Dataset.
- Args:
- root (string): Root directory of the VOC Dataset.
- year (string, optional): The dataset year, supports years ``"2007"`` to ``"2012"``.
- image_set (string, optional): Select the image_set to use, ``"train"``, ``"trainval"`` or ``"val"``. If
- ``year=="2007"``, can also be ``"test"``.
- download (bool, optional): If true, downloads the dataset from the internet and
- puts it in root directory. If dataset is already downloaded, it is not
- downloaded again.
- transform (callable, optional): A function/transform that takes in an PIL image
- and returns a transformed version. E.g, ``transforms.RandomCrop``
- target_transform (callable, optional): A function/transform that takes in the
- target and transforms it.
- transforms (callable, optional): A function/transform that takes input sample and its target as entry
- and returns a transformed version.
- """
- _SPLITS_DIR = "Segmentation"
- _TARGET_DIR = "SegmentationClass"
- _TARGET_FILE_EXT = ".png"
- @property
- def masks(self) -> List[str]:
- return self.targets
- def __getitem__(self, index: int) -> Tuple[Any, Any]:
- """
- Args:
- index (int): Index
- Returns:
- tuple: (image, target) where target is the image segmentation.
- """
- img = Image.open(self.images[index]).convert("RGB")
- target = Image.open(self.masks[index])
- if self.transforms is not None:
- img, target = self.transforms(img, target)
- return img, target
- class VOCDetection(_VOCBase):
- """`Pascal VOC <http://host.robots.ox.ac.uk/pascal/VOC/>`_ Detection Dataset.
- Args:
- root (string): Root directory of the VOC Dataset.
- year (string, optional): The dataset year, supports years ``"2007"`` to ``"2012"``.
- image_set (string, optional): Select the image_set to use, ``"train"``, ``"trainval"`` or ``"val"``. If
- ``year=="2007"``, can also be ``"test"``.
- download (bool, optional): If true, downloads the dataset from the internet and
- puts it in root directory. If dataset is already downloaded, it is not
- downloaded again.
- (default: alphabetic indexing of VOC's 20 classes).
- transform (callable, optional): A function/transform that takes in an PIL image
- and returns a transformed version. E.g, ``transforms.RandomCrop``
- target_transform (callable, required): A function/transform that takes in the
- target and transforms it.
- transforms (callable, optional): A function/transform that takes input sample and its target as entry
- and returns a transformed version.
- """
- _SPLITS_DIR = "Main"
- _TARGET_DIR = "Annotations"
- _TARGET_FILE_EXT = ".xml"
- @property
- def annotations(self) -> List[str]:
- return self.targets
- def __getitem__(self, index: int) -> Tuple[Any, Any]:
- """
- Args:
- index (int): Index
- Returns:
- tuple: (image, target) where target is a dictionary of the XML tree.
- """
- img = Image.open(self.images[index]).convert("RGB")
- target = self.parse_voc_xml(ET_parse(self.annotations[index]).getroot())
- if self.transforms is not None:
- img, target = self.transforms(img, target)
- return img, target
- @staticmethod
- def parse_voc_xml(node: ET_Element) -> Dict[str, Any]:
- voc_dict: Dict[str, Any] = {}
- children = list(node)
- if children:
- def_dic: Dict[str, Any] = collections.defaultdict(list)
- for dc in map(VOCDetection.parse_voc_xml, children):
- for ind, v in dc.items():
- def_dic[ind].append(v)
- if node.tag == "annotation":
- def_dic["object"] = [def_dic["object"]]
- voc_dict = {node.tag: {ind: v[0] if len(v) == 1 else v for ind, v in def_dic.items()}}
- if node.text:
- text = node.text.strip()
- if not children:
- voc_dict[node.tag] = text
- return voc_dict
|