_fsdp_extensions.py 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115
  1. from abc import ABC, abstractmethod
  2. from typing import Any, List, Optional, Tuple
  3. import torch
  4. import torch.distributed as dist
  5. from torch.distributed._shard.sharded_tensor.api import ShardedTensor
  6. from torch.distributed._shard.sharded_tensor.shard import Shard
  7. from torch.distributed.fsdp._shard_utils import _create_chunk_sharded_tensor
  8. class FSDPExtensions(ABC):
  9. """
  10. This enables some customizable hooks to enable composability with tensor
  11. parallelism. To activate these hooks, use :func:`_set_fsdp_extensions` to
  12. set a custom :class:`FSDPExtensions` that implements the hooks.
  13. """
  14. @abstractmethod
  15. def pre_flatten_transform(
  16. self,
  17. tensor: torch.Tensor,
  18. ) -> Tuple[torch.Tensor, Optional[Any]]:
  19. """E.g. converting ``DistributedTensor`` to local tensor."""
  20. ...
  21. @abstractmethod
  22. def post_unflatten_transform(
  23. self,
  24. tensor: torch.Tensor,
  25. param_extension: Any,
  26. ) -> torch.Tensor:
  27. """E.g. converting local tensor to ``DistributedTensor``."""
  28. ...
  29. @abstractmethod
  30. def chunk_tensor(
  31. self,
  32. tensor: torch.Tensor,
  33. rank: int,
  34. world_size: int,
  35. num_devices_per_node: int,
  36. pg: dist.ProcessGroup,
  37. ) -> torch.Tensor:
  38. """Shards a tensor to chunks and returns the local chunk."""
  39. ...
  40. @abstractmethod
  41. def pre_load_state_dict_transform(
  42. self,
  43. tensor: torch.Tensor,
  44. ) -> Tuple[torch.Tensor, List[Shard]]:
  45. """
  46. This is to be called before loading a *sharded* model state dict and
  47. should return the tensor and list of shards from which to load data.
  48. """
  49. ...
  50. _extensions: Optional[FSDPExtensions] = None
  51. def _set_fsdp_extensions(flattener: FSDPExtensions) -> None:
  52. global _extensions
  53. _extensions = flattener
  54. def _ext_pre_flatten_transform(
  55. tensor: torch.Tensor,
  56. ) -> Tuple[torch.Tensor, Optional[Any]]:
  57. if _extensions is not None:
  58. new_tensor, extension = _extensions.pre_flatten_transform(tensor)
  59. if extension is not None:
  60. return new_tensor, extension
  61. return tensor, None
  62. def _ext_post_unflatten_transform(
  63. tensor: torch.Tensor,
  64. param_extension: Any,
  65. ) -> torch.Tensor:
  66. if _extensions is not None and param_extension is not None:
  67. return _extensions.post_unflatten_transform(tensor, param_extension)
  68. return tensor
  69. def _ext_chunk_tensor(
  70. tensor: torch.Tensor,
  71. rank: int,
  72. world_size: int,
  73. num_devices_per_node: int,
  74. pg: dist.ProcessGroup,
  75. ) -> torch.Tensor:
  76. chunk_tensor_fn = (
  77. _extensions.chunk_tensor
  78. if _extensions is not None
  79. else _create_chunk_sharded_tensor
  80. )
  81. return chunk_tensor_fn(
  82. tensor,
  83. rank,
  84. world_size,
  85. num_devices_per_node,
  86. pg,
  87. )
  88. def _ext_pre_load_state_dict_transform(
  89. tensor: torch.Tensor,
  90. ) -> Tuple[torch.Tensor, List[Shard]]:
  91. if _extensions is not None:
  92. return _extensions.pre_load_state_dict_transform(tensor)
  93. assert type(tensor) is ShardedTensor
  94. shards = tensor.local_shards()
  95. return (tensor, shards)