folder.py 12 KB


  1. import os
  2. import os.path
  3. from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union
  4. from PIL import Image
  5. from .vision import VisionDataset
  6. def has_file_allowed_extension(filename: str, extensions: Union[str, Tuple[str, ...]]) -> bool:
  7. """Checks if a file is an allowed extension.
  8. Args:
  9. filename (string): path to a file
  10. extensions (tuple of strings): extensions to consider (lowercase)
  11. Returns:
  12. bool: True if the filename ends with one of given extensions
  13. """
  14. return filename.lower().endswith(extensions if isinstance(extensions, str) else tuple(extensions))
  15. def is_image_file(filename: str) -> bool:
  16. """Checks if a file is an allowed image extension.
  17. Args:
  18. filename (string): path to a file
  19. Returns:
  20. bool: True if the filename ends with a known image extension
  21. """
  22. return has_file_allowed_extension(filename, IMG_EXTENSIONS)
  23. def find_classes(directory: str) -> Tuple[List[str], Dict[str, int]]:
  24. """Finds the class folders in a dataset.
  25. See :class:`DatasetFolder` for details.
  26. """
  27. classes = sorted(entry.name for entry in os.scandir(directory) if entry.is_dir())
  28. if not classes:
  29. raise FileNotFoundError(f"Couldn't find any class folder in {directory}.")
  30. class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
  31. return classes, class_to_idx
  32. def make_dataset(
  33. directory: str,
  34. class_to_idx: Optional[Dict[str, int]] = None,
  35. extensions: Optional[Union[str, Tuple[str, ...]]] = None,
  36. is_valid_file: Optional[Callable[[str], bool]] = None,
  37. ) -> List[Tuple[str, int]]:
  38. """Generates a list of samples of a form (path_to_sample, class).
  39. See :class:`DatasetFolder` for details.
  40. Note: The class_to_idx parameter is here optional and will use the logic of the ``find_classes`` function
  41. by default.
  42. """
  43. directory = os.path.expanduser(directory)
  44. if class_to_idx is None:
  45. _, class_to_idx = find_classes(directory)
  46. elif not class_to_idx:
  47. raise ValueError("'class_to_index' must have at least one entry to collect any samples.")
  48. both_none = extensions is None and is_valid_file is None
  49. both_something = extensions is not None and is_valid_file is not None
  50. if both_none or both_something:
  51. raise ValueError("Both extensions and is_valid_file cannot be None or not None at the same time")
  52. if extensions is not None:
  53. def is_valid_file(x: str) -> bool:
  54. return has_file_allowed_extension(x, extensions) # type: ignore[arg-type]
  55. is_valid_file = cast(Callable[[str], bool], is_valid_file)
  56. instances = []
  57. available_classes = set()
  58. for target_class in sorted(class_to_idx.keys()):
  59. class_index = class_to_idx[target_class]
  60. target_dir = os.path.join(directory, target_class)
  61. if not os.path.isdir(target_dir):
  62. continue
  63. for root, _, fnames in sorted(os.walk(target_dir, followlinks=True)):
  64. for fname in sorted(fnames):
  65. path = os.path.join(root, fname)
  66. if is_valid_file(path):
  67. item = path, class_index
  68. instances.append(item)
  69. if target_class not in available_classes:
  70. available_classes.add(target_class)
  71. empty_classes = set(class_to_idx.keys()) - available_classes
  72. if empty_classes:
  73. msg = f"Found no valid file for the classes {', '.join(sorted(empty_classes))}. "
  74. if extensions is not None:
  75. msg += f"Supported extensions are: {extensions if isinstance(extensions, str) else ', '.join(extensions)}"
  76. raise FileNotFoundError(msg)
  77. return instances
  78. class DatasetFolder(VisionDataset):
  79. """A generic data loader.
  80. This default directory structure can be customized by overriding the
  81. :meth:`find_classes` method.
  82. Args:
  83. root (string): Root directory path.
  84. loader (callable): A function to load a sample given its path.
  85. extensions (tuple[string]): A list of allowed extensions.
  86. both extensions and is_valid_file should not be passed.
  87. transform (callable, optional): A function/transform that takes in
  88. a sample and returns a transformed version.
  89. E.g, ``transforms.RandomCrop`` for images.
  90. target_transform (callable, optional): A function/transform that takes
  91. in the target and transforms it.
  92. is_valid_file (callable, optional): A function that takes path of a file
  93. and check if the file is a valid file (used to check of corrupt files)
  94. both extensions and is_valid_file should not be passed.
  95. Attributes:
  96. classes (list): List of the class names sorted alphabetically.
  97. class_to_idx (dict): Dict with items (class_name, class_index).
  98. samples (list): List of (sample path, class_index) tuples
  99. targets (list): The class_index value for each image in the dataset
  100. """
  101. def __init__(
  102. self,
  103. root: str,
  104. loader: Callable[[str], Any],
  105. extensions: Optional[Tuple[str, ...]] = None,
  106. transform: Optional[Callable] = None,
  107. target_transform: Optional[Callable] = None,
  108. is_valid_file: Optional[Callable[[str], bool]] = None,
  109. ) -> None:
  110. super().__init__(root, transform=transform, target_transform=target_transform)
  111. classes, class_to_idx = self.find_classes(self.root)
  112. samples = self.make_dataset(self.root, class_to_idx, extensions, is_valid_file)
  113. self.loader = loader
  114. self.extensions = extensions
  115. self.classes = classes
  116. self.class_to_idx = class_to_idx
  117. self.samples = samples
  118. self.targets = [s[1] for s in samples]
  119. @staticmethod
  120. def make_dataset(
  121. directory: str,
  122. class_to_idx: Dict[str, int],
  123. extensions: Optional[Tuple[str, ...]] = None,
  124. is_valid_file: Optional[Callable[[str], bool]] = None,
  125. ) -> List[Tuple[str, int]]:
  126. """Generates a list of samples of a form (path_to_sample, class).
  127. This can be overridden to e.g. read files from a compressed zip file instead of from the disk.
  128. Args:
  129. directory (str): root dataset directory, corresponding to ``self.root``.
  130. class_to_idx (Dict[str, int]): Dictionary mapping class name to class index.
  131. extensions (optional): A list of allowed extensions.
  132. Either extensions or is_valid_file should be passed. Defaults to None.
  133. is_valid_file (optional): A function that takes path of a file
  134. and checks if the file is a valid file
  135. (used to check of corrupt files) both extensions and
  136. is_valid_file should not be passed. Defaults to None.
  137. Raises:
  138. ValueError: In case ``class_to_idx`` is empty.
  139. ValueError: In case ``extensions`` and ``is_valid_file`` are None or both are not None.
  140. FileNotFoundError: In case no valid file was found for any class.
  141. Returns:
  142. List[Tuple[str, int]]: samples of a form (path_to_sample, class)
  143. """
  144. if class_to_idx is None:
  145. # prevent potential bug since make_dataset() would use the class_to_idx logic of the
  146. # find_classes() function, instead of using that of the find_classes() method, which
  147. # is potentially overridden and thus could have a different logic.
  148. raise ValueError("The class_to_idx parameter cannot be None.")
  149. return make_dataset(directory, class_to_idx, extensions=extensions, is_valid_file=is_valid_file)
  150. def find_classes(self, directory: str) -> Tuple[List[str], Dict[str, int]]:
  151. """Find the class folders in a dataset structured as follows::
  152. directory/
  153. ├── class_x
  154. │ ├── xxx.ext
  155. │ ├── xxy.ext
  156. │ └── ...
  157. │ └── xxz.ext
  158. └── class_y
  159. ├── 123.ext
  160. ├── nsdf3.ext
  161. └── ...
  162. └── asd932_.ext
  163. This method can be overridden to only consider
  164. a subset of classes, or to adapt to a different dataset directory structure.
  165. Args:
  166. directory(str): Root directory path, corresponding to ``self.root``
  167. Raises:
  168. FileNotFoundError: If ``dir`` has no class folders.
  169. Returns:
  170. (Tuple[List[str], Dict[str, int]]): List of all classes and dictionary mapping each class to an index.
  171. """
  172. return find_classes(directory)
  173. def __getitem__(self, index: int) -> Tuple[Any, Any]:
  174. """
  175. Args:
  176. index (int): Index
  177. Returns:
  178. tuple: (sample, target) where target is class_index of the target class.
  179. """
  180. path, target = self.samples[index]
  181. sample = self.loader(path)
  182. if self.transform is not None:
  183. sample = self.transform(sample)
  184. if self.target_transform is not None:
  185. target = self.target_transform(target)
  186. return sample, target
  187. def __len__(self) -> int:
  188. return len(self.samples)
  189. IMG_EXTENSIONS = (".jpg", ".jpeg", ".png", ".ppm", ".bmp", ".pgm", ".tif", ".tiff", ".webp")
  190. def pil_loader(path: str) -> Image.Image:
  191. # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
  192. with open(path, "rb") as f:
  193. img = Image.open(f)
  194. return img.convert("RGB")
  195. # TODO: specify the return type
  196. def accimage_loader(path: str) -> Any:
  197. import accimage
  198. try:
  199. return accimage.Image(path)
  200. except OSError:
  201. # Potentially a decoding problem, fall back to PIL.Image
  202. return pil_loader(path)
  203. def default_loader(path: str) -> Any:
  204. from torchvision import get_image_backend
  205. if get_image_backend() == "accimage":
  206. return accimage_loader(path)
  207. else:
  208. return pil_loader(path)
  209. class ImageFolder(DatasetFolder):
  210. """A generic data loader where the images are arranged in this way by default: ::
  211. root/dog/xxx.png
  212. root/dog/xxy.png
  213. root/dog/[...]/xxz.png
  214. root/cat/123.png
  215. root/cat/nsdf3.png
  216. root/cat/[...]/asd932_.png
  217. This class inherits from :class:`~torchvision.datasets.DatasetFolder` so
  218. the same methods can be overridden to customize the dataset.
  219. Args:
  220. root (string): Root directory path.
  221. transform (callable, optional): A function/transform that takes in an PIL image
  222. and returns a transformed version. E.g, ``transforms.RandomCrop``
  223. target_transform (callable, optional): A function/transform that takes in the
  224. target and transforms it.
  225. loader (callable, optional): A function to load an image given its path.
  226. is_valid_file (callable, optional): A function that takes path of an Image file
  227. and check if the file is a valid file (used to check of corrupt files)
  228. Attributes:
  229. classes (list): List of the class names sorted alphabetically.
  230. class_to_idx (dict): Dict with items (class_name, class_index).
  231. imgs (list): List of (image path, class_index) tuples
  232. """
  233. def __init__(
  234. self,
  235. root: str,
  236. transform: Optional[Callable] = None,
  237. target_transform: Optional[Callable] = None,
  238. loader: Callable[[str], Any] = default_loader,
  239. is_valid_file: Optional[Callable[[str], bool]] = None,
  240. ):
  241. super().__init__(
  242. root,
  243. loader,
  244. IMG_EXTENSIONS if is_valid_file is None else None,
  245. transform=transform,
  246. target_transform=target_transform,
  247. is_valid_file=is_valid_file,
  248. )
  249. self.imgs = self.samples