fetch.py 1.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354
  1. r""""Contains definitions of the methods used by the _BaseDataLoaderIter to fetch
  2. data from an iterable-style or map-style dataset. This logic is shared in both
  3. single- and multi-processing data loading.
  4. """
  5. class _BaseDatasetFetcher:
  6. def __init__(self, dataset, auto_collation, collate_fn, drop_last):
  7. self.dataset = dataset
  8. self.auto_collation = auto_collation
  9. self.collate_fn = collate_fn
  10. self.drop_last = drop_last
  11. def fetch(self, possibly_batched_index):
  12. raise NotImplementedError()
  13. class _IterableDatasetFetcher(_BaseDatasetFetcher):
  14. def __init__(self, dataset, auto_collation, collate_fn, drop_last):
  15. super().__init__(dataset, auto_collation, collate_fn, drop_last)
  16. self.dataset_iter = iter(dataset)
  17. self.ended = False
  18. def fetch(self, possibly_batched_index):
  19. if self.ended:
  20. raise StopIteration
  21. if self.auto_collation:
  22. data = []
  23. for _ in possibly_batched_index:
  24. try:
  25. data.append(next(self.dataset_iter))
  26. except StopIteration:
  27. self.ended = True
  28. break
  29. if len(data) == 0 or (
  30. self.drop_last and len(data) < len(possibly_batched_index)
  31. ):
  32. raise StopIteration
  33. else:
  34. data = next(self.dataset_iter)
  35. return self.collate_fn(data)
  36. class _MapDatasetFetcher(_BaseDatasetFetcher):
  37. def fetch(self, possibly_batched_index):
  38. if self.auto_collation:
  39. if hasattr(self.dataset, "__getitems__") and self.dataset.__getitems__:
  40. data = self.dataset.__getitems__(possibly_batched_index)
  41. else:
  42. data = [self.dataset[idx] for idx in possibly_batched_index]
  43. else:
  44. data = self.dataset[possibly_batched_index]
  45. return self.collate_fn(data)