123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237 |
- import os
- import os.path
- from typing import Any, Callable, List, Optional, Tuple, Union
- from PIL import Image
- from .utils import download_and_extract_archive, verify_str_arg
- from .vision import VisionDataset
- class Caltech101(VisionDataset):
- """`Caltech 101 <https://data.caltech.edu/records/20086>`_ Dataset.
- .. warning::
- This class needs `scipy <https://docs.scipy.org/doc/>`_ to load target files from `.mat` format.
- Args:
- root (string): Root directory of dataset where directory
- ``caltech101`` exists or will be saved to if download is set to True.
- target_type (string or list, optional): Type of target to use, ``category`` or
- ``annotation``. Can also be a list to output a tuple with all specified
- target types. ``category`` represents the target class, and
- ``annotation`` is a list of points from a hand-generated outline.
- Defaults to ``category``.
- transform (callable, optional): A function/transform that takes in an PIL image
- and returns a transformed version. E.g, ``transforms.RandomCrop``
- target_transform (callable, optional): A function/transform that takes in the
- target and transforms it.
- download (bool, optional): If true, downloads the dataset from the internet and
- puts it in root directory. If dataset is already downloaded, it is not
- downloaded again.
- """
- def __init__(
- self,
- root: str,
- target_type: Union[List[str], str] = "category",
- transform: Optional[Callable] = None,
- target_transform: Optional[Callable] = None,
- download: bool = False,
- ) -> None:
- super().__init__(os.path.join(root, "caltech101"), transform=transform, target_transform=target_transform)
- os.makedirs(self.root, exist_ok=True)
- if isinstance(target_type, str):
- target_type = [target_type]
- self.target_type = [verify_str_arg(t, "target_type", ("category", "annotation")) for t in target_type]
- if download:
- self.download()
- if not self._check_integrity():
- raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
- self.categories = sorted(os.listdir(os.path.join(self.root, "101_ObjectCategories")))
- self.categories.remove("BACKGROUND_Google") # this is not a real class
- # For some reason, the category names in "101_ObjectCategories" and
- # "Annotations" do not always match. This is a manual map between the
- # two. Defaults to using same name, since most names are fine.
- name_map = {
- "Faces": "Faces_2",
- "Faces_easy": "Faces_3",
- "Motorbikes": "Motorbikes_16",
- "airplanes": "Airplanes_Side_2",
- }
- self.annotation_categories = list(map(lambda x: name_map[x] if x in name_map else x, self.categories))
- self.index: List[int] = []
- self.y = []
- for (i, c) in enumerate(self.categories):
- n = len(os.listdir(os.path.join(self.root, "101_ObjectCategories", c)))
- self.index.extend(range(1, n + 1))
- self.y.extend(n * [i])
- def __getitem__(self, index: int) -> Tuple[Any, Any]:
- """
- Args:
- index (int): Index
- Returns:
- tuple: (image, target) where the type of target specified by target_type.
- """
- import scipy.io
- img = Image.open(
- os.path.join(
- self.root,
- "101_ObjectCategories",
- self.categories[self.y[index]],
- f"image_{self.index[index]:04d}.jpg",
- )
- )
- target: Any = []
- for t in self.target_type:
- if t == "category":
- target.append(self.y[index])
- elif t == "annotation":
- data = scipy.io.loadmat(
- os.path.join(
- self.root,
- "Annotations",
- self.annotation_categories[self.y[index]],
- f"annotation_{self.index[index]:04d}.mat",
- )
- )
- target.append(data["obj_contour"])
- target = tuple(target) if len(target) > 1 else target[0]
- if self.transform is not None:
- img = self.transform(img)
- if self.target_transform is not None:
- target = self.target_transform(target)
- return img, target
- def _check_integrity(self) -> bool:
- # can be more robust and check hash of files
- return os.path.exists(os.path.join(self.root, "101_ObjectCategories"))
- def __len__(self) -> int:
- return len(self.index)
- def download(self) -> None:
- if self._check_integrity():
- print("Files already downloaded and verified")
- return
- download_and_extract_archive(
- "https://drive.google.com/file/d/137RyRjvTBkBiIfeYBNZBtViDHQ6_Ewsp",
- self.root,
- filename="101_ObjectCategories.tar.gz",
- md5="b224c7392d521a49829488ab0f1120d9",
- )
- download_and_extract_archive(
- "https://drive.google.com/file/d/175kQy3UsZ0wUEHZjqkUDdNVssr7bgh_m",
- self.root,
- filename="Annotations.tar",
- md5="6f83eeb1f24d99cab4eb377263132c91",
- )
- def extra_repr(self) -> str:
- return "Target type: {target_type}".format(**self.__dict__)
- class Caltech256(VisionDataset):
- """`Caltech 256 <https://data.caltech.edu/records/20087>`_ Dataset.
- Args:
- root (string): Root directory of dataset where directory
- ``caltech256`` exists or will be saved to if download is set to True.
- transform (callable, optional): A function/transform that takes in an PIL image
- and returns a transformed version. E.g, ``transforms.RandomCrop``
- target_transform (callable, optional): A function/transform that takes in the
- target and transforms it.
- download (bool, optional): If true, downloads the dataset from the internet and
- puts it in root directory. If dataset is already downloaded, it is not
- downloaded again.
- """
- def __init__(
- self,
- root: str,
- transform: Optional[Callable] = None,
- target_transform: Optional[Callable] = None,
- download: bool = False,
- ) -> None:
- super().__init__(os.path.join(root, "caltech256"), transform=transform, target_transform=target_transform)
- os.makedirs(self.root, exist_ok=True)
- if download:
- self.download()
- if not self._check_integrity():
- raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
- self.categories = sorted(os.listdir(os.path.join(self.root, "256_ObjectCategories")))
- self.index: List[int] = []
- self.y = []
- for (i, c) in enumerate(self.categories):
- n = len(
- [
- item
- for item in os.listdir(os.path.join(self.root, "256_ObjectCategories", c))
- if item.endswith(".jpg")
- ]
- )
- self.index.extend(range(1, n + 1))
- self.y.extend(n * [i])
- def __getitem__(self, index: int) -> Tuple[Any, Any]:
- """
- Args:
- index (int): Index
- Returns:
- tuple: (image, target) where target is index of the target class.
- """
- img = Image.open(
- os.path.join(
- self.root,
- "256_ObjectCategories",
- self.categories[self.y[index]],
- f"{self.y[index] + 1:03d}_{self.index[index]:04d}.jpg",
- )
- )
- target = self.y[index]
- if self.transform is not None:
- img = self.transform(img)
- if self.target_transform is not None:
- target = self.target_transform(target)
- return img, target
- def _check_integrity(self) -> bool:
- # can be more robust and check hash of files
- return os.path.exists(os.path.join(self.root, "256_ObjectCategories"))
- def __len__(self) -> int:
- return len(self.index)
- def download(self) -> None:
- if self._check_integrity():
- print("Files already downloaded and verified")
- return
- download_and_extract_archive(
- "https://drive.google.com/file/d/1r6o0pSROcV1_VwT4oSjA2FBUSCWGuxLK",
- self.root,
- filename="256_ObjectCategories.tar",
- md5="67b4f42ca05d46448c6bb8ecd2220f6d",
- )
|