123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172 |
- from io import IOBase
- from typing import Iterable, Tuple, Optional
- from torch.utils.data.datapipes._decorator import functional_datapipe
- from torch.utils.data.datapipes.datapipe import IterDataPipe
- from torch.utils.data.datapipes.utils.common import get_file_binaries_from_pathnames
- __all__ = [
- "FileOpenerIterDataPipe",
- ]
- @functional_datapipe("open_files")
- class FileOpenerIterDataPipe(IterDataPipe[Tuple[str, IOBase]]):
- r"""
- Given pathnames, opens files and yield pathname and file stream
- in a tuple (functional name: ``open_files``).
- Args:
- datapipe: Iterable datapipe that provides pathnames
- mode: An optional string that specifies the mode in which
- the file is opened by ``open()``. It defaults to ``r``, other options are
- ``b`` for reading in binary mode and ``t`` for text mode.
- encoding: An optional string that specifies the encoding of the
- underlying file. It defaults to ``None`` to match the default encoding of ``open``.
- length: Nominal length of the datapipe
- Note:
- The opened file handles will be closed by Python's GC periodically. Users can choose
- to close them explicitly.
- Example:
- >>> # xdoctest: +SKIP
- >>> from torchdata.datapipes.iter import FileLister, FileOpener, StreamReader
- >>> dp = FileLister(root=".").filter(lambda fname: fname.endswith('.txt'))
- >>> dp = FileOpener(dp)
- >>> dp = StreamReader(dp)
- >>> list(dp)
- [('./abc.txt', 'abc')]
- """
- def __init__(
- self,
- datapipe: Iterable[str],
- mode: str = 'r',
- encoding: Optional[str] = None,
- length: int = -1):
- super().__init__()
- self.datapipe: Iterable = datapipe
- self.mode: str = mode
- self.encoding: Optional[str] = encoding
- if self.mode not in ('b', 't', 'rb', 'rt', 'r'):
- raise ValueError("Invalid mode {}".format(mode))
- # TODO: enforce typing for each instance based on mode, otherwise
- # `argument_validation` with this DataPipe may be potentially broken
- if 'b' in mode and encoding is not None:
- raise ValueError("binary mode doesn't take an encoding argument")
- self.length: int = length
- # Remove annotation due to 'IOBase' is a general type and true type
- # is determined at runtime based on mode. Some `DataPipe` requiring
- # a subtype would cause mypy error.
- def __iter__(self):
- yield from get_file_binaries_from_pathnames(self.datapipe, self.mode, self.encoding)
- def __len__(self):
- if self.length == -1:
- raise TypeError("{} instance doesn't have valid length".format(type(self).__name__))
- return self.length
|