123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191 |
- import os
- from os.path import abspath, expanduser
- from typing import Any, Callable, Dict, List, Optional, Tuple, Union
- import torch
- from PIL import Image
- from .utils import download_and_extract_archive, download_file_from_google_drive, extract_archive, verify_str_arg
- from .vision import VisionDataset
- class WIDERFace(VisionDataset):
- """`WIDERFace <http://shuoyang1213.me/WIDERFACE/>`_ Dataset.
- Args:
- root (string): Root directory where images and annotations are downloaded to.
- Expects the following folder structure if download=False:
- .. code::
- <root>
- └── widerface
- ├── wider_face_split ('wider_face_split.zip' if compressed)
- ├── WIDER_train ('WIDER_train.zip' if compressed)
- ├── WIDER_val ('WIDER_val.zip' if compressed)
- └── WIDER_test ('WIDER_test.zip' if compressed)
- split (string): The dataset split to use. One of {``train``, ``val``, ``test``}.
- Defaults to ``train``.
- transform (callable, optional): A function/transform that takes in a PIL image
- and returns a transformed version. E.g, ``transforms.RandomCrop``
- target_transform (callable, optional): A function/transform that takes in the
- target and transforms it.
- download (bool, optional): If true, downloads the dataset from the internet and
- puts it in root directory. If dataset is already downloaded, it is not
- downloaded again.
- """
- BASE_FOLDER = "widerface"
- FILE_LIST = [
- # File ID MD5 Hash Filename
- ("15hGDLhsx8bLgLcIRD5DhYt5iBxnjNF1M", "3fedf70df600953d25982bcd13d91ba2", "WIDER_train.zip"),
- ("1GUCogbp16PMGa39thoMMeWxp7Rp5oM8Q", "dfa7d7e790efa35df3788964cf0bbaea", "WIDER_val.zip"),
- ("1HIfDbVEWKmsYKJZm4lchTBDLW5N7dY5T", "e5d8f4248ed24c334bbd12f49c29dd40", "WIDER_test.zip"),
- ]
- ANNOTATIONS_FILE = (
- "http://shuoyang1213.me/WIDERFACE/support/bbx_annotation/wider_face_split.zip",
- "0e3767bcf0e326556d407bf5bff5d27c",
- "wider_face_split.zip",
- )
- def __init__(
- self,
- root: str,
- split: str = "train",
- transform: Optional[Callable] = None,
- target_transform: Optional[Callable] = None,
- download: bool = False,
- ) -> None:
- super().__init__(
- root=os.path.join(root, self.BASE_FOLDER), transform=transform, target_transform=target_transform
- )
- # check arguments
- self.split = verify_str_arg(split, "split", ("train", "val", "test"))
- if download:
- self.download()
- if not self._check_integrity():
- raise RuntimeError("Dataset not found or corrupted. You can use download=True to download and prepare it")
- self.img_info: List[Dict[str, Union[str, Dict[str, torch.Tensor]]]] = []
- if self.split in ("train", "val"):
- self.parse_train_val_annotations_file()
- else:
- self.parse_test_annotations_file()
- def __getitem__(self, index: int) -> Tuple[Any, Any]:
- """
- Args:
- index (int): Index
- Returns:
- tuple: (image, target) where target is a dict of annotations for all faces in the image.
- target=None for the test split.
- """
- # stay consistent with other datasets and return a PIL Image
- img = Image.open(self.img_info[index]["img_path"])
- if self.transform is not None:
- img = self.transform(img)
- target = None if self.split == "test" else self.img_info[index]["annotations"]
- if self.target_transform is not None:
- target = self.target_transform(target)
- return img, target
- def __len__(self) -> int:
- return len(self.img_info)
- def extra_repr(self) -> str:
- lines = ["Split: {split}"]
- return "\n".join(lines).format(**self.__dict__)
- def parse_train_val_annotations_file(self) -> None:
- filename = "wider_face_train_bbx_gt.txt" if self.split == "train" else "wider_face_val_bbx_gt.txt"
- filepath = os.path.join(self.root, "wider_face_split", filename)
- with open(filepath) as f:
- lines = f.readlines()
- file_name_line, num_boxes_line, box_annotation_line = True, False, False
- num_boxes, box_counter = 0, 0
- labels = []
- for line in lines:
- line = line.rstrip()
- if file_name_line:
- img_path = os.path.join(self.root, "WIDER_" + self.split, "images", line)
- img_path = abspath(expanduser(img_path))
- file_name_line = False
- num_boxes_line = True
- elif num_boxes_line:
- num_boxes = int(line)
- num_boxes_line = False
- box_annotation_line = True
- elif box_annotation_line:
- box_counter += 1
- line_split = line.split(" ")
- line_values = [int(x) for x in line_split]
- labels.append(line_values)
- if box_counter >= num_boxes:
- box_annotation_line = False
- file_name_line = True
- labels_tensor = torch.tensor(labels)
- self.img_info.append(
- {
- "img_path": img_path,
- "annotations": {
- "bbox": labels_tensor[:, 0:4].clone(), # x, y, width, height
- "blur": labels_tensor[:, 4].clone(),
- "expression": labels_tensor[:, 5].clone(),
- "illumination": labels_tensor[:, 6].clone(),
- "occlusion": labels_tensor[:, 7].clone(),
- "pose": labels_tensor[:, 8].clone(),
- "invalid": labels_tensor[:, 9].clone(),
- },
- }
- )
- box_counter = 0
- labels.clear()
- else:
- raise RuntimeError(f"Error parsing annotation file {filepath}")
- def parse_test_annotations_file(self) -> None:
- filepath = os.path.join(self.root, "wider_face_split", "wider_face_test_filelist.txt")
- filepath = abspath(expanduser(filepath))
- with open(filepath) as f:
- lines = f.readlines()
- for line in lines:
- line = line.rstrip()
- img_path = os.path.join(self.root, "WIDER_test", "images", line)
- img_path = abspath(expanduser(img_path))
- self.img_info.append({"img_path": img_path})
- def _check_integrity(self) -> bool:
- # Allow original archive to be deleted (zip). Only need the extracted images
- all_files = self.FILE_LIST.copy()
- all_files.append(self.ANNOTATIONS_FILE)
- for (_, md5, filename) in all_files:
- file, ext = os.path.splitext(filename)
- extracted_dir = os.path.join(self.root, file)
- if not os.path.exists(extracted_dir):
- return False
- return True
- def download(self) -> None:
- if self._check_integrity():
- print("Files already downloaded and verified")
- return
- # download and extract image data
- for (file_id, md5, filename) in self.FILE_LIST:
- download_file_from_google_drive(file_id, self.root, filename, md5)
- filepath = os.path.join(self.root, filename)
- extract_archive(filepath)
- # download and extract annotation files
- download_and_extract_archive(
- url=self.ANNOTATIONS_FILE[0], download_root=self.root, md5=self.ANNOTATIONS_FILE[1]
- )
|