datasets_utils.py 40 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051
  1. import contextlib
  2. import functools
  3. import importlib
  4. import inspect
  5. import itertools
  6. import os
  7. import pathlib
  8. import platform
  9. import random
  10. import shutil
  11. import string
  12. import struct
  13. import tarfile
  14. import unittest
  15. import unittest.mock
  16. import zipfile
  17. from collections import defaultdict
  18. from typing import Any, Callable, Dict, Iterator, List, Optional, Sequence, Tuple, Union
  19. import numpy as np
  20. import PIL
  21. import PIL.Image
  22. import pytest
  23. import torch
  24. import torchvision.datasets
  25. import torchvision.io
  26. from common_utils import disable_console_output, get_tmp_dir
  27. from torch.utils._pytree import tree_any
  28. from torchvision.transforms.functional import get_dimensions
  29. __all__ = [
  30. "UsageError",
  31. "lazy_importer",
  32. "test_all_configs",
  33. "DatasetTestCase",
  34. "ImageDatasetTestCase",
  35. "VideoDatasetTestCase",
  36. "create_image_or_video_tensor",
  37. "create_image_file",
  38. "create_image_folder",
  39. "create_video_file",
  40. "create_video_folder",
  41. "make_tar",
  42. "make_zip",
  43. "create_random_string",
  44. ]
  45. class UsageError(Exception):
  46. """Should be raised in case an error happens in the setup rather than the test."""
  47. class LazyImporter:
  48. r"""Lazy importer for additional dependencies.
  49. Some datasets require additional packages that are no direct dependencies of torchvision. Instances of this class
  50. provide modules listed in MODULES as attributes. They are only imported when accessed.
  51. """
  52. MODULES = (
  53. "av",
  54. "lmdb",
  55. "pycocotools",
  56. "requests",
  57. "scipy.io",
  58. "scipy.sparse",
  59. "h5py",
  60. )
  61. def __init__(self):
  62. modules = defaultdict(list)
  63. for module in self.MODULES:
  64. module, *submodules = module.split(".", 1)
  65. if submodules:
  66. modules[module].append(submodules[0])
  67. else:
  68. # This introduces the module so that it is known when we later iterate over the dictionary.
  69. modules.__missing__(module)
  70. for module, submodules in modules.items():
  71. # We need the quirky 'module=module' and submodules=submodules arguments to the lambda since otherwise the
  72. # lookup for these would happen at runtime rather than at definition. Thus, without it, every property
  73. # would try to import the last item in 'modules'
  74. setattr(
  75. type(self),
  76. module,
  77. property(lambda self, module=module, submodules=submodules: LazyImporter._import(module, submodules)),
  78. )
  79. @staticmethod
  80. def _import(package, subpackages):
  81. try:
  82. module = importlib.import_module(package)
  83. except ImportError as error:
  84. raise UsageError(
  85. f"Failed to import module '{package}'. "
  86. f"This probably means that the current test case needs '{package}' installed, "
  87. f"but it is not a dependency of torchvision. "
  88. f"You need to install it manually, for example 'pip install {package}'."
  89. ) from error
  90. for name in subpackages:
  91. importlib.import_module(f".{name}", package=package)
  92. return module
  93. lazy_importer = LazyImporter()
  94. def requires_lazy_imports(*modules):
  95. def outer_wrapper(fn):
  96. @functools.wraps(fn)
  97. def inner_wrapper(*args, **kwargs):
  98. for module in modules:
  99. getattr(lazy_importer, module.replace(".", "_"))
  100. return fn(*args, **kwargs)
  101. return inner_wrapper
  102. return outer_wrapper
  103. def test_all_configs(test):
  104. """Decorator to run test against all configurations.
  105. Add this as decorator to an arbitrary test to run it against all configurations. This includes
  106. :attr:`DatasetTestCase.DEFAULT_CONFIG` and :attr:`DatasetTestCase.ADDITIONAL_CONFIGS`.
  107. The current configuration is provided as the first parameter for the test:
  108. .. code-block::
  109. @test_all_configs()
  110. def test_foo(self, config):
  111. pass
  112. .. note::
  113. This will try to remove duplicate configurations. During this process it will not preserve a potential
  114. ordering of the configurations or an inner ordering of a configuration.
  115. """
  116. def maybe_remove_duplicates(configs):
  117. try:
  118. return [dict(config_) for config_ in {tuple(sorted(config.items())) for config in configs}]
  119. except TypeError:
  120. # A TypeError will be raised if a value of any config is not hashable, e.g. a list. In that case duplicate
  121. # removal would be a lot more elaborate, and we simply bail out.
  122. return configs
  123. @functools.wraps(test)
  124. def wrapper(self):
  125. configs = []
  126. if self.DEFAULT_CONFIG is not None:
  127. configs.append(self.DEFAULT_CONFIG)
  128. if self.ADDITIONAL_CONFIGS is not None:
  129. configs.extend(self.ADDITIONAL_CONFIGS)
  130. if not configs:
  131. configs = [self._KWARG_DEFAULTS.copy()]
  132. else:
  133. configs = maybe_remove_duplicates(configs)
  134. for config in configs:
  135. with self.subTest(**config):
  136. test(self, config)
  137. return wrapper
  138. class DatasetTestCase(unittest.TestCase):
  139. """Abstract base class for all dataset testcases.
  140. You have to overwrite the following class attributes:
  141. - DATASET_CLASS (torchvision.datasets.VisionDataset): Class of dataset to be tested.
  142. - FEATURE_TYPES (Sequence[Any]): Types of the elements returned by index access of the dataset. Instead of
  143. providing these manually, you can instead subclass ``ImageDatasetTestCase`` or ``VideoDatasetTestCase```to
  144. get a reasonable default, that should work for most cases. Each entry of the sequence may be a tuple,
  145. to indicate multiple possible values.
  146. Optionally, you can overwrite the following class attributes:
  147. - DEFAULT_CONFIG (Dict[str, Any]): Config that will be used by default. If omitted, this defaults to all
  148. keyword arguments of the dataset minus ``transform``, ``target_transform``, ``transforms``, and
  149. ``download``. Overwrite this if you want to use a default value for a parameter for which the dataset does
  150. not provide one.
  151. - ADDITIONAL_CONFIGS (Sequence[Dict[str, Any]]): Additional configs that should be tested. Each dictionary can
  152. contain an arbitrary combination of dataset parameters that are **not** ``transform``, ``target_transform``,
  153. ``transforms``, or ``download``.
  154. - REQUIRED_PACKAGES (Iterable[str]): Additional dependencies to use the dataset. If these packages are not
  155. available, the tests are skipped.
  156. Additionally, you need to overwrite the ``inject_fake_data()`` method that provides the data that the tests rely on.
  157. The fake data should resemble the original data as close as necessary, while containing only few examples. During
  158. the creation of the dataset check-, download-, and extract-functions from ``torchvision.datasets.utils`` are
  159. disabled.
  160. Without further configuration, the testcase will test if
  161. 1. the dataset raises a :class:`FileNotFoundError` or a :class:`RuntimeError` if the data files are not found or
  162. corrupted,
  163. 2. the dataset inherits from `torchvision.datasets.VisionDataset`,
  164. 3. the dataset can be turned into a string,
  165. 4. the feature types of a returned example matches ``FEATURE_TYPES``,
  166. 5. the number of examples matches the injected fake data, and
  167. 6. the dataset calls ``transform``, ``target_transform``, or ``transforms`` if available when accessing data.
  168. Case 3. to 6. are tested against all configurations in ``CONFIGS``.
  169. To add dataset-specific tests, create a new method that takes no arguments with ``test_`` as a name prefix:
  170. .. code-block::
  171. def test_foo(self):
  172. pass
  173. If you want to run the test against all configs, add the ``@test_all_configs`` decorator to the definition and
  174. accept a single argument:
  175. .. code-block::
  176. @test_all_configs
  177. def test_bar(self, config):
  178. pass
  179. Within the test you can use the ``create_dataset()`` method that yields the dataset as well as additional
  180. information provided by the ``ìnject_fake_data()`` method:
  181. .. code-block::
  182. def test_baz(self):
  183. with self.create_dataset() as (dataset, info):
  184. pass
  185. """
  186. DATASET_CLASS = None
  187. FEATURE_TYPES = None
  188. DEFAULT_CONFIG = None
  189. ADDITIONAL_CONFIGS = None
  190. REQUIRED_PACKAGES = None
  191. # These keyword arguments are checked by test_transforms in case they are available in DATASET_CLASS.
  192. _TRANSFORM_KWARGS = {
  193. "transform",
  194. "target_transform",
  195. "transforms",
  196. }
  197. # These keyword arguments get a 'special' treatment and should not be set in DEFAULT_CONFIG or ADDITIONAL_CONFIGS.
  198. _SPECIAL_KWARGS = {
  199. *_TRANSFORM_KWARGS,
  200. "download",
  201. }
  202. # These fields are populated during setupClass() within _populate_private_class_attributes()
  203. # This will be a dictionary containing all keyword arguments with their respective default values extracted from
  204. # the dataset constructor.
  205. _KWARG_DEFAULTS = None
  206. # This will be a set of all _SPECIAL_KWARGS that the dataset constructor takes.
  207. _HAS_SPECIAL_KWARG = None
  208. # These functions are disabled during dataset creation in create_dataset().
  209. _CHECK_FUNCTIONS = {
  210. "check_md5",
  211. "check_integrity",
  212. }
  213. _DOWNLOAD_EXTRACT_FUNCTIONS = {
  214. "download_url",
  215. "download_file_from_google_drive",
  216. "extract_archive",
  217. "download_and_extract_archive",
  218. }
  219. def dataset_args(self, tmpdir: str, config: Dict[str, Any]) -> Sequence[Any]:
  220. """Define positional arguments passed to the dataset.
  221. .. note::
  222. The default behavior is only valid if the dataset to be tested has ``root`` as the only required parameter.
  223. Otherwise, you need to overwrite this method.
  224. Args:
  225. tmpdir (str): Path to a temporary directory. For most cases this acts as root directory for the dataset
  226. to be created and in turn also for the fake data injected here.
  227. config (Dict[str, Any]): Configuration that will be passed to the dataset constructor. It provides at least
  228. fields for all dataset parameters with default values.
  229. Returns:
  230. (Tuple[str]): ``tmpdir`` which corresponds to ``root`` for most datasets.
  231. """
  232. return (tmpdir,)
  233. def inject_fake_data(self, tmpdir: str, config: Dict[str, Any]) -> Union[int, Dict[str, Any]]:
  234. """Inject fake data for dataset into a temporary directory.
  235. During the creation of the dataset the download and extract logic is disabled. Thus, the fake data injected
  236. here needs to resemble the raw data, i.e. the state of the dataset directly after the files are downloaded and
  237. potentially extracted.
  238. Args:
  239. tmpdir (str): Path to a temporary directory. For most cases this acts as root directory for the dataset
  240. to be created and in turn also for the fake data injected here.
  241. config (Dict[str, Any]): Configuration that will be passed to the dataset constructor. It provides at least
  242. fields for all dataset parameters with default values.
  243. Needs to return one of the following:
  244. 1. (int): Number of examples in the dataset to be created, or
  245. 2. (Dict[str, Any]): Additional information about the injected fake data. Must contain the field
  246. ``"num_examples"`` that corresponds to the number of examples in the dataset to be created.
  247. """
  248. raise NotImplementedError("You need to provide fake data in order for the tests to run.")
  249. @contextlib.contextmanager
  250. def create_dataset(
  251. self,
  252. config: Optional[Dict[str, Any]] = None,
  253. inject_fake_data: bool = True,
  254. patch_checks: Optional[bool] = None,
  255. **kwargs: Any,
  256. ) -> Iterator[Tuple[torchvision.datasets.VisionDataset, Dict[str, Any]]]:
  257. r"""Create the dataset in a temporary directory.
  258. The configuration passed to the dataset is populated to contain at least all parameters with default values.
  259. For this the following order of precedence is used:
  260. 1. Parameters in :attr:`kwargs`.
  261. 2. Configuration in :attr:`config`.
  262. 3. Configuration in :attr:`~DatasetTestCase.DEFAULT_CONFIG`.
  263. 4. Default parameters of the dataset.
  264. Args:
  265. config (Optional[Dict[str, Any]]): Configuration that will be used to create the dataset.
  266. inject_fake_data (bool): If ``True`` (default) inject the fake data with :meth:`.inject_fake_data` before
  267. creating the dataset.
  268. patch_checks (Optional[bool]): If ``True`` disable integrity check logic while creating the dataset. If
  269. omitted defaults to the same value as ``inject_fake_data``.
  270. **kwargs (Any): Additional parameters passed to the dataset. These parameters take precedence in case they
  271. overlap with ``config``.
  272. Yields:
  273. dataset (torchvision.dataset.VisionDataset): Dataset.
  274. info (Dict[str, Any]): Additional information about the injected fake data. See :meth:`.inject_fake_data`
  275. for details.
  276. """
  277. if patch_checks is None:
  278. patch_checks = inject_fake_data
  279. special_kwargs, other_kwargs = self._split_kwargs(kwargs)
  280. complete_config = self._KWARG_DEFAULTS.copy()
  281. if self.DEFAULT_CONFIG:
  282. complete_config.update(self.DEFAULT_CONFIG)
  283. if config:
  284. complete_config.update(config)
  285. if other_kwargs:
  286. complete_config.update(other_kwargs)
  287. if "download" in self._HAS_SPECIAL_KWARG and special_kwargs.get("download", False):
  288. # override download param to False param if its default is truthy
  289. special_kwargs["download"] = False
  290. patchers = self._patch_download_extract()
  291. if patch_checks:
  292. patchers.update(self._patch_checks())
  293. with get_tmp_dir() as tmpdir:
  294. args = self.dataset_args(tmpdir, complete_config)
  295. info = self._inject_fake_data(tmpdir, complete_config) if inject_fake_data else None
  296. with self._maybe_apply_patches(patchers), disable_console_output():
  297. dataset = self.DATASET_CLASS(*args, **complete_config, **special_kwargs)
  298. yield dataset, info
  299. @classmethod
  300. def setUpClass(cls):
  301. cls._verify_required_public_class_attributes()
  302. cls._populate_private_class_attributes()
  303. cls._process_optional_public_class_attributes()
  304. super().setUpClass()
  305. @classmethod
  306. def _verify_required_public_class_attributes(cls):
  307. if cls.DATASET_CLASS is None:
  308. raise UsageError(
  309. "The class attribute 'DATASET_CLASS' needs to be overwritten. "
  310. "It should contain the class of the dataset to be tested."
  311. )
  312. if cls.FEATURE_TYPES is None:
  313. raise UsageError(
  314. "The class attribute 'FEATURE_TYPES' needs to be overwritten. "
  315. "It should contain a sequence of types that the dataset returns when accessed by index."
  316. )
  317. @classmethod
  318. def _populate_private_class_attributes(cls):
  319. defaults = []
  320. for cls_ in cls.DATASET_CLASS.__mro__:
  321. if cls_ is torchvision.datasets.VisionDataset:
  322. break
  323. argspec = inspect.getfullargspec(cls_.__init__)
  324. if not argspec.defaults:
  325. continue
  326. defaults.append(
  327. {
  328. kwarg: default
  329. for kwarg, default in zip(argspec.args[-len(argspec.defaults) :], argspec.defaults)
  330. if not kwarg.startswith("_")
  331. }
  332. )
  333. if not argspec.varkw:
  334. break
  335. kwarg_defaults = dict()
  336. for config in reversed(defaults):
  337. kwarg_defaults.update(config)
  338. has_special_kwargs = set()
  339. for name in cls._SPECIAL_KWARGS:
  340. if name not in kwarg_defaults:
  341. continue
  342. del kwarg_defaults[name]
  343. has_special_kwargs.add(name)
  344. cls._KWARG_DEFAULTS = kwarg_defaults
  345. cls._HAS_SPECIAL_KWARG = has_special_kwargs
  346. @classmethod
  347. def _process_optional_public_class_attributes(cls):
  348. def check_config(config, name):
  349. special_kwargs = tuple(f"'{name}'" for name in cls._SPECIAL_KWARGS if name in config)
  350. if special_kwargs:
  351. raise UsageError(
  352. f"{name} contains a value for the parameter(s) {', '.join(special_kwargs)}. "
  353. f"These are handled separately by the test case and should not be set here. "
  354. f"If you need to test some custom behavior regarding these parameters, "
  355. f"you need to write a custom test (*not* test case), e.g. test_custom_transform()."
  356. )
  357. if cls.DEFAULT_CONFIG is not None:
  358. check_config(cls.DEFAULT_CONFIG, "DEFAULT_CONFIG")
  359. if cls.ADDITIONAL_CONFIGS is not None:
  360. for idx, config in enumerate(cls.ADDITIONAL_CONFIGS):
  361. check_config(config, f"CONFIGS[{idx}]")
  362. if cls.REQUIRED_PACKAGES:
  363. missing_pkgs = []
  364. for pkg in cls.REQUIRED_PACKAGES:
  365. try:
  366. importlib.import_module(pkg)
  367. except ImportError:
  368. missing_pkgs.append(f"'{pkg}'")
  369. if missing_pkgs:
  370. raise unittest.SkipTest(
  371. f"The package(s) {', '.join(missing_pkgs)} are required to load the dataset "
  372. f"'{cls.DATASET_CLASS.__name__}', but are not installed."
  373. )
  374. def _split_kwargs(self, kwargs):
  375. special_kwargs = kwargs.copy()
  376. other_kwargs = {key: special_kwargs.pop(key) for key in set(special_kwargs.keys()) - self._SPECIAL_KWARGS}
  377. return special_kwargs, other_kwargs
  378. def _inject_fake_data(self, tmpdir, config):
  379. info = self.inject_fake_data(tmpdir, config)
  380. if info is None:
  381. raise UsageError(
  382. "The method 'inject_fake_data' needs to return at least an integer indicating the number of "
  383. "examples for the current configuration."
  384. )
  385. elif isinstance(info, int):
  386. info = dict(num_examples=info)
  387. elif not isinstance(info, dict):
  388. raise UsageError(
  389. f"The additional information returned by the method 'inject_fake_data' must be either an "
  390. f"integer indicating the number of examples for the current configuration or a dictionary with "
  391. f"the same content. Got {type(info)} instead."
  392. )
  393. elif "num_examples" not in info:
  394. raise UsageError(
  395. "The information dictionary returned by the method 'inject_fake_data' must contain a "
  396. "'num_examples' field that holds the number of examples for the current configuration."
  397. )
  398. return info
  399. def _patch_download_extract(self):
  400. module = inspect.getmodule(self.DATASET_CLASS).__name__
  401. return {unittest.mock.patch(f"{module}.{function}") for function in self._DOWNLOAD_EXTRACT_FUNCTIONS}
  402. def _patch_checks(self):
  403. module = inspect.getmodule(self.DATASET_CLASS).__name__
  404. return {unittest.mock.patch(f"{module}.{function}", return_value=True) for function in self._CHECK_FUNCTIONS}
  405. @contextlib.contextmanager
  406. def _maybe_apply_patches(self, patchers):
  407. with contextlib.ExitStack() as stack:
  408. mocks = {}
  409. for patcher in patchers:
  410. with contextlib.suppress(AttributeError):
  411. mocks[patcher.target] = stack.enter_context(patcher)
  412. yield mocks
  413. def test_not_found_or_corrupted(self):
  414. with pytest.raises((FileNotFoundError, RuntimeError)):
  415. with self.create_dataset(inject_fake_data=False):
  416. pass
  417. def test_smoke(self):
  418. with self.create_dataset() as (dataset, _):
  419. assert isinstance(dataset, torchvision.datasets.VisionDataset)
  420. @test_all_configs
  421. def test_str_smoke(self, config):
  422. with self.create_dataset(config) as (dataset, _):
  423. assert isinstance(str(dataset), str)
  424. @test_all_configs
  425. def test_feature_types(self, config):
  426. with self.create_dataset(config) as (dataset, _):
  427. example = dataset[0]
  428. if len(self.FEATURE_TYPES) > 1:
  429. actual = len(example)
  430. expected = len(self.FEATURE_TYPES)
  431. assert (
  432. actual == expected
  433. ), "The number of the returned features does not match the the number of elements in FEATURE_TYPES: "
  434. f"{actual} != {expected}"
  435. else:
  436. example = (example,)
  437. for idx, (feature, expected_feature_type) in enumerate(zip(example, self.FEATURE_TYPES)):
  438. with self.subTest(idx=idx):
  439. assert isinstance(feature, expected_feature_type)
  440. @test_all_configs
  441. def test_num_examples(self, config):
  442. with self.create_dataset(config) as (dataset, info):
  443. assert len(list(dataset)) == len(dataset) == info["num_examples"]
  444. @test_all_configs
  445. def test_transforms(self, config):
  446. mock = unittest.mock.Mock(wraps=lambda *args: args[0] if len(args) == 1 else args)
  447. for kwarg in self._TRANSFORM_KWARGS:
  448. if kwarg not in self._HAS_SPECIAL_KWARG:
  449. continue
  450. mock.reset_mock()
  451. with self.subTest(kwarg=kwarg):
  452. with self.create_dataset(config, **{kwarg: mock}) as (dataset, _):
  453. dataset[0]
  454. mock.assert_called()
  455. @test_all_configs
  456. def test_transforms_v2_wrapper(self, config):
  457. from torchvision import tv_tensors
  458. from torchvision.datasets import wrap_dataset_for_transforms_v2
  459. try:
  460. with self.create_dataset(config) as (dataset, info):
  461. for target_keys in [None, "all"]:
  462. if target_keys is not None and self.DATASET_CLASS not in {
  463. torchvision.datasets.CocoDetection,
  464. torchvision.datasets.VOCDetection,
  465. torchvision.datasets.Kitti,
  466. torchvision.datasets.WIDERFace,
  467. }:
  468. with self.assertRaisesRegex(ValueError, "`target_keys` is currently only supported for"):
  469. wrap_dataset_for_transforms_v2(dataset, target_keys=target_keys)
  470. continue
  471. wrapped_dataset = wrap_dataset_for_transforms_v2(dataset, target_keys=target_keys)
  472. assert isinstance(wrapped_dataset, self.DATASET_CLASS)
  473. assert len(wrapped_dataset) == info["num_examples"]
  474. wrapped_sample = wrapped_dataset[0]
  475. assert tree_any(
  476. lambda item: isinstance(item, (tv_tensors.TVTensor, PIL.Image.Image)), wrapped_sample
  477. )
  478. except TypeError as error:
  479. msg = f"No wrapper exists for dataset class {type(dataset).__name__}"
  480. if str(error).startswith(msg):
  481. pytest.skip(msg)
  482. raise error
  483. except RuntimeError as error:
  484. if "currently not supported by this wrapper" in str(error):
  485. pytest.skip("Config is currently not supported by this wrapper")
  486. raise error
  487. class ImageDatasetTestCase(DatasetTestCase):
  488. """Abstract base class for image dataset testcases.
  489. - Overwrites the FEATURE_TYPES class attribute to expect a :class:`PIL.Image.Image` and an integer label.
  490. """
  491. FEATURE_TYPES = (PIL.Image.Image, int)
  492. @contextlib.contextmanager
  493. def create_dataset(
  494. self,
  495. config: Optional[Dict[str, Any]] = None,
  496. inject_fake_data: bool = True,
  497. patch_checks: Optional[bool] = None,
  498. **kwargs: Any,
  499. ) -> Iterator[Tuple[torchvision.datasets.VisionDataset, Dict[str, Any]]]:
  500. with super().create_dataset(
  501. config=config,
  502. inject_fake_data=inject_fake_data,
  503. patch_checks=patch_checks,
  504. **kwargs,
  505. ) as (dataset, info):
  506. # PIL.Image.open() only loads the image metadata upfront and keeps the file open until the first access
  507. # to the pixel data occurs. Trying to delete such a file results in an PermissionError on Windows. Thus, we
  508. # force-load opened images.
  509. # This problem only occurs during testing since some tests, e.g. DatasetTestCase.test_feature_types open an
  510. # image, but never use the underlying data. During normal operation it is reasonable to assume that the
  511. # user wants to work with the image he just opened rather than deleting the underlying file.
  512. with self._force_load_images():
  513. yield dataset, info
  514. @contextlib.contextmanager
  515. def _force_load_images(self):
  516. open = PIL.Image.open
  517. def new(fp, *args, **kwargs):
  518. image = open(fp, *args, **kwargs)
  519. if isinstance(fp, (str, pathlib.Path)):
  520. image.load()
  521. return image
  522. with unittest.mock.patch("PIL.Image.open", new=new):
  523. yield
  524. class VideoDatasetTestCase(DatasetTestCase):
  525. """Abstract base class for video dataset testcases.
  526. - Overwrites the 'FEATURE_TYPES' class attribute to expect two :class:`torch.Tensor` s for the video and audio as
  527. well as an integer label.
  528. - Overwrites the 'REQUIRED_PACKAGES' class attribute to require PyAV (``av``).
  529. - Adds the 'DEFAULT_FRAMES_PER_CLIP' class attribute. If no 'frames_per_clip' is provided by 'inject_fake_data()'
  530. and it is the last parameter without a default value in the dataset constructor, the value of the
  531. 'DEFAULT_FRAMES_PER_CLIP' class attribute is appended to the output.
  532. """
  533. FEATURE_TYPES = (torch.Tensor, torch.Tensor, int)
  534. REQUIRED_PACKAGES = ("av",)
  535. FRAMES_PER_CLIP = 1
  536. def __init__(self, *args, **kwargs):
  537. super().__init__(*args, **kwargs)
  538. self.dataset_args = self._set_default_frames_per_clip(self.dataset_args)
  539. def _set_default_frames_per_clip(self, dataset_args):
  540. argspec = inspect.getfullargspec(self.DATASET_CLASS.__init__)
  541. args_without_default = argspec.args[1 : (-len(argspec.defaults) if argspec.defaults else None)]
  542. frames_per_clip_last = args_without_default[-1] == "frames_per_clip"
  543. @functools.wraps(dataset_args)
  544. def wrapper(tmpdir, config):
  545. args = dataset_args(tmpdir, config)
  546. if frames_per_clip_last and len(args) == len(args_without_default) - 1:
  547. args = (*args, self.FRAMES_PER_CLIP)
  548. return args
  549. return wrapper
  550. def test_output_format(self):
  551. for output_format in ["TCHW", "THWC"]:
  552. with self.create_dataset(output_format=output_format) as (dataset, _):
  553. for video, *_ in dataset:
  554. if output_format == "TCHW":
  555. num_frames, num_channels, *_ = video.shape
  556. else: # output_format == "THWC":
  557. num_frames, *_, num_channels = video.shape
  558. assert num_frames == self.FRAMES_PER_CLIP
  559. assert num_channels == 3
  560. @test_all_configs
  561. def test_transforms_v2_wrapper(self, config):
  562. # `output_format == "THWC"` is not supported by the wrapper. Thus, we skip the `config` if it is set explicitly
  563. # or use the supported `"TCHW"`
  564. if config.setdefault("output_format", "TCHW") == "THWC":
  565. return
  566. super().test_transforms_v2_wrapper.__wrapped__(self, config)
  567. def _no_collate(batch):
  568. return batch
  569. def check_transforms_v2_wrapper_spawn(dataset):
  570. # On Linux and Windows, the DataLoader forks the main process by default. This is not available on macOS, so new
  571. # subprocesses are spawned. This requires the whole pipeline including the dataset to be pickleable, which is what
  572. # we are enforcing here.
  573. if platform.system() != "Darwin":
  574. pytest.skip("Multiprocessing spawning is only checked on macOS.")
  575. from torch.utils.data import DataLoader
  576. from torchvision import tv_tensors
  577. from torchvision.datasets import wrap_dataset_for_transforms_v2
  578. wrapped_dataset = wrap_dataset_for_transforms_v2(dataset)
  579. dataloader = DataLoader(wrapped_dataset, num_workers=2, multiprocessing_context="spawn", collate_fn=_no_collate)
  580. for wrapped_sample in dataloader:
  581. assert tree_any(
  582. lambda item: isinstance(item, (tv_tensors.Image, tv_tensors.Video, PIL.Image.Image)), wrapped_sample
  583. )
  584. def create_image_or_video_tensor(size: Sequence[int]) -> torch.Tensor:
  585. r"""Create a random uint8 tensor.
  586. Args:
  587. size (Sequence[int]): Size of the tensor.
  588. """
  589. return torch.randint(0, 256, size, dtype=torch.uint8)
  590. def create_image_file(
  591. root: Union[pathlib.Path, str], name: Union[pathlib.Path, str], size: Union[Sequence[int], int] = 10, **kwargs: Any
  592. ) -> pathlib.Path:
  593. """Create an image file from random data.
  594. Args:
  595. root (Union[str, pathlib.Path]): Root directory the image file will be placed in.
  596. name (Union[str, pathlib.Path]): Name of the image file.
  597. size (Union[Sequence[int], int]): Size of the image that represents the ``(num_channels, height, width)``. If
  598. scalar, the value is used for the height and width. If not provided, three channels are assumed.
  599. kwargs (Any): Additional parameters passed to :meth:`PIL.Image.Image.save`.
  600. Returns:
  601. pathlib.Path: Path to the created image file.
  602. """
  603. if isinstance(size, int):
  604. size = (size, size)
  605. if len(size) == 2:
  606. size = (3, *size)
  607. if len(size) != 3:
  608. raise UsageError(
  609. f"The 'size' argument should either be an int or a sequence of length 2 or 3. Got {len(size)} instead"
  610. )
  611. image = create_image_or_video_tensor(size)
  612. file = pathlib.Path(root) / name
  613. # torch (num_channels x height x width) -> PIL (width x height x num_channels)
  614. image = image.permute(2, 1, 0)
  615. # For grayscale images PIL doesn't use a channel dimension
  616. if image.shape[2] == 1:
  617. image = torch.squeeze(image, 2)
  618. PIL.Image.fromarray(image.numpy()).save(file, **kwargs)
  619. return file
  620. def create_image_folder(
  621. root: Union[pathlib.Path, str],
  622. name: Union[pathlib.Path, str],
  623. file_name_fn: Callable[[int], str],
  624. num_examples: int,
  625. size: Optional[Union[Sequence[int], int, Callable[[int], Union[Sequence[int], int]]]] = None,
  626. **kwargs: Any,
  627. ) -> List[pathlib.Path]:
  628. """Create a folder of random images.
  629. Args:
  630. root (Union[str, pathlib.Path]): Root directory the image folder will be placed in.
  631. name (Union[str, pathlib.Path]): Name of the image folder.
  632. file_name_fn (Callable[[int], str]): Should return a file name if called with the file index.
  633. num_examples (int): Number of images to create.
  634. size (Optional[Union[Sequence[int], int, Callable[[int], Union[Sequence[int], int]]]]): Size of the images. If
  635. callable, will be called with the index of the corresponding file. If omitted, a random height and width
  636. between 3 and 10 pixels is selected on a per-image basis.
  637. kwargs (Any): Additional parameters passed to :func:`create_image_file`.
  638. Returns:
  639. List[pathlib.Path]: Paths to all created image files.
  640. .. seealso::
  641. - :func:`create_image_file`
  642. """
  643. if size is None:
  644. def size(idx: int) -> Tuple[int, int, int]:
  645. num_channels = 3
  646. height, width = torch.randint(3, 11, size=(2,), dtype=torch.int).tolist()
  647. return (num_channels, height, width)
  648. root = pathlib.Path(root) / name
  649. os.makedirs(root, exist_ok=True)
  650. return [
  651. create_image_file(root, file_name_fn(idx), size=size(idx) if callable(size) else size, **kwargs)
  652. for idx in range(num_examples)
  653. ]
  654. def shape_test_for_stereo(
  655. left: PIL.Image.Image,
  656. right: PIL.Image.Image,
  657. disparity: Optional[np.ndarray] = None,
  658. valid_mask: Optional[np.ndarray] = None,
  659. ):
  660. left_dims = get_dimensions(left)
  661. right_dims = get_dimensions(right)
  662. c, h, w = left_dims
  663. # check that left and right are the same size
  664. assert left_dims == right_dims
  665. assert c == 3
  666. # check that the disparity has the same spatial dimensions
  667. # as the input
  668. if disparity is not None:
  669. assert disparity.ndim == 3
  670. assert disparity.shape == (1, h, w)
  671. if valid_mask is not None:
  672. # check that valid mask is the same size as the disparity
  673. _, dh, dw = disparity.shape
  674. mh, mw = valid_mask.shape
  675. assert dh == mh
  676. assert dw == mw
  677. @requires_lazy_imports("av")
  678. def create_video_file(
  679. root: Union[pathlib.Path, str],
  680. name: Union[pathlib.Path, str],
  681. size: Union[Sequence[int], int] = (1, 3, 10, 10),
  682. fps: float = 25,
  683. **kwargs: Any,
  684. ) -> pathlib.Path:
  685. """Create a video file from random data.
  686. Args:
  687. root (Union[str, pathlib.Path]): Root directory the video file will be placed in.
  688. name (Union[str, pathlib.Path]): Name of the video file.
  689. size (Union[Sequence[int], int]): Size of the video that represents the
  690. ``(num_frames, num_channels, height, width)``. If scalar, the value is used for the height and width.
  691. If not provided, ``num_frames=1`` and ``num_channels=3`` are assumed.
  692. fps (float): Frame rate in frames per second.
  693. kwargs (Any): Additional parameters passed to :func:`torchvision.io.write_video`.
  694. Returns:
  695. pathlib.Path: Path to the created image file.
  696. Raises:
  697. UsageError: If PyAV is not available.
  698. """
  699. if isinstance(size, int):
  700. size = (size, size)
  701. if len(size) == 2:
  702. size = (3, *size)
  703. if len(size) == 3:
  704. size = (1, *size)
  705. if len(size) != 4:
  706. raise UsageError(
  707. f"The 'size' argument should either be an int or a sequence of length 2, 3, or 4. Got {len(size)} instead"
  708. )
  709. video = create_image_or_video_tensor(size)
  710. file = pathlib.Path(root) / name
  711. torchvision.io.write_video(str(file), video.permute(0, 2, 3, 1), fps, **kwargs)
  712. return file
  713. @requires_lazy_imports("av")
  714. def create_video_folder(
  715. root: Union[str, pathlib.Path],
  716. name: Union[str, pathlib.Path],
  717. file_name_fn: Callable[[int], str],
  718. num_examples: int,
  719. size: Optional[Union[Sequence[int], int, Callable[[int], Union[Sequence[int], int]]]] = None,
  720. fps=25,
  721. **kwargs,
  722. ) -> List[pathlib.Path]:
  723. """Create a folder of random videos.
  724. Args:
  725. root (Union[str, pathlib.Path]): Root directory the video folder will be placed in.
  726. name (Union[str, pathlib.Path]): Name of the video folder.
  727. file_name_fn (Callable[[int], str]): Should return a file name if called with the file index.
  728. num_examples (int): Number of videos to create.
  729. size (Optional[Union[Sequence[int], int, Callable[[int], Union[Sequence[int], int]]]]): Size of the videos. If
  730. callable, will be called with the index of the corresponding file. If omitted, a random even height and
  731. width between 4 and 10 pixels is selected on a per-video basis.
  732. fps (float): Frame rate in frames per second.
  733. kwargs (Any): Additional parameters passed to :func:`create_video_file`.
  734. Returns:
  735. List[pathlib.Path]: Paths to all created video files.
  736. Raises:
  737. UsageError: If PyAV is not available.
  738. .. seealso::
  739. - :func:`create_video_file`
  740. """
  741. if size is None:
  742. def size(idx):
  743. num_frames = 1
  744. num_channels = 3
  745. # The 'libx264' video codec, which is the default of torchvision.io.write_video, requires the height and
  746. # width of the video to be divisible by 2.
  747. height, width = (torch.randint(2, 6, size=(2,), dtype=torch.int) * 2).tolist()
  748. return (num_frames, num_channels, height, width)
  749. root = pathlib.Path(root) / name
  750. os.makedirs(root, exist_ok=True)
  751. return [
  752. create_video_file(root, file_name_fn(idx), size=size(idx) if callable(size) else size, **kwargs)
  753. for idx in range(num_examples)
  754. ]
  755. def _split_files_or_dirs(root, *files_or_dirs):
  756. files = set()
  757. dirs = set()
  758. for file_or_dir in files_or_dirs:
  759. path = pathlib.Path(file_or_dir)
  760. if not path.is_absolute():
  761. path = root / path
  762. if path.is_file():
  763. files.add(path)
  764. else:
  765. dirs.add(path)
  766. for sub_file_or_dir in path.glob("**/*"):
  767. if sub_file_or_dir.is_file():
  768. files.add(sub_file_or_dir)
  769. else:
  770. dirs.add(sub_file_or_dir)
  771. if root in dirs:
  772. dirs.remove(root)
  773. return files, dirs
  774. def _make_archive(root, name, *files_or_dirs, opener, adder, remove=True):
  775. archive = pathlib.Path(root) / name
  776. if not files_or_dirs:
  777. # We need to invoke `Path.with_suffix("")`, since call only applies to the last suffix if multiple suffixes are
  778. # present. For example, `pathlib.Path("foo.tar.gz").with_suffix("")` results in `foo.tar`.
  779. file_or_dir = archive
  780. for _ in range(len(archive.suffixes)):
  781. file_or_dir = file_or_dir.with_suffix("")
  782. if file_or_dir.exists():
  783. files_or_dirs = (file_or_dir,)
  784. else:
  785. raise ValueError("No file or dir provided.")
  786. files, dirs = _split_files_or_dirs(root, *files_or_dirs)
  787. with opener(archive) as fh:
  788. for file in sorted(files):
  789. adder(fh, file, file.relative_to(root))
  790. if remove:
  791. for file in files:
  792. os.remove(file)
  793. for dir in dirs:
  794. shutil.rmtree(dir, ignore_errors=True)
  795. return archive
  796. def make_tar(root, name, *files_or_dirs, remove=True, compression=None):
  797. # TODO: detect compression from name
  798. return _make_archive(
  799. root,
  800. name,
  801. *files_or_dirs,
  802. opener=lambda archive: tarfile.open(archive, f"w:{compression}" if compression else "w"),
  803. adder=lambda fh, file, relative_file: fh.add(file, arcname=relative_file),
  804. remove=remove,
  805. )
  806. def make_zip(root, name, *files_or_dirs, remove=True):
  807. return _make_archive(
  808. root,
  809. name,
  810. *files_or_dirs,
  811. opener=lambda archive: zipfile.ZipFile(archive, "w"),
  812. adder=lambda fh, file, relative_file: fh.write(file, arcname=relative_file),
  813. remove=remove,
  814. )
  815. def create_random_string(length: int, *digits: str) -> str:
  816. """Create a random string.
  817. Args:
  818. length (int): Number of characters in the generated string.
  819. *digits (str): Characters to sample from. If omitted defaults to :attr:`string.ascii_lowercase`.
  820. """
  821. if not digits:
  822. digits = string.ascii_lowercase
  823. else:
  824. digits = "".join(itertools.chain(*digits))
  825. return "".join(random.choice(digits) for _ in range(length))
  826. def make_fake_pfm_file(h, w, file_name):
  827. values = list(range(3 * h * w))
  828. # Note: we pack everything in little endian: -1.0, and "<"
  829. content = f"PF \n{w} {h} \n-1.0\n".encode() + struct.pack("<" + "f" * len(values), *values)
  830. with open(file_name, "wb") as f:
  831. f.write(content)
  832. def make_fake_flo_file(h, w, file_name):
  833. """Creates a fake flow file in .flo format."""
  834. # Everything needs to be in little Endian according to
  835. # https://vision.middlebury.edu/flow/code/flow-code/README.txt
  836. values = list(range(2 * h * w))
  837. content = (
  838. struct.pack("<4c", *(c.encode() for c in "PIEH"))
  839. + struct.pack("<i", w)
  840. + struct.pack("<i", h)
  841. + struct.pack("<" + "f" * len(values), *values)
  842. )
  843. with open(file_name, "wb") as f:
  844. f.write(content)