callable.py 1.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859
  1. from torch.utils.data.datapipes.utils.common import _check_unpickable_fn
  2. from typing import Callable, TypeVar
  3. from torch.utils.data.datapipes._decorator import functional_datapipe
  4. from torch.utils.data.datapipes.datapipe import MapDataPipe
  5. __all__ = ["MapperMapDataPipe", "default_fn"]
  6. T_co = TypeVar('T_co', covariant=True)
  7. # Default function to return each item directly
  8. # In order to keep datapipe picklable, eliminates the usage
  9. # of python lambda function
  10. def default_fn(data):
  11. return data
  12. @functional_datapipe('map')
  13. class MapperMapDataPipe(MapDataPipe[T_co]):
  14. r"""
  15. Apply the input function over each item from the source DataPipe (functional name: ``map``).
  16. The function can be any regular Python function or partial object. Lambda
  17. function is not recommended as it is not supported by pickle.
  18. Args:
  19. datapipe: Source MapDataPipe
  20. fn: Function being applied to each item
  21. Example:
  22. >>> # xdoctest: +SKIP
  23. >>> from torchdata.datapipes.map import SequenceWrapper, Mapper
  24. >>> def add_one(x):
  25. ... return x + 1
  26. >>> dp = SequenceWrapper(range(10))
  27. >>> map_dp_1 = dp.map(add_one)
  28. >>> list(map_dp_1)
  29. [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
  30. >>> map_dp_2 = Mapper(dp, lambda x: x + 1)
  31. >>> list(map_dp_2)
  32. [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
  33. """
  34. datapipe: MapDataPipe
  35. fn: Callable
  36. def __init__(
  37. self,
  38. datapipe: MapDataPipe,
  39. fn: Callable = default_fn,
  40. ) -> None:
  41. super().__init__()
  42. self.datapipe = datapipe
  43. _check_unpickable_fn(fn)
  44. self.fn = fn # type: ignore[assignment]
  45. def __len__(self) -> int:
  46. return len(self.datapipe)
  47. def __getitem__(self, index) -> T_co:
  48. return self.fn(self.datapipe[index])