123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317 |
- import os
- import os.path
- from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union
- from PIL import Image
- from .vision import VisionDataset
- def has_file_allowed_extension(filename: str, extensions: Union[str, Tuple[str, ...]]) -> bool:
- """Checks if a file is an allowed extension.
- Args:
- filename (string): path to a file
- extensions (tuple of strings): extensions to consider (lowercase)
- Returns:
- bool: True if the filename ends with one of given extensions
- """
- return filename.lower().endswith(extensions if isinstance(extensions, str) else tuple(extensions))
- def is_image_file(filename: str) -> bool:
- """Checks if a file is an allowed image extension.
- Args:
- filename (string): path to a file
- Returns:
- bool: True if the filename ends with a known image extension
- """
- return has_file_allowed_extension(filename, IMG_EXTENSIONS)
- def find_classes(directory: str) -> Tuple[List[str], Dict[str, int]]:
- """Finds the class folders in a dataset.
- See :class:`DatasetFolder` for details.
- """
- classes = sorted(entry.name for entry in os.scandir(directory) if entry.is_dir())
- if not classes:
- raise FileNotFoundError(f"Couldn't find any class folder in {directory}.")
- class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
- return classes, class_to_idx
- def make_dataset(
- directory: str,
- class_to_idx: Optional[Dict[str, int]] = None,
- extensions: Optional[Union[str, Tuple[str, ...]]] = None,
- is_valid_file: Optional[Callable[[str], bool]] = None,
- ) -> List[Tuple[str, int]]:
- """Generates a list of samples of a form (path_to_sample, class).
- See :class:`DatasetFolder` for details.
- Note: The class_to_idx parameter is here optional and will use the logic of the ``find_classes`` function
- by default.
- """
- directory = os.path.expanduser(directory)
- if class_to_idx is None:
- _, class_to_idx = find_classes(directory)
- elif not class_to_idx:
- raise ValueError("'class_to_index' must have at least one entry to collect any samples.")
- both_none = extensions is None and is_valid_file is None
- both_something = extensions is not None and is_valid_file is not None
- if both_none or both_something:
- raise ValueError("Both extensions and is_valid_file cannot be None or not None at the same time")
- if extensions is not None:
- def is_valid_file(x: str) -> bool:
- return has_file_allowed_extension(x, extensions) # type: ignore[arg-type]
- is_valid_file = cast(Callable[[str], bool], is_valid_file)
- instances = []
- available_classes = set()
- for target_class in sorted(class_to_idx.keys()):
- class_index = class_to_idx[target_class]
- target_dir = os.path.join(directory, target_class)
- if not os.path.isdir(target_dir):
- continue
- for root, _, fnames in sorted(os.walk(target_dir, followlinks=True)):
- for fname in sorted(fnames):
- path = os.path.join(root, fname)
- if is_valid_file(path):
- item = path, class_index
- instances.append(item)
- if target_class not in available_classes:
- available_classes.add(target_class)
- empty_classes = set(class_to_idx.keys()) - available_classes
- if empty_classes:
- msg = f"Found no valid file for the classes {', '.join(sorted(empty_classes))}. "
- if extensions is not None:
- msg += f"Supported extensions are: {extensions if isinstance(extensions, str) else ', '.join(extensions)}"
- raise FileNotFoundError(msg)
- return instances
- class DatasetFolder(VisionDataset):
- """A generic data loader.
- This default directory structure can be customized by overriding the
- :meth:`find_classes` method.
- Args:
- root (string): Root directory path.
- loader (callable): A function to load a sample given its path.
- extensions (tuple[string]): A list of allowed extensions.
- both extensions and is_valid_file should not be passed.
- transform (callable, optional): A function/transform that takes in
- a sample and returns a transformed version.
- E.g, ``transforms.RandomCrop`` for images.
- target_transform (callable, optional): A function/transform that takes
- in the target and transforms it.
- is_valid_file (callable, optional): A function that takes path of a file
- and check if the file is a valid file (used to check of corrupt files)
- both extensions and is_valid_file should not be passed.
- Attributes:
- classes (list): List of the class names sorted alphabetically.
- class_to_idx (dict): Dict with items (class_name, class_index).
- samples (list): List of (sample path, class_index) tuples
- targets (list): The class_index value for each image in the dataset
- """
- def __init__(
- self,
- root: str,
- loader: Callable[[str], Any],
- extensions: Optional[Tuple[str, ...]] = None,
- transform: Optional[Callable] = None,
- target_transform: Optional[Callable] = None,
- is_valid_file: Optional[Callable[[str], bool]] = None,
- ) -> None:
- super().__init__(root, transform=transform, target_transform=target_transform)
- classes, class_to_idx = self.find_classes(self.root)
- samples = self.make_dataset(self.root, class_to_idx, extensions, is_valid_file)
- self.loader = loader
- self.extensions = extensions
- self.classes = classes
- self.class_to_idx = class_to_idx
- self.samples = samples
- self.targets = [s[1] for s in samples]
- @staticmethod
- def make_dataset(
- directory: str,
- class_to_idx: Dict[str, int],
- extensions: Optional[Tuple[str, ...]] = None,
- is_valid_file: Optional[Callable[[str], bool]] = None,
- ) -> List[Tuple[str, int]]:
- """Generates a list of samples of a form (path_to_sample, class).
- This can be overridden to e.g. read files from a compressed zip file instead of from the disk.
- Args:
- directory (str): root dataset directory, corresponding to ``self.root``.
- class_to_idx (Dict[str, int]): Dictionary mapping class name to class index.
- extensions (optional): A list of allowed extensions.
- Either extensions or is_valid_file should be passed. Defaults to None.
- is_valid_file (optional): A function that takes path of a file
- and checks if the file is a valid file
- (used to check of corrupt files) both extensions and
- is_valid_file should not be passed. Defaults to None.
- Raises:
- ValueError: In case ``class_to_idx`` is empty.
- ValueError: In case ``extensions`` and ``is_valid_file`` are None or both are not None.
- FileNotFoundError: In case no valid file was found for any class.
- Returns:
- List[Tuple[str, int]]: samples of a form (path_to_sample, class)
- """
- if class_to_idx is None:
- # prevent potential bug since make_dataset() would use the class_to_idx logic of the
- # find_classes() function, instead of using that of the find_classes() method, which
- # is potentially overridden and thus could have a different logic.
- raise ValueError("The class_to_idx parameter cannot be None.")
- return make_dataset(directory, class_to_idx, extensions=extensions, is_valid_file=is_valid_file)
- def find_classes(self, directory: str) -> Tuple[List[str], Dict[str, int]]:
- """Find the class folders in a dataset structured as follows::
- directory/
- ├── class_x
- │ ├── xxx.ext
- │ ├── xxy.ext
- │ └── ...
- │ └── xxz.ext
- └── class_y
- ├── 123.ext
- ├── nsdf3.ext
- └── ...
- └── asd932_.ext
- This method can be overridden to only consider
- a subset of classes, or to adapt to a different dataset directory structure.
- Args:
- directory(str): Root directory path, corresponding to ``self.root``
- Raises:
- FileNotFoundError: If ``dir`` has no class folders.
- Returns:
- (Tuple[List[str], Dict[str, int]]): List of all classes and dictionary mapping each class to an index.
- """
- return find_classes(directory)
- def __getitem__(self, index: int) -> Tuple[Any, Any]:
- """
- Args:
- index (int): Index
- Returns:
- tuple: (sample, target) where target is class_index of the target class.
- """
- path, target = self.samples[index]
- sample = self.loader(path)
- if self.transform is not None:
- sample = self.transform(sample)
- if self.target_transform is not None:
- target = self.target_transform(target)
- return sample, target
- def __len__(self) -> int:
- return len(self.samples)
- IMG_EXTENSIONS = (".jpg", ".jpeg", ".png", ".ppm", ".bmp", ".pgm", ".tif", ".tiff", ".webp")
- def pil_loader(path: str) -> Image.Image:
- # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
- with open(path, "rb") as f:
- img = Image.open(f)
- return img.convert("RGB")
- # TODO: specify the return type
- def accimage_loader(path: str) -> Any:
- import accimage
- try:
- return accimage.Image(path)
- except OSError:
- # Potentially a decoding problem, fall back to PIL.Image
- return pil_loader(path)
- def default_loader(path: str) -> Any:
- from torchvision import get_image_backend
- if get_image_backend() == "accimage":
- return accimage_loader(path)
- else:
- return pil_loader(path)
- class ImageFolder(DatasetFolder):
- """A generic data loader where the images are arranged in this way by default: ::
- root/dog/xxx.png
- root/dog/xxy.png
- root/dog/[...]/xxz.png
- root/cat/123.png
- root/cat/nsdf3.png
- root/cat/[...]/asd932_.png
- This class inherits from :class:`~torchvision.datasets.DatasetFolder` so
- the same methods can be overridden to customize the dataset.
- Args:
- root (string): Root directory path.
- transform (callable, optional): A function/transform that takes in an PIL image
- and returns a transformed version. E.g, ``transforms.RandomCrop``
- target_transform (callable, optional): A function/transform that takes in the
- target and transforms it.
- loader (callable, optional): A function to load an image given its path.
- is_valid_file (callable, optional): A function that takes path of an Image file
- and check if the file is a valid file (used to check of corrupt files)
- Attributes:
- classes (list): List of the class names sorted alphabetically.
- class_to_idx (dict): Dict with items (class_name, class_index).
- imgs (list): List of (image path, class_index) tuples
- """
- def __init__(
- self,
- root: str,
- transform: Optional[Callable] = None,
- target_transform: Optional[Callable] = None,
- loader: Callable[[str], Any] = default_loader,
- is_valid_file: Optional[Callable[[str], bool]] = None,
- ):
- super().__init__(
- root,
- loader,
- IMG_EXTENSIONS if is_valid_file is None else None,
- transform=transform,
- target_transform=target_transform,
- is_valid_file=is_valid_file,
- )
- self.imgs = self.samples
|