dataset.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356
  1. import bisect
  2. import warnings
  3. import math
  4. from typing import (
  5. Generic,
  6. Iterable,
  7. Iterator,
  8. List,
  9. Optional,
  10. Sequence,
  11. Tuple,
  12. TypeVar,
  13. Union
  14. )
  15. # No 'default_generator' in torch/__init__.pyi
  16. from torch import default_generator, randperm
  17. from torch._utils import _accumulate
  18. from ... import Generator, Tensor
  19. __all__ = [
  20. "Dataset",
  21. "IterableDataset",
  22. "TensorDataset",
  23. "ConcatDataset",
  24. "ChainDataset",
  25. "Subset",
  26. "random_split",
  27. ]
  28. T_co = TypeVar('T_co', covariant=True)
  29. T = TypeVar('T')
  30. class Dataset(Generic[T_co]):
  31. r"""An abstract class representing a :class:`Dataset`.
  32. All datasets that represent a map from keys to data samples should subclass
  33. it. All subclasses should overwrite :meth:`__getitem__`, supporting fetching a
  34. data sample for a given key. Subclasses could also optionally overwrite
  35. :meth:`__len__`, which is expected to return the size of the dataset by many
  36. :class:`~torch.utils.data.Sampler` implementations and the default options
  37. of :class:`~torch.utils.data.DataLoader`.
  38. .. note::
  39. :class:`~torch.utils.data.DataLoader` by default constructs a index
  40. sampler that yields integral indices. To make it work with a map-style
  41. dataset with non-integral indices/keys, a custom sampler must be provided.
  42. """
  43. def __getitem__(self, index) -> T_co:
  44. raise NotImplementedError
  45. def __add__(self, other: 'Dataset[T_co]') -> 'ConcatDataset[T_co]':
  46. return ConcatDataset([self, other])
  47. # No `def __len__(self)` default?
  48. # See NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
  49. # in pytorch/torch/utils/data/sampler.py
  50. class IterableDataset(Dataset[T_co]):
  51. r"""An iterable Dataset.
  52. All datasets that represent an iterable of data samples should subclass it.
  53. Such form of datasets is particularly useful when data come from a stream.
  54. All subclasses should overwrite :meth:`__iter__`, which would return an
  55. iterator of samples in this dataset.
  56. When a subclass is used with :class:`~torch.utils.data.DataLoader`, each
  57. item in the dataset will be yielded from the :class:`~torch.utils.data.DataLoader`
  58. iterator. When :attr:`num_workers > 0`, each worker process will have a
  59. different copy of the dataset object, so it is often desired to configure
  60. each copy independently to avoid having duplicate data returned from the
  61. workers. :func:`~torch.utils.data.get_worker_info`, when called in a worker
  62. process, returns information about the worker. It can be used in either the
  63. dataset's :meth:`__iter__` method or the :class:`~torch.utils.data.DataLoader` 's
  64. :attr:`worker_init_fn` option to modify each copy's behavior.
  65. Example 1: splitting workload across all workers in :meth:`__iter__`::
  66. >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_DATALOADER)
  67. >>> # xdoctest: +SKIP("Fails on MacOS12")
  68. >>> class MyIterableDataset(torch.utils.data.IterableDataset):
  69. ... def __init__(self, start, end):
  70. ... super(MyIterableDataset).__init__()
  71. ... assert end > start, "this example code only works with end >= start"
  72. ... self.start = start
  73. ... self.end = end
  74. ...
  75. ... def __iter__(self):
  76. ... worker_info = torch.utils.data.get_worker_info()
  77. ... if worker_info is None: # single-process data loading, return the full iterator
  78. ... iter_start = self.start
  79. ... iter_end = self.end
  80. ... else: # in a worker process
  81. ... # split workload
  82. ... per_worker = int(math.ceil((self.end - self.start) / float(worker_info.num_workers)))
  83. ... worker_id = worker_info.id
  84. ... iter_start = self.start + worker_id * per_worker
  85. ... iter_end = min(iter_start + per_worker, self.end)
  86. ... return iter(range(iter_start, iter_end))
  87. ...
  88. >>> # should give same set of data as range(3, 7), i.e., [3, 4, 5, 6].
  89. >>> ds = MyIterableDataset(start=3, end=7)
  90. >>> # Single-process loading
  91. >>> print(list(torch.utils.data.DataLoader(ds, num_workers=0)))
  92. [tensor([3]), tensor([4]), tensor([5]), tensor([6])]
  93. >>> # xdoctest: +REQUIRES(POSIX)
  94. >>> # Mult-process loading with two worker processes
  95. >>> # Worker 0 fetched [3, 4]. Worker 1 fetched [5, 6].
  96. >>> # xdoctest: +IGNORE_WANT("non deterministic")
  97. >>> print(list(torch.utils.data.DataLoader(ds, num_workers=2)))
  98. [tensor([3]), tensor([5]), tensor([4]), tensor([6])]
  99. >>> # With even more workers
  100. >>> # xdoctest: +IGNORE_WANT("non deterministic")
  101. >>> print(list(torch.utils.data.DataLoader(ds, num_workers=12)))
  102. [tensor([3]), tensor([5]), tensor([4]), tensor([6])]
  103. Example 2: splitting workload across all workers using :attr:`worker_init_fn`::
  104. >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_DATALOADER)
  105. >>> class MyIterableDataset(torch.utils.data.IterableDataset):
  106. ... def __init__(self, start, end):
  107. ... super(MyIterableDataset).__init__()
  108. ... assert end > start, "this example code only works with end >= start"
  109. ... self.start = start
  110. ... self.end = end
  111. ...
  112. ... def __iter__(self):
  113. ... return iter(range(self.start, self.end))
  114. ...
  115. >>> # should give same set of data as range(3, 7), i.e., [3, 4, 5, 6].
  116. >>> ds = MyIterableDataset(start=3, end=7)
  117. >>> # Single-process loading
  118. >>> print(list(torch.utils.data.DataLoader(ds, num_workers=0)))
  119. [3, 4, 5, 6]
  120. >>>
  121. >>> # Directly doing multi-process loading yields duplicate data
  122. >>> print(list(torch.utils.data.DataLoader(ds, num_workers=2)))
  123. [3, 3, 4, 4, 5, 5, 6, 6]
  124. >>> # Define a `worker_init_fn` that configures each dataset copy differently
  125. >>> def worker_init_fn(worker_id):
  126. ... worker_info = torch.utils.data.get_worker_info()
  127. ... dataset = worker_info.dataset # the dataset copy in this worker process
  128. ... overall_start = dataset.start
  129. ... overall_end = dataset.end
  130. ... # configure the dataset to only process the split workload
  131. ... per_worker = int(math.ceil((overall_end - overall_start) / float(worker_info.num_workers)))
  132. ... worker_id = worker_info.id
  133. ... dataset.start = overall_start + worker_id * per_worker
  134. ... dataset.end = min(dataset.start + per_worker, overall_end)
  135. ...
  136. >>> # Mult-process loading with the custom `worker_init_fn`
  137. >>> # Worker 0 fetched [3, 4]. Worker 1 fetched [5, 6].
  138. >>> print(list(torch.utils.data.DataLoader(ds, num_workers=2, worker_init_fn=worker_init_fn)))
  139. [3, 5, 4, 6]
  140. >>> # With even more workers
  141. >>> print(list(torch.utils.data.DataLoader(ds, num_workers=12, worker_init_fn=worker_init_fn)))
  142. [3, 4, 5, 6]
  143. """
  144. def __iter__(self) -> Iterator[T_co]:
  145. raise NotImplementedError
  146. def __add__(self, other: Dataset[T_co]):
  147. return ChainDataset([self, other])
  148. # No `def __len__(self)` default? Subclasses raise `TypeError` when needed.
  149. # See NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
  150. class TensorDataset(Dataset[Tuple[Tensor, ...]]):
  151. r"""Dataset wrapping tensors.
  152. Each sample will be retrieved by indexing tensors along the first dimension.
  153. Args:
  154. *tensors (Tensor): tensors that have the same size of the first dimension.
  155. """
  156. tensors: Tuple[Tensor, ...]
  157. def __init__(self, *tensors: Tensor) -> None:
  158. assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors), "Size mismatch between tensors"
  159. self.tensors = tensors
  160. def __getitem__(self, index):
  161. return tuple(tensor[index] for tensor in self.tensors)
  162. def __len__(self):
  163. return self.tensors[0].size(0)
  164. class ConcatDataset(Dataset[T_co]):
  165. r"""Dataset as a concatenation of multiple datasets.
  166. This class is useful to assemble different existing datasets.
  167. Args:
  168. datasets (sequence): List of datasets to be concatenated
  169. """
  170. datasets: List[Dataset[T_co]]
  171. cumulative_sizes: List[int]
  172. @staticmethod
  173. def cumsum(sequence):
  174. r, s = [], 0
  175. for e in sequence:
  176. l = len(e)
  177. r.append(l + s)
  178. s += l
  179. return r
  180. def __init__(self, datasets: Iterable[Dataset]) -> None:
  181. super().__init__()
  182. self.datasets = list(datasets)
  183. assert len(self.datasets) > 0, 'datasets should not be an empty iterable' # type: ignore[arg-type]
  184. for d in self.datasets:
  185. assert not isinstance(d, IterableDataset), "ConcatDataset does not support IterableDataset"
  186. self.cumulative_sizes = self.cumsum(self.datasets)
  187. def __len__(self):
  188. return self.cumulative_sizes[-1]
  189. def __getitem__(self, idx):
  190. if idx < 0:
  191. if -idx > len(self):
  192. raise ValueError("absolute value of index should not exceed dataset length")
  193. idx = len(self) + idx
  194. dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
  195. if dataset_idx == 0:
  196. sample_idx = idx
  197. else:
  198. sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
  199. return self.datasets[dataset_idx][sample_idx]
  200. @property
  201. def cummulative_sizes(self):
  202. warnings.warn("cummulative_sizes attribute is renamed to "
  203. "cumulative_sizes", DeprecationWarning, stacklevel=2)
  204. return self.cumulative_sizes
  205. class ChainDataset(IterableDataset):
  206. r"""Dataset for chaining multiple :class:`IterableDataset` s.
  207. This class is useful to assemble different existing dataset streams. The
  208. chaining operation is done on-the-fly, so concatenating large-scale
  209. datasets with this class will be efficient.
  210. Args:
  211. datasets (iterable of IterableDataset): datasets to be chained together
  212. """
  213. def __init__(self, datasets: Iterable[Dataset]) -> None:
  214. super().__init__()
  215. self.datasets = datasets
  216. def __iter__(self):
  217. for d in self.datasets:
  218. assert isinstance(d, IterableDataset), "ChainDataset only supports IterableDataset"
  219. for x in d:
  220. yield x
  221. def __len__(self):
  222. total = 0
  223. for d in self.datasets:
  224. assert isinstance(d, IterableDataset), "ChainDataset only supports IterableDataset"
  225. total += len(d) # type: ignore[arg-type]
  226. return total
  227. class Subset(Dataset[T_co]):
  228. r"""
  229. Subset of a dataset at specified indices.
  230. Args:
  231. dataset (Dataset): The whole Dataset
  232. indices (sequence): Indices in the whole set selected for subset
  233. """
  234. dataset: Dataset[T_co]
  235. indices: Sequence[int]
  236. def __init__(self, dataset: Dataset[T_co], indices: Sequence[int]) -> None:
  237. self.dataset = dataset
  238. self.indices = indices
  239. def __getitem__(self, idx):
  240. if isinstance(idx, list):
  241. return self.dataset[[self.indices[i] for i in idx]]
  242. return self.dataset[self.indices[idx]]
  243. def __len__(self):
  244. return len(self.indices)
  245. def random_split(dataset: Dataset[T], lengths: Sequence[Union[int, float]],
  246. generator: Optional[Generator] = default_generator) -> List[Subset[T]]:
  247. r"""
  248. Randomly split a dataset into non-overlapping new datasets of given lengths.
  249. If a list of fractions that sum up to 1 is given,
  250. the lengths will be computed automatically as
  251. floor(frac * len(dataset)) for each fraction provided.
  252. After computing the lengths, if there are any remainders, 1 count will be
  253. distributed in round-robin fashion to the lengths
  254. until there are no remainders left.
  255. Optionally fix the generator for reproducible results, e.g.:
  256. Example:
  257. >>> # xdoctest: +SKIP
  258. >>> generator1 = torch.Generator().manual_seed(42)
  259. >>> generator2 = torch.Generator().manual_seed(42)
  260. >>> random_split(range(10), [3, 7], generator=generator1)
  261. >>> random_split(range(30), [0.3, 0.3, 0.4], generator=generator2)
  262. Args:
  263. dataset (Dataset): Dataset to be split
  264. lengths (sequence): lengths or fractions of splits to be produced
  265. generator (Generator): Generator used for the random permutation.
  266. """
  267. if math.isclose(sum(lengths), 1) and sum(lengths) <= 1:
  268. subset_lengths: List[int] = []
  269. for i, frac in enumerate(lengths):
  270. if frac < 0 or frac > 1:
  271. raise ValueError(f"Fraction at index {i} is not between 0 and 1")
  272. n_items_in_split = int(
  273. math.floor(len(dataset) * frac) # type: ignore[arg-type]
  274. )
  275. subset_lengths.append(n_items_in_split)
  276. remainder = len(dataset) - sum(subset_lengths) # type: ignore[arg-type]
  277. # add 1 to all the lengths in round-robin fashion until the remainder is 0
  278. for i in range(remainder):
  279. idx_to_add_at = i % len(subset_lengths)
  280. subset_lengths[idx_to_add_at] += 1
  281. lengths = subset_lengths
  282. for i, length in enumerate(lengths):
  283. if length == 0:
  284. warnings.warn(f"Length of split at index {i} is 0. "
  285. f"This might result in an empty dataset.")
  286. # Cannot verify that dataset is Sized
  287. if sum(lengths) != len(dataset): # type: ignore[arg-type]
  288. raise ValueError("Sum of input lengths does not equal the length of the input dataset!")
  289. indices = randperm(sum(lengths), generator=generator).tolist() # type: ignore[call-overload]
  290. return [Subset(dataset, indices[offset - length : offset]) for offset, length in zip(_accumulate(lengths), lengths)]