# Copyright (c) Meta Platforms, Inc. and affiliates import torch from typing import ( Callable, Collection, List, Mapping, MutableMapping, Optional, Tuple, TypeVar, Union, cast, ) from torch.distributed.checkpoint.metadata import ( STATE_DICT_TYPE, ) from torch.distributed._shard.sharded_tensor.api import ShardedTensor from torch.distributed._tensor import DTensor PATH_ITEM = Union[str, int] OBJ_PATH = Tuple[PATH_ITEM, ...] T = TypeVar("T") STATE_DICT_ITEM = object CONTAINER_TYPE = MutableMapping[PATH_ITEM, STATE_DICT_ITEM] __all__ = ["traverse_state_dict", "set_element", "get_element", "print_tensor"] def _keep_visiting_tensors(value: STATE_DICT_ITEM) -> bool: return isinstance(value, torch.Tensor) # TODO: update docstring for traverse.py def traverse_state_dict( state_dict: STATE_DICT_TYPE, visitor: Callable[[OBJ_PATH, STATE_DICT_ITEM], None], keep_traversing: Callable[[STATE_DICT_ITEM], bool] = _keep_visiting_tensors, ) -> None: """ Invoke ``visitor`` for each value recursively in ``state_dict``. Traversal is short-circuited when if finds a collection for which ``keep_visiting_tensors`` evaluates to false for all elements. By default, all collections with at least one ``torch.Tensor`` element are traversed. Visitor takes a path argument that is a tuple of the keys used to reach it. """ # a value is terminal if it has no other containers values inside it def _is_terminal(value: STATE_DICT_ITEM) -> bool: values: Collection[STATE_DICT_ITEM] if isinstance(value, Mapping): values = value.values() elif isinstance(value, list): values = value else: return True for entry in values: if isinstance(entry, (Mapping, list)) and not _is_terminal(entry): return False if keep_traversing is not None and keep_traversing(entry): return False return True def _traverse_obj(path: OBJ_PATH, value: STATE_DICT_ITEM) -> None: if _is_terminal(value): visitor(path, value) elif isinstance(value, Mapping): for k, v in value.items(): _traverse_obj(path + (str(k),), v) elif isinstance(value, list): for i, v in enumerate(value): _traverse_obj(path + (i,), v) for key, value in state_dict.items(): _traverse_obj((str(key),), value) def set_element( root_dict: STATE_DICT_TYPE, path: OBJ_PATH, value: STATE_DICT_ITEM ) -> None: """ Set ``value`` in ``root_dict`` along the ``path`` object path. """ cur_container = cast(CONTAINER_TYPE, root_dict) def extend_list(lst: List[STATE_DICT_ITEM], idx: int) -> None: while len(lst) <= idx: lst.append(None) for i in range(1, len(path)): prev_key = path[i - 1] key = path[i] def_val = cast(STATE_DICT_ITEM, {} if type(key) == str else []) if isinstance(cur_container, Mapping): cur_container = cast( CONTAINER_TYPE, cur_container.setdefault(prev_key, def_val) ) else: extend_list(cur_container, prev_key) if cur_container[prev_key] is None: cur_container[prev_key] = def_val cur_container = cur_container[prev_key] key = path[-1] if type(key) == int: extend_list(cast(List[STATE_DICT_ITEM], cur_container), key) cur_container[key] = value def get_element( root_dict: STATE_DICT_TYPE, path: OBJ_PATH, default_value: Optional[T] = None, ) -> Optional[T]: """ Retrieve the value at ``path``from ``root_dict``, returning ``default_value`` if not found. """ cur_value = cast(CONTAINER_TYPE, root_dict) for part in path: if type(part) is int: if not isinstance(cur_value, list) or len(cur_value) < part: return default_value elif not isinstance(cur_value, Mapping) or part not in cur_value: return default_value cur_value = cast(CONTAINER_TYPE, cur_value[part]) return cast(Optional[T], cur_value) def _print_nested( value: STATE_DICT_ITEM, prefix: str = "", print_fun: Callable[[str], None] = print, ) -> None: if type(value) is ShardedTensor: print_fun(f"{prefix} ShardedTensor size: {value.size()}") for shard in value.local_shards(): _print_nested( shard.tensor, f"{shard.metadata.shard_offsets} ", print_fun=print_fun, ) elif type(value) is (DTensor): print_fun(f"{prefix} DistributedTensor size: {value.size()}") # TODO: add local offset for _local_tensor in print_nested. _print_nested( value._local_tensor, print_fun=print_fun, ) elif isinstance(value, torch.Tensor): print_fun(f"{prefix} Tensor size: {value.size()}") else: print_fun(f"{prefix} Type: {type(value)}") def print_tensor( path: OBJ_PATH, value: STATE_DICT_ITEM, print_fun: Callable[[str], None] = print, ) -> None: """ Callback that can be used with travese_state_dict to print its content. By default the content is printed using the builtin ``print`` but this can be change by passing a different ``print_fun` callable. """ _print_nested(value, prefix=str(path), print_fun=print_fun)