sbu.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. import os
  2. from typing import Any, Callable, Optional, Tuple
  3. from PIL import Image
  4. from .utils import check_integrity, download_and_extract_archive, download_url
  5. from .vision import VisionDataset
  6. class SBU(VisionDataset):
  7. """`SBU Captioned Photo <http://www.cs.virginia.edu/~vicente/sbucaptions/>`_ Dataset.
  8. Args:
  9. root (string): Root directory of dataset where tarball
  10. ``SBUCaptionedPhotoDataset.tar.gz`` exists.
  11. transform (callable, optional): A function/transform that takes in a PIL image
  12. and returns a transformed version. E.g, ``transforms.RandomCrop``
  13. target_transform (callable, optional): A function/transform that takes in the
  14. target and transforms it.
  15. download (bool, optional): If True, downloads the dataset from the internet and
  16. puts it in root directory. If dataset is already downloaded, it is not
  17. downloaded again.
  18. """
  19. url = "https://www.cs.rice.edu/~vo9/sbucaptions/SBUCaptionedPhotoDataset.tar.gz"
  20. filename = "SBUCaptionedPhotoDataset.tar.gz"
  21. md5_checksum = "9aec147b3488753cf758b4d493422285"
  22. def __init__(
  23. self,
  24. root: str,
  25. transform: Optional[Callable] = None,
  26. target_transform: Optional[Callable] = None,
  27. download: bool = True,
  28. ) -> None:
  29. super().__init__(root, transform=transform, target_transform=target_transform)
  30. if download:
  31. self.download()
  32. if not self._check_integrity():
  33. raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
  34. # Read the caption for each photo
  35. self.photos = []
  36. self.captions = []
  37. file1 = os.path.join(self.root, "dataset", "SBU_captioned_photo_dataset_urls.txt")
  38. file2 = os.path.join(self.root, "dataset", "SBU_captioned_photo_dataset_captions.txt")
  39. for line1, line2 in zip(open(file1), open(file2)):
  40. url = line1.rstrip()
  41. photo = os.path.basename(url)
  42. filename = os.path.join(self.root, "dataset", photo)
  43. if os.path.exists(filename):
  44. caption = line2.rstrip()
  45. self.photos.append(photo)
  46. self.captions.append(caption)
  47. def __getitem__(self, index: int) -> Tuple[Any, Any]:
  48. """
  49. Args:
  50. index (int): Index
  51. Returns:
  52. tuple: (image, target) where target is a caption for the photo.
  53. """
  54. filename = os.path.join(self.root, "dataset", self.photos[index])
  55. img = Image.open(filename).convert("RGB")
  56. if self.transform is not None:
  57. img = self.transform(img)
  58. target = self.captions[index]
  59. if self.target_transform is not None:
  60. target = self.target_transform(target)
  61. return img, target
  62. def __len__(self) -> int:
  63. """The number of photos in the dataset."""
  64. return len(self.photos)
  65. def _check_integrity(self) -> bool:
  66. """Check the md5 checksum of the downloaded tarball."""
  67. root = self.root
  68. fpath = os.path.join(root, self.filename)
  69. if not check_integrity(fpath, self.md5_checksum):
  70. return False
  71. return True
  72. def download(self) -> None:
  73. """Download and extract the tarball, and download each individual photo."""
  74. if self._check_integrity():
  75. print("Files already downloaded and verified")
  76. return
  77. download_and_extract_archive(self.url, self.root, self.root, self.filename, self.md5_checksum)
  78. # Download individual photos
  79. with open(os.path.join(self.root, "dataset", "SBU_captioned_photo_dataset_urls.txt")) as fh:
  80. for line in fh:
  81. url = line.rstrip()
  82. try:
  83. download_url(url, os.path.join(self.root, "dataset"))
  84. except OSError:
  85. # The images point to public images on Flickr.
  86. # Note: Images might be removed by users at anytime.
  87. pass