123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146 |
- import io
- import pickle
- import warnings
- from collections.abc import Collection
- from typing import Dict, List, Optional, Set, Tuple, Type, Union
- from torch.utils.data import IterDataPipe, MapDataPipe
- from torch.utils.data._utils.serialization import DILL_AVAILABLE
- __all__ = ["traverse", "traverse_dps"]
- DataPipe = Union[IterDataPipe, MapDataPipe]
- DataPipeGraph = Dict[int, Tuple[DataPipe, "DataPipeGraph"]] # type: ignore[misc]
- def _stub_unpickler():
- return "STUB"
- # TODO(VitalyFedyunin): Make sure it works without dill module installed
- def _list_connected_datapipes(scan_obj: DataPipe, only_datapipe: bool, cache: Set[int]) -> List[DataPipe]:
- f = io.BytesIO()
- p = pickle.Pickler(f) # Not going to work for lambdas, but dill infinite loops on typing and can't be used as is
- if DILL_AVAILABLE:
- from dill import Pickler as dill_Pickler
- d = dill_Pickler(f)
- else:
- d = None
- captured_connections = []
- def getstate_hook(ori_state):
- state = None
- if isinstance(ori_state, dict):
- state = {} # type: ignore[assignment]
- for k, v in ori_state.items():
- if isinstance(v, (IterDataPipe, MapDataPipe, Collection)):
- state[k] = v # type: ignore[attr-defined]
- elif isinstance(ori_state, (tuple, list)):
- state = [] # type: ignore[assignment]
- for v in ori_state:
- if isinstance(v, (IterDataPipe, MapDataPipe, Collection)):
- state.append(v) # type: ignore[attr-defined]
- elif isinstance(ori_state, (IterDataPipe, MapDataPipe, Collection)):
- state = ori_state # type: ignore[assignment]
- return state
- def reduce_hook(obj):
- if obj == scan_obj or id(obj) in cache:
- raise NotImplementedError
- else:
- captured_connections.append(obj)
- # Adding id to remove duplicate DataPipe serialized at the same level
- cache.add(id(obj))
- return _stub_unpickler, ()
- datapipe_classes: Tuple[Type[DataPipe]] = (IterDataPipe, MapDataPipe) # type: ignore[assignment]
- try:
- for cls in datapipe_classes:
- cls.set_reduce_ex_hook(reduce_hook)
- if only_datapipe:
- cls.set_getstate_hook(getstate_hook)
- try:
- p.dump(scan_obj)
- except (pickle.PickleError, AttributeError, TypeError):
- if DILL_AVAILABLE:
- d.dump(scan_obj)
- else:
- raise
- finally:
- for cls in datapipe_classes:
- cls.set_reduce_ex_hook(None)
- if only_datapipe:
- cls.set_getstate_hook(None)
- if DILL_AVAILABLE:
- from dill import extend as dill_extend
- dill_extend(False) # Undo change to dispatch table
- return captured_connections
- def traverse_dps(datapipe: DataPipe) -> DataPipeGraph:
- r"""
- Traverse the DataPipes and their attributes to extract the DataPipe graph.
- This only looks into the attribute from each DataPipe that is either a
- DataPipe and a Python collection object such as ``list``, ``tuple``,
- ``set`` and ``dict``.
- Args:
- datapipe: the end DataPipe of the graph
- Returns:
- A graph represented as a nested dictionary, where keys are ids of DataPipe instances
- and values are tuples of DataPipe instance and the sub-graph
- """
- cache: Set[int] = set()
- return _traverse_helper(datapipe, only_datapipe=True, cache=cache)
- def traverse(datapipe: DataPipe, only_datapipe: Optional[bool] = None) -> DataPipeGraph:
- r"""
- [Deprecated] Traverse the DataPipes and their attributes to extract the DataPipe graph. When
- ``only_dataPipe`` is specified as ``True``, it would only look into the attribute
- from each DataPipe that is either a DataPipe and a Python collection object such as
- ``list``, ``tuple``, ``set`` and ``dict``.
- Note:
- This function is deprecated. Please use `traverse_dps` instead.
- Args:
- datapipe: the end DataPipe of the graph
- only_datapipe: If ``False`` (default), all attributes of each DataPipe are traversed.
- This argument is deprecating and will be removed after the next release.
- Returns:
- A graph represented as a nested dictionary, where keys are ids of DataPipe instances
- and values are tuples of DataPipe instance and the sub-graph
- """
- msg = "`traverse` function and will be removed after 1.13. " \
- "Please use `traverse_dps` instead."
- if not only_datapipe:
- msg += " And, the behavior will be changed to the equivalent of `only_datapipe=True`."
- warnings.warn(msg, FutureWarning)
- if only_datapipe is None:
- only_datapipe = False
- cache: Set[int] = set()
- return _traverse_helper(datapipe, only_datapipe, cache)
- # Add cache here to prevent infinite recursion on DataPipe
- def _traverse_helper(datapipe: DataPipe, only_datapipe: bool, cache: Set[int]) -> DataPipeGraph:
- if not isinstance(datapipe, (IterDataPipe, MapDataPipe)):
- raise RuntimeError("Expected `IterDataPipe` or `MapDataPipe`, but {} is found".format(type(datapipe)))
- dp_id = id(datapipe)
- if dp_id in cache:
- return {}
- cache.add(dp_id)
- # Using cache.copy() here is to prevent the same DataPipe pollutes the cache on different paths
- items = _list_connected_datapipes(datapipe, only_datapipe, cache.copy())
- d: DataPipeGraph = {dp_id: (datapipe, {})}
- for item in items:
- # Using cache.copy() here is to prevent recursion on a single path rather than global graph
- # Single DataPipe can present multiple times in different paths in graph
- d[dp_id][1].update(_traverse_helper(item, only_datapipe, cache.copy()))
- return d
|