utils.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351
  1. from typing import (
  2. List,
  3. Callable,
  4. Optional,
  5. Union,
  6. TypeVar,
  7. Dict,
  8. Any,
  9. cast,
  10. Sequence,
  11. )
  12. import torch.distributed as dist
  13. from .api import (
  14. CheckpointException,
  15. _wrap_exception,
  16. _is_wrapped_exception,
  17. WRAPPED_EXCEPTION,
  18. )
  19. import torch
  20. from torch.distributed._shard.sharded_tensor import (
  21. ShardedTensor,
  22. )
  23. from torch.distributed._shard.sharded_tensor.shard import Shard
  24. from .metadata import (
  25. STATE_DICT_TYPE,
  26. MetadataIndex,
  27. )
  28. __all__ = ["find_tensor_shard", "find_state_dict_object"]
  29. T = TypeVar("T")
  30. R = TypeVar("R")
  31. def _get_failure_dict(
  32. results: List[Union[T, WRAPPED_EXCEPTION]]
  33. ) -> Dict[int, WRAPPED_EXCEPTION]:
  34. return cast(
  35. Dict[int, WRAPPED_EXCEPTION],
  36. {i: err for i, err in enumerate(results) if _is_wrapped_exception(err)},
  37. )
  38. class _DistWrapper:
  39. """
  40. This is a wrapper around PG that provides a series of features around object collectives.
  41. It works without distributed initialized, where most collectives turns into nops.
  42. All variants that take functions are exception robust, meaning that if one or more
  43. ranks raise errors, all ranks will observe those.
  44. """
  45. def __init__(
  46. self,
  47. group: Optional[dist.ProcessGroup],
  48. use_dist: bool,
  49. coordinator_rank: int,
  50. ):
  51. self.group = group
  52. self.use_dist = use_dist
  53. self.coordinator_rank = coordinator_rank
  54. if self.use_dist:
  55. self.rank = dist.get_rank(group)
  56. self.is_coordinator = self.rank == coordinator_rank
  57. else:
  58. self.rank = 0
  59. self.is_coordinator = True
  60. def get_rank(self) -> int:
  61. return self.rank
  62. def get_world_size(self) -> int:
  63. if self.use_dist:
  64. return dist.get_world_size(self.group)
  65. return 1
  66. def broadcast_object(self, object: Optional[T]) -> T:
  67. """
  68. Same as c10d::broadcast_object_list but works without distributed enabled.
  69. """
  70. object_list = [object]
  71. if self.use_dist:
  72. dist.broadcast_object_list(
  73. object_list=object_list,
  74. group=self.group,
  75. src=self.coordinator_rank,
  76. )
  77. return cast(T, object_list[0])
  78. def gather_object(self, object: T) -> Optional[List[T]]:
  79. """
  80. Same as c10d::gather_object but works without distributed enabled.
  81. """
  82. if self.use_dist:
  83. gather_objs = (
  84. cast(List[T], [None] * dist.get_world_size(self.group))
  85. if self.is_coordinator
  86. else None
  87. )
  88. dist.gather_object(
  89. obj=object,
  90. object_gather_list=gather_objs if self.is_coordinator else None,
  91. dst=self.coordinator_rank,
  92. group=self.group,
  93. )
  94. result = gather_objs
  95. else:
  96. result = [object]
  97. return result
  98. def all_gather_object(self, object: T) -> List[T]:
  99. """
  100. Same as c10d::all_gather_object but works without distributed enabled.
  101. """
  102. if self.use_dist:
  103. gather_objs = cast(
  104. List[T], [None] * dist.get_world_size(self.group)
  105. )
  106. dist.all_gather_object(
  107. object_list=gather_objs, obj=object, group=self.group
  108. )
  109. else:
  110. gather_objs = [object]
  111. return gather_objs
  112. def scatter_object(self, object_list: Optional[List[T]]) -> T:
  113. """
  114. Same as c10d::scatter_object but works without distributed enabled.
  115. """
  116. if self.use_dist:
  117. gather_result = cast(List[T], [None])
  118. dist.scatter_object_list(
  119. scatter_object_output_list=gather_result,
  120. scatter_object_input_list=object_list
  121. if self.is_coordinator
  122. else None,
  123. src=self.coordinator_rank,
  124. group=self.group,
  125. )
  126. local_reply = gather_result[0]
  127. else:
  128. assert object_list is not None
  129. local_reply = object_list[0]
  130. return local_reply
  131. def reduce_scatter(
  132. self,
  133. step: str,
  134. map_fun: Callable[[], T],
  135. reduce_fun: Callable[[List[T]], List[R]],
  136. ) -> R:
  137. """
  138. Compute a value on each rank, then do centralized reduce on a single rank, followed by a scatter.
  139. This method operates in the following way:
  140. Run ``map_fun`` on all ranks
  141. Gather results on rank 0
  142. Call ``reduce_fun`` on all those values
  143. Scatter to each rank part of the result.
  144. """
  145. local_data: Union[WRAPPED_EXCEPTION, T]
  146. try:
  147. local_data = map_fun()
  148. except BaseException as e:
  149. local_data = _wrap_exception(e)
  150. all_data = self.gather_object(local_data)
  151. all_results: Optional[List[Union[R, CheckpointException]]] = None
  152. if self.is_coordinator:
  153. assert all_data is not None
  154. node_failures = _get_failure_dict(all_data)
  155. if len(node_failures) == 0:
  156. try:
  157. # N.B. why can't mypy cast List[R] to List[Union[R, WRAPPED_EXCEPTION]]?
  158. all_results = cast(
  159. List[Union[R, CheckpointException]],
  160. reduce_fun(cast(List[T], all_data)),
  161. )
  162. except BaseException as e:
  163. node_failures[self.rank] = _wrap_exception(e)
  164. if len(node_failures) > 0:
  165. all_results = [
  166. CheckpointException(step, node_failures)
  167. ] * self.get_world_size()
  168. result = self.scatter_object(all_results)
  169. if isinstance(result, CheckpointException):
  170. raise result
  171. return result
  172. def all_reduce(
  173. self,
  174. step: str,
  175. map_fun: Callable[[], T],
  176. reduce_fun: Callable[[List[T]], R],
  177. ) -> R:
  178. """
  179. Compute a value on each rank, then do centralized reduce on a single rank, followed by a broadcast.
  180. This method operates in the following way:
  181. Run ``map_fun`` on all ranks
  182. Gather results on rank 0
  183. Call ``reduce_fun`` on all those values
  184. Broadcast the reduced value to all ranks.
  185. """
  186. local_data: Union[T, WRAPPED_EXCEPTION]
  187. try:
  188. local_data = map_fun()
  189. except BaseException as e:
  190. local_data = _wrap_exception(e)
  191. all_data = self.gather_object(local_data)
  192. result: Optional[Union[R, CheckpointException]] = None
  193. if self.is_coordinator:
  194. assert all_data is not None
  195. node_failures = _get_failure_dict(all_data)
  196. if len(node_failures) == 0:
  197. try:
  198. result = reduce_fun(cast(List[T], all_data))
  199. except BaseException as e:
  200. node_failures[self.rank] = _wrap_exception(e)
  201. if len(node_failures) > 0:
  202. result = CheckpointException(step, node_failures)
  203. final_result = self.broadcast_object(result)
  204. if isinstance(final_result, CheckpointException):
  205. raise final_result
  206. return cast(R, final_result)
  207. def all_gather(
  208. self,
  209. step: str,
  210. map_fun: Callable[[], T],
  211. ) -> List[T]:
  212. """
  213. Compute a value on each rank, then all_gather them.
  214. This method operates in the following way:
  215. Run ``map_cp`` on all ranks
  216. all_gather the values to all ranks
  217. """
  218. result: Union[T, WRAPPED_EXCEPTION]
  219. try:
  220. result = map_fun()
  221. except BaseException as e:
  222. result = _wrap_exception(e)
  223. all_results = self.all_gather_object(result)
  224. node_failures = _get_failure_dict(all_results)
  225. if len(node_failures) > 0:
  226. raise CheckpointException(step, node_failures)
  227. return cast(List[T], all_results)
  228. def broadcast(
  229. self,
  230. step: str,
  231. map_fun: Callable[[], T],
  232. ) -> T:
  233. """
  234. Compute a value on rank 0 and broadcast it.
  235. This method operates in the following way:
  236. Run ``map_cp`` on rank 0
  237. broadcast the value
  238. """
  239. result: Optional[Union[T, CheckpointException]] = None
  240. if self.is_coordinator:
  241. try:
  242. result = map_fun()
  243. except BaseException as e:
  244. result = CheckpointException(
  245. step, {self.rank: _wrap_exception(e)}
  246. )
  247. final_result = self.broadcast_object(result)
  248. if isinstance(final_result, CheckpointException):
  249. raise final_result
  250. return cast(T, final_result)
  251. def _find_shard(tensor: ShardedTensor, index: MetadataIndex) -> Shard:
  252. if index.offset is None:
  253. raise ValueError(
  254. f"Cannot lookup {index.fqn} since its a ShardedTensor and no offset was provided"
  255. )
  256. shards = tensor.local_shards()
  257. # index fast path
  258. if index.index is not None:
  259. if (
  260. len(shards) > index.index
  261. and torch.Size(shards[index.index].metadata.shard_offsets)
  262. == index.offset
  263. ):
  264. return shards[index.index]
  265. for shard in shards:
  266. if torch.Size(shard.metadata.shard_offsets) == index.offset:
  267. return shard
  268. raise ValueError(
  269. f"Could not find shard at '{index.offset}' for FQN: '{index.fqn}'"
  270. )
  271. def find_tensor_shard(
  272. tensor: torch.Tensor, index: MetadataIndex
  273. ) -> torch.Tensor:
  274. if isinstance(tensor, ShardedTensor):
  275. return _find_shard(tensor, index).tensor
  276. if index.offset is not None:
  277. # special case looking up a tensor by origin
  278. if index.offset == torch.Size([0] * len(tensor.size())):
  279. return tensor
  280. raise ValueError(
  281. f"FQN: '{index.fqn}' is not a ShardedTensor, can't find by offset: '{index.offset}'"
  282. )
  283. return tensor
  284. def find_state_dict_object(
  285. state_dict: STATE_DICT_TYPE, index: MetadataIndex
  286. ) -> Any:
  287. if index.fqn not in state_dict:
  288. raise ValueError(f"Could not find FQN: '{index.fqn}'")
  289. obj = state_dict[index.fqn]
  290. if isinstance(obj, torch.Tensor):
  291. return find_tensor_shard(obj, index)
  292. elif index.offset is not None:
  293. raise ValueError(
  294. f"FQN: '{index.fqn}' is not a ShardedTensor, can't find by offset: '{index.offset}'"
  295. )
  296. return obj
  297. def _element_wise_add(a: Sequence[int], b: Sequence[int]) -> List[int]:
  298. return [i_a + i_b for i_a, i_b in zip(a, b)]
  299. def _element_wise_sub(a: Sequence[int], b: Sequence[int]) -> List[int]:
  300. return [i_a - i_b for i_a, i_b in zip(a, b)]