_traverse.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170
  1. # Copyright (c) Meta Platforms, Inc. and affiliates
  2. import torch
  3. from typing import (
  4. Callable,
  5. Collection,
  6. List,
  7. Mapping,
  8. MutableMapping,
  9. Optional,
  10. Tuple,
  11. TypeVar,
  12. Union,
  13. cast,
  14. )
  15. from torch.distributed.checkpoint.metadata import (
  16. STATE_DICT_TYPE,
  17. )
  18. from torch.distributed._shard.sharded_tensor.api import ShardedTensor
  19. from torch.distributed._tensor import DTensor
  20. PATH_ITEM = Union[str, int]
  21. OBJ_PATH = Tuple[PATH_ITEM, ...]
  22. T = TypeVar("T")
  23. STATE_DICT_ITEM = object
  24. CONTAINER_TYPE = MutableMapping[PATH_ITEM, STATE_DICT_ITEM]
  25. __all__ = ["traverse_state_dict", "set_element", "get_element", "print_tensor"]
  26. def _keep_visiting_tensors(value: STATE_DICT_ITEM) -> bool:
  27. return isinstance(value, torch.Tensor)
  28. # TODO: update docstring for traverse.py
  29. def traverse_state_dict(
  30. state_dict: STATE_DICT_TYPE,
  31. visitor: Callable[[OBJ_PATH, STATE_DICT_ITEM], None],
  32. keep_traversing: Callable[[STATE_DICT_ITEM], bool] = _keep_visiting_tensors,
  33. ) -> None:
  34. """
  35. Invoke ``visitor`` for each value recursively in ``state_dict``.
  36. Traversal is short-circuited when if finds a collection for which ``keep_visiting_tensors`` evaluates
  37. to false for all elements.
  38. By default, all collections with at least one ``torch.Tensor`` element are traversed.
  39. Visitor takes a path argument that is a tuple of the keys used to reach it.
  40. """
  41. # a value is terminal if it has no other containers values inside it
  42. def _is_terminal(value: STATE_DICT_ITEM) -> bool:
  43. values: Collection[STATE_DICT_ITEM]
  44. if isinstance(value, Mapping):
  45. values = value.values()
  46. elif isinstance(value, list):
  47. values = value
  48. else:
  49. return True
  50. for entry in values:
  51. if isinstance(entry, (Mapping, list)) and not _is_terminal(entry):
  52. return False
  53. if keep_traversing is not None and keep_traversing(entry):
  54. return False
  55. return True
  56. def _traverse_obj(path: OBJ_PATH, value: STATE_DICT_ITEM) -> None:
  57. if _is_terminal(value):
  58. visitor(path, value)
  59. elif isinstance(value, Mapping):
  60. for k, v in value.items():
  61. _traverse_obj(path + (str(k),), v)
  62. elif isinstance(value, list):
  63. for i, v in enumerate(value):
  64. _traverse_obj(path + (i,), v)
  65. for key, value in state_dict.items():
  66. _traverse_obj((str(key),), value)
  67. def set_element(
  68. root_dict: STATE_DICT_TYPE, path: OBJ_PATH, value: STATE_DICT_ITEM
  69. ) -> None:
  70. """
  71. Set ``value`` in ``root_dict`` along the ``path`` object path.
  72. """
  73. cur_container = cast(CONTAINER_TYPE, root_dict)
  74. def extend_list(lst: List[STATE_DICT_ITEM], idx: int) -> None:
  75. while len(lst) <= idx:
  76. lst.append(None)
  77. for i in range(1, len(path)):
  78. prev_key = path[i - 1]
  79. key = path[i]
  80. def_val = cast(STATE_DICT_ITEM, {} if type(key) == str else [])
  81. if isinstance(cur_container, Mapping):
  82. cur_container = cast(
  83. CONTAINER_TYPE, cur_container.setdefault(prev_key, def_val)
  84. )
  85. else:
  86. extend_list(cur_container, prev_key)
  87. if cur_container[prev_key] is None:
  88. cur_container[prev_key] = def_val
  89. cur_container = cur_container[prev_key]
  90. key = path[-1]
  91. if type(key) == int:
  92. extend_list(cast(List[STATE_DICT_ITEM], cur_container), key)
  93. cur_container[key] = value
  94. def get_element(
  95. root_dict: STATE_DICT_TYPE,
  96. path: OBJ_PATH,
  97. default_value: Optional[T] = None,
  98. ) -> Optional[T]:
  99. """
  100. Retrieve the value at ``path``from ``root_dict``, returning ``default_value`` if not found.
  101. """
  102. cur_value = cast(CONTAINER_TYPE, root_dict)
  103. for part in path:
  104. if type(part) is int:
  105. if not isinstance(cur_value, list) or len(cur_value) < part:
  106. return default_value
  107. elif not isinstance(cur_value, Mapping) or part not in cur_value:
  108. return default_value
  109. cur_value = cast(CONTAINER_TYPE, cur_value[part])
  110. return cast(Optional[T], cur_value)
  111. def _print_nested(
  112. value: STATE_DICT_ITEM,
  113. prefix: str = "",
  114. print_fun: Callable[[str], None] = print,
  115. ) -> None:
  116. if type(value) is ShardedTensor:
  117. print_fun(f"{prefix} ShardedTensor size: {value.size()}")
  118. for shard in value.local_shards():
  119. _print_nested(
  120. shard.tensor,
  121. f"{shard.metadata.shard_offsets} ",
  122. print_fun=print_fun,
  123. )
  124. elif type(value) is (DTensor):
  125. print_fun(f"{prefix} DistributedTensor size: {value.size()}")
  126. # TODO: add local offset for _local_tensor in print_nested.
  127. _print_nested(
  128. value._local_tensor,
  129. print_fun=print_fun,
  130. )
  131. elif isinstance(value, torch.Tensor):
  132. print_fun(f"{prefix} Tensor size: {value.size()}")
  133. else:
  134. print_fun(f"{prefix} Type: {type(value)}")
  135. def print_tensor(
  136. path: OBJ_PATH,
  137. value: STATE_DICT_ITEM,
  138. print_fun: Callable[[str], None] = print,
  139. ) -> None:
  140. """
  141. Callback that can be used with travese_state_dict to print its content.
  142. By default the content is printed using the builtin ``print`` but this can
  143. be change by passing a different ``print_fun` callable.
  144. """
  145. _print_nested(value, prefix=str(path), print_fun=print_fun)