123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166 |
- import glob
- import os
- from collections import defaultdict
- from html.parser import HTMLParser
- from typing import Any, Callable, Dict, List, Optional, Tuple
- from PIL import Image
- from .vision import VisionDataset
- class Flickr8kParser(HTMLParser):
- """Parser for extracting captions from the Flickr8k dataset web page."""
- def __init__(self, root: str) -> None:
- super().__init__()
- self.root = root
- # Data structure to store captions
- self.annotations: Dict[str, List[str]] = {}
- # State variables
- self.in_table = False
- self.current_tag: Optional[str] = None
- self.current_img: Optional[str] = None
- def handle_starttag(self, tag: str, attrs: List[Tuple[str, Optional[str]]]) -> None:
- self.current_tag = tag
- if tag == "table":
- self.in_table = True
- def handle_endtag(self, tag: str) -> None:
- self.current_tag = None
- if tag == "table":
- self.in_table = False
- def handle_data(self, data: str) -> None:
- if self.in_table:
- if data == "Image Not Found":
- self.current_img = None
- elif self.current_tag == "a":
- img_id = data.split("/")[-2]
- img_id = os.path.join(self.root, img_id + "_*.jpg")
- img_id = glob.glob(img_id)[0]
- self.current_img = img_id
- self.annotations[img_id] = []
- elif self.current_tag == "li" and self.current_img:
- img_id = self.current_img
- self.annotations[img_id].append(data.strip())
- class Flickr8k(VisionDataset):
- """`Flickr8k Entities <http://hockenmaier.cs.illinois.edu/8k-pictures.html>`_ Dataset.
- Args:
- root (string): Root directory where images are downloaded to.
- ann_file (string): Path to annotation file.
- transform (callable, optional): A function/transform that takes in a PIL image
- and returns a transformed version. E.g, ``transforms.PILToTensor``
- target_transform (callable, optional): A function/transform that takes in the
- target and transforms it.
- """
- def __init__(
- self,
- root: str,
- ann_file: str,
- transform: Optional[Callable] = None,
- target_transform: Optional[Callable] = None,
- ) -> None:
- super().__init__(root, transform=transform, target_transform=target_transform)
- self.ann_file = os.path.expanduser(ann_file)
- # Read annotations and store in a dict
- parser = Flickr8kParser(self.root)
- with open(self.ann_file) as fh:
- parser.feed(fh.read())
- self.annotations = parser.annotations
- self.ids = list(sorted(self.annotations.keys()))
- def __getitem__(self, index: int) -> Tuple[Any, Any]:
- """
- Args:
- index (int): Index
- Returns:
- tuple: Tuple (image, target). target is a list of captions for the image.
- """
- img_id = self.ids[index]
- # Image
- img = Image.open(img_id).convert("RGB")
- if self.transform is not None:
- img = self.transform(img)
- # Captions
- target = self.annotations[img_id]
- if self.target_transform is not None:
- target = self.target_transform(target)
- return img, target
- def __len__(self) -> int:
- return len(self.ids)
- class Flickr30k(VisionDataset):
- """`Flickr30k Entities <https://bryanplummer.com/Flickr30kEntities/>`_ Dataset.
- Args:
- root (string): Root directory where images are downloaded to.
- ann_file (string): Path to annotation file.
- transform (callable, optional): A function/transform that takes in a PIL image
- and returns a transformed version. E.g, ``transforms.PILToTensor``
- target_transform (callable, optional): A function/transform that takes in the
- target and transforms it.
- """
- def __init__(
- self,
- root: str,
- ann_file: str,
- transform: Optional[Callable] = None,
- target_transform: Optional[Callable] = None,
- ) -> None:
- super().__init__(root, transform=transform, target_transform=target_transform)
- self.ann_file = os.path.expanduser(ann_file)
- # Read annotations and store in a dict
- self.annotations = defaultdict(list)
- with open(self.ann_file) as fh:
- for line in fh:
- img_id, caption = line.strip().split("\t")
- self.annotations[img_id[:-2]].append(caption)
- self.ids = list(sorted(self.annotations.keys()))
- def __getitem__(self, index: int) -> Tuple[Any, Any]:
- """
- Args:
- index (int): Index
- Returns:
- tuple: Tuple (image, target). target is a list of captions for the image.
- """
- img_id = self.ids[index]
- # Image
- filename = os.path.join(self.root, img_id)
- img = Image.open(filename).convert("RGB")
- if self.transform is not None:
- img = self.transform(img)
- # Captions
- target = self.annotations[img_id]
- if self.target_transform is not None:
- target = self.target_transform(target)
- return img, target
- def __len__(self) -> int:
- return len(self.ids)
|