utils.py 1.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950
  1. import copy
  2. import warnings
  3. from torch.utils.data.datapipes.datapipe import IterDataPipe
  4. __all__ = ["IterableWrapperIterDataPipe", ]
  5. class IterableWrapperIterDataPipe(IterDataPipe):
  6. r"""
  7. Wraps an iterable object to create an IterDataPipe.
  8. Args:
  9. iterable: Iterable object to be wrapped into an IterDataPipe
  10. deepcopy: Option to deepcopy input iterable object for each
  11. iterator. The copy is made when the first element is read in ``iter()``.
  12. .. note::
  13. If ``deepcopy`` is explicitly set to ``False``, users should ensure
  14. that the data pipeline doesn't contain any in-place operations over
  15. the iterable instance to prevent data inconsistency across iterations.
  16. Example:
  17. >>> # xdoctest: +SKIP
  18. >>> from torchdata.datapipes.iter import IterableWrapper
  19. >>> dp = IterableWrapper(range(10))
  20. >>> list(dp)
  21. [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
  22. """
  23. def __init__(self, iterable, deepcopy=True):
  24. self.iterable = iterable
  25. self.deepcopy = deepcopy
  26. def __iter__(self):
  27. source_data = self.iterable
  28. if self.deepcopy:
  29. try:
  30. source_data = copy.deepcopy(self.iterable)
  31. # For the case that data cannot be deep-copied,
  32. # all in-place operations will affect iterable variable.
  33. # When this DataPipe is iterated second time, it will
  34. # yield modified items.
  35. except TypeError:
  36. warnings.warn(
  37. "The input iterable can not be deepcopied, "
  38. "please be aware of in-place modification would affect source data."
  39. )
  40. yield from source_data
  41. def __len__(self):
  42. return len(self.iterable)