coco.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. import os.path
  2. from typing import Any, Callable, List, Optional, Tuple
  3. from PIL import Image
  4. from .vision import VisionDataset
  5. class CocoDetection(VisionDataset):
  6. """`MS Coco Detection <https://cocodataset.org/#detection-2016>`_ Dataset.
  7. It requires the `COCO API to be installed <https://github.com/pdollar/coco/tree/master/PythonAPI>`_.
  8. Args:
  9. root (string): Root directory where images are downloaded to.
  10. annFile (string): Path to json annotation file.
  11. transform (callable, optional): A function/transform that takes in an PIL image
  12. and returns a transformed version. E.g, ``transforms.PILToTensor``
  13. target_transform (callable, optional): A function/transform that takes in the
  14. target and transforms it.
  15. transforms (callable, optional): A function/transform that takes input sample and its target as entry
  16. and returns a transformed version.
  17. """
  18. def __init__(
  19. self,
  20. root: str,
  21. annFile: str,
  22. transform: Optional[Callable] = None,
  23. target_transform: Optional[Callable] = None,
  24. transforms: Optional[Callable] = None,
  25. ) -> None:
  26. super().__init__(root, transforms, transform, target_transform)
  27. from pycocotools.coco import COCO
  28. self.coco = COCO(annFile)
  29. self.ids = list(sorted(self.coco.imgs.keys()))
  30. def _load_image(self, id: int) -> Image.Image:
  31. path = self.coco.loadImgs(id)[0]["file_name"]
  32. return Image.open(os.path.join(self.root, path)).convert("RGB")
  33. def _load_target(self, id: int) -> List[Any]:
  34. return self.coco.loadAnns(self.coco.getAnnIds(id))
  35. def __getitem__(self, index: int) -> Tuple[Any, Any]:
  36. id = self.ids[index]
  37. image = self._load_image(id)
  38. target = self._load_target(id)
  39. if self.transforms is not None:
  40. image, target = self.transforms(image, target)
  41. return image, target
  42. def __len__(self) -> int:
  43. return len(self.ids)
  44. class CocoCaptions(CocoDetection):
  45. """`MS Coco Captions <https://cocodataset.org/#captions-2015>`_ Dataset.
  46. It requires the `COCO API to be installed <https://github.com/pdollar/coco/tree/master/PythonAPI>`_.
  47. Args:
  48. root (string): Root directory where images are downloaded to.
  49. annFile (string): Path to json annotation file.
  50. transform (callable, optional): A function/transform that takes in an PIL image
  51. and returns a transformed version. E.g, ``transforms.PILToTensor``
  52. target_transform (callable, optional): A function/transform that takes in the
  53. target and transforms it.
  54. transforms (callable, optional): A function/transform that takes input sample and its target as entry
  55. and returns a transformed version.
  56. Example:
  57. .. code:: python
  58. import torchvision.datasets as dset
  59. import torchvision.transforms as transforms
  60. cap = dset.CocoCaptions(root = 'dir where images are',
  61. annFile = 'json annotation file',
  62. transform=transforms.PILToTensor())
  63. print('Number of samples: ', len(cap))
  64. img, target = cap[3] # load 4th sample
  65. print("Image Size: ", img.size())
  66. print(target)
  67. Output: ::
  68. Number of samples: 82783
  69. Image Size: (3L, 427L, 640L)
  70. [u'A plane emitting smoke stream flying over a mountain.',
  71. u'A plane darts across a bright blue sky behind a mountain covered in snow',
  72. u'A plane leaves a contrail above the snowy mountain top.',
  73. u'A mountain that has a plane flying overheard in the distance.',
  74. u'A mountain view with a plume of smoke in the background']
  75. """
  76. def _load_target(self, id: int) -> List[str]:
  77. return [ann["caption"] for ann in super()._load_target(id)]