graph_settings.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143
  1. import inspect
  2. import warnings
  3. from typing import Any, List, Optional, Set
  4. import torch
  5. from torch.utils.data.datapipes.iter.sharding import (
  6. _ShardingIterDataPipe,
  7. SHARDING_PRIORITIES,
  8. )
  9. from torch.utils.data.graph import DataPipe, DataPipeGraph, traverse_dps
  10. __all__ = [
  11. "apply_random_seed",
  12. "apply_sharding",
  13. "apply_shuffle_seed",
  14. "apply_shuffle_settings",
  15. "get_all_graph_pipes",
  16. ]
  17. def get_all_graph_pipes(graph: DataPipeGraph) -> List[DataPipe]:
  18. return _get_all_graph_pipes_helper(graph, set())
  19. def _get_all_graph_pipes_helper(graph: DataPipeGraph, id_cache: Set[int]) -> List[DataPipe]:
  20. results: List[DataPipe] = []
  21. for dp_id, (datapipe, sub_graph) in graph.items():
  22. if dp_id in id_cache:
  23. continue
  24. id_cache.add(dp_id)
  25. results.append(datapipe)
  26. results.extend(_get_all_graph_pipes_helper(sub_graph, id_cache))
  27. return results
  28. def apply_sharding(datapipe: DataPipe,
  29. num_of_instances: int,
  30. instance_id: int,
  31. sharding_group=SHARDING_PRIORITIES.DEFAULT) -> DataPipe:
  32. r"""
  33. Apply dynamic sharding over the ``sharding_filter`` DataPipe that has a method ``apply_sharding``.
  34. RuntimeError will be raised when multiple ``sharding_filter`` are presented in the same branch.
  35. """
  36. graph = traverse_dps(datapipe)
  37. def _helper(graph, prev_applied=None):
  38. for _, (dp, sub_graph) in graph.items():
  39. applied = None
  40. if isinstance(dp, _ShardingIterDataPipe):
  41. if prev_applied is not None:
  42. raise RuntimeError("Sharding twice on a single pipeline is likely unintended and will cause data loss. "
  43. f"Sharding already applied to {prev_applied} while trying to apply to {dp}")
  44. dp.apply_sharding(num_of_instances, instance_id, sharding_group=sharding_group)
  45. applied = dp
  46. if applied is None:
  47. applied = prev_applied
  48. _helper(sub_graph, applied)
  49. _helper(graph)
  50. return datapipe
  51. def _is_shuffle_datapipe(datapipe: DataPipe) -> bool:
  52. if not hasattr(datapipe, "set_shuffle") or not hasattr(datapipe, "set_seed"):
  53. return False
  54. if not inspect.ismethod(datapipe.set_shuffle) or not inspect.ismethod(datapipe.set_seed):
  55. return False
  56. return True
  57. def apply_shuffle_settings(datapipe: DataPipe, shuffle: Optional[bool] = None) -> DataPipe:
  58. r"""
  59. Traverse the graph of ``DataPipes`` to find and set shuffle attribute
  60. to each `DataPipe` that has APIs of ``set_shuffle`` and ``set_seed``.
  61. Args:
  62. datapipe: DataPipe that needs to set shuffle attribute
  63. shuffle: Shuffle option (default: ``None`` and no-op to the graph)
  64. """
  65. if shuffle is None:
  66. return datapipe
  67. graph = traverse_dps(datapipe)
  68. all_pipes = get_all_graph_pipes(graph)
  69. shufflers = [pipe for pipe in all_pipes if _is_shuffle_datapipe(pipe)]
  70. if not shufflers and shuffle:
  71. warnings.warn(
  72. "`shuffle=True` was set, but the datapipe does not contain a `Shuffler`. Adding one at the end. "
  73. "Be aware that the default buffer size might not be sufficient for your task."
  74. )
  75. datapipe = datapipe.shuffle()
  76. shufflers = [datapipe, ] # type: ignore[list-item]
  77. for shuffler in shufflers:
  78. shuffler.set_shuffle(shuffle)
  79. return datapipe
  80. def apply_shuffle_seed(datapipe: DataPipe, rng: Any) -> DataPipe:
  81. warnings.warn(
  82. "`apply_shuffle_seed` is deprecated since 1.12 and will be removed in the future releases."
  83. "\nPlease use `apply_random_seed` instead."
  84. )
  85. return apply_random_seed(datapipe, rng)
  86. def _is_random_datapipe(datapipe: DataPipe) -> bool:
  87. if hasattr(datapipe, "set_seed") and inspect.ismethod(datapipe.set_seed):
  88. return True
  89. return False
  90. def apply_random_seed(datapipe: DataPipe, rng: torch.Generator) -> DataPipe:
  91. r"""
  92. Traverse the graph of ``DataPipes`` to find random ``DataPipe`` with an API of
  93. ``set_seed`` then set the random seed based on the provided RNG.
  94. Args:
  95. datapipe: DataPipe that needs to set randomness
  96. rng: Random number generator to generate random seeds
  97. """
  98. graph = traverse_dps(datapipe)
  99. all_pipes = get_all_graph_pipes(graph)
  100. # Using a set to track id of DataPipe to prevent setting randomness per DataPipe more than once.
  101. # And, `id` is used in case of unhashable DataPipe
  102. cache = set()
  103. random_datapipes = []
  104. for pipe in all_pipes:
  105. if id(pipe) in cache:
  106. continue
  107. if _is_random_datapipe(pipe):
  108. random_datapipes.append(pipe)
  109. cache.add(id(pipe))
  110. for pipe in random_datapipes:
  111. random_seed = int(torch.empty((), dtype=torch.int64).random_(generator=rng).item())
  112. pipe.set_seed(random_seed)
  113. return datapipe