123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104 |
- import os.path
- from typing import Any, Callable, List, Optional, Tuple
- from PIL import Image
- from .vision import VisionDataset
- class CocoDetection(VisionDataset):
- """`MS Coco Detection <https://cocodataset.org/#detection-2016>`_ Dataset.
- It requires the `COCO API to be installed <https://github.com/pdollar/coco/tree/master/PythonAPI>`_.
- Args:
- root (string): Root directory where images are downloaded to.
- annFile (string): Path to json annotation file.
- transform (callable, optional): A function/transform that takes in an PIL image
- and returns a transformed version. E.g, ``transforms.PILToTensor``
- target_transform (callable, optional): A function/transform that takes in the
- target and transforms it.
- transforms (callable, optional): A function/transform that takes input sample and its target as entry
- and returns a transformed version.
- """
- def __init__(
- self,
- root: str,
- annFile: str,
- transform: Optional[Callable] = None,
- target_transform: Optional[Callable] = None,
- transforms: Optional[Callable] = None,
- ) -> None:
- super().__init__(root, transforms, transform, target_transform)
- from pycocotools.coco import COCO
- self.coco = COCO(annFile)
- self.ids = list(sorted(self.coco.imgs.keys()))
- def _load_image(self, id: int) -> Image.Image:
- path = self.coco.loadImgs(id)[0]["file_name"]
- return Image.open(os.path.join(self.root, path)).convert("RGB")
- def _load_target(self, id: int) -> List[Any]:
- return self.coco.loadAnns(self.coco.getAnnIds(id))
- def __getitem__(self, index: int) -> Tuple[Any, Any]:
- id = self.ids[index]
- image = self._load_image(id)
- target = self._load_target(id)
- if self.transforms is not None:
- image, target = self.transforms(image, target)
- return image, target
- def __len__(self) -> int:
- return len(self.ids)
- class CocoCaptions(CocoDetection):
- """`MS Coco Captions <https://cocodataset.org/#captions-2015>`_ Dataset.
- It requires the `COCO API to be installed <https://github.com/pdollar/coco/tree/master/PythonAPI>`_.
- Args:
- root (string): Root directory where images are downloaded to.
- annFile (string): Path to json annotation file.
- transform (callable, optional): A function/transform that takes in an PIL image
- and returns a transformed version. E.g, ``transforms.PILToTensor``
- target_transform (callable, optional): A function/transform that takes in the
- target and transforms it.
- transforms (callable, optional): A function/transform that takes input sample and its target as entry
- and returns a transformed version.
- Example:
- .. code:: python
- import torchvision.datasets as dset
- import torchvision.transforms as transforms
- cap = dset.CocoCaptions(root = 'dir where images are',
- annFile = 'json annotation file',
- transform=transforms.PILToTensor())
- print('Number of samples: ', len(cap))
- img, target = cap[3] # load 4th sample
- print("Image Size: ", img.size())
- print(target)
- Output: ::
- Number of samples: 82783
- Image Size: (3L, 427L, 640L)
- [u'A plane emitting smoke stream flying over a mountain.',
- u'A plane darts across a bright blue sky behind a mountain covered in snow',
- u'A plane leaves a contrail above the snowy mountain top.',
- u'A mountain that has a plane flying overheard in the distance.',
- u'A mountain view with a plume of smoke in the background']
- """
- def _load_target(self, id: int) -> List[str]:
- return [ann["caption"] for ann in super()._load_target(id)]
|