123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218 |
- import os
- import shutil
- import tempfile
- from contextlib import contextmanager
- from typing import Any, Dict, Iterator, List, Optional, Tuple
- import torch
- from .folder import ImageFolder
- from .utils import check_integrity, extract_archive, verify_str_arg
- ARCHIVE_META = {
- "train": ("ILSVRC2012_img_train.tar", "1d675b47d978889d74fa0da5fadfb00e"),
- "val": ("ILSVRC2012_img_val.tar", "29b22e2961454d5413ddabcf34fc5622"),
- "devkit": ("ILSVRC2012_devkit_t12.tar.gz", "fa75699e90414af021442c21a62c3abf"),
- }
- META_FILE = "meta.bin"
- class ImageNet(ImageFolder):
- """`ImageNet <http://image-net.org/>`_ 2012 Classification Dataset.
- .. note::
- Before using this class, it is required to download ImageNet 2012 dataset from
- `here <https://image-net.org/challenges/LSVRC/2012/2012-downloads.php>`_ and
- place the files ``ILSVRC2012_devkit_t12.tar.gz`` and ``ILSVRC2012_img_train.tar``
- or ``ILSVRC2012_img_val.tar`` based on ``split`` in the root directory.
- Args:
- root (string): Root directory of the ImageNet Dataset.
- split (string, optional): The dataset split, supports ``train``, or ``val``.
- transform (callable, optional): A function/transform that takes in an PIL image
- and returns a transformed version. E.g, ``transforms.RandomCrop``
- target_transform (callable, optional): A function/transform that takes in the
- target and transforms it.
- loader (callable, optional): A function to load an image given its path.
- Attributes:
- classes (list): List of the class name tuples.
- class_to_idx (dict): Dict with items (class_name, class_index).
- wnids (list): List of the WordNet IDs.
- wnid_to_idx (dict): Dict with items (wordnet_id, class_index).
- imgs (list): List of (image path, class_index) tuples
- targets (list): The class_index value for each image in the dataset
- """
- def __init__(self, root: str, split: str = "train", **kwargs: Any) -> None:
- root = self.root = os.path.expanduser(root)
- self.split = verify_str_arg(split, "split", ("train", "val"))
- self.parse_archives()
- wnid_to_classes = load_meta_file(self.root)[0]
- super().__init__(self.split_folder, **kwargs)
- self.root = root
- self.wnids = self.classes
- self.wnid_to_idx = self.class_to_idx
- self.classes = [wnid_to_classes[wnid] for wnid in self.wnids]
- self.class_to_idx = {cls: idx for idx, clss in enumerate(self.classes) for cls in clss}
- def parse_archives(self) -> None:
- if not check_integrity(os.path.join(self.root, META_FILE)):
- parse_devkit_archive(self.root)
- if not os.path.isdir(self.split_folder):
- if self.split == "train":
- parse_train_archive(self.root)
- elif self.split == "val":
- parse_val_archive(self.root)
- @property
- def split_folder(self) -> str:
- return os.path.join(self.root, self.split)
- def extra_repr(self) -> str:
- return "Split: {split}".format(**self.__dict__)
- def load_meta_file(root: str, file: Optional[str] = None) -> Tuple[Dict[str, str], List[str]]:
- if file is None:
- file = META_FILE
- file = os.path.join(root, file)
- if check_integrity(file):
- return torch.load(file)
- else:
- msg = (
- "The meta file {} is not present in the root directory or is corrupted. "
- "This file is automatically created by the ImageNet dataset."
- )
- raise RuntimeError(msg.format(file, root))
- def _verify_archive(root: str, file: str, md5: str) -> None:
- if not check_integrity(os.path.join(root, file), md5):
- msg = (
- "The archive {} is not present in the root directory or is corrupted. "
- "You need to download it externally and place it in {}."
- )
- raise RuntimeError(msg.format(file, root))
- def parse_devkit_archive(root: str, file: Optional[str] = None) -> None:
- """Parse the devkit archive of the ImageNet2012 classification dataset and save
- the meta information in a binary file.
- Args:
- root (str): Root directory containing the devkit archive
- file (str, optional): Name of devkit archive. Defaults to
- 'ILSVRC2012_devkit_t12.tar.gz'
- """
- import scipy.io as sio
- def parse_meta_mat(devkit_root: str) -> Tuple[Dict[int, str], Dict[str, Tuple[str, ...]]]:
- metafile = os.path.join(devkit_root, "data", "meta.mat")
- meta = sio.loadmat(metafile, squeeze_me=True)["synsets"]
- nums_children = list(zip(*meta))[4]
- meta = [meta[idx] for idx, num_children in enumerate(nums_children) if num_children == 0]
- idcs, wnids, classes = list(zip(*meta))[:3]
- classes = [tuple(clss.split(", ")) for clss in classes]
- idx_to_wnid = {idx: wnid for idx, wnid in zip(idcs, wnids)}
- wnid_to_classes = {wnid: clss for wnid, clss in zip(wnids, classes)}
- return idx_to_wnid, wnid_to_classes
- def parse_val_groundtruth_txt(devkit_root: str) -> List[int]:
- file = os.path.join(devkit_root, "data", "ILSVRC2012_validation_ground_truth.txt")
- with open(file) as txtfh:
- val_idcs = txtfh.readlines()
- return [int(val_idx) for val_idx in val_idcs]
- @contextmanager
- def get_tmp_dir() -> Iterator[str]:
- tmp_dir = tempfile.mkdtemp()
- try:
- yield tmp_dir
- finally:
- shutil.rmtree(tmp_dir)
- archive_meta = ARCHIVE_META["devkit"]
- if file is None:
- file = archive_meta[0]
- md5 = archive_meta[1]
- _verify_archive(root, file, md5)
- with get_tmp_dir() as tmp_dir:
- extract_archive(os.path.join(root, file), tmp_dir)
- devkit_root = os.path.join(tmp_dir, "ILSVRC2012_devkit_t12")
- idx_to_wnid, wnid_to_classes = parse_meta_mat(devkit_root)
- val_idcs = parse_val_groundtruth_txt(devkit_root)
- val_wnids = [idx_to_wnid[idx] for idx in val_idcs]
- torch.save((wnid_to_classes, val_wnids), os.path.join(root, META_FILE))
- def parse_train_archive(root: str, file: Optional[str] = None, folder: str = "train") -> None:
- """Parse the train images archive of the ImageNet2012 classification dataset and
- prepare it for usage with the ImageNet dataset.
- Args:
- root (str): Root directory containing the train images archive
- file (str, optional): Name of train images archive. Defaults to
- 'ILSVRC2012_img_train.tar'
- folder (str, optional): Optional name for train images folder. Defaults to
- 'train'
- """
- archive_meta = ARCHIVE_META["train"]
- if file is None:
- file = archive_meta[0]
- md5 = archive_meta[1]
- _verify_archive(root, file, md5)
- train_root = os.path.join(root, folder)
- extract_archive(os.path.join(root, file), train_root)
- archives = [os.path.join(train_root, archive) for archive in os.listdir(train_root)]
- for archive in archives:
- extract_archive(archive, os.path.splitext(archive)[0], remove_finished=True)
- def parse_val_archive(
- root: str, file: Optional[str] = None, wnids: Optional[List[str]] = None, folder: str = "val"
- ) -> None:
- """Parse the validation images archive of the ImageNet2012 classification dataset
- and prepare it for usage with the ImageNet dataset.
- Args:
- root (str): Root directory containing the validation images archive
- file (str, optional): Name of validation images archive. Defaults to
- 'ILSVRC2012_img_val.tar'
- wnids (list, optional): List of WordNet IDs of the validation images. If None
- is given, the IDs are loaded from the meta file in the root directory
- folder (str, optional): Optional name for validation images folder. Defaults to
- 'val'
- """
- archive_meta = ARCHIVE_META["val"]
- if file is None:
- file = archive_meta[0]
- md5 = archive_meta[1]
- if wnids is None:
- wnids = load_meta_file(root)[1]
- _verify_archive(root, file, md5)
- val_root = os.path.join(root, folder)
- extract_archive(os.path.join(root, file), val_root)
- images = sorted(os.path.join(val_root, image) for image in os.listdir(val_root))
- for wnid in set(wnids):
- os.mkdir(os.path.join(val_root, wnid))
- for wnid, img_file in zip(wnids, images):
- shutil.move(img_file, os.path.join(val_root, wnid, os.path.basename(img_file)))
|