1234567891011121314151617181920212223242526272829303132333435363738 |
- from typing import Tuple
- from torch.utils.data.datapipes._decorator import functional_datapipe
- from torch.utils.data.datapipes.datapipe import IterDataPipe
- __all__ = ["StreamReaderIterDataPipe", ]
- @functional_datapipe('read_from_stream')
- class StreamReaderIterDataPipe(IterDataPipe[Tuple[str, bytes]]):
- r"""
- Given IO streams and their label names, yields bytes with label
- name in a tuple (functional name: ``read_from_stream``).
- Args:
- datapipe: Iterable DataPipe provides label/URL and byte stream
- chunk: Number of bytes to be read from stream per iteration.
- If ``None``, all bytes will be read until the EOF.
- Example:
- >>> # xdoctest: +SKIP
- >>> from torchdata.datapipes.iter import IterableWrapper, StreamReader
- >>> from io import StringIO
- >>> dp = IterableWrapper([("alphabet", StringIO("abcde"))])
- >>> list(StreamReader(dp, chunk=1))
- [('alphabet', 'a'), ('alphabet', 'b'), ('alphabet', 'c'), ('alphabet', 'd'), ('alphabet', 'e')]
- """
- def __init__(self, datapipe, chunk=None):
- self.datapipe = datapipe
- self.chunk = chunk
- def __iter__(self):
- for furl, stream in self.datapipe:
- while True:
- d = stream.read(self.chunk)
- if not d:
- stream.close()
- break
- yield (furl, d)
|