graph.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146
  1. import io
  2. import pickle
  3. import warnings
  4. from collections.abc import Collection
  5. from typing import Dict, List, Optional, Set, Tuple, Type, Union
  6. from torch.utils.data import IterDataPipe, MapDataPipe
  7. from torch.utils.data._utils.serialization import DILL_AVAILABLE
  8. __all__ = ["traverse", "traverse_dps"]
  9. DataPipe = Union[IterDataPipe, MapDataPipe]
  10. DataPipeGraph = Dict[int, Tuple[DataPipe, "DataPipeGraph"]] # type: ignore[misc]
  11. def _stub_unpickler():
  12. return "STUB"
  13. # TODO(VitalyFedyunin): Make sure it works without dill module installed
  14. def _list_connected_datapipes(scan_obj: DataPipe, only_datapipe: bool, cache: Set[int]) -> List[DataPipe]:
  15. f = io.BytesIO()
  16. p = pickle.Pickler(f) # Not going to work for lambdas, but dill infinite loops on typing and can't be used as is
  17. if DILL_AVAILABLE:
  18. from dill import Pickler as dill_Pickler
  19. d = dill_Pickler(f)
  20. else:
  21. d = None
  22. captured_connections = []
  23. def getstate_hook(ori_state):
  24. state = None
  25. if isinstance(ori_state, dict):
  26. state = {} # type: ignore[assignment]
  27. for k, v in ori_state.items():
  28. if isinstance(v, (IterDataPipe, MapDataPipe, Collection)):
  29. state[k] = v # type: ignore[attr-defined]
  30. elif isinstance(ori_state, (tuple, list)):
  31. state = [] # type: ignore[assignment]
  32. for v in ori_state:
  33. if isinstance(v, (IterDataPipe, MapDataPipe, Collection)):
  34. state.append(v) # type: ignore[attr-defined]
  35. elif isinstance(ori_state, (IterDataPipe, MapDataPipe, Collection)):
  36. state = ori_state # type: ignore[assignment]
  37. return state
  38. def reduce_hook(obj):
  39. if obj == scan_obj or id(obj) in cache:
  40. raise NotImplementedError
  41. else:
  42. captured_connections.append(obj)
  43. # Adding id to remove duplicate DataPipe serialized at the same level
  44. cache.add(id(obj))
  45. return _stub_unpickler, ()
  46. datapipe_classes: Tuple[Type[DataPipe]] = (IterDataPipe, MapDataPipe) # type: ignore[assignment]
  47. try:
  48. for cls in datapipe_classes:
  49. cls.set_reduce_ex_hook(reduce_hook)
  50. if only_datapipe:
  51. cls.set_getstate_hook(getstate_hook)
  52. try:
  53. p.dump(scan_obj)
  54. except (pickle.PickleError, AttributeError, TypeError):
  55. if DILL_AVAILABLE:
  56. d.dump(scan_obj)
  57. else:
  58. raise
  59. finally:
  60. for cls in datapipe_classes:
  61. cls.set_reduce_ex_hook(None)
  62. if only_datapipe:
  63. cls.set_getstate_hook(None)
  64. if DILL_AVAILABLE:
  65. from dill import extend as dill_extend
  66. dill_extend(False) # Undo change to dispatch table
  67. return captured_connections
  68. def traverse_dps(datapipe: DataPipe) -> DataPipeGraph:
  69. r"""
  70. Traverse the DataPipes and their attributes to extract the DataPipe graph.
  71. This only looks into the attribute from each DataPipe that is either a
  72. DataPipe and a Python collection object such as ``list``, ``tuple``,
  73. ``set`` and ``dict``.
  74. Args:
  75. datapipe: the end DataPipe of the graph
  76. Returns:
  77. A graph represented as a nested dictionary, where keys are ids of DataPipe instances
  78. and values are tuples of DataPipe instance and the sub-graph
  79. """
  80. cache: Set[int] = set()
  81. return _traverse_helper(datapipe, only_datapipe=True, cache=cache)
  82. def traverse(datapipe: DataPipe, only_datapipe: Optional[bool] = None) -> DataPipeGraph:
  83. r"""
  84. [Deprecated] Traverse the DataPipes and their attributes to extract the DataPipe graph. When
  85. ``only_dataPipe`` is specified as ``True``, it would only look into the attribute
  86. from each DataPipe that is either a DataPipe and a Python collection object such as
  87. ``list``, ``tuple``, ``set`` and ``dict``.
  88. Note:
  89. This function is deprecated. Please use `traverse_dps` instead.
  90. Args:
  91. datapipe: the end DataPipe of the graph
  92. only_datapipe: If ``False`` (default), all attributes of each DataPipe are traversed.
  93. This argument is deprecating and will be removed after the next release.
  94. Returns:
  95. A graph represented as a nested dictionary, where keys are ids of DataPipe instances
  96. and values are tuples of DataPipe instance and the sub-graph
  97. """
  98. msg = "`traverse` function and will be removed after 1.13. " \
  99. "Please use `traverse_dps` instead."
  100. if not only_datapipe:
  101. msg += " And, the behavior will be changed to the equivalent of `only_datapipe=True`."
  102. warnings.warn(msg, FutureWarning)
  103. if only_datapipe is None:
  104. only_datapipe = False
  105. cache: Set[int] = set()
  106. return _traverse_helper(datapipe, only_datapipe, cache)
  107. # Add cache here to prevent infinite recursion on DataPipe
  108. def _traverse_helper(datapipe: DataPipe, only_datapipe: bool, cache: Set[int]) -> DataPipeGraph:
  109. if not isinstance(datapipe, (IterDataPipe, MapDataPipe)):
  110. raise RuntimeError("Expected `IterDataPipe` or `MapDataPipe`, but {} is found".format(type(datapipe)))
  111. dp_id = id(datapipe)
  112. if dp_id in cache:
  113. return {}
  114. cache.add(dp_id)
  115. # Using cache.copy() here is to prevent the same DataPipe pollutes the cache on different paths
  116. items = _list_connected_datapipes(datapipe, only_datapipe, cache.copy())
  117. d: DataPipeGraph = {dp_id: (datapipe, {})}
  118. for item in items:
  119. # Using cache.copy() here is to prevent recursion on a single path rather than global graph
  120. # Single DataPipe can present multiple times in different paths in graph
  121. d[dp_id][1].update(_traverse_helper(item, only_datapipe, cache.copy()))
  122. return d