123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143 |
- import inspect
- import warnings
- from typing import Any, List, Optional, Set
- import torch
- from torch.utils.data.datapipes.iter.sharding import (
- _ShardingIterDataPipe,
- SHARDING_PRIORITIES,
- )
- from torch.utils.data.graph import DataPipe, DataPipeGraph, traverse_dps
- __all__ = [
- "apply_random_seed",
- "apply_sharding",
- "apply_shuffle_seed",
- "apply_shuffle_settings",
- "get_all_graph_pipes",
- ]
- def get_all_graph_pipes(graph: DataPipeGraph) -> List[DataPipe]:
- return _get_all_graph_pipes_helper(graph, set())
- def _get_all_graph_pipes_helper(graph: DataPipeGraph, id_cache: Set[int]) -> List[DataPipe]:
- results: List[DataPipe] = []
- for dp_id, (datapipe, sub_graph) in graph.items():
- if dp_id in id_cache:
- continue
- id_cache.add(dp_id)
- results.append(datapipe)
- results.extend(_get_all_graph_pipes_helper(sub_graph, id_cache))
- return results
- def apply_sharding(datapipe: DataPipe,
- num_of_instances: int,
- instance_id: int,
- sharding_group=SHARDING_PRIORITIES.DEFAULT) -> DataPipe:
- r"""
- Apply dynamic sharding over the ``sharding_filter`` DataPipe that has a method ``apply_sharding``.
- RuntimeError will be raised when multiple ``sharding_filter`` are presented in the same branch.
- """
- graph = traverse_dps(datapipe)
- def _helper(graph, prev_applied=None):
- for _, (dp, sub_graph) in graph.items():
- applied = None
- if isinstance(dp, _ShardingIterDataPipe):
- if prev_applied is not None:
- raise RuntimeError("Sharding twice on a single pipeline is likely unintended and will cause data loss. "
- f"Sharding already applied to {prev_applied} while trying to apply to {dp}")
- dp.apply_sharding(num_of_instances, instance_id, sharding_group=sharding_group)
- applied = dp
- if applied is None:
- applied = prev_applied
- _helper(sub_graph, applied)
- _helper(graph)
- return datapipe
- def _is_shuffle_datapipe(datapipe: DataPipe) -> bool:
- if not hasattr(datapipe, "set_shuffle") or not hasattr(datapipe, "set_seed"):
- return False
- if not inspect.ismethod(datapipe.set_shuffle) or not inspect.ismethod(datapipe.set_seed):
- return False
- return True
- def apply_shuffle_settings(datapipe: DataPipe, shuffle: Optional[bool] = None) -> DataPipe:
- r"""
- Traverse the graph of ``DataPipes`` to find and set shuffle attribute
- to each `DataPipe` that has APIs of ``set_shuffle`` and ``set_seed``.
- Args:
- datapipe: DataPipe that needs to set shuffle attribute
- shuffle: Shuffle option (default: ``None`` and no-op to the graph)
- """
- if shuffle is None:
- return datapipe
- graph = traverse_dps(datapipe)
- all_pipes = get_all_graph_pipes(graph)
- shufflers = [pipe for pipe in all_pipes if _is_shuffle_datapipe(pipe)]
- if not shufflers and shuffle:
- warnings.warn(
- "`shuffle=True` was set, but the datapipe does not contain a `Shuffler`. Adding one at the end. "
- "Be aware that the default buffer size might not be sufficient for your task."
- )
- datapipe = datapipe.shuffle()
- shufflers = [datapipe, ] # type: ignore[list-item]
- for shuffler in shufflers:
- shuffler.set_shuffle(shuffle)
- return datapipe
- def apply_shuffle_seed(datapipe: DataPipe, rng: Any) -> DataPipe:
- warnings.warn(
- "`apply_shuffle_seed` is deprecated since 1.12 and will be removed in the future releases."
- "\nPlease use `apply_random_seed` instead."
- )
- return apply_random_seed(datapipe, rng)
- def _is_random_datapipe(datapipe: DataPipe) -> bool:
- if hasattr(datapipe, "set_seed") and inspect.ismethod(datapipe.set_seed):
- return True
- return False
- def apply_random_seed(datapipe: DataPipe, rng: torch.Generator) -> DataPipe:
- r"""
- Traverse the graph of ``DataPipes`` to find random ``DataPipe`` with an API of
- ``set_seed`` then set the random seed based on the provided RNG.
- Args:
- datapipe: DataPipe that needs to set randomness
- rng: Random number generator to generate random seeds
- """
- graph = traverse_dps(datapipe)
- all_pipes = get_all_graph_pipes(graph)
- # Using a set to track id of DataPipe to prevent setting randomness per DataPipe more than once.
- # And, `id` is used in case of unhashable DataPipe
- cache = set()
- random_datapipes = []
- for pipe in all_pipes:
- if id(pipe) in cache:
- continue
- if _is_random_datapipe(pipe):
- random_datapipes.append(pipe)
- cache.add(id(pipe))
- for pipe in random_datapipes:
- random_seed = int(torch.empty((), dtype=torch.int64).random_(generator=rng).item())
- pipe.set_seed(random_seed)
- return datapipe
|