123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351 |
- from typing import (
- List,
- Callable,
- Optional,
- Union,
- TypeVar,
- Dict,
- Any,
- cast,
- Sequence,
- )
- import torch.distributed as dist
- from .api import (
- CheckpointException,
- _wrap_exception,
- _is_wrapped_exception,
- WRAPPED_EXCEPTION,
- )
- import torch
- from torch.distributed._shard.sharded_tensor import (
- ShardedTensor,
- )
- from torch.distributed._shard.sharded_tensor.shard import Shard
- from .metadata import (
- STATE_DICT_TYPE,
- MetadataIndex,
- )
- __all__ = ["find_tensor_shard", "find_state_dict_object"]
- T = TypeVar("T")
- R = TypeVar("R")
- def _get_failure_dict(
- results: List[Union[T, WRAPPED_EXCEPTION]]
- ) -> Dict[int, WRAPPED_EXCEPTION]:
- return cast(
- Dict[int, WRAPPED_EXCEPTION],
- {i: err for i, err in enumerate(results) if _is_wrapped_exception(err)},
- )
- class _DistWrapper:
- """
- This is a wrapper around PG that provides a series of features around object collectives.
- It works without distributed initialized, where most collectives turns into nops.
- All variants that take functions are exception robust, meaning that if one or more
- ranks raise errors, all ranks will observe those.
- """
- def __init__(
- self,
- group: Optional[dist.ProcessGroup],
- use_dist: bool,
- coordinator_rank: int,
- ):
- self.group = group
- self.use_dist = use_dist
- self.coordinator_rank = coordinator_rank
- if self.use_dist:
- self.rank = dist.get_rank(group)
- self.is_coordinator = self.rank == coordinator_rank
- else:
- self.rank = 0
- self.is_coordinator = True
- def get_rank(self) -> int:
- return self.rank
- def get_world_size(self) -> int:
- if self.use_dist:
- return dist.get_world_size(self.group)
- return 1
- def broadcast_object(self, object: Optional[T]) -> T:
- """
- Same as c10d::broadcast_object_list but works without distributed enabled.
- """
- object_list = [object]
- if self.use_dist:
- dist.broadcast_object_list(
- object_list=object_list,
- group=self.group,
- src=self.coordinator_rank,
- )
- return cast(T, object_list[0])
- def gather_object(self, object: T) -> Optional[List[T]]:
- """
- Same as c10d::gather_object but works without distributed enabled.
- """
- if self.use_dist:
- gather_objs = (
- cast(List[T], [None] * dist.get_world_size(self.group))
- if self.is_coordinator
- else None
- )
- dist.gather_object(
- obj=object,
- object_gather_list=gather_objs if self.is_coordinator else None,
- dst=self.coordinator_rank,
- group=self.group,
- )
- result = gather_objs
- else:
- result = [object]
- return result
- def all_gather_object(self, object: T) -> List[T]:
- """
- Same as c10d::all_gather_object but works without distributed enabled.
- """
- if self.use_dist:
- gather_objs = cast(
- List[T], [None] * dist.get_world_size(self.group)
- )
- dist.all_gather_object(
- object_list=gather_objs, obj=object, group=self.group
- )
- else:
- gather_objs = [object]
- return gather_objs
- def scatter_object(self, object_list: Optional[List[T]]) -> T:
- """
- Same as c10d::scatter_object but works without distributed enabled.
- """
- if self.use_dist:
- gather_result = cast(List[T], [None])
- dist.scatter_object_list(
- scatter_object_output_list=gather_result,
- scatter_object_input_list=object_list
- if self.is_coordinator
- else None,
- src=self.coordinator_rank,
- group=self.group,
- )
- local_reply = gather_result[0]
- else:
- assert object_list is not None
- local_reply = object_list[0]
- return local_reply
- def reduce_scatter(
- self,
- step: str,
- map_fun: Callable[[], T],
- reduce_fun: Callable[[List[T]], List[R]],
- ) -> R:
- """
- Compute a value on each rank, then do centralized reduce on a single rank, followed by a scatter.
- This method operates in the following way:
- Run ``map_fun`` on all ranks
- Gather results on rank 0
- Call ``reduce_fun`` on all those values
- Scatter to each rank part of the result.
- """
- local_data: Union[WRAPPED_EXCEPTION, T]
- try:
- local_data = map_fun()
- except BaseException as e:
- local_data = _wrap_exception(e)
- all_data = self.gather_object(local_data)
- all_results: Optional[List[Union[R, CheckpointException]]] = None
- if self.is_coordinator:
- assert all_data is not None
- node_failures = _get_failure_dict(all_data)
- if len(node_failures) == 0:
- try:
- # N.B. why can't mypy cast List[R] to List[Union[R, WRAPPED_EXCEPTION]]?
- all_results = cast(
- List[Union[R, CheckpointException]],
- reduce_fun(cast(List[T], all_data)),
- )
- except BaseException as e:
- node_failures[self.rank] = _wrap_exception(e)
- if len(node_failures) > 0:
- all_results = [
- CheckpointException(step, node_failures)
- ] * self.get_world_size()
- result = self.scatter_object(all_results)
- if isinstance(result, CheckpointException):
- raise result
- return result
- def all_reduce(
- self,
- step: str,
- map_fun: Callable[[], T],
- reduce_fun: Callable[[List[T]], R],
- ) -> R:
- """
- Compute a value on each rank, then do centralized reduce on a single rank, followed by a broadcast.
- This method operates in the following way:
- Run ``map_fun`` on all ranks
- Gather results on rank 0
- Call ``reduce_fun`` on all those values
- Broadcast the reduced value to all ranks.
- """
- local_data: Union[T, WRAPPED_EXCEPTION]
- try:
- local_data = map_fun()
- except BaseException as e:
- local_data = _wrap_exception(e)
- all_data = self.gather_object(local_data)
- result: Optional[Union[R, CheckpointException]] = None
- if self.is_coordinator:
- assert all_data is not None
- node_failures = _get_failure_dict(all_data)
- if len(node_failures) == 0:
- try:
- result = reduce_fun(cast(List[T], all_data))
- except BaseException as e:
- node_failures[self.rank] = _wrap_exception(e)
- if len(node_failures) > 0:
- result = CheckpointException(step, node_failures)
- final_result = self.broadcast_object(result)
- if isinstance(final_result, CheckpointException):
- raise final_result
- return cast(R, final_result)
- def all_gather(
- self,
- step: str,
- map_fun: Callable[[], T],
- ) -> List[T]:
- """
- Compute a value on each rank, then all_gather them.
- This method operates in the following way:
- Run ``map_cp`` on all ranks
- all_gather the values to all ranks
- """
- result: Union[T, WRAPPED_EXCEPTION]
- try:
- result = map_fun()
- except BaseException as e:
- result = _wrap_exception(e)
- all_results = self.all_gather_object(result)
- node_failures = _get_failure_dict(all_results)
- if len(node_failures) > 0:
- raise CheckpointException(step, node_failures)
- return cast(List[T], all_results)
- def broadcast(
- self,
- step: str,
- map_fun: Callable[[], T],
- ) -> T:
- """
- Compute a value on rank 0 and broadcast it.
- This method operates in the following way:
- Run ``map_cp`` on rank 0
- broadcast the value
- """
- result: Optional[Union[T, CheckpointException]] = None
- if self.is_coordinator:
- try:
- result = map_fun()
- except BaseException as e:
- result = CheckpointException(
- step, {self.rank: _wrap_exception(e)}
- )
- final_result = self.broadcast_object(result)
- if isinstance(final_result, CheckpointException):
- raise final_result
- return cast(T, final_result)
- def _find_shard(tensor: ShardedTensor, index: MetadataIndex) -> Shard:
- if index.offset is None:
- raise ValueError(
- f"Cannot lookup {index.fqn} since its a ShardedTensor and no offset was provided"
- )
- shards = tensor.local_shards()
- # index fast path
- if index.index is not None:
- if (
- len(shards) > index.index
- and torch.Size(shards[index.index].metadata.shard_offsets)
- == index.offset
- ):
- return shards[index.index]
- for shard in shards:
- if torch.Size(shard.metadata.shard_offsets) == index.offset:
- return shard
- raise ValueError(
- f"Could not find shard at '{index.offset}' for FQN: '{index.fqn}'"
- )
- def find_tensor_shard(
- tensor: torch.Tensor, index: MetadataIndex
- ) -> torch.Tensor:
- if isinstance(tensor, ShardedTensor):
- return _find_shard(tensor, index).tensor
- if index.offset is not None:
- # special case looking up a tensor by origin
- if index.offset == torch.Size([0] * len(tensor.size())):
- return tensor
- raise ValueError(
- f"FQN: '{index.fqn}' is not a ShardedTensor, can't find by offset: '{index.offset}'"
- )
- return tensor
- def find_state_dict_object(
- state_dict: STATE_DICT_TYPE, index: MetadataIndex
- ) -> Any:
- if index.fqn not in state_dict:
- raise ValueError(f"Could not find FQN: '{index.fqn}'")
- obj = state_dict[index.fqn]
- if isinstance(obj, torch.Tensor):
- return find_tensor_shard(obj, index)
- elif index.offset is not None:
- raise ValueError(
- f"FQN: '{index.fqn}' is not a ShardedTensor, can't find by offset: '{index.offset}'"
- )
- return obj
- def _element_wise_add(a: Sequence[int], b: Sequence[int]) -> List[int]:
- return [i_a + i_b for i_a, i_b in zip(a, b)]
- def _element_wise_sub(a: Sequence[int], b: Sequence[int]) -> List[int]:
- return [i_a - i_b for i_a, i_b in zip(a, b)]
|