utils.py 1.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849
  1. import copy
  2. import warnings
  3. from torch.utils.data.datapipes.datapipe import MapDataPipe
  4. __all__ = ["SequenceWrapperMapDataPipe", ]
  5. class SequenceWrapperMapDataPipe(MapDataPipe):
  6. r"""
  7. Wraps a sequence object into a MapDataPipe.
  8. Args:
  9. sequence: Sequence object to be wrapped into an MapDataPipe
  10. deepcopy: Option to deepcopy input sequence object
  11. .. note::
  12. If ``deepcopy`` is set to False explicitly, users should ensure
  13. that data pipeline doesn't contain any in-place operations over
  14. the iterable instance, in order to prevent data inconsistency
  15. across iterations.
  16. Example:
  17. >>> # xdoctest: +SKIP
  18. >>> from torchdata.datapipes.map import SequenceWrapper
  19. >>> dp = SequenceWrapper(range(10))
  20. >>> list(dp)
  21. [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
  22. >>> dp = SequenceWrapper({'a': 100, 'b': 200, 'c': 300, 'd': 400})
  23. >>> dp['a']
  24. 100
  25. """
  26. def __init__(self, sequence, deepcopy=True):
  27. if deepcopy:
  28. try:
  29. self.sequence = copy.deepcopy(sequence)
  30. except TypeError:
  31. warnings.warn(
  32. "The input sequence can not be deepcopied, "
  33. "please be aware of in-place modification would affect source data"
  34. )
  35. self.sequence = sequence
  36. else:
  37. self.sequence = sequence
  38. def __getitem__(self, index):
  39. return self.sequence[index]
  40. def __len__(self):
  41. return len(self.sequence)