123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167 |
- import io
- import os.path
- import pickle
- import string
- from collections.abc import Iterable
- from typing import Any, Callable, cast, List, Optional, Tuple, Union
- from PIL import Image
- from .utils import iterable_to_str, verify_str_arg
- from .vision import VisionDataset
- class LSUNClass(VisionDataset):
- def __init__(
- self, root: str, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None
- ) -> None:
- import lmdb
- super().__init__(root, transform=transform, target_transform=target_transform)
- self.env = lmdb.open(root, max_readers=1, readonly=True, lock=False, readahead=False, meminit=False)
- with self.env.begin(write=False) as txn:
- self.length = txn.stat()["entries"]
- cache_file = "_cache_" + "".join(c for c in root if c in string.ascii_letters)
- if os.path.isfile(cache_file):
- self.keys = pickle.load(open(cache_file, "rb"))
- else:
- with self.env.begin(write=False) as txn:
- self.keys = [key for key in txn.cursor().iternext(keys=True, values=False)]
- pickle.dump(self.keys, open(cache_file, "wb"))
- def __getitem__(self, index: int) -> Tuple[Any, Any]:
- img, target = None, None
- env = self.env
- with env.begin(write=False) as txn:
- imgbuf = txn.get(self.keys[index])
- buf = io.BytesIO()
- buf.write(imgbuf)
- buf.seek(0)
- img = Image.open(buf).convert("RGB")
- if self.transform is not None:
- img = self.transform(img)
- if self.target_transform is not None:
- target = self.target_transform(target)
- return img, target
- def __len__(self) -> int:
- return self.length
- class LSUN(VisionDataset):
- """`LSUN <https://www.yf.io/p/lsun>`_ dataset.
- You will need to install the ``lmdb`` package to use this dataset: run
- ``pip install lmdb``
- Args:
- root (string): Root directory for the database files.
- classes (string or list): One of {'train', 'val', 'test'} or a list of
- categories to load. e,g. ['bedroom_train', 'church_outdoor_train'].
- transform (callable, optional): A function/transform that takes in an PIL image
- and returns a transformed version. E.g, ``transforms.RandomCrop``
- target_transform (callable, optional): A function/transform that takes in the
- target and transforms it.
- """
- def __init__(
- self,
- root: str,
- classes: Union[str, List[str]] = "train",
- transform: Optional[Callable] = None,
- target_transform: Optional[Callable] = None,
- ) -> None:
- super().__init__(root, transform=transform, target_transform=target_transform)
- self.classes = self._verify_classes(classes)
- # for each class, create an LSUNClassDataset
- self.dbs = []
- for c in self.classes:
- self.dbs.append(LSUNClass(root=os.path.join(root, f"{c}_lmdb"), transform=transform))
- self.indices = []
- count = 0
- for db in self.dbs:
- count += len(db)
- self.indices.append(count)
- self.length = count
- def _verify_classes(self, classes: Union[str, List[str]]) -> List[str]:
- categories = [
- "bedroom",
- "bridge",
- "church_outdoor",
- "classroom",
- "conference_room",
- "dining_room",
- "kitchen",
- "living_room",
- "restaurant",
- "tower",
- ]
- dset_opts = ["train", "val", "test"]
- try:
- classes = cast(str, classes)
- verify_str_arg(classes, "classes", dset_opts)
- if classes == "test":
- classes = [classes]
- else:
- classes = [c + "_" + classes for c in categories]
- except ValueError:
- if not isinstance(classes, Iterable):
- msg = "Expected type str or Iterable for argument classes, but got type {}."
- raise ValueError(msg.format(type(classes)))
- classes = list(classes)
- msg_fmtstr_type = "Expected type str for elements in argument classes, but got type {}."
- for c in classes:
- verify_str_arg(c, custom_msg=msg_fmtstr_type.format(type(c)))
- c_short = c.split("_")
- category, dset_opt = "_".join(c_short[:-1]), c_short[-1]
- msg_fmtstr = "Unknown value '{}' for {}. Valid values are {{{}}}."
- msg = msg_fmtstr.format(category, "LSUN class", iterable_to_str(categories))
- verify_str_arg(category, valid_values=categories, custom_msg=msg)
- msg = msg_fmtstr.format(dset_opt, "postfix", iterable_to_str(dset_opts))
- verify_str_arg(dset_opt, valid_values=dset_opts, custom_msg=msg)
- return classes
- def __getitem__(self, index: int) -> Tuple[Any, Any]:
- """
- Args:
- index (int): Index
- Returns:
- tuple: Tuple (image, target) where target is the index of the target category.
- """
- target = 0
- sub = 0
- for ind in self.indices:
- if index < ind:
- break
- target += 1
- sub = ind
- db = self.dbs[target]
- index = index - sub
- if self.target_transform is not None:
- target = self.target_transform(target)
- img, _ = db[index]
- return img, target
- def __len__(self) -> int:
- return self.length
- def extra_repr(self) -> str:
- return "Classes: {classes}".format(**self.__dict__)
|