inaturalist.py 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241
  1. import os
  2. import os.path
  3. from typing import Any, Callable, Dict, List, Optional, Tuple, Union
  4. from PIL import Image
  5. from .utils import download_and_extract_archive, verify_str_arg
  6. from .vision import VisionDataset
  7. CATEGORIES_2021 = ["kingdom", "phylum", "class", "order", "family", "genus"]
  8. DATASET_URLS = {
  9. "2017": "https://ml-inat-competition-datasets.s3.amazonaws.com/2017/train_val_images.tar.gz",
  10. "2018": "https://ml-inat-competition-datasets.s3.amazonaws.com/2018/train_val2018.tar.gz",
  11. "2019": "https://ml-inat-competition-datasets.s3.amazonaws.com/2019/train_val2019.tar.gz",
  12. "2021_train": "https://ml-inat-competition-datasets.s3.amazonaws.com/2021/train.tar.gz",
  13. "2021_train_mini": "https://ml-inat-competition-datasets.s3.amazonaws.com/2021/train_mini.tar.gz",
  14. "2021_valid": "https://ml-inat-competition-datasets.s3.amazonaws.com/2021/val.tar.gz",
  15. }
  16. DATASET_MD5 = {
  17. "2017": "7c784ea5e424efaec655bd392f87301f",
  18. "2018": "b1c6952ce38f31868cc50ea72d066cc3",
  19. "2019": "c60a6e2962c9b8ccbd458d12c8582644",
  20. "2021_train": "e0526d53c7f7b2e3167b2b43bb2690ed",
  21. "2021_train_mini": "db6ed8330e634445efc8fec83ae81442",
  22. "2021_valid": "f6f6e0e242e3d4c9569ba56400938afc",
  23. }
  24. class INaturalist(VisionDataset):
  25. """`iNaturalist <https://github.com/visipedia/inat_comp>`_ Dataset.
  26. Args:
  27. root (string): Root directory of dataset where the image files are stored.
  28. This class does not require/use annotation files.
  29. version (string, optional): Which version of the dataset to download/use. One of
  30. '2017', '2018', '2019', '2021_train', '2021_train_mini', '2021_valid'.
  31. Default: `2021_train`.
  32. target_type (string or list, optional): Type of target to use, for 2021 versions, one of:
  33. - ``full``: the full category (species)
  34. - ``kingdom``: e.g. "Animalia"
  35. - ``phylum``: e.g. "Arthropoda"
  36. - ``class``: e.g. "Insecta"
  37. - ``order``: e.g. "Coleoptera"
  38. - ``family``: e.g. "Cleridae"
  39. - ``genus``: e.g. "Trichodes"
  40. for 2017-2019 versions, one of:
  41. - ``full``: the full (numeric) category
  42. - ``super``: the super category, e.g. "Amphibians"
  43. Can also be a list to output a tuple with all specified target types.
  44. Defaults to ``full``.
  45. transform (callable, optional): A function/transform that takes in an PIL image
  46. and returns a transformed version. E.g, ``transforms.RandomCrop``
  47. target_transform (callable, optional): A function/transform that takes in the
  48. target and transforms it.
  49. download (bool, optional): If true, downloads the dataset from the internet and
  50. puts it in root directory. If dataset is already downloaded, it is not
  51. downloaded again.
  52. """
  53. def __init__(
  54. self,
  55. root: str,
  56. version: str = "2021_train",
  57. target_type: Union[List[str], str] = "full",
  58. transform: Optional[Callable] = None,
  59. target_transform: Optional[Callable] = None,
  60. download: bool = False,
  61. ) -> None:
  62. self.version = verify_str_arg(version, "version", DATASET_URLS.keys())
  63. super().__init__(os.path.join(root, version), transform=transform, target_transform=target_transform)
  64. os.makedirs(root, exist_ok=True)
  65. if download:
  66. self.download()
  67. if not self._check_integrity():
  68. raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
  69. self.all_categories: List[str] = []
  70. # map: category type -> name of category -> index
  71. self.categories_index: Dict[str, Dict[str, int]] = {}
  72. # list indexed by category id, containing mapping from category type -> index
  73. self.categories_map: List[Dict[str, int]] = []
  74. if not isinstance(target_type, list):
  75. target_type = [target_type]
  76. if self.version[:4] == "2021":
  77. self.target_type = [verify_str_arg(t, "target_type", ("full", *CATEGORIES_2021)) for t in target_type]
  78. self._init_2021()
  79. else:
  80. self.target_type = [verify_str_arg(t, "target_type", ("full", "super")) for t in target_type]
  81. self._init_pre2021()
  82. # index of all files: (full category id, filename)
  83. self.index: List[Tuple[int, str]] = []
  84. for dir_index, dir_name in enumerate(self.all_categories):
  85. files = os.listdir(os.path.join(self.root, dir_name))
  86. for fname in files:
  87. self.index.append((dir_index, fname))
  88. def _init_2021(self) -> None:
  89. """Initialize based on 2021 layout"""
  90. self.all_categories = sorted(os.listdir(self.root))
  91. # map: category type -> name of category -> index
  92. self.categories_index = {k: {} for k in CATEGORIES_2021}
  93. for dir_index, dir_name in enumerate(self.all_categories):
  94. pieces = dir_name.split("_")
  95. if len(pieces) != 8:
  96. raise RuntimeError(f"Unexpected category name {dir_name}, wrong number of pieces")
  97. if pieces[0] != f"{dir_index:05d}":
  98. raise RuntimeError(f"Unexpected category id {pieces[0]}, expecting {dir_index:05d}")
  99. cat_map = {}
  100. for cat, name in zip(CATEGORIES_2021, pieces[1:7]):
  101. if name in self.categories_index[cat]:
  102. cat_id = self.categories_index[cat][name]
  103. else:
  104. cat_id = len(self.categories_index[cat])
  105. self.categories_index[cat][name] = cat_id
  106. cat_map[cat] = cat_id
  107. self.categories_map.append(cat_map)
  108. def _init_pre2021(self) -> None:
  109. """Initialize based on 2017-2019 layout"""
  110. # map: category type -> name of category -> index
  111. self.categories_index = {"super": {}}
  112. cat_index = 0
  113. super_categories = sorted(os.listdir(self.root))
  114. for sindex, scat in enumerate(super_categories):
  115. self.categories_index["super"][scat] = sindex
  116. subcategories = sorted(os.listdir(os.path.join(self.root, scat)))
  117. for subcat in subcategories:
  118. if self.version == "2017":
  119. # this version does not use ids as directory names
  120. subcat_i = cat_index
  121. cat_index += 1
  122. else:
  123. try:
  124. subcat_i = int(subcat)
  125. except ValueError:
  126. raise RuntimeError(f"Unexpected non-numeric dir name: {subcat}")
  127. if subcat_i >= len(self.categories_map):
  128. old_len = len(self.categories_map)
  129. self.categories_map.extend([{}] * (subcat_i - old_len + 1))
  130. self.all_categories.extend([""] * (subcat_i - old_len + 1))
  131. if self.categories_map[subcat_i]:
  132. raise RuntimeError(f"Duplicate category {subcat}")
  133. self.categories_map[subcat_i] = {"super": sindex}
  134. self.all_categories[subcat_i] = os.path.join(scat, subcat)
  135. # validate the dictionary
  136. for cindex, c in enumerate(self.categories_map):
  137. if not c:
  138. raise RuntimeError(f"Missing category {cindex}")
  139. def __getitem__(self, index: int) -> Tuple[Any, Any]:
  140. """
  141. Args:
  142. index (int): Index
  143. Returns:
  144. tuple: (image, target) where the type of target specified by target_type.
  145. """
  146. cat_id, fname = self.index[index]
  147. img = Image.open(os.path.join(self.root, self.all_categories[cat_id], fname))
  148. target: Any = []
  149. for t in self.target_type:
  150. if t == "full":
  151. target.append(cat_id)
  152. else:
  153. target.append(self.categories_map[cat_id][t])
  154. target = tuple(target) if len(target) > 1 else target[0]
  155. if self.transform is not None:
  156. img = self.transform(img)
  157. if self.target_transform is not None:
  158. target = self.target_transform(target)
  159. return img, target
  160. def __len__(self) -> int:
  161. return len(self.index)
  162. def category_name(self, category_type: str, category_id: int) -> str:
  163. """
  164. Args:
  165. category_type(str): one of "full", "kingdom", "phylum", "class", "order", "family", "genus" or "super"
  166. category_id(int): an index (class id) from this category
  167. Returns:
  168. the name of the category
  169. """
  170. if category_type == "full":
  171. return self.all_categories[category_id]
  172. else:
  173. if category_type not in self.categories_index:
  174. raise ValueError(f"Invalid category type '{category_type}'")
  175. else:
  176. for name, id in self.categories_index[category_type].items():
  177. if id == category_id:
  178. return name
  179. raise ValueError(f"Invalid category id {category_id} for {category_type}")
  180. def _check_integrity(self) -> bool:
  181. return os.path.exists(self.root) and len(os.listdir(self.root)) > 0
  182. def download(self) -> None:
  183. if self._check_integrity():
  184. raise RuntimeError(
  185. f"The directory {self.root} already exists. "
  186. f"If you want to re-download or re-extract the images, delete the directory."
  187. )
  188. base_root = os.path.dirname(self.root)
  189. download_and_extract_archive(
  190. DATASET_URLS[self.version], base_root, filename=f"{self.version}.tgz", md5=DATASET_MD5[self.version]
  191. )
  192. orig_dir_name = os.path.join(base_root, os.path.basename(DATASET_URLS[self.version]).rstrip(".tar.gz"))
  193. if not os.path.exists(orig_dir_name):
  194. raise RuntimeError(f"Unable to find downloaded files at {orig_dir_name}")
  195. os.rename(orig_dir_name, self.root)
  196. print(f"Dataset version '{self.version}' has been downloaded and prepared for use")