1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495 |
- from torch.utils.data.datapipes._decorator import functional_datapipe
- from torch.utils.data.datapipes.datapipe import MapDataPipe
- from typing import Sized, Tuple, TypeVar
- __all__ = ["ConcaterMapDataPipe", "ZipperMapDataPipe"]
- T_co = TypeVar('T_co', covariant=True)
- @functional_datapipe('concat')
- class ConcaterMapDataPipe(MapDataPipe):
- r"""
- Concatenate multiple Map DataPipes (functional name: ``concat``).
- The new index of is the cumulative sum of source DataPipes.
- For example, if there are 2 source DataPipes both with length 5,
- index 0 to 4 of the resulting `ConcatMapDataPipe` would refer to
- elements of the first DataPipe, and 5 to 9 would refer to elements
- of the second DataPipe.
- Args:
- datapipes: Map DataPipes being concatenated
- Example:
- >>> # xdoctest: +SKIP
- >>> from torchdata.datapipes.map import SequenceWrapper
- >>> dp1 = SequenceWrapper(range(3))
- >>> dp2 = SequenceWrapper(range(3))
- >>> concat_dp = dp1.concat(dp2)
- >>> list(concat_dp)
- [0, 1, 2, 0, 1, 2]
- """
- datapipes: Tuple[MapDataPipe]
- def __init__(self, *datapipes: MapDataPipe):
- if len(datapipes) == 0:
- raise ValueError("Expected at least one DataPipe, but got nothing")
- if not all(isinstance(dp, MapDataPipe) for dp in datapipes):
- raise TypeError("Expected all inputs to be `MapDataPipe`")
- if not all(isinstance(dp, Sized) for dp in datapipes):
- raise TypeError("Expected all inputs to be `Sized`")
- self.datapipes = datapipes # type: ignore[assignment]
- def __getitem__(self, index) -> T_co:
- offset = 0
- for dp in self.datapipes:
- if index - offset < len(dp):
- return dp[index - offset]
- else:
- offset += len(dp)
- raise IndexError("Index {} is out of range.".format(index))
- def __len__(self) -> int:
- return sum(len(dp) for dp in self.datapipes)
- @functional_datapipe('zip')
- class ZipperMapDataPipe(MapDataPipe[Tuple[T_co, ...]]):
- r"""
- Aggregates elements into a tuple from each of the input DataPipes (functional name: ``zip``).
- This MataPipe is out of bound as soon as the shortest input DataPipe is exhausted.
- Args:
- *datapipes: Map DataPipes being aggregated
- Example:
- >>> # xdoctest: +SKIP
- >>> from torchdata.datapipes.map import SequenceWrapper
- >>> dp1 = SequenceWrapper(range(3))
- >>> dp2 = SequenceWrapper(range(10, 13))
- >>> zip_dp = dp1.zip(dp2)
- >>> list(zip_dp)
- [(0, 10), (1, 11), (2, 12)]
- """
- datapipes: Tuple[MapDataPipe[T_co], ...]
- def __init__(self, *datapipes: MapDataPipe[T_co]) -> None:
- if len(datapipes) == 0:
- raise ValueError("Expected at least one DataPipe, but got nothing")
- if not all(isinstance(dp, MapDataPipe) for dp in datapipes):
- raise TypeError("Expected all inputs to be `MapDataPipe`")
- if not all(isinstance(dp, Sized) for dp in datapipes):
- raise TypeError("Expected all inputs to be `Sized`")
- self.datapipes = datapipes
- def __getitem__(self, index) -> Tuple[T_co, ...]:
- res = []
- for dp in self.datapipes:
- try:
- res.append(dp[index])
- except IndexError as e:
- raise IndexError(f"Index {index} is out of range for one of the input MapDataPipes {dp}.") from e
- return tuple(res)
- def __len__(self) -> int:
- return min(len(dp) for dp in self.datapipes)
|