123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123 |
- import os
- import shutil
- from typing import Any, Callable, Optional, Tuple
- import numpy as np
- from PIL import Image
- from .utils import download_and_extract_archive, download_url, verify_str_arg
- from .vision import VisionDataset
- class SBDataset(VisionDataset):
- """`Semantic Boundaries Dataset <http://home.bharathh.info/pubs/codes/SBD/download.html>`_
- The SBD currently contains annotations from 11355 images taken from the PASCAL VOC 2011 dataset.
- .. note ::
- Please note that the train and val splits included with this dataset are different from
- the splits in the PASCAL VOC dataset. In particular some "train" images might be part of
- VOC2012 val.
- If you are interested in testing on VOC 2012 val, then use `image_set='train_noval'`,
- which excludes all val images.
- .. warning::
- This class needs `scipy <https://docs.scipy.org/doc/>`_ to load target files from `.mat` format.
- Args:
- root (string): Root directory of the Semantic Boundaries Dataset
- image_set (string, optional): Select the image_set to use, ``train``, ``val`` or ``train_noval``.
- Image set ``train_noval`` excludes VOC 2012 val images.
- mode (string, optional): Select target type. Possible values 'boundaries' or 'segmentation'.
- In case of 'boundaries', the target is an array of shape `[num_classes, H, W]`,
- where `num_classes=20`.
- download (bool, optional): If true, downloads the dataset from the internet and
- puts it in root directory. If dataset is already downloaded, it is not
- downloaded again.
- transforms (callable, optional): A function/transform that takes input sample and its target as entry
- and returns a transformed version. Input sample is PIL image and target is a numpy array
- if `mode='boundaries'` or PIL image if `mode='segmentation'`.
- """
- url = "https://www2.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/semantic_contours/benchmark.tgz"
- md5 = "82b4d87ceb2ed10f6038a1cba92111cb"
- filename = "benchmark.tgz"
- voc_train_url = "http://home.bharathh.info/pubs/codes/SBD/train_noval.txt"
- voc_split_filename = "train_noval.txt"
- voc_split_md5 = "79bff800c5f0b1ec6b21080a3c066722"
- def __init__(
- self,
- root: str,
- image_set: str = "train",
- mode: str = "boundaries",
- download: bool = False,
- transforms: Optional[Callable] = None,
- ) -> None:
- try:
- from scipy.io import loadmat
- self._loadmat = loadmat
- except ImportError:
- raise RuntimeError("Scipy is not found. This dataset needs to have scipy installed: pip install scipy")
- super().__init__(root, transforms)
- self.image_set = verify_str_arg(image_set, "image_set", ("train", "val", "train_noval"))
- self.mode = verify_str_arg(mode, "mode", ("segmentation", "boundaries"))
- self.num_classes = 20
- sbd_root = self.root
- image_dir = os.path.join(sbd_root, "img")
- mask_dir = os.path.join(sbd_root, "cls")
- if download:
- download_and_extract_archive(self.url, self.root, filename=self.filename, md5=self.md5)
- extracted_ds_root = os.path.join(self.root, "benchmark_RELEASE", "dataset")
- for f in ["cls", "img", "inst", "train.txt", "val.txt"]:
- old_path = os.path.join(extracted_ds_root, f)
- shutil.move(old_path, sbd_root)
- download_url(self.voc_train_url, sbd_root, self.voc_split_filename, self.voc_split_md5)
- if not os.path.isdir(sbd_root):
- raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
- split_f = os.path.join(sbd_root, image_set.rstrip("\n") + ".txt")
- with open(os.path.join(split_f)) as fh:
- file_names = [x.strip() for x in fh.readlines()]
- self.images = [os.path.join(image_dir, x + ".jpg") for x in file_names]
- self.masks = [os.path.join(mask_dir, x + ".mat") for x in file_names]
- self._get_target = self._get_segmentation_target if self.mode == "segmentation" else self._get_boundaries_target
- def _get_segmentation_target(self, filepath: str) -> Image.Image:
- mat = self._loadmat(filepath)
- return Image.fromarray(mat["GTcls"][0]["Segmentation"][0])
- def _get_boundaries_target(self, filepath: str) -> np.ndarray:
- mat = self._loadmat(filepath)
- return np.concatenate(
- [np.expand_dims(mat["GTcls"][0]["Boundaries"][0][i][0].toarray(), axis=0) for i in range(self.num_classes)],
- axis=0,
- )
- def __getitem__(self, index: int) -> Tuple[Any, Any]:
- img = Image.open(self.images[index]).convert("RGB")
- target = self._get_target(self.masks[index])
- if self.transforms is not None:
- img, target = self.transforms(img, target)
- return img, target
- def __len__(self) -> int:
- return len(self.images)
- def extra_repr(self) -> str:
- lines = ["Image set: {image_set}", "Mode: {mode}"]
- return "\n".join(lines).format(**self.__dict__)
|