flickr.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166
  1. import glob
  2. import os
  3. from collections import defaultdict
  4. from html.parser import HTMLParser
  5. from typing import Any, Callable, Dict, List, Optional, Tuple
  6. from PIL import Image
  7. from .vision import VisionDataset
  8. class Flickr8kParser(HTMLParser):
  9. """Parser for extracting captions from the Flickr8k dataset web page."""
  10. def __init__(self, root: str) -> None:
  11. super().__init__()
  12. self.root = root
  13. # Data structure to store captions
  14. self.annotations: Dict[str, List[str]] = {}
  15. # State variables
  16. self.in_table = False
  17. self.current_tag: Optional[str] = None
  18. self.current_img: Optional[str] = None
  19. def handle_starttag(self, tag: str, attrs: List[Tuple[str, Optional[str]]]) -> None:
  20. self.current_tag = tag
  21. if tag == "table":
  22. self.in_table = True
  23. def handle_endtag(self, tag: str) -> None:
  24. self.current_tag = None
  25. if tag == "table":
  26. self.in_table = False
  27. def handle_data(self, data: str) -> None:
  28. if self.in_table:
  29. if data == "Image Not Found":
  30. self.current_img = None
  31. elif self.current_tag == "a":
  32. img_id = data.split("/")[-2]
  33. img_id = os.path.join(self.root, img_id + "_*.jpg")
  34. img_id = glob.glob(img_id)[0]
  35. self.current_img = img_id
  36. self.annotations[img_id] = []
  37. elif self.current_tag == "li" and self.current_img:
  38. img_id = self.current_img
  39. self.annotations[img_id].append(data.strip())
  40. class Flickr8k(VisionDataset):
  41. """`Flickr8k Entities <http://hockenmaier.cs.illinois.edu/8k-pictures.html>`_ Dataset.
  42. Args:
  43. root (string): Root directory where images are downloaded to.
  44. ann_file (string): Path to annotation file.
  45. transform (callable, optional): A function/transform that takes in a PIL image
  46. and returns a transformed version. E.g, ``transforms.PILToTensor``
  47. target_transform (callable, optional): A function/transform that takes in the
  48. target and transforms it.
  49. """
  50. def __init__(
  51. self,
  52. root: str,
  53. ann_file: str,
  54. transform: Optional[Callable] = None,
  55. target_transform: Optional[Callable] = None,
  56. ) -> None:
  57. super().__init__(root, transform=transform, target_transform=target_transform)
  58. self.ann_file = os.path.expanduser(ann_file)
  59. # Read annotations and store in a dict
  60. parser = Flickr8kParser(self.root)
  61. with open(self.ann_file) as fh:
  62. parser.feed(fh.read())
  63. self.annotations = parser.annotations
  64. self.ids = list(sorted(self.annotations.keys()))
  65. def __getitem__(self, index: int) -> Tuple[Any, Any]:
  66. """
  67. Args:
  68. index (int): Index
  69. Returns:
  70. tuple: Tuple (image, target). target is a list of captions for the image.
  71. """
  72. img_id = self.ids[index]
  73. # Image
  74. img = Image.open(img_id).convert("RGB")
  75. if self.transform is not None:
  76. img = self.transform(img)
  77. # Captions
  78. target = self.annotations[img_id]
  79. if self.target_transform is not None:
  80. target = self.target_transform(target)
  81. return img, target
  82. def __len__(self) -> int:
  83. return len(self.ids)
  84. class Flickr30k(VisionDataset):
  85. """`Flickr30k Entities <https://bryanplummer.com/Flickr30kEntities/>`_ Dataset.
  86. Args:
  87. root (string): Root directory where images are downloaded to.
  88. ann_file (string): Path to annotation file.
  89. transform (callable, optional): A function/transform that takes in a PIL image
  90. and returns a transformed version. E.g, ``transforms.PILToTensor``
  91. target_transform (callable, optional): A function/transform that takes in the
  92. target and transforms it.
  93. """
  94. def __init__(
  95. self,
  96. root: str,
  97. ann_file: str,
  98. transform: Optional[Callable] = None,
  99. target_transform: Optional[Callable] = None,
  100. ) -> None:
  101. super().__init__(root, transform=transform, target_transform=target_transform)
  102. self.ann_file = os.path.expanduser(ann_file)
  103. # Read annotations and store in a dict
  104. self.annotations = defaultdict(list)
  105. with open(self.ann_file) as fh:
  106. for line in fh:
  107. img_id, caption = line.strip().split("\t")
  108. self.annotations[img_id[:-2]].append(caption)
  109. self.ids = list(sorted(self.annotations.keys()))
  110. def __getitem__(self, index: int) -> Tuple[Any, Any]:
  111. """
  112. Args:
  113. index (int): Index
  114. Returns:
  115. tuple: Tuple (image, target). target is a list of captions for the image.
  116. """
  117. img_id = self.ids[index]
  118. # Image
  119. filename = os.path.join(self.root, img_id)
  120. img = Image.open(filename).convert("RGB")
  121. if self.transform is not None:
  122. img = self.transform(img)
  123. # Captions
  124. target = self.annotations[img_id]
  125. if self.target_transform is not None:
  126. target = self.target_transform(target)
  127. return img, target
  128. def __len__(self) -> int:
  129. return len(self.ids)