123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272 |
- import torch
- from torch import Tensor
- from typing import Iterator, Iterable, Optional, Sequence, List, TypeVar, Generic, Sized, Union
- __all__ = [
- "BatchSampler",
- "RandomSampler",
- "Sampler",
- "SequentialSampler",
- "SubsetRandomSampler",
- "WeightedRandomSampler",
- ]
- T_co = TypeVar('T_co', covariant=True)
- class Sampler(Generic[T_co]):
- r"""Base class for all Samplers.
- Every Sampler subclass has to provide an :meth:`__iter__` method, providing a
- way to iterate over indices of dataset elements, and a :meth:`__len__` method
- that returns the length of the returned iterators.
- .. note:: The :meth:`__len__` method isn't strictly required by
- :class:`~torch.utils.data.DataLoader`, but is expected in any
- calculation involving the length of a :class:`~torch.utils.data.DataLoader`.
- """
- def __init__(self, data_source: Optional[Sized]) -> None:
- pass
- def __iter__(self) -> Iterator[T_co]:
- raise NotImplementedError
- # NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
- #
- # Many times we have an abstract class representing a collection/iterable of
- # data, e.g., `torch.utils.data.Sampler`, with its subclasses optionally
- # implementing a `__len__` method. In such cases, we must make sure to not
- # provide a default implementation, because both straightforward default
- # implementations have their issues:
- #
- # + `return NotImplemented`:
- # Calling `len(subclass_instance)` raises:
- # TypeError: 'NotImplementedType' object cannot be interpreted as an integer
- #
- # + `raise NotImplementedError()`:
- # This prevents triggering some fallback behavior. E.g., the built-in
- # `list(X)` tries to call `len(X)` first, and executes a different code
- # path if the method is not found or `NotImplemented` is returned, while
- # raising an `NotImplementedError` will propagate and and make the call
- # fail where it could have use `__iter__` to complete the call.
- #
- # Thus, the only two sensible things to do are
- #
- # + **not** provide a default `__len__`.
- #
- # + raise a `TypeError` instead, which is what Python uses when users call
- # a method that is not defined on an object.
- # (@ssnl verifies that this works on at least Python 3.7.)
- class SequentialSampler(Sampler[int]):
- r"""Samples elements sequentially, always in the same order.
- Args:
- data_source (Dataset): dataset to sample from
- """
- data_source: Sized
- def __init__(self, data_source: Sized) -> None:
- self.data_source = data_source
- def __iter__(self) -> Iterator[int]:
- return iter(range(len(self.data_source)))
- def __len__(self) -> int:
- return len(self.data_source)
- class RandomSampler(Sampler[int]):
- r"""Samples elements randomly. If without replacement, then sample from a shuffled dataset.
- If with replacement, then user can specify :attr:`num_samples` to draw.
- Args:
- data_source (Dataset): dataset to sample from
- replacement (bool): samples are drawn on-demand with replacement if ``True``, default=``False``
- num_samples (int): number of samples to draw, default=`len(dataset)`.
- generator (Generator): Generator used in sampling.
- """
- data_source: Sized
- replacement: bool
- def __init__(self, data_source: Sized, replacement: bool = False,
- num_samples: Optional[int] = None, generator=None) -> None:
- self.data_source = data_source
- self.replacement = replacement
- self._num_samples = num_samples
- self.generator = generator
- if not isinstance(self.replacement, bool):
- raise TypeError("replacement should be a boolean value, but got "
- "replacement={}".format(self.replacement))
- if not isinstance(self.num_samples, int) or self.num_samples <= 0:
- raise ValueError("num_samples should be a positive integer "
- "value, but got num_samples={}".format(self.num_samples))
- @property
- def num_samples(self) -> int:
- # dataset size might change at runtime
- if self._num_samples is None:
- return len(self.data_source)
- return self._num_samples
- def __iter__(self) -> Iterator[int]:
- n = len(self.data_source)
- if self.generator is None:
- seed = int(torch.empty((), dtype=torch.int64).random_().item())
- generator = torch.Generator()
- generator.manual_seed(seed)
- else:
- generator = self.generator
- if self.replacement:
- for _ in range(self.num_samples // 32):
- yield from torch.randint(high=n, size=(32,), dtype=torch.int64, generator=generator).tolist()
- yield from torch.randint(high=n, size=(self.num_samples % 32,), dtype=torch.int64, generator=generator).tolist()
- else:
- for _ in range(self.num_samples // n):
- yield from torch.randperm(n, generator=generator).tolist()
- yield from torch.randperm(n, generator=generator).tolist()[:self.num_samples % n]
- def __len__(self) -> int:
- return self.num_samples
- class SubsetRandomSampler(Sampler[int]):
- r"""Samples elements randomly from a given list of indices, without replacement.
- Args:
- indices (sequence): a sequence of indices
- generator (Generator): Generator used in sampling.
- """
- indices: Sequence[int]
- def __init__(self, indices: Sequence[int], generator=None) -> None:
- self.indices = indices
- self.generator = generator
- def __iter__(self) -> Iterator[int]:
- for i in torch.randperm(len(self.indices), generator=self.generator):
- yield self.indices[i]
- def __len__(self) -> int:
- return len(self.indices)
- class WeightedRandomSampler(Sampler[int]):
- r"""Samples elements from ``[0,..,len(weights)-1]`` with given probabilities (weights).
- Args:
- weights (sequence) : a sequence of weights, not necessary summing up to one
- num_samples (int): number of samples to draw
- replacement (bool): if ``True``, samples are drawn with replacement.
- If not, they are drawn without replacement, which means that when a
- sample index is drawn for a row, it cannot be drawn again for that row.
- generator (Generator): Generator used in sampling.
- Example:
- >>> # xdoctest: +IGNORE_WANT("non-deterministic")
- >>> list(WeightedRandomSampler([0.1, 0.9, 0.4, 0.7, 3.0, 0.6], 5, replacement=True))
- [4, 4, 1, 4, 5]
- >>> list(WeightedRandomSampler([0.9, 0.4, 0.05, 0.2, 0.3, 0.1], 5, replacement=False))
- [0, 1, 4, 3, 2]
- """
- weights: Tensor
- num_samples: int
- replacement: bool
- def __init__(self, weights: Sequence[float], num_samples: int,
- replacement: bool = True, generator=None) -> None:
- if not isinstance(num_samples, int) or isinstance(num_samples, bool) or \
- num_samples <= 0:
- raise ValueError("num_samples should be a positive integer "
- "value, but got num_samples={}".format(num_samples))
- if not isinstance(replacement, bool):
- raise ValueError("replacement should be a boolean value, but got "
- "replacement={}".format(replacement))
- weights_tensor = torch.as_tensor(weights, dtype=torch.double)
- if len(weights_tensor.shape) != 1:
- raise ValueError("weights should be a 1d sequence but given "
- "weights have shape {}".format(tuple(weights_tensor.shape)))
- self.weights = weights_tensor
- self.num_samples = num_samples
- self.replacement = replacement
- self.generator = generator
- def __iter__(self) -> Iterator[int]:
- rand_tensor = torch.multinomial(self.weights, self.num_samples, self.replacement, generator=self.generator)
- yield from iter(rand_tensor.tolist())
- def __len__(self) -> int:
- return self.num_samples
- class BatchSampler(Sampler[List[int]]):
- r"""Wraps another sampler to yield a mini-batch of indices.
- Args:
- sampler (Sampler or Iterable): Base sampler. Can be any iterable object
- batch_size (int): Size of mini-batch.
- drop_last (bool): If ``True``, the sampler will drop the last batch if
- its size would be less than ``batch_size``
- Example:
- >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False))
- [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
- >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True))
- [[0, 1, 2], [3, 4, 5], [6, 7, 8]]
- """
- def __init__(self, sampler: Union[Sampler[int], Iterable[int]], batch_size: int, drop_last: bool) -> None:
- # Since collections.abc.Iterable does not check for `__getitem__`, which
- # is one way for an object to be an iterable, we don't do an `isinstance`
- # check here.
- if not isinstance(batch_size, int) or isinstance(batch_size, bool) or \
- batch_size <= 0:
- raise ValueError("batch_size should be a positive integer value, "
- "but got batch_size={}".format(batch_size))
- if not isinstance(drop_last, bool):
- raise ValueError("drop_last should be a boolean value, but got "
- "drop_last={}".format(drop_last))
- self.sampler = sampler
- self.batch_size = batch_size
- self.drop_last = drop_last
- def __iter__(self) -> Iterator[List[int]]:
- # Implemented based on the benchmarking in https://github.com/pytorch/pytorch/pull/76951
- if self.drop_last:
- sampler_iter = iter(self.sampler)
- while True:
- try:
- batch = [next(sampler_iter) for _ in range(self.batch_size)]
- yield batch
- except StopIteration:
- break
- else:
- batch = [0] * self.batch_size
- idx_in_batch = 0
- for idx in self.sampler:
- batch[idx_in_batch] = idx
- idx_in_batch += 1
- if idx_in_batch == self.batch_size:
- yield batch
- idx_in_batch = 0
- batch = [0] * self.batch_size
- if idx_in_batch > 0:
- yield batch[:idx_in_batch]
- def __len__(self) -> int:
- # Can only be called if self.sampler has __len__ implemented
- # We cannot enforce this condition, so we turn off typechecking for the
- # implementation below.
- # Somewhat related: see NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
- if self.drop_last:
- return len(self.sampler) // self.batch_size # type: ignore[arg-type]
- else:
- return (len(self.sampler) + self.batch_size - 1) // self.batch_size # type: ignore[arg-type]
|