eurosat.py 2.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758
  1. import os
  2. from typing import Callable, Optional
  3. from .folder import ImageFolder
  4. from .utils import download_and_extract_archive
  5. class EuroSAT(ImageFolder):
  6. """RGB version of the `EuroSAT <https://github.com/phelber/eurosat>`_ Dataset.
  7. Args:
  8. root (string): Root directory of dataset where ``root/eurosat`` exists.
  9. transform (callable, optional): A function/transform that takes in an PIL image
  10. and returns a transformed version. E.g, ``transforms.RandomCrop``
  11. target_transform (callable, optional): A function/transform that takes in the
  12. target and transforms it.
  13. download (bool, optional): If True, downloads the dataset from the internet and
  14. puts it in root directory. If dataset is already downloaded, it is not
  15. downloaded again. Default is False.
  16. """
  17. def __init__(
  18. self,
  19. root: str,
  20. transform: Optional[Callable] = None,
  21. target_transform: Optional[Callable] = None,
  22. download: bool = False,
  23. ) -> None:
  24. self.root = os.path.expanduser(root)
  25. self._base_folder = os.path.join(self.root, "eurosat")
  26. self._data_folder = os.path.join(self._base_folder, "2750")
  27. if download:
  28. self.download()
  29. if not self._check_exists():
  30. raise RuntimeError("Dataset not found. You can use download=True to download it")
  31. super().__init__(self._data_folder, transform=transform, target_transform=target_transform)
  32. self.root = os.path.expanduser(root)
  33. def __len__(self) -> int:
  34. return len(self.samples)
  35. def _check_exists(self) -> bool:
  36. return os.path.exists(self._data_folder)
  37. def download(self) -> None:
  38. if self._check_exists():
  39. return
  40. os.makedirs(self._base_folder, exist_ok=True)
  41. download_and_extract_archive(
  42. "https://madm.dfki.de/files/sentinel/EuroSAT.zip",
  43. download_root=self._base_folder,
  44. md5="c8fa014336c82ac7804f0398fcb19387",
  45. )