omniglot.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. from os.path import join
  2. from typing import Any, Callable, List, Optional, Tuple
  3. from PIL import Image
  4. from .utils import check_integrity, download_and_extract_archive, list_dir, list_files
  5. from .vision import VisionDataset
  6. class Omniglot(VisionDataset):
  7. """`Omniglot <https://github.com/brendenlake/omniglot>`_ Dataset.
  8. Args:
  9. root (string): Root directory of dataset where directory
  10. ``omniglot-py`` exists.
  11. background (bool, optional): If True, creates dataset from the "background" set, otherwise
  12. creates from the "evaluation" set. This terminology is defined by the authors.
  13. transform (callable, optional): A function/transform that takes in an PIL image
  14. and returns a transformed version. E.g, ``transforms.RandomCrop``
  15. target_transform (callable, optional): A function/transform that takes in the
  16. target and transforms it.
  17. download (bool, optional): If true, downloads the dataset zip files from the internet and
  18. puts it in root directory. If the zip files are already downloaded, they are not
  19. downloaded again.
  20. """
  21. folder = "omniglot-py"
  22. download_url_prefix = "https://raw.githubusercontent.com/brendenlake/omniglot/master/python"
  23. zips_md5 = {
  24. "images_background": "68d2efa1b9178cc56df9314c21c6e718",
  25. "images_evaluation": "6b91aef0f799c5bb55b94e3f2daec811",
  26. }
  27. def __init__(
  28. self,
  29. root: str,
  30. background: bool = True,
  31. transform: Optional[Callable] = None,
  32. target_transform: Optional[Callable] = None,
  33. download: bool = False,
  34. ) -> None:
  35. super().__init__(join(root, self.folder), transform=transform, target_transform=target_transform)
  36. self.background = background
  37. if download:
  38. self.download()
  39. if not self._check_integrity():
  40. raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
  41. self.target_folder = join(self.root, self._get_target_folder())
  42. self._alphabets = list_dir(self.target_folder)
  43. self._characters: List[str] = sum(
  44. ([join(a, c) for c in list_dir(join(self.target_folder, a))] for a in self._alphabets), []
  45. )
  46. self._character_images = [
  47. [(image, idx) for image in list_files(join(self.target_folder, character), ".png")]
  48. for idx, character in enumerate(self._characters)
  49. ]
  50. self._flat_character_images: List[Tuple[str, int]] = sum(self._character_images, [])
  51. def __len__(self) -> int:
  52. return len(self._flat_character_images)
  53. def __getitem__(self, index: int) -> Tuple[Any, Any]:
  54. """
  55. Args:
  56. index (int): Index
  57. Returns:
  58. tuple: (image, target) where target is index of the target character class.
  59. """
  60. image_name, character_class = self._flat_character_images[index]
  61. image_path = join(self.target_folder, self._characters[character_class], image_name)
  62. image = Image.open(image_path, mode="r").convert("L")
  63. if self.transform:
  64. image = self.transform(image)
  65. if self.target_transform:
  66. character_class = self.target_transform(character_class)
  67. return image, character_class
  68. def _check_integrity(self) -> bool:
  69. zip_filename = self._get_target_folder()
  70. if not check_integrity(join(self.root, zip_filename + ".zip"), self.zips_md5[zip_filename]):
  71. return False
  72. return True
  73. def download(self) -> None:
  74. if self._check_integrity():
  75. print("Files already downloaded and verified")
  76. return
  77. filename = self._get_target_folder()
  78. zip_filename = filename + ".zip"
  79. url = self.download_url_prefix + "/" + zip_filename
  80. download_and_extract_archive(url, self.root, filename=zip_filename, md5=self.zips_md5[filename])
  81. def _get_target_folder(self) -> str:
  82. return "images_background" if self.background else "images_evaluation"