places365.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170
  1. import os
  2. from os import path
  3. from typing import Any, Callable, Dict, List, Optional, Tuple
  4. from urllib.parse import urljoin
  5. from .folder import default_loader
  6. from .utils import check_integrity, download_and_extract_archive, verify_str_arg
  7. from .vision import VisionDataset
  8. class Places365(VisionDataset):
  9. r"""`Places365 <http://places2.csail.mit.edu/index.html>`_ classification dataset.
  10. Args:
  11. root (string): Root directory of the Places365 dataset.
  12. split (string, optional): The dataset split. Can be one of ``train-standard`` (default), ``train-challenge``,
  13. ``val``.
  14. small (bool, optional): If ``True``, uses the small images, i.e. resized to 256 x 256 pixels, instead of the
  15. high resolution ones.
  16. download (bool, optional): If ``True``, downloads the dataset components and places them in ``root``. Already
  17. downloaded archives are not downloaded again.
  18. transform (callable, optional): A function/transform that takes in an PIL image
  19. and returns a transformed version. E.g, ``transforms.RandomCrop``
  20. target_transform (callable, optional): A function/transform that takes in the
  21. target and transforms it.
  22. loader (callable, optional): A function to load an image given its path.
  23. Attributes:
  24. classes (list): List of the class names.
  25. class_to_idx (dict): Dict with items (class_name, class_index).
  26. imgs (list): List of (image path, class_index) tuples
  27. targets (list): The class_index value for each image in the dataset
  28. Raises:
  29. RuntimeError: If ``download is False`` and the meta files, i.e. the devkit, are not present or corrupted.
  30. RuntimeError: If ``download is True`` and the image archive is already extracted.
  31. """
  32. _SPLITS = ("train-standard", "train-challenge", "val")
  33. _BASE_URL = "http://data.csail.mit.edu/places/places365/"
  34. # {variant: (archive, md5)}
  35. _DEVKIT_META = {
  36. "standard": ("filelist_places365-standard.tar", "35a0585fee1fa656440f3ab298f8479c"),
  37. "challenge": ("filelist_places365-challenge.tar", "70a8307e459c3de41690a7c76c931734"),
  38. }
  39. # (file, md5)
  40. _CATEGORIES_META = ("categories_places365.txt", "06c963b85866bd0649f97cb43dd16673")
  41. # {split: (file, md5)}
  42. _FILE_LIST_META = {
  43. "train-standard": ("places365_train_standard.txt", "30f37515461640559006b8329efbed1a"),
  44. "train-challenge": ("places365_train_challenge.txt", "b2931dc997b8c33c27e7329c073a6b57"),
  45. "val": ("places365_val.txt", "e9f2fd57bfd9d07630173f4e8708e4b1"),
  46. }
  47. # {(split, small): (file, md5)}
  48. _IMAGES_META = {
  49. ("train-standard", False): ("train_large_places365standard.tar", "67e186b496a84c929568076ed01a8aa1"),
  50. ("train-challenge", False): ("train_large_places365challenge.tar", "605f18e68e510c82b958664ea134545f"),
  51. ("val", False): ("val_large.tar", "9b71c4993ad89d2d8bcbdc4aef38042f"),
  52. ("train-standard", True): ("train_256_places365standard.tar", "53ca1c756c3d1e7809517cc47c5561c5"),
  53. ("train-challenge", True): ("train_256_places365challenge.tar", "741915038a5e3471ec7332404dfb64ef"),
  54. ("val", True): ("val_256.tar", "e27b17d8d44f4af9a78502beb927f808"),
  55. }
  56. def __init__(
  57. self,
  58. root: str,
  59. split: str = "train-standard",
  60. small: bool = False,
  61. download: bool = False,
  62. transform: Optional[Callable] = None,
  63. target_transform: Optional[Callable] = None,
  64. loader: Callable[[str], Any] = default_loader,
  65. ) -> None:
  66. super().__init__(root, transform=transform, target_transform=target_transform)
  67. self.split = self._verify_split(split)
  68. self.small = small
  69. self.loader = loader
  70. self.classes, self.class_to_idx = self.load_categories(download)
  71. self.imgs, self.targets = self.load_file_list(download)
  72. if download:
  73. self.download_images()
  74. def __getitem__(self, index: int) -> Tuple[Any, Any]:
  75. file, target = self.imgs[index]
  76. image = self.loader(file)
  77. if self.transforms is not None:
  78. image, target = self.transforms(image, target)
  79. return image, target
  80. def __len__(self) -> int:
  81. return len(self.imgs)
  82. @property
  83. def variant(self) -> str:
  84. return "challenge" if "challenge" in self.split else "standard"
  85. @property
  86. def images_dir(self) -> str:
  87. size = "256" if self.small else "large"
  88. if self.split.startswith("train"):
  89. dir = f"data_{size}_{self.variant}"
  90. else:
  91. dir = f"{self.split}_{size}"
  92. return path.join(self.root, dir)
  93. def load_categories(self, download: bool = True) -> Tuple[List[str], Dict[str, int]]:
  94. def process(line: str) -> Tuple[str, int]:
  95. cls, idx = line.split()
  96. return cls, int(idx)
  97. file, md5 = self._CATEGORIES_META
  98. file = path.join(self.root, file)
  99. if not self._check_integrity(file, md5, download):
  100. self.download_devkit()
  101. with open(file) as fh:
  102. class_to_idx = dict(process(line) for line in fh)
  103. return sorted(class_to_idx.keys()), class_to_idx
  104. def load_file_list(self, download: bool = True) -> Tuple[List[Tuple[str, int]], List[int]]:
  105. def process(line: str, sep="/") -> Tuple[str, int]:
  106. image, idx = line.split()
  107. return path.join(self.images_dir, image.lstrip(sep).replace(sep, os.sep)), int(idx)
  108. file, md5 = self._FILE_LIST_META[self.split]
  109. file = path.join(self.root, file)
  110. if not self._check_integrity(file, md5, download):
  111. self.download_devkit()
  112. with open(file) as fh:
  113. images = [process(line) for line in fh]
  114. _, targets = zip(*images)
  115. return images, list(targets)
  116. def download_devkit(self) -> None:
  117. file, md5 = self._DEVKIT_META[self.variant]
  118. download_and_extract_archive(urljoin(self._BASE_URL, file), self.root, md5=md5)
  119. def download_images(self) -> None:
  120. if path.exists(self.images_dir):
  121. raise RuntimeError(
  122. f"The directory {self.images_dir} already exists. If you want to re-download or re-extract the images, "
  123. f"delete the directory."
  124. )
  125. file, md5 = self._IMAGES_META[(self.split, self.small)]
  126. download_and_extract_archive(urljoin(self._BASE_URL, file), self.root, md5=md5)
  127. if self.split.startswith("train"):
  128. os.rename(self.images_dir.rsplit("_", 1)[0], self.images_dir)
  129. def extra_repr(self) -> str:
  130. return "\n".join(("Split: {split}", "Small: {small}")).format(**self.__dict__)
  131. def _verify_split(self, split: str) -> str:
  132. return verify_str_arg(split, "split", self._SPLITS)
  133. def _check_integrity(self, file: str, md5: str, download: bool) -> bool:
  134. integrity = check_integrity(file, md5=md5)
  135. if not integrity and not download:
  136. raise RuntimeError(
  137. f"The file {file} does not exist or is corrupted. You can set download=True to download it."
  138. )
  139. return integrity