123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221 |
- import json
- import os
- from collections import namedtuple
- from typing import Any, Callable, Dict, List, Optional, Tuple, Union
- from PIL import Image
- from .utils import extract_archive, iterable_to_str, verify_str_arg
- from .vision import VisionDataset
- class Cityscapes(VisionDataset):
- """`Cityscapes <http://www.cityscapes-dataset.com/>`_ Dataset.
- Args:
- root (string): Root directory of dataset where directory ``leftImg8bit``
- and ``gtFine`` or ``gtCoarse`` are located.
- split (string, optional): The image split to use, ``train``, ``test`` or ``val`` if mode="fine"
- otherwise ``train``, ``train_extra`` or ``val``
- mode (string, optional): The quality mode to use, ``fine`` or ``coarse``
- target_type (string or list, optional): Type of target to use, ``instance``, ``semantic``, ``polygon``
- or ``color``. Can also be a list to output a tuple with all specified target types.
- transform (callable, optional): A function/transform that takes in a PIL image
- and returns a transformed version. E.g, ``transforms.RandomCrop``
- target_transform (callable, optional): A function/transform that takes in the
- target and transforms it.
- transforms (callable, optional): A function/transform that takes input sample and its target as entry
- and returns a transformed version.
- Examples:
- Get semantic segmentation target
- .. code-block:: python
- dataset = Cityscapes('./data/cityscapes', split='train', mode='fine',
- target_type='semantic')
- img, smnt = dataset[0]
- Get multiple targets
- .. code-block:: python
- dataset = Cityscapes('./data/cityscapes', split='train', mode='fine',
- target_type=['instance', 'color', 'polygon'])
- img, (inst, col, poly) = dataset[0]
- Validate on the "coarse" set
- .. code-block:: python
- dataset = Cityscapes('./data/cityscapes', split='val', mode='coarse',
- target_type='semantic')
- img, smnt = dataset[0]
- """
- # Based on https://github.com/mcordts/cityscapesScripts
- CityscapesClass = namedtuple(
- "CityscapesClass",
- ["name", "id", "train_id", "category", "category_id", "has_instances", "ignore_in_eval", "color"],
- )
- classes = [
- CityscapesClass("unlabeled", 0, 255, "void", 0, False, True, (0, 0, 0)),
- CityscapesClass("ego vehicle", 1, 255, "void", 0, False, True, (0, 0, 0)),
- CityscapesClass("rectification border", 2, 255, "void", 0, False, True, (0, 0, 0)),
- CityscapesClass("out of roi", 3, 255, "void", 0, False, True, (0, 0, 0)),
- CityscapesClass("static", 4, 255, "void", 0, False, True, (0, 0, 0)),
- CityscapesClass("dynamic", 5, 255, "void", 0, False, True, (111, 74, 0)),
- CityscapesClass("ground", 6, 255, "void", 0, False, True, (81, 0, 81)),
- CityscapesClass("road", 7, 0, "flat", 1, False, False, (128, 64, 128)),
- CityscapesClass("sidewalk", 8, 1, "flat", 1, False, False, (244, 35, 232)),
- CityscapesClass("parking", 9, 255, "flat", 1, False, True, (250, 170, 160)),
- CityscapesClass("rail track", 10, 255, "flat", 1, False, True, (230, 150, 140)),
- CityscapesClass("building", 11, 2, "construction", 2, False, False, (70, 70, 70)),
- CityscapesClass("wall", 12, 3, "construction", 2, False, False, (102, 102, 156)),
- CityscapesClass("fence", 13, 4, "construction", 2, False, False, (190, 153, 153)),
- CityscapesClass("guard rail", 14, 255, "construction", 2, False, True, (180, 165, 180)),
- CityscapesClass("bridge", 15, 255, "construction", 2, False, True, (150, 100, 100)),
- CityscapesClass("tunnel", 16, 255, "construction", 2, False, True, (150, 120, 90)),
- CityscapesClass("pole", 17, 5, "object", 3, False, False, (153, 153, 153)),
- CityscapesClass("polegroup", 18, 255, "object", 3, False, True, (153, 153, 153)),
- CityscapesClass("traffic light", 19, 6, "object", 3, False, False, (250, 170, 30)),
- CityscapesClass("traffic sign", 20, 7, "object", 3, False, False, (220, 220, 0)),
- CityscapesClass("vegetation", 21, 8, "nature", 4, False, False, (107, 142, 35)),
- CityscapesClass("terrain", 22, 9, "nature", 4, False, False, (152, 251, 152)),
- CityscapesClass("sky", 23, 10, "sky", 5, False, False, (70, 130, 180)),
- CityscapesClass("person", 24, 11, "human", 6, True, False, (220, 20, 60)),
- CityscapesClass("rider", 25, 12, "human", 6, True, False, (255, 0, 0)),
- CityscapesClass("car", 26, 13, "vehicle", 7, True, False, (0, 0, 142)),
- CityscapesClass("truck", 27, 14, "vehicle", 7, True, False, (0, 0, 70)),
- CityscapesClass("bus", 28, 15, "vehicle", 7, True, False, (0, 60, 100)),
- CityscapesClass("caravan", 29, 255, "vehicle", 7, True, True, (0, 0, 90)),
- CityscapesClass("trailer", 30, 255, "vehicle", 7, True, True, (0, 0, 110)),
- CityscapesClass("train", 31, 16, "vehicle", 7, True, False, (0, 80, 100)),
- CityscapesClass("motorcycle", 32, 17, "vehicle", 7, True, False, (0, 0, 230)),
- CityscapesClass("bicycle", 33, 18, "vehicle", 7, True, False, (119, 11, 32)),
- CityscapesClass("license plate", -1, -1, "vehicle", 7, False, True, (0, 0, 142)),
- ]
- def __init__(
- self,
- root: str,
- split: str = "train",
- mode: str = "fine",
- target_type: Union[List[str], str] = "instance",
- transform: Optional[Callable] = None,
- target_transform: Optional[Callable] = None,
- transforms: Optional[Callable] = None,
- ) -> None:
- super().__init__(root, transforms, transform, target_transform)
- self.mode = "gtFine" if mode == "fine" else "gtCoarse"
- self.images_dir = os.path.join(self.root, "leftImg8bit", split)
- self.targets_dir = os.path.join(self.root, self.mode, split)
- self.target_type = target_type
- self.split = split
- self.images = []
- self.targets = []
- verify_str_arg(mode, "mode", ("fine", "coarse"))
- if mode == "fine":
- valid_modes = ("train", "test", "val")
- else:
- valid_modes = ("train", "train_extra", "val")
- msg = "Unknown value '{}' for argument split if mode is '{}'. Valid values are {{{}}}."
- msg = msg.format(split, mode, iterable_to_str(valid_modes))
- verify_str_arg(split, "split", valid_modes, msg)
- if not isinstance(target_type, list):
- self.target_type = [target_type]
- [
- verify_str_arg(value, "target_type", ("instance", "semantic", "polygon", "color"))
- for value in self.target_type
- ]
- if not os.path.isdir(self.images_dir) or not os.path.isdir(self.targets_dir):
- if split == "train_extra":
- image_dir_zip = os.path.join(self.root, "leftImg8bit_trainextra.zip")
- else:
- image_dir_zip = os.path.join(self.root, "leftImg8bit_trainvaltest.zip")
- if self.mode == "gtFine":
- target_dir_zip = os.path.join(self.root, f"{self.mode}_trainvaltest.zip")
- elif self.mode == "gtCoarse":
- target_dir_zip = os.path.join(self.root, f"{self.mode}.zip")
- if os.path.isfile(image_dir_zip) and os.path.isfile(target_dir_zip):
- extract_archive(from_path=image_dir_zip, to_path=self.root)
- extract_archive(from_path=target_dir_zip, to_path=self.root)
- else:
- raise RuntimeError(
- "Dataset not found or incomplete. Please make sure all required folders for the"
- ' specified "split" and "mode" are inside the "root" directory'
- )
- for city in os.listdir(self.images_dir):
- img_dir = os.path.join(self.images_dir, city)
- target_dir = os.path.join(self.targets_dir, city)
- for file_name in os.listdir(img_dir):
- target_types = []
- for t in self.target_type:
- target_name = "{}_{}".format(
- file_name.split("_leftImg8bit")[0], self._get_target_suffix(self.mode, t)
- )
- target_types.append(os.path.join(target_dir, target_name))
- self.images.append(os.path.join(img_dir, file_name))
- self.targets.append(target_types)
- def __getitem__(self, index: int) -> Tuple[Any, Any]:
- """
- Args:
- index (int): Index
- Returns:
- tuple: (image, target) where target is a tuple of all target types if target_type is a list with more
- than one item. Otherwise, target is a json object if target_type="polygon", else the image segmentation.
- """
- image = Image.open(self.images[index]).convert("RGB")
- targets: Any = []
- for i, t in enumerate(self.target_type):
- if t == "polygon":
- target = self._load_json(self.targets[index][i])
- else:
- target = Image.open(self.targets[index][i])
- targets.append(target)
- target = tuple(targets) if len(targets) > 1 else targets[0]
- if self.transforms is not None:
- image, target = self.transforms(image, target)
- return image, target
- def __len__(self) -> int:
- return len(self.images)
- def extra_repr(self) -> str:
- lines = ["Split: {split}", "Mode: {mode}", "Type: {target_type}"]
- return "\n".join(lines).format(**self.__dict__)
- def _load_json(self, path: str) -> Dict[str, Any]:
- with open(path) as file:
- data = json.load(file)
- return data
- def _get_target_suffix(self, mode: str, target_type: str) -> str:
- if target_type == "instance":
- return f"{mode}_instanceIds.png"
- elif target_type == "semantic":
- return f"{mode}_labelIds.png"
- elif target_type == "color":
- return f"{mode}_color.png"
- else:
- return f"{mode}_polygons.json"
|