sun397.py 2.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576
  1. from pathlib import Path
  2. from typing import Any, Callable, Optional, Tuple
  3. import PIL.Image
  4. from .utils import download_and_extract_archive
  5. from .vision import VisionDataset
  6. class SUN397(VisionDataset):
  7. """`The SUN397 Data Set <https://vision.princeton.edu/projects/2010/SUN/>`_.
  8. The SUN397 or Scene UNderstanding (SUN) is a dataset for scene recognition consisting of
  9. 397 categories with 108'754 images.
  10. Args:
  11. root (string): Root directory of the dataset.
  12. transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed
  13. version. E.g, ``transforms.RandomCrop``.
  14. target_transform (callable, optional): A function/transform that takes in the target and transforms it.
  15. download (bool, optional): If true, downloads the dataset from the internet and
  16. puts it in root directory. If dataset is already downloaded, it is not
  17. downloaded again.
  18. """
  19. _DATASET_URL = "http://vision.princeton.edu/projects/2010/SUN/SUN397.tar.gz"
  20. _DATASET_MD5 = "8ca2778205c41d23104230ba66911c7a"
  21. def __init__(
  22. self,
  23. root: str,
  24. transform: Optional[Callable] = None,
  25. target_transform: Optional[Callable] = None,
  26. download: bool = False,
  27. ) -> None:
  28. super().__init__(root, transform=transform, target_transform=target_transform)
  29. self._data_dir = Path(self.root) / "SUN397"
  30. if download:
  31. self._download()
  32. if not self._check_exists():
  33. raise RuntimeError("Dataset not found. You can use download=True to download it")
  34. with open(self._data_dir / "ClassName.txt") as f:
  35. self.classes = [c[3:].strip() for c in f]
  36. self.class_to_idx = dict(zip(self.classes, range(len(self.classes))))
  37. self._image_files = list(self._data_dir.rglob("sun_*.jpg"))
  38. self._labels = [
  39. self.class_to_idx["/".join(path.relative_to(self._data_dir).parts[1:-1])] for path in self._image_files
  40. ]
  41. def __len__(self) -> int:
  42. return len(self._image_files)
  43. def __getitem__(self, idx: int) -> Tuple[Any, Any]:
  44. image_file, label = self._image_files[idx], self._labels[idx]
  45. image = PIL.Image.open(image_file).convert("RGB")
  46. if self.transform:
  47. image = self.transform(image)
  48. if self.target_transform:
  49. label = self.target_transform(label)
  50. return image, label
  51. def _check_exists(self) -> bool:
  52. return self._data_dir.is_dir()
  53. def _download(self) -> None:
  54. if self._check_exists():
  55. return
  56. download_and_extract_archive(self._DATASET_URL, download_root=self.root, md5=self._DATASET_MD5)