123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170 |
- # Copyright (c) Meta Platforms, Inc. and affiliates
- import torch
- from typing import (
- Callable,
- Collection,
- List,
- Mapping,
- MutableMapping,
- Optional,
- Tuple,
- TypeVar,
- Union,
- cast,
- )
- from torch.distributed.checkpoint.metadata import (
- STATE_DICT_TYPE,
- )
- from torch.distributed._shard.sharded_tensor.api import ShardedTensor
- from torch.distributed._tensor import DTensor
- PATH_ITEM = Union[str, int]
- OBJ_PATH = Tuple[PATH_ITEM, ...]
- T = TypeVar("T")
- STATE_DICT_ITEM = object
- CONTAINER_TYPE = MutableMapping[PATH_ITEM, STATE_DICT_ITEM]
- __all__ = ["traverse_state_dict", "set_element", "get_element", "print_tensor"]
- def _keep_visiting_tensors(value: STATE_DICT_ITEM) -> bool:
- return isinstance(value, torch.Tensor)
- # TODO: update docstring for traverse.py
- def traverse_state_dict(
- state_dict: STATE_DICT_TYPE,
- visitor: Callable[[OBJ_PATH, STATE_DICT_ITEM], None],
- keep_traversing: Callable[[STATE_DICT_ITEM], bool] = _keep_visiting_tensors,
- ) -> None:
- """
- Invoke ``visitor`` for each value recursively in ``state_dict``.
- Traversal is short-circuited when if finds a collection for which ``keep_visiting_tensors`` evaluates
- to false for all elements.
- By default, all collections with at least one ``torch.Tensor`` element are traversed.
- Visitor takes a path argument that is a tuple of the keys used to reach it.
- """
- # a value is terminal if it has no other containers values inside it
- def _is_terminal(value: STATE_DICT_ITEM) -> bool:
- values: Collection[STATE_DICT_ITEM]
- if isinstance(value, Mapping):
- values = value.values()
- elif isinstance(value, list):
- values = value
- else:
- return True
- for entry in values:
- if isinstance(entry, (Mapping, list)) and not _is_terminal(entry):
- return False
- if keep_traversing is not None and keep_traversing(entry):
- return False
- return True
- def _traverse_obj(path: OBJ_PATH, value: STATE_DICT_ITEM) -> None:
- if _is_terminal(value):
- visitor(path, value)
- elif isinstance(value, Mapping):
- for k, v in value.items():
- _traverse_obj(path + (str(k),), v)
- elif isinstance(value, list):
- for i, v in enumerate(value):
- _traverse_obj(path + (i,), v)
- for key, value in state_dict.items():
- _traverse_obj((str(key),), value)
- def set_element(
- root_dict: STATE_DICT_TYPE, path: OBJ_PATH, value: STATE_DICT_ITEM
- ) -> None:
- """
- Set ``value`` in ``root_dict`` along the ``path`` object path.
- """
- cur_container = cast(CONTAINER_TYPE, root_dict)
- def extend_list(lst: List[STATE_DICT_ITEM], idx: int) -> None:
- while len(lst) <= idx:
- lst.append(None)
- for i in range(1, len(path)):
- prev_key = path[i - 1]
- key = path[i]
- def_val = cast(STATE_DICT_ITEM, {} if type(key) == str else [])
- if isinstance(cur_container, Mapping):
- cur_container = cast(
- CONTAINER_TYPE, cur_container.setdefault(prev_key, def_val)
- )
- else:
- extend_list(cur_container, prev_key)
- if cur_container[prev_key] is None:
- cur_container[prev_key] = def_val
- cur_container = cur_container[prev_key]
- key = path[-1]
- if type(key) == int:
- extend_list(cast(List[STATE_DICT_ITEM], cur_container), key)
- cur_container[key] = value
- def get_element(
- root_dict: STATE_DICT_TYPE,
- path: OBJ_PATH,
- default_value: Optional[T] = None,
- ) -> Optional[T]:
- """
- Retrieve the value at ``path``from ``root_dict``, returning ``default_value`` if not found.
- """
- cur_value = cast(CONTAINER_TYPE, root_dict)
- for part in path:
- if type(part) is int:
- if not isinstance(cur_value, list) or len(cur_value) < part:
- return default_value
- elif not isinstance(cur_value, Mapping) or part not in cur_value:
- return default_value
- cur_value = cast(CONTAINER_TYPE, cur_value[part])
- return cast(Optional[T], cur_value)
- def _print_nested(
- value: STATE_DICT_ITEM,
- prefix: str = "",
- print_fun: Callable[[str], None] = print,
- ) -> None:
- if type(value) is ShardedTensor:
- print_fun(f"{prefix} ShardedTensor size: {value.size()}")
- for shard in value.local_shards():
- _print_nested(
- shard.tensor,
- f"{shard.metadata.shard_offsets} ",
- print_fun=print_fun,
- )
- elif type(value) is (DTensor):
- print_fun(f"{prefix} DistributedTensor size: {value.size()}")
- # TODO: add local offset for _local_tensor in print_nested.
- _print_nested(
- value._local_tensor,
- print_fun=print_fun,
- )
- elif isinstance(value, torch.Tensor):
- print_fun(f"{prefix} Tensor size: {value.size()}")
- else:
- print_fun(f"{prefix} Type: {type(value)}")
- def print_tensor(
- path: OBJ_PATH,
- value: STATE_DICT_ITEM,
- print_fun: Callable[[str], None] = print,
- ) -> None:
- """
- Callback that can be used with travese_state_dict to print its content.
- By default the content is printed using the builtin ``print`` but this can
- be change by passing a different ``print_fun` callable.
- """
- _print_nested(value, prefix=str(path), print_fun=print_fun)
|