streamreader.py 1.4 KB

1234567891011121314151617181920212223242526272829303132333435363738
  1. from typing import Tuple
  2. from torch.utils.data.datapipes._decorator import functional_datapipe
  3. from torch.utils.data.datapipes.datapipe import IterDataPipe
  4. __all__ = ["StreamReaderIterDataPipe", ]
  5. @functional_datapipe('read_from_stream')
  6. class StreamReaderIterDataPipe(IterDataPipe[Tuple[str, bytes]]):
  7. r"""
  8. Given IO streams and their label names, yields bytes with label
  9. name in a tuple (functional name: ``read_from_stream``).
  10. Args:
  11. datapipe: Iterable DataPipe provides label/URL and byte stream
  12. chunk: Number of bytes to be read from stream per iteration.
  13. If ``None``, all bytes will be read until the EOF.
  14. Example:
  15. >>> # xdoctest: +SKIP
  16. >>> from torchdata.datapipes.iter import IterableWrapper, StreamReader
  17. >>> from io import StringIO
  18. >>> dp = IterableWrapper([("alphabet", StringIO("abcde"))])
  19. >>> list(StreamReader(dp, chunk=1))
  20. [('alphabet', 'a'), ('alphabet', 'b'), ('alphabet', 'c'), ('alphabet', 'd'), ('alphabet', 'e')]
  21. """
  22. def __init__(self, datapipe, chunk=None):
  23. self.datapipe = datapipe
  24. self.chunk = chunk
  25. def __iter__(self):
  26. for furl, stream in self.datapipe:
  27. while True:
  28. d = stream.read(self.chunk)
  29. if not d:
  30. stream.close()
  31. break
  32. yield (furl, d)