123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332 |
- # Copyright (c) Meta Platforms, Inc. and affiliates
- import copy
- import dataclasses
- from typing import Dict, List, Optional, Sequence, Tuple, Union, cast
- from torch.distributed.checkpoint.planner import LoadPlan
- import torch
- import torch.distributed as dist
- from torch.distributed._shard.sharded_tensor.api import ShardedTensor
- from torch.distributed._shard.sharded_tensor.metadata import TensorProperties
- from torch.distributed._shard.sharded_tensor.shard import Shard
- from torch.distributed._shard.sharding_spec.chunk_sharding_spec import (
- ChunkShardingSpec,
- )
- import torch.distributed.checkpoint as dist_cp
- from torch.distributed.checkpoint.metadata import (
- BytesStorageMetadata,
- Metadata,
- MetadataIndex,
- STATE_DICT_TYPE,
- TensorStorageMetadata,
- )
- from torch.distributed.checkpoint.planner_helpers import (
- _create_sharded_read_items,
- _create_read_items,
- )
- from torch.distributed.remote_device import _remote_device
- from torch.distributed._tensor import DTensor
- from torch.distributed.checkpoint.default_planner import (
- DefaultLoadPlanner,
- )
- from torch.distributed._shard.api import _shard_tensor
- from torch.distributed.checkpoint._nested_dict import unflatten_state_dict
- from torch.distributed.checkpoint.utils import (
- _element_wise_add,
- _element_wise_sub,
- )
- STATE_DICT_2D_LAYOUT = Dict[str, Tuple[Optional[Sequence[int]], Sequence[int]]]
- # TODO: Update docstrings for optimizer.py
- __all__ = [
- "load_sharded_optimizer_state_dict",
- ]
- def _gen_rank_device(global_rank: int) -> str:
- if torch.cuda.is_available():
- return f"cuda:{global_rank % torch.cuda.device_count()}"
- return "cpu"
- def _create_colwise_spec(
- pg: Optional[dist.ProcessGroup] = None,
- ) -> ChunkShardingSpec:
- if pg is None:
- placements = [
- f"rank:{idx}/{_gen_rank_device(idx)}"
- for idx in range(dist.get_world_size())
- ]
- else:
- placements = [
- f"rank:{idx}/{_gen_rank_device(dist.get_global_rank(pg, idx))}"
- for idx in range(pg.size())
- ]
- return ChunkShardingSpec(
- dim=0,
- placements=cast(List[Union[_remote_device, str]], placements),
- )
- def _is_nested_tensor(val: torch.Tensor) -> bool:
- if type(val) is ShardedTensor:
- if len(val.local_shards()) == 0:
- return False
- if type(val.local_shards()[0].tensor) is ShardedTensor:
- return True
- if type(val.local_shards()[0].tensor) is DTensor:
- raise ValueError(
- "Cannot handle DTensor nested insided ShardedTensor"
- )
- elif type(val) is DTensor and (
- type(val._local_tensor) is DTensor
- or type(val._local_tensor) is ShardedTensor
- ):
- raise ValueError("Cannot handle nested DTensor")
- return False
- def _alloc_tensor(props: TensorProperties, size: Sequence[int]) -> torch.Tensor:
- return torch.empty(
- size=size,
- dtype=props.dtype,
- layout=props.layout,
- requires_grad=props.requires_grad,
- pin_memory=props.pin_memory,
- device=cast(torch.device, torch.cuda.current_device()),
- )
- def _get_state_dict_2d_layout(
- state_dict: STATE_DICT_TYPE,
- ) -> Tuple[STATE_DICT_2D_LAYOUT, Optional[dist.ProcessGroup]]:
- """
- We have to load the right TP slice of the optimizer state.
- This is not easy since the per-tensor slicing can't be inferred from checkpoint metadata.
- We take advantage of the model state_dict producing a sliced ST to figure out what we need to load.
- This is pretty fragile and it might be easier for FSDP to compute this info for us.
- Returns a dictionary where keys are the same of the state_dict and the value is a tuple of
- (offset, size) for the current rank TP slice.
- N.B. The state_dict *MUST* come from FSDP.sharded_state_dict.
- """
- specs: STATE_DICT_2D_LAYOUT = {}
- dp_pg: Optional[dist.ProcessGroup] = None
- for key, value in state_dict.items():
- specs[key] = (None, value.size())
- if _is_nested_tensor(value):
- assert (
- len(value.local_shards()) == 1
- ), "Cannot handle ST with multiple shards"
- assert isinstance(
- value, ShardedTensor
- ), "Can only handle nested ShardedTensor"
- shard = value.local_shards()[0]
- specs[key] = (
- shard.metadata.shard_offsets,
- shard.metadata.shard_sizes,
- )
- dp_pg = shard.tensor._process_group # type: ignore[attr-defined]
- return (
- specs,
- dp_pg,
- )
- class _ReaderWithOffset(DefaultLoadPlanner):
- translation: Dict[MetadataIndex, MetadataIndex]
- state_dict: STATE_DICT_TYPE
- metadata: Metadata
- def __init__(self, fqn_to_offset: Dict[str, Sequence[int]]) -> None:
- super().__init__()
- self.fqn_to_offset = fqn_to_offset
- self.metadata = Metadata({})
- self.state_dict = {}
- self.translation = {}
- def create_local_plan(self) -> LoadPlan:
- requests = []
- self.translation = {}
- for fqn, obj in self.state_dict.items():
- md = self.metadata.state_dict_metadata[fqn]
- if not isinstance(obj, ShardedTensor):
- requests += _create_read_items(fqn, md, obj)
- continue
- if fqn not in self.fqn_to_offset:
- requests += _create_read_items(fqn, md, obj)
- continue
- offset = self.fqn_to_offset[fqn]
- assert len(obj.local_shards()) == 1
- original_shard = obj.local_shards()[0]
- shard_md = copy.deepcopy(original_shard.metadata)
- shard_md.shard_offsets = _element_wise_add(
- shard_md.shard_offsets, offset
- )
- local_shards = [Shard(original_shard.tensor, shard_md)]
- reqs = _create_sharded_read_items(
- fqn, cast(TensorStorageMetadata, md), local_shards
- )
- # TODO: The WriteItems will have a displaced MetadataIndex, fix it.
- # TODO: we should change _create_sharded_read_items to have more ergonomic API
- for wi in reqs:
- assert wi.dest_index.offset is not None
- original_offset = _element_wise_sub(
- wi.dest_index.offset, offset
- )
- original_index = dataclasses.replace(
- wi.dest_index, offset=torch.Size(original_offset)
- )
- self.translation[wi.dest_index] = original_index
- requests += reqs
- return LoadPlan(requests)
- def lookup_tensor(self, index: MetadataIndex) -> torch.Tensor:
- return super().lookup_tensor(self.translation.get(index, index))
- def load_sharded_optimizer_state_dict(
- model_state_dict: STATE_DICT_TYPE,
- optimizer_key: str,
- storage_reader: dist_cp.StorageReader,
- ) -> STATE_DICT_TYPE:
- """
- Loads a state_dict to be used in conjuntion with FSDP sharded optimizer state.
- This is the current recommended way to checkpoint is FSDP
- >>> # xdoctest: +SKIP
- >>> import torch.distributed.checkpoint as dist_cp
- >>> # Save
- >>> model: torch.nn.Model
- >>> optim_params = model.parameters()
- >>> optim = torch.optim.SGD(optim_params, lr=0.01)
- >>>
- >>> with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT):
- >>> state_dict = {
- >>> "optimizer": FSDP.sharded_optim_state_dict(model, optim, optim_params),
- >>> "model": model.state_dict()
- >>> }
- >>> dist_cp.save_state_dict(
- >>> state_dict=optim_state,
- >>> storage_writer=dist_cp.FileSystemWriter("checkpoint"),
- >>> planner=dist_cp.DefaultSavePlanner(),
- >>> )
- >>>
- >>> # Load
- >>> with FSDP.state_dict_type(model_tp, StateDictType.SHARDED_STATE_DICT):
- >>> model_state_dict = model_tp.state_dict()
- >>> checkpoint = {
- >>> "model": model_state_dict
- >>> }
- >>> dist_cp.load_state_dict(
- >>> state_dict=checkpoint,
- >>> storage_reader=dist_cp.FileSystemReader(checkpoint_file),
- >>> planner=dist_cp.DefaultLoadPlanner(),
- >>> )
- >>> model.load_state_dict(checkpoint["model_state"])
- >>>
- >>> optim_state = sp_cp.load_sharded_optimizer_state_dict(
- >>> model_state_dict,
- >>> optimizer_key="optimizer",
- >>> storage_reader=dist_cp.FileSystemReader("checkpoint"),
- >>> )
- >>>
- >>> flattened_osd = FSDP.flatten_sharded_optim_state_dict(
- >>> optim_state["optimizer"], model, optim
- >>> )
- >>>
- >>> optim.load_state_dict(flattened_osd)
- """
- metadata = storage_reader.read_metadata()
- layout_specs, dp_pg = _get_state_dict_2d_layout(model_state_dict)
- if dp_pg is None:
- sharding_spec = ChunkShardingSpec(
- dim=0,
- placements=[
- f"rank:{i}/cuda:{i}" for i in range(dist.get_world_size())
- ],
- )
- else:
- sharding_spec = _create_colwise_spec(dp_pg)
- # Create a state_dict for optimizer state
- state_dict: STATE_DICT_TYPE = {}
- fqn_to_offset: Dict[str, Sequence[int]] = {}
- for key, value in metadata.state_dict_metadata.items():
- key_path = metadata.planner_data[key]
- if key_path[0] != optimizer_key:
- continue
- if isinstance(value, BytesStorageMetadata):
- state_dict[key] = "<bytes_io>"
- continue
- # value: TensorStorageMetadata
- if value.size.numel() == 1:
- state_dict[key] = _alloc_tensor(value.properties, value.size)
- elif dp_pg is None:
- state_dict[key] = _shard_tensor(
- _alloc_tensor(value.properties, value.size), sharding_spec
- )
- else:
- spec_key = key_path[2]
- alloc_size = layout_specs.get(spec_key, (None, value.size))[1]
- st_md = sharding_spec.build_metadata(
- torch.Size(alloc_size), value.properties
- )
- local_shards = []
- current_rank = dist.get_rank(dp_pg)
- for shard_md in st_md.shards_metadata:
- if (
- cast(_remote_device, shard_md.placement).rank()
- != current_rank
- ):
- continue
- local_shards.append(
- Shard(
- tensor=_alloc_tensor(
- value.properties, shard_md.shard_sizes
- ),
- metadata=shard_md,
- )
- )
- st = ShardedTensor._init_from_local_shards_and_global_metadata(
- local_shards, st_md, process_group=dp_pg
- )
- if (
- spec_key in layout_specs
- and layout_specs[spec_key][0] is not None
- ):
- fqn_to_offset[key] = cast(
- Sequence[int], layout_specs[spec_key][0]
- )
- state_dict[key] = st
- # Whether we unflatten before or after doesn't matter
- dist_cp.load_state_dict(
- state_dict=state_dict,
- storage_reader=storage_reader,
- # FIXME the type of planner is wrong in load_state_dict
- planner=_ReaderWithOffset(fqn_to_offset) if dp_pg is not None else None,
- )
- state_dict = unflatten_state_dict(state_dict, metadata.planner_data)
- return state_dict
|