stanford_cars.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121
  1. import pathlib
  2. from typing import Any, Callable, Optional, Tuple
  3. from PIL import Image
  4. from .utils import download_and_extract_archive, download_url, verify_str_arg
  5. from .vision import VisionDataset
  6. class StanfordCars(VisionDataset):
  7. """`Stanford Cars <https://ai.stanford.edu/~jkrause/cars/car_dataset.html>`_ Dataset
  8. The Cars dataset contains 16,185 images of 196 classes of cars. The data is
  9. split into 8,144 training images and 8,041 testing images, where each class
  10. has been split roughly in a 50-50 split
  11. .. note::
  12. This class needs `scipy <https://docs.scipy.org/doc/>`_ to load target files from `.mat` format.
  13. Args:
  14. root (string): Root directory of dataset
  15. split (string, optional): The dataset split, supports ``"train"`` (default) or ``"test"``.
  16. transform (callable, optional): A function/transform that takes in an PIL image
  17. and returns a transformed version. E.g, ``transforms.RandomCrop``
  18. target_transform (callable, optional): A function/transform that takes in the
  19. target and transforms it.
  20. download (bool, optional): If True, downloads the dataset from the internet and
  21. puts it in root directory. If dataset is already downloaded, it is not
  22. downloaded again."""
  23. def __init__(
  24. self,
  25. root: str,
  26. split: str = "train",
  27. transform: Optional[Callable] = None,
  28. target_transform: Optional[Callable] = None,
  29. download: bool = False,
  30. ) -> None:
  31. try:
  32. import scipy.io as sio
  33. except ImportError:
  34. raise RuntimeError("Scipy is not found. This dataset needs to have scipy installed: pip install scipy")
  35. super().__init__(root, transform=transform, target_transform=target_transform)
  36. self._split = verify_str_arg(split, "split", ("train", "test"))
  37. self._base_folder = pathlib.Path(root) / "stanford_cars"
  38. devkit = self._base_folder / "devkit"
  39. if self._split == "train":
  40. self._annotations_mat_path = devkit / "cars_train_annos.mat"
  41. self._images_base_path = self._base_folder / "cars_train"
  42. else:
  43. self._annotations_mat_path = self._base_folder / "cars_test_annos_withlabels.mat"
  44. self._images_base_path = self._base_folder / "cars_test"
  45. if download:
  46. self.download()
  47. if not self._check_exists():
  48. raise RuntimeError("Dataset not found. You can use download=True to download it")
  49. self._samples = [
  50. (
  51. str(self._images_base_path / annotation["fname"]),
  52. annotation["class"] - 1, # Original target mapping starts from 1, hence -1
  53. )
  54. for annotation in sio.loadmat(self._annotations_mat_path, squeeze_me=True)["annotations"]
  55. ]
  56. self.classes = sio.loadmat(str(devkit / "cars_meta.mat"), squeeze_me=True)["class_names"].tolist()
  57. self.class_to_idx = {cls: i for i, cls in enumerate(self.classes)}
  58. def __len__(self) -> int:
  59. return len(self._samples)
  60. def __getitem__(self, idx: int) -> Tuple[Any, Any]:
  61. """Returns pil_image and class_id for given index"""
  62. image_path, target = self._samples[idx]
  63. pil_image = Image.open(image_path).convert("RGB")
  64. if self.transform is not None:
  65. pil_image = self.transform(pil_image)
  66. if self.target_transform is not None:
  67. target = self.target_transform(target)
  68. return pil_image, target
  69. def download(self) -> None:
  70. if self._check_exists():
  71. return
  72. download_and_extract_archive(
  73. url="https://ai.stanford.edu/~jkrause/cars/car_devkit.tgz",
  74. download_root=str(self._base_folder),
  75. md5="c3b158d763b6e2245038c8ad08e45376",
  76. )
  77. if self._split == "train":
  78. download_and_extract_archive(
  79. url="https://ai.stanford.edu/~jkrause/car196/cars_train.tgz",
  80. download_root=str(self._base_folder),
  81. md5="065e5b463ae28d29e77c1b4b166cfe61",
  82. )
  83. else:
  84. download_and_extract_archive(
  85. url="https://ai.stanford.edu/~jkrause/car196/cars_test.tgz",
  86. download_root=str(self._base_folder),
  87. md5="4ce7ebf6a94d07f1952d94dd34c4d501",
  88. )
  89. download_url(
  90. url="https://ai.stanford.edu/~jkrause/car196/cars_test_annos_withlabels.mat",
  91. root=str(self._base_folder),
  92. md5="b0a2b23655a3edd16d84508592a98d10",
  93. )
  94. def _check_exists(self) -> bool:
  95. if not (self._base_folder / "devkit").is_dir():
  96. return False
  97. return self._annotations_mat_path.exists() and self._images_base_path.is_dir()