123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115 |
- from abc import ABC, abstractmethod
- from typing import Any, List, Optional, Tuple
- import torch
- import torch.distributed as dist
- from torch.distributed._shard.sharded_tensor.api import ShardedTensor
- from torch.distributed._shard.sharded_tensor.shard import Shard
- from torch.distributed.fsdp._shard_utils import _create_chunk_sharded_tensor
- class FSDPExtensions(ABC):
- """
- This enables some customizable hooks to enable composability with tensor
- parallelism. To activate these hooks, use :func:`_set_fsdp_extensions` to
- set a custom :class:`FSDPExtensions` that implements the hooks.
- """
- @abstractmethod
- def pre_flatten_transform(
- self,
- tensor: torch.Tensor,
- ) -> Tuple[torch.Tensor, Optional[Any]]:
- """E.g. converting ``DistributedTensor`` to local tensor."""
- ...
- @abstractmethod
- def post_unflatten_transform(
- self,
- tensor: torch.Tensor,
- param_extension: Any,
- ) -> torch.Tensor:
- """E.g. converting local tensor to ``DistributedTensor``."""
- ...
- @abstractmethod
- def chunk_tensor(
- self,
- tensor: torch.Tensor,
- rank: int,
- world_size: int,
- num_devices_per_node: int,
- pg: dist.ProcessGroup,
- ) -> torch.Tensor:
- """Shards a tensor to chunks and returns the local chunk."""
- ...
- @abstractmethod
- def pre_load_state_dict_transform(
- self,
- tensor: torch.Tensor,
- ) -> Tuple[torch.Tensor, List[Shard]]:
- """
- This is to be called before loading a *sharded* model state dict and
- should return the tensor and list of shards from which to load data.
- """
- ...
- _extensions: Optional[FSDPExtensions] = None
- def _set_fsdp_extensions(flattener: FSDPExtensions) -> None:
- global _extensions
- _extensions = flattener
- def _ext_pre_flatten_transform(
- tensor: torch.Tensor,
- ) -> Tuple[torch.Tensor, Optional[Any]]:
- if _extensions is not None:
- new_tensor, extension = _extensions.pre_flatten_transform(tensor)
- if extension is not None:
- return new_tensor, extension
- return tensor, None
- def _ext_post_unflatten_transform(
- tensor: torch.Tensor,
- param_extension: Any,
- ) -> torch.Tensor:
- if _extensions is not None and param_extension is not None:
- return _extensions.post_unflatten_transform(tensor, param_extension)
- return tensor
- def _ext_chunk_tensor(
- tensor: torch.Tensor,
- rank: int,
- world_size: int,
- num_devices_per_node: int,
- pg: dist.ProcessGroup,
- ) -> torch.Tensor:
- chunk_tensor_fn = (
- _extensions.chunk_tensor
- if _extensions is not None
- else _create_chunk_sharded_tensor
- )
- return chunk_tensor_fn(
- tensor,
- rank,
- world_size,
- num_devices_per_node,
- pg,
- )
- def _ext_pre_load_state_dict_transform(
- tensor: torch.Tensor,
- ) -> Tuple[torch.Tensor, List[Shard]]:
- if _extensions is not None:
- return _extensions.pre_load_state_dict_transform(tensor)
- assert type(tensor) is ShardedTensor
- shards = tensor.local_shards()
- return (tensor, shards)
|