123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189 |
- import csv
- import os
- from collections import namedtuple
- from typing import Any, Callable, List, Optional, Tuple, Union
- import PIL
- import torch
- from .utils import check_integrity, download_file_from_google_drive, extract_archive, verify_str_arg
- from .vision import VisionDataset
- CSV = namedtuple("CSV", ["header", "index", "data"])
- class CelebA(VisionDataset):
- """`Large-scale CelebFaces Attributes (CelebA) Dataset <http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html>`_ Dataset.
- Args:
- root (string): Root directory where images are downloaded to.
- split (string): One of {'train', 'valid', 'test', 'all'}.
- Accordingly dataset is selected.
- target_type (string or list, optional): Type of target to use, ``attr``, ``identity``, ``bbox``,
- or ``landmarks``. Can also be a list to output a tuple with all specified target types.
- The targets represent:
- - ``attr`` (Tensor shape=(40,) dtype=int): binary (0, 1) labels for attributes
- - ``identity`` (int): label for each person (data points with the same identity are the same person)
- - ``bbox`` (Tensor shape=(4,) dtype=int): bounding box (x, y, width, height)
- - ``landmarks`` (Tensor shape=(10,) dtype=int): landmark points (lefteye_x, lefteye_y, righteye_x,
- righteye_y, nose_x, nose_y, leftmouth_x, leftmouth_y, rightmouth_x, rightmouth_y)
- Defaults to ``attr``. If empty, ``None`` will be returned as target.
- transform (callable, optional): A function/transform that takes in an PIL image
- and returns a transformed version. E.g, ``transforms.PILToTensor``
- target_transform (callable, optional): A function/transform that takes in the
- target and transforms it.
- download (bool, optional): If true, downloads the dataset from the internet and
- puts it in root directory. If dataset is already downloaded, it is not
- downloaded again.
- """
- base_folder = "celeba"
- # There currently does not appear to be an easy way to extract 7z in python (without introducing additional
- # dependencies). The "in-the-wild" (not aligned+cropped) images are only in 7z, so they are not available
- # right now.
- file_list = [
- # File ID MD5 Hash Filename
- ("0B7EVK8r0v71pZjFTYXZWM3FlRnM", "00d2c5bc6d35e252742224ab0c1e8fcb", "img_align_celeba.zip"),
- # ("0B7EVK8r0v71pbWNEUjJKdDQ3dGc","b6cd7e93bc7a96c2dc33f819aa3ac651", "img_align_celeba_png.7z"),
- # ("0B7EVK8r0v71peklHb0pGdDl6R28", "b6cd7e93bc7a96c2dc33f819aa3ac651", "img_celeba.7z"),
- ("0B7EVK8r0v71pblRyaVFSWGxPY0U", "75e246fa4810816ffd6ee81facbd244c", "list_attr_celeba.txt"),
- ("1_ee_0u7vcNLOfNLegJRHmolfH5ICW-XS", "32bd1bd63d3c78cd57e08160ec5ed1e2", "identity_CelebA.txt"),
- ("0B7EVK8r0v71pbThiMVRxWXZ4dU0", "00566efa6fedff7a56946cd1c10f1c16", "list_bbox_celeba.txt"),
- ("0B7EVK8r0v71pd0FJY3Blby1HUTQ", "cc24ecafdb5b50baae59b03474781f8c", "list_landmarks_align_celeba.txt"),
- # ("0B7EVK8r0v71pTzJIdlJWdHczRlU", "063ee6ddb681f96bc9ca28c6febb9d1a", "list_landmarks_celeba.txt"),
- ("0B7EVK8r0v71pY0NSMzRuSXJEVkk", "d32c9cbf5e040fd4025c592c306e6668", "list_eval_partition.txt"),
- ]
- def __init__(
- self,
- root: str,
- split: str = "train",
- target_type: Union[List[str], str] = "attr",
- transform: Optional[Callable] = None,
- target_transform: Optional[Callable] = None,
- download: bool = False,
- ) -> None:
- super().__init__(root, transform=transform, target_transform=target_transform)
- self.split = split
- if isinstance(target_type, list):
- self.target_type = target_type
- else:
- self.target_type = [target_type]
- if not self.target_type and self.target_transform is not None:
- raise RuntimeError("target_transform is specified but target_type is empty")
- if download:
- self.download()
- if not self._check_integrity():
- raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
- split_map = {
- "train": 0,
- "valid": 1,
- "test": 2,
- "all": None,
- }
- split_ = split_map[verify_str_arg(split.lower(), "split", ("train", "valid", "test", "all"))]
- splits = self._load_csv("list_eval_partition.txt")
- identity = self._load_csv("identity_CelebA.txt")
- bbox = self._load_csv("list_bbox_celeba.txt", header=1)
- landmarks_align = self._load_csv("list_landmarks_align_celeba.txt", header=1)
- attr = self._load_csv("list_attr_celeba.txt", header=1)
- mask = slice(None) if split_ is None else (splits.data == split_).squeeze()
- if mask == slice(None): # if split == "all"
- self.filename = splits.index
- else:
- self.filename = [splits.index[i] for i in torch.squeeze(torch.nonzero(mask))]
- self.identity = identity.data[mask]
- self.bbox = bbox.data[mask]
- self.landmarks_align = landmarks_align.data[mask]
- self.attr = attr.data[mask]
- # map from {-1, 1} to {0, 1}
- self.attr = torch.div(self.attr + 1, 2, rounding_mode="floor")
- self.attr_names = attr.header
- def _load_csv(
- self,
- filename: str,
- header: Optional[int] = None,
- ) -> CSV:
- with open(os.path.join(self.root, self.base_folder, filename)) as csv_file:
- data = list(csv.reader(csv_file, delimiter=" ", skipinitialspace=True))
- if header is not None:
- headers = data[header]
- data = data[header + 1 :]
- else:
- headers = []
- indices = [row[0] for row in data]
- data = [row[1:] for row in data]
- data_int = [list(map(int, i)) for i in data]
- return CSV(headers, indices, torch.tensor(data_int))
- def _check_integrity(self) -> bool:
- for (_, md5, filename) in self.file_list:
- fpath = os.path.join(self.root, self.base_folder, filename)
- _, ext = os.path.splitext(filename)
- # Allow original archive to be deleted (zip and 7z)
- # Only need the extracted images
- if ext not in [".zip", ".7z"] and not check_integrity(fpath, md5):
- return False
- # Should check a hash of the images
- return os.path.isdir(os.path.join(self.root, self.base_folder, "img_align_celeba"))
- def download(self) -> None:
- if self._check_integrity():
- print("Files already downloaded and verified")
- return
- for (file_id, md5, filename) in self.file_list:
- download_file_from_google_drive(file_id, os.path.join(self.root, self.base_folder), filename, md5)
- extract_archive(os.path.join(self.root, self.base_folder, "img_align_celeba.zip"))
- def __getitem__(self, index: int) -> Tuple[Any, Any]:
- X = PIL.Image.open(os.path.join(self.root, self.base_folder, "img_align_celeba", self.filename[index]))
- target: Any = []
- for t in self.target_type:
- if t == "attr":
- target.append(self.attr[index, :])
- elif t == "identity":
- target.append(self.identity[index, 0])
- elif t == "bbox":
- target.append(self.bbox[index, :])
- elif t == "landmarks":
- target.append(self.landmarks_align[index, :])
- else:
- # TODO: refactor with utils.verify_str_arg
- raise ValueError(f'Target type "{t}" is not recognized.')
- if self.transform is not None:
- X = self.transform(X)
- if target:
- target = tuple(target) if len(target) > 1 else target[0]
- if self.target_transform is not None:
- target = self.target_transform(target)
- else:
- target = None
- return X, target
- def __len__(self) -> int:
- return len(self.attr)
- def extra_repr(self) -> str:
- lines = ["Target type: {target_type}", "Split: {split}"]
- return "\n".join(lines).format(**self.__dict__)
|