fileopener.py 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172
  1. from io import IOBase
  2. from typing import Iterable, Tuple, Optional
  3. from torch.utils.data.datapipes._decorator import functional_datapipe
  4. from torch.utils.data.datapipes.datapipe import IterDataPipe
  5. from torch.utils.data.datapipes.utils.common import get_file_binaries_from_pathnames
  6. __all__ = [
  7. "FileOpenerIterDataPipe",
  8. ]
  9. @functional_datapipe("open_files")
  10. class FileOpenerIterDataPipe(IterDataPipe[Tuple[str, IOBase]]):
  11. r"""
  12. Given pathnames, opens files and yield pathname and file stream
  13. in a tuple (functional name: ``open_files``).
  14. Args:
  15. datapipe: Iterable datapipe that provides pathnames
  16. mode: An optional string that specifies the mode in which
  17. the file is opened by ``open()``. It defaults to ``r``, other options are
  18. ``b`` for reading in binary mode and ``t`` for text mode.
  19. encoding: An optional string that specifies the encoding of the
  20. underlying file. It defaults to ``None`` to match the default encoding of ``open``.
  21. length: Nominal length of the datapipe
  22. Note:
  23. The opened file handles will be closed by Python's GC periodically. Users can choose
  24. to close them explicitly.
  25. Example:
  26. >>> # xdoctest: +SKIP
  27. >>> from torchdata.datapipes.iter import FileLister, FileOpener, StreamReader
  28. >>> dp = FileLister(root=".").filter(lambda fname: fname.endswith('.txt'))
  29. >>> dp = FileOpener(dp)
  30. >>> dp = StreamReader(dp)
  31. >>> list(dp)
  32. [('./abc.txt', 'abc')]
  33. """
  34. def __init__(
  35. self,
  36. datapipe: Iterable[str],
  37. mode: str = 'r',
  38. encoding: Optional[str] = None,
  39. length: int = -1):
  40. super().__init__()
  41. self.datapipe: Iterable = datapipe
  42. self.mode: str = mode
  43. self.encoding: Optional[str] = encoding
  44. if self.mode not in ('b', 't', 'rb', 'rt', 'r'):
  45. raise ValueError("Invalid mode {}".format(mode))
  46. # TODO: enforce typing for each instance based on mode, otherwise
  47. # `argument_validation` with this DataPipe may be potentially broken
  48. if 'b' in mode and encoding is not None:
  49. raise ValueError("binary mode doesn't take an encoding argument")
  50. self.length: int = length
  51. # Remove annotation due to 'IOBase' is a general type and true type
  52. # is determined at runtime based on mode. Some `DataPipe` requiring
  53. # a subtype would cause mypy error.
  54. def __iter__(self):
  55. yield from get_file_binaries_from_pathnames(self.datapipe, self.mode, self.encoding)
  56. def __len__(self):
  57. if self.length == -1:
  58. raise TypeError("{} instance doesn't have valid length".format(type(self).__name__))
  59. return self.length