combining.py 3.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
  1. from torch.utils.data.datapipes._decorator import functional_datapipe
  2. from torch.utils.data.datapipes.datapipe import MapDataPipe
  3. from typing import Sized, Tuple, TypeVar
  4. __all__ = ["ConcaterMapDataPipe", "ZipperMapDataPipe"]
  5. T_co = TypeVar('T_co', covariant=True)
  6. @functional_datapipe('concat')
  7. class ConcaterMapDataPipe(MapDataPipe):
  8. r"""
  9. Concatenate multiple Map DataPipes (functional name: ``concat``).
  10. The new index of is the cumulative sum of source DataPipes.
  11. For example, if there are 2 source DataPipes both with length 5,
  12. index 0 to 4 of the resulting `ConcatMapDataPipe` would refer to
  13. elements of the first DataPipe, and 5 to 9 would refer to elements
  14. of the second DataPipe.
  15. Args:
  16. datapipes: Map DataPipes being concatenated
  17. Example:
  18. >>> # xdoctest: +SKIP
  19. >>> from torchdata.datapipes.map import SequenceWrapper
  20. >>> dp1 = SequenceWrapper(range(3))
  21. >>> dp2 = SequenceWrapper(range(3))
  22. >>> concat_dp = dp1.concat(dp2)
  23. >>> list(concat_dp)
  24. [0, 1, 2, 0, 1, 2]
  25. """
  26. datapipes: Tuple[MapDataPipe]
  27. def __init__(self, *datapipes: MapDataPipe):
  28. if len(datapipes) == 0:
  29. raise ValueError("Expected at least one DataPipe, but got nothing")
  30. if not all(isinstance(dp, MapDataPipe) for dp in datapipes):
  31. raise TypeError("Expected all inputs to be `MapDataPipe`")
  32. if not all(isinstance(dp, Sized) for dp in datapipes):
  33. raise TypeError("Expected all inputs to be `Sized`")
  34. self.datapipes = datapipes # type: ignore[assignment]
  35. def __getitem__(self, index) -> T_co:
  36. offset = 0
  37. for dp in self.datapipes:
  38. if index - offset < len(dp):
  39. return dp[index - offset]
  40. else:
  41. offset += len(dp)
  42. raise IndexError("Index {} is out of range.".format(index))
  43. def __len__(self) -> int:
  44. return sum(len(dp) for dp in self.datapipes)
  45. @functional_datapipe('zip')
  46. class ZipperMapDataPipe(MapDataPipe[Tuple[T_co, ...]]):
  47. r"""
  48. Aggregates elements into a tuple from each of the input DataPipes (functional name: ``zip``).
  49. This MataPipe is out of bound as soon as the shortest input DataPipe is exhausted.
  50. Args:
  51. *datapipes: Map DataPipes being aggregated
  52. Example:
  53. >>> # xdoctest: +SKIP
  54. >>> from torchdata.datapipes.map import SequenceWrapper
  55. >>> dp1 = SequenceWrapper(range(3))
  56. >>> dp2 = SequenceWrapper(range(10, 13))
  57. >>> zip_dp = dp1.zip(dp2)
  58. >>> list(zip_dp)
  59. [(0, 10), (1, 11), (2, 12)]
  60. """
  61. datapipes: Tuple[MapDataPipe[T_co], ...]
  62. def __init__(self, *datapipes: MapDataPipe[T_co]) -> None:
  63. if len(datapipes) == 0:
  64. raise ValueError("Expected at least one DataPipe, but got nothing")
  65. if not all(isinstance(dp, MapDataPipe) for dp in datapipes):
  66. raise TypeError("Expected all inputs to be `MapDataPipe`")
  67. if not all(isinstance(dp, Sized) for dp in datapipes):
  68. raise TypeError("Expected all inputs to be `Sized`")
  69. self.datapipes = datapipes
  70. def __getitem__(self, index) -> Tuple[T_co, ...]:
  71. res = []
  72. for dp in self.datapipes:
  73. try:
  74. res.append(dp[index])
  75. except IndexError as e:
  76. raise IndexError(f"Index {index} is out of range for one of the input MapDataPipes {dp}.") from e
  77. return tuple(res)
  78. def __len__(self) -> int:
  79. return min(len(dp) for dp in self.datapipes)