common_fsdp.py 39 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084
  1. # Owner(s): ["oncall: distributed"]
  2. import itertools
  3. import sys
  4. from abc import ABC, abstractmethod
  5. from contextlib import suppress
  6. from copy import deepcopy
  7. from enum import auto, Enum
  8. from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
  9. from unittest import mock
  10. import torch
  11. import torch.distributed as dist
  12. import torch.nn as nn
  13. from torch.distributed.fsdp import CPUOffload, FullyShardedDataParallel as FSDP
  14. from torch.distributed.fsdp._common_utils import TrainingState
  15. from torch.distributed.fsdp.fully_sharded_data_parallel import (
  16. BackwardPrefetch,
  17. MixedPrecision,
  18. ShardingStrategy,
  19. )
  20. from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
  21. from torch.distributed.fsdp.wrap import always_wrap_policy, ModuleWrapPolicy, wrap
  22. from torch.nn import TransformerDecoderLayer, TransformerEncoderLayer
  23. from torch.nn.parallel.distributed import DistributedDataParallel as DDP
  24. from torch.testing._internal.common_distributed import MultiProcessTestCase, TEST_SKIPS
  25. from torch.testing._internal.common_utils import FILE_SCHEMA, get_cycles_per_ms
  26. class FSDPInitMode(Enum):
  27. # No FSDP wrapping
  28. NO_FSDP = auto()
  29. # FSDP recursive wrapping
  30. RECURSIVE = auto()
  31. # TODO: FSDP non-recursive wrapping
  32. # NONRECURSIVE = auto()
  33. class CUDAInitMode(Enum):
  34. # Move model to CUDA before passing to the FSDP constructor
  35. CUDA_BEFORE = auto()
  36. # Move model to CUDA after passing to the FSDP constructor
  37. CUDA_AFTER = auto()
  38. # Keep on CPU
  39. CUDA_NEVER = auto()
  40. class FSDPTestModel(nn.Module, ABC):
  41. """This defines the interface expected from all models used commonly for
  42. FSDP unit tests."""
  43. @abstractmethod
  44. def get_input(self, device) -> Tuple[torch.Tensor, ...]:
  45. """Returns an input for the model as as tuple."""
  46. ...
  47. @abstractmethod
  48. def get_loss(self, input, output) -> torch.Tensor:
  49. """Returns the loss given the input and output."""
  50. ...
  51. @abstractmethod
  52. def run_backward(self, loss) -> None:
  53. """Runs the backward pass (e.g. including ``loss.backward()``)."""
  54. ...
  55. @staticmethod
  56. @abstractmethod
  57. def init(
  58. group: dist.ProcessGroup,
  59. fsdp_init_mode: FSDPInitMode,
  60. *init_args: Any,
  61. cuda_init_mode: CUDAInitMode,
  62. fsdp_kwargs: Optional[Dict[str, Any]] = None,
  63. deterministic: bool = False,
  64. **init_kwargs: Any,
  65. ) -> nn.Module:
  66. """Initializes an instance of this model."""
  67. ...
  68. def _assert_module_states(
  69. model: nn.Module,
  70. process_group: dist.ProcessGroup,
  71. assert_fn: Callable,
  72. ):
  73. """
  74. All-gathers module states across ranks and calls ``assert_fn`` on each pair
  75. of corresponding states from rank 0 and a nonzero rank. For example, if
  76. ``assert_fn`` is ``self.assertEqual()``, then this checks that all module
  77. states are equal across ranks.
  78. """
  79. # Include names for debugging convenience
  80. named_module_states = [
  81. (param_name, param.detach().cpu())
  82. for param_name, param in model.named_parameters()
  83. ]
  84. named_module_states += [
  85. (buffer_name, buffer.detach().cpu())
  86. for buffer_name, buffer in model.named_buffers()
  87. ]
  88. world_size = dist.get_world_size(process_group)
  89. olist = [None for _ in range(world_size)]
  90. dist.all_gather_object(olist, named_module_states, group=process_group)
  91. rank0_states = olist[0]
  92. for state in olist[1:]:
  93. for (_, p1), (_, p2) in zip(rank0_states, state):
  94. assert_fn(p1, p2)
  95. def _zero_model(
  96. model: nn.Module,
  97. zero_buffers: bool = False,
  98. summon_full=True,
  99. ):
  100. """Zeros the parameters and optionally buffers of ``model`` in place."""
  101. ctx = FSDP.summon_full_params(model) if summon_full else suppress()
  102. with ctx:
  103. for param in model.parameters():
  104. with torch.no_grad():
  105. param.zero_()
  106. if zero_buffers:
  107. for buffer in model.buffers():
  108. with torch.no_grad():
  109. buffer.zero_()
  110. def _get_state_dict(model, cpu_offload=False, half=False):
  111. if not cpu_offload:
  112. model = model.cuda()
  113. if half:
  114. model.half()
  115. return model.state_dict()
  116. def subtest_name(test_name_mapping, *args):
  117. return "_".join(
  118. [test_name_mapping[str(s)] if s is not None else "none" for s in args]
  119. )
  120. def _broadcast_state_dict(rank, state_dict):
  121. # For non-FSDP roots, some parts of the model state on rank 0 may
  122. # not be on CPU, so we move everything to CPU to avoid issues like:
  123. # https://github.com/pytorch/pytorch/issues/77113.
  124. for param_name, param in state_dict.items():
  125. if param.device != torch.device("cpu"):
  126. state_dict[param_name] = param.cpu()
  127. olist = [state_dict if rank == 0 else None]
  128. dist.broadcast_object_list(olist)
  129. state_dict = olist[0]
  130. # Ensure that the state is on CUDA
  131. for param_name in state_dict.keys():
  132. state_dict[param_name] = state_dict[param_name].cuda()
  133. return state_dict
  134. def get_full_params(model: nn.Module, recurse: bool = True):
  135. """
  136. Returns the full unsharded parameters of ``model``. Any FSDP-managed
  137. parameters offloaded to CPU are moved to GPU in the returned list.
  138. Args:
  139. recurse (bool): If ``False``, only unshards the parameters immediate to
  140. ``model``; if ``True``, recurses through the module hierarchy
  141. rooted at ``model``.
  142. """
  143. with FSDP.summon_full_params(model, recurse=recurse):
  144. return deepcopy(list(model.parameters()))
  145. def _maybe_cuda(model: nn.Module, move_to_cuda: bool):
  146. return model.cuda() if move_to_cuda else model
  147. def _maybe_wrap_fsdp(model: nn.Module, wrap_fsdp: bool, *args, **kwargs):
  148. return model if not wrap_fsdp else FSDP(model, *args, **kwargs)
  149. class DummyProcessGroup:
  150. def __init__(self, rank: int, size: int):
  151. self._rank = rank
  152. self._size = size
  153. def rank(self) -> int:
  154. return self._rank
  155. def size(self) -> int:
  156. return self._size
  157. def allreduce(self, *args, **kwargs):
  158. dist_wait = mock.Mock()
  159. def get_future():
  160. future = torch.futures.Future()
  161. future.set_result(1)
  162. return future
  163. dist_wait.get_future = get_future
  164. return dist_wait
  165. class TransformerWithSharedParams(FSDPTestModel):
  166. def __init__(
  167. self,
  168. group: dist.ProcessGroup,
  169. cuda_init_mode: CUDAInitMode,
  170. add_bn: bool,
  171. deterministic: bool,
  172. ):
  173. super().__init__()
  174. self.rank = group.rank()
  175. self.world_size = group.size()
  176. if deterministic:
  177. torch.manual_seed(0)
  178. d_vocab = 23
  179. d_model = 16
  180. self.embed_tokens = nn.Embedding(d_vocab, d_model)
  181. self.transformer = nn.Transformer(
  182. d_model=d_model,
  183. num_encoder_layers=2,
  184. num_decoder_layers=2,
  185. dim_feedforward=8,
  186. dropout=0.1,
  187. )
  188. self.output_proj = nn.Linear(d_model, d_vocab)
  189. # share the embedding and output projection weights
  190. self.output_proj.weight = self.embed_tokens.weight
  191. self.register_buffer(
  192. "vocab_bias", self.embed_tokens.weight.new_ones((d_model,))
  193. )
  194. self.register_buffer(
  195. "long_buffer",
  196. torch.zeros_like(self.vocab_bias, dtype=torch.long),
  197. ) # type: ignore[arg-type]
  198. self.bs = 2
  199. self.bn = torch.nn.BatchNorm1d(self.bs) if add_bn else torch.nn.Identity()
  200. if cuda_init_mode == CUDAInitMode.CUDA_BEFORE:
  201. self = self.cuda()
  202. if deterministic:
  203. self.eval()
  204. def get_input(self, device):
  205. torch.manual_seed(1 + self.rank) # keep everything deterministic
  206. src = torch.arange(12, device=device).view(6, self.bs) # T x B
  207. tgt = torch.arange(self.bs * 4, device=device).view(4, self.bs) # T x B
  208. return (src, tgt)
  209. def forward(self, src_ids, tgt_ids):
  210. src = self.embed_tokens(src_ids)
  211. src = src + self.vocab_bias + self.long_buffer.type_as(src) # type: ignore[operator]
  212. tgt = self.embed_tokens(tgt_ids)
  213. tgt = self.bn(tgt)
  214. x = self.transformer(src, tgt)
  215. return self.output_proj(x)
  216. def get_loss(self, input, output):
  217. _, tgt = input
  218. return nn.functional.cross_entropy(
  219. output.view(-1, output.size(-1)), tgt.view(-1), reduction="sum"
  220. )
  221. def run_backward(self, loss):
  222. loss.backward()
  223. @staticmethod
  224. def init(
  225. group: dist.ProcessGroup,
  226. fsdp_init_mode: FSDPInitMode,
  227. cuda_init_mode: CUDAInitMode,
  228. fsdp_kwargs: Optional[Dict[str, Any]] = None,
  229. deterministic: bool = False,
  230. add_bn: bool = True,
  231. ) -> Union[nn.Module, FSDP]:
  232. """
  233. Initializes a :class:`TransformerWithSharedParams` instance.
  234. Args:
  235. fsdp_init_mode (FSDPInitMode): If ``NO_FSDP``, then does not wrap
  236. any modules with FSDP. If ``RECURSIVE``, then wraps with
  237. top-level FSDP. By default, the top-level FSDP uses the
  238. ``ModuleWrapPolicy`` for encoder and decoder layers, but a
  239. different auto wrap policy may be specified via
  240. ``fsdp_kwargs``.
  241. cuda_init_mode (CUDAInitMode): Determines model movement to CUDA.
  242. fsdp_kwargs (Optional[Dict[str, Any]]): Optional keyword arguments
  243. forwarded to the FSDP constructor.
  244. deterministic (bool): Whether to make the model deterministic
  245. across constructions.
  246. add_bn (bool): Whether to include batch norm in the model.
  247. """
  248. if fsdp_kwargs is None:
  249. fsdp_kwargs = {}
  250. if fsdp_init_mode == FSDPInitMode.NO_FSDP:
  251. if isinstance(group, tuple):
  252. pg = group[0]
  253. else:
  254. pg = group
  255. return TransformerWithSharedParams(
  256. pg, cuda_init_mode, add_bn, deterministic
  257. )
  258. elif fsdp_init_mode == FSDPInitMode.RECURSIVE:
  259. # Default to the `ModuleWrapPolicy`
  260. if "auto_wrap_policy" not in fsdp_kwargs:
  261. auto_wrap_policy = ModuleWrapPolicy(
  262. {
  263. TransformerEncoderLayer,
  264. TransformerDecoderLayer,
  265. }
  266. )
  267. else:
  268. auto_wrap_policy = fsdp_kwargs.pop("auto_wrap_policy")
  269. if (
  270. "sharding_strategy" in fsdp_kwargs
  271. and fsdp_kwargs["sharding_strategy"]
  272. in {ShardingStrategy.HYBRID_SHARD, ShardingStrategy._HYBRID_SHARD_ZERO2}
  273. and not isinstance(group, tuple)
  274. ):
  275. fsdp_pg = None
  276. else:
  277. fsdp_pg = group
  278. if isinstance(group, tuple):
  279. tformer_pg = group[0]
  280. else:
  281. tformer_pg = group
  282. fsdp_model = FSDP(
  283. TransformerWithSharedParams(
  284. tformer_pg, cuda_init_mode, add_bn, deterministic
  285. ),
  286. fsdp_pg,
  287. auto_wrap_policy=auto_wrap_policy,
  288. **fsdp_kwargs,
  289. )
  290. if cuda_init_mode == CUDAInitMode.CUDA_AFTER:
  291. fsdp_model = fsdp_model.cuda()
  292. return fsdp_model
  293. raise ValueError(f"Unsupported FSDP init mode: {fsdp_init_mode}")
  294. def get_ignored_modules(self):
  295. return [self.transformer]
  296. class NestedWrappedModule(FSDPTestModel):
  297. def __init__(
  298. self,
  299. group: dist.ProcessGroup,
  300. wrap_fsdp: bool,
  301. cuda_init_mode: CUDAInitMode,
  302. deterministic: bool,
  303. **fsdp_kwargs,
  304. ):
  305. super().__init__()
  306. self.rank = group.rank()
  307. self.world_size = group.size()
  308. move_to_cuda = cuda_init_mode == CUDAInitMode.CUDA_BEFORE
  309. def _maybe_wrap(layer):
  310. if wrap_fsdp:
  311. return FSDP(layer, group, **fsdp_kwargs)
  312. return layer
  313. if deterministic:
  314. torch.manual_seed(0)
  315. self.module = nn.Sequential(
  316. _maybe_cuda(nn.Linear(8, 4), move_to_cuda),
  317. _maybe_wrap(
  318. nn.Sequential(
  319. _maybe_wrap(_maybe_cuda(nn.Linear(4, 16), move_to_cuda)),
  320. _maybe_cuda(nn.Linear(16, 16), move_to_cuda),
  321. ),
  322. ),
  323. _maybe_wrap(_maybe_cuda(nn.Linear(16, 4), move_to_cuda)),
  324. _maybe_cuda(nn.Linear(4, 8), move_to_cuda),
  325. )
  326. def get_input(self, device):
  327. torch.manual_seed(1 + self.rank) # keep everything deterministic
  328. return (torch.rand(4, 8, device=device),)
  329. def forward(self, x):
  330. return self.module(x)
  331. def get_loss(self, input, output):
  332. loss = output.sum()
  333. return loss
  334. def run_backward(self, loss):
  335. loss.backward()
  336. @staticmethod
  337. def init(
  338. group: dist.ProcessGroup,
  339. fsdp_init_mode: FSDPInitMode,
  340. cuda_init_mode: CUDAInitMode,
  341. fsdp_kwargs: Optional[Dict[str, Any]] = None,
  342. deterministic: bool = False,
  343. ) -> nn.Module:
  344. """
  345. Initializes a :class:`NestedWrappedModule` instance.
  346. Args:
  347. fsdp_init_mode (FSDPInitMode): If ``NO_FSDP``, then does not wrap
  348. any modules with FSDP. If ``RECURSIVE``, then wraps some nested
  349. modules with FSDP but not the top-level module. The model may
  350. later be wrapped with a top-level FSDP external to this method
  351. if desired.
  352. cuda_init_mode (CUDAInitMode): Determines model movement to CUDA.
  353. fsdp_kwargs (Optional[Dict[str, Any]]): Optional keyword arguments
  354. forwarded to the FSDP constructor.
  355. deterministic (bool): Whether to make the model deterministic
  356. across constructions.
  357. """
  358. if fsdp_kwargs is None:
  359. fsdp_kwargs = {}
  360. if fsdp_init_mode == FSDPInitMode.NO_FSDP:
  361. return NestedWrappedModule(
  362. group,
  363. wrap_fsdp=False,
  364. cuda_init_mode=cuda_init_mode,
  365. deterministic=deterministic,
  366. )
  367. elif fsdp_init_mode == FSDPInitMode.RECURSIVE:
  368. # Does not wrap with top-level FSDP
  369. fsdp_model = NestedWrappedModule(
  370. group,
  371. wrap_fsdp=True,
  372. cuda_init_mode=cuda_init_mode,
  373. deterministic=deterministic,
  374. **fsdp_kwargs,
  375. )
  376. if cuda_init_mode == CUDAInitMode.CUDA_AFTER:
  377. fsdp_model = fsdp_model.cuda()
  378. return fsdp_model
  379. raise ValueError(f"Unsupported FSDP init mode: {fsdp_init_mode}")
  380. class AlwaysWrapNestedWrappedModule(NestedWrappedModule):
  381. @staticmethod
  382. def init(
  383. group: dist.ProcessGroup,
  384. fsdp_init_mode: FSDPInitMode,
  385. cuda_init_mode: CUDAInitMode,
  386. fsdp_kwargs: Optional[Dict[str, Any]] = None,
  387. deterministic: bool = False,
  388. ):
  389. """
  390. Initializes a :class:`NestedWrappedModule` instance, but unlike
  391. :meth:`NestedWrappedModule.init`, for the ``RECURSIVE`` init mode, this
  392. wraps with top-level FSDP and the ``always_wrap_policy()`` auto wrap
  393. policy.
  394. """
  395. super_ = super(AlwaysWrapNestedWrappedModule, AlwaysWrapNestedWrappedModule)
  396. model = super_.init(
  397. group=group,
  398. fsdp_init_mode=FSDPInitMode.NO_FSDP,
  399. cuda_init_mode=cuda_init_mode,
  400. fsdp_kwargs=fsdp_kwargs,
  401. deterministic=deterministic,
  402. )
  403. if fsdp_init_mode == FSDPInitMode.NO_FSDP:
  404. return model
  405. elif fsdp_init_mode == FSDPInitMode.RECURSIVE:
  406. fsdp_model = FSDP(model, auto_wrap_policy=always_wrap_policy, **fsdp_kwargs)
  407. if cuda_init_mode == CUDAInitMode.CUDA_AFTER:
  408. fsdp_model = fsdp_model.cuda()
  409. return fsdp_model
  410. class ModuleWithDelay(FSDPTestModel):
  411. """This class wraps a :class:`FSDPTestModel` to optionally add a delay
  412. after computing the loss and/or before the gradient reduction."""
  413. def __init__(
  414. self,
  415. module: nn.Module,
  416. delay_after_loss_ms: int,
  417. delay_before_reduction_ms: int,
  418. ):
  419. super().__init__()
  420. self.delay_after_loss_ms = delay_after_loss_ms
  421. self.delay_before_reduction_ms = delay_before_reduction_ms
  422. self.module = module
  423. def get_input(self, device):
  424. return self.module.get_input(device)
  425. def forward(self, x):
  426. return self.module(x)
  427. def get_loss(self, input, output):
  428. loss = self.module.get_loss(input, output)
  429. if self.delay_after_loss_ms > 0:
  430. torch.cuda._sleep(int(self.delay_after_loss_ms * get_cycles_per_ms()))
  431. return loss
  432. def run_backward(self, loss):
  433. orig_reduce_scatter = torch.distributed.reduce_scatter_tensor
  434. def _delayed_reduce_scatter(*args, **kwargs):
  435. if self.delay_before_reduction_ms > 0:
  436. torch.cuda._sleep(
  437. int(self.delay_before_reduction_ms * get_cycles_per_ms())
  438. )
  439. return orig_reduce_scatter(*args, **kwargs)
  440. with mock.patch(
  441. "torch.distributed.reduce_scatter_tensor", _delayed_reduce_scatter
  442. ):
  443. self.module.run_backward(loss)
  444. @staticmethod
  445. def init(
  446. module_class: Type[FSDPTestModel],
  447. *model_args: Any,
  448. delay_after_loss_ms: int,
  449. delay_before_reduction_ms: int,
  450. **model_kwargs: Any,
  451. ):
  452. """
  453. Args:
  454. module_class (Type[FSDPTestModel]): Wrapped module class to which
  455. to add delays.
  456. model_args: Positional arguments forwarded to the ``module_class``
  457. ``init()``.
  458. delay_after_loss_ms (int): Delay after computing the loss/before
  459. the optimizer step (in ms).
  460. delay_before_reduction_ms (int): Delay before reduce-scattering
  461. gradients (in ms).
  462. model_kwargs: Keyword arguments forwarded to the ``module_class``
  463. ``init()``.
  464. """
  465. return ModuleWithDelay(
  466. module_class.init(*model_args, **model_kwargs),
  467. delay_after_loss_ms,
  468. delay_before_reduction_ms,
  469. )
  470. class NestedWrappedModuleWithDelay(ModuleWithDelay):
  471. @staticmethod
  472. def init(
  473. group: dist.ProcessGroup,
  474. fsdp_init_mode: FSDPInitMode,
  475. cuda_init_mode: CUDAInitMode = CUDAInitMode.CUDA_AFTER,
  476. fsdp_kwargs: Optional[Dict[str, Any]] = None,
  477. deterministic: bool = False,
  478. delay_after_loss_ms: int = 0,
  479. delay_before_reduction_ms: int = 0,
  480. ):
  481. return super(NestedWrappedModuleWithDelay, NestedWrappedModuleWithDelay).init(
  482. NestedWrappedModule,
  483. group=group,
  484. fsdp_init_mode=fsdp_init_mode,
  485. cuda_init_mode=cuda_init_mode,
  486. fsdp_kwargs=fsdp_kwargs,
  487. deterministic=deterministic,
  488. delay_after_loss_ms=delay_after_loss_ms,
  489. delay_before_reduction_ms=delay_before_reduction_ms,
  490. )
  491. class DummyDDP(nn.Module):
  492. def __init__(self, module):
  493. super().__init__()
  494. self.module = module
  495. def forward(self, *args, **kwargs):
  496. return self.module(*args, **kwargs)
  497. class MixtureOfExperts(NestedWrappedModule):
  498. def __init__(
  499. self,
  500. group: dist.ProcessGroup,
  501. wrap_fsdp: bool,
  502. cuda_init_mode: CUDAInitMode,
  503. delay_before_free_ms: int,
  504. deterministic: bool,
  505. **fsdp_kwargs,
  506. ):
  507. super().__init__(
  508. group=group,
  509. wrap_fsdp=wrap_fsdp,
  510. cuda_init_mode=cuda_init_mode,
  511. deterministic=deterministic,
  512. )
  513. self.group = group
  514. self.delay_before_free_ms = delay_before_free_ms
  515. self.wrap_fsdp = wrap_fsdp
  516. self.move_to_cuda = cuda_init_mode == CUDAInitMode.CUDA_BEFORE
  517. if deterministic:
  518. # Give each rank different expert parameters
  519. torch.manual_seed(42 + self.rank)
  520. d_expert = 23
  521. d_shared = 12
  522. d_input = 8
  523. expert = _maybe_cuda(nn.Linear(d_expert, d_shared), self.move_to_cuda)
  524. self.num_expert_params = sum([p.numel() for p in expert.parameters()])
  525. for p in expert.parameters():
  526. p.expert = True # type: ignore[attr-defined]
  527. if deterministic:
  528. # Keep all other parameters the same across ranks
  529. torch.manual_seed(0)
  530. shared = _maybe_cuda(nn.Linear(d_shared, d_expert), self.move_to_cuda)
  531. if wrap_fsdp:
  532. # we create a process group of size 1 for the expert params
  533. expert_group = torch.distributed.new_group(
  534. [group.rank()]
  535. ) # world size 1 means no shard
  536. expert = FSDP(expert, expert_group, **fsdp_kwargs) # type: ignore[assignment]
  537. shared = FSDP(shared, group, **fsdp_kwargs) # type: ignore[assignment]
  538. self.module = nn.Sequential(
  539. _maybe_cuda(nn.Linear(d_input, d_shared), self.move_to_cuda),
  540. shared,
  541. expert,
  542. _maybe_cuda(nn.Linear(d_shared, d_input), self.move_to_cuda),
  543. )
  544. def forward(self, x):
  545. if self.delay_before_free_ms > 0:
  546. expert = self.module[2]
  547. if isinstance(expert, FSDP):
  548. orig_reshard = torch.distributed.fsdp._runtime_utils._reshard
  549. def _delayed_reshard(*args, **kwargs):
  550. torch.cuda._sleep(
  551. int(self.delay_before_free_ms * get_cycles_per_ms())
  552. )
  553. return orig_reshard(*args, **kwargs)
  554. # This patch covers any `import torch..._reshard` uses.
  555. with mock.patch(
  556. "torch.distributed.fsdp._runtime_utils._reshard", _delayed_reshard
  557. ):
  558. return self.module(x)
  559. return self.module(x)
  560. def run_backward(self, loss):
  561. loss.backward()
  562. # Manually reduce gradients if not wrapped in FullyShardedDataParallel
  563. if not self.wrap_fsdp:
  564. with torch.no_grad():
  565. for p in self.parameters():
  566. if hasattr(p, "expert"):
  567. continue # these params don't need grad reduction
  568. p.grad.div_(self.world_size)
  569. torch.distributed.all_reduce(p.grad, group=self.group)
  570. @staticmethod
  571. def init(
  572. group: dist.ProcessGroup,
  573. fsdp_init_mode: FSDPInitMode,
  574. cuda_init_mode: CUDAInitMode,
  575. fsdp_kwargs: Optional[Dict[str, Any]] = None,
  576. deterministic: bool = False,
  577. delay_before_free_ms: int = 0,
  578. ):
  579. """
  580. Initializes a :class:`MixtureOfExperts` instance.
  581. Args:
  582. fsdp_init_mode (FSDPInitMode): If ``NO_FSDP``, then does not wrap
  583. any modules with FSDP. If ``RECURSIVE``, then wraps some nested
  584. modules with FSDP, including the expert and shared layers, but
  585. not the top-level module. The model may later be wrapped with a
  586. top-level FSDP external to this method if desired.
  587. cuda_init_mode (CUDAInitMode): Determines model movement to CUDA.
  588. fsdp_kwargs (Optional[Dict[str, Any]]): Optional keyword arguments
  589. forwarded to the FSDP constructor.
  590. deterministic (bool): Whether to make the model deterministic
  591. across constructions.
  592. delay_before_free_ms (int): Delay before resharding expert
  593. parameters in the forward pass (in ms).
  594. """
  595. if fsdp_kwargs is None:
  596. fsdp_kwargs = {}
  597. if fsdp_init_mode == FSDPInitMode.NO_FSDP:
  598. return MixtureOfExperts(
  599. group,
  600. wrap_fsdp=False,
  601. cuda_init_mode=cuda_init_mode,
  602. delay_before_free_ms=delay_before_free_ms,
  603. deterministic=deterministic,
  604. )
  605. elif fsdp_init_mode == FSDPInitMode.RECURSIVE:
  606. # Does not wrap with top-level FSDP
  607. fsdp_model = MixtureOfExperts(
  608. group,
  609. wrap_fsdp=True,
  610. cuda_init_mode=cuda_init_mode,
  611. delay_before_free_ms=delay_before_free_ms,
  612. deterministic=deterministic,
  613. **fsdp_kwargs,
  614. )
  615. if cuda_init_mode == CUDAInitMode.CUDA_AFTER:
  616. fsdp_model = fsdp_model.cuda()
  617. return fsdp_model
  618. raise ValueError(f"Unsupported FSDP init mode: {fsdp_init_mode}")
  619. class FSDPTest(MultiProcessTestCase):
  620. def setUp(self):
  621. super().setUp()
  622. self._spawn_processes()
  623. @property
  624. def world_size(self):
  625. return torch.cuda.device_count() if torch.cuda.is_available() else 4
  626. @property
  627. def process_group(self):
  628. return dist.distributed_c10d._get_default_group()
  629. @property
  630. def init_method(self):
  631. return "{}{file_name}".format(FILE_SCHEMA, file_name=self.file_name)
  632. def _check_cpu_offload(self, fsdp_model, cpu_offload):
  633. self.assertEqual(cpu_offload, fsdp_model.cpu_offload)
  634. def _check_backward_prefetch(self, fsdp_model, backward_prefetch):
  635. self.assertEqual(backward_prefetch, fsdp_model.backward_prefetch)
  636. def _check_forward_prefetch(self, fsdp_model, forward_prefetch):
  637. self.assertEqual(forward_prefetch, fsdp_model.forward_prefetch)
  638. def run_subtests(
  639. self,
  640. subtest_config: Dict[str, List[Any]],
  641. test_fn: Callable,
  642. *test_args,
  643. **test_kwargs: Any,
  644. ):
  645. """
  646. Runs a test function given by ``test_fn`` as a subtest according to the
  647. configurations specified by ``subtest_config``. This amortizes the
  648. costly setup overhead (including process spawn and initializing the
  649. process group) over the subtests.
  650. Args:
  651. subtest_config (Dict[str, List[Any]]): A mapping from subtest
  652. keyword argument name to a list of its possible values.
  653. test_fn (Callable): A callable that runs the actual test.
  654. test_args: Positional arguments to pass to ``test_fn``.
  655. test_kwargs: Keyword arguments to pass to ``test_fn``.
  656. """
  657. # Convert the config mapping to a list to have a fixed order
  658. subtest_config_items: List[Tuple[str, List[Any]]] = list(subtest_config.items())
  659. subtest_config_keys: List[str] = [item[0] for item in subtest_config_items]
  660. subtest_config_values: List[List[Any]] = [
  661. item[1] for item in subtest_config_items
  662. ]
  663. for values in itertools.product(*subtest_config_values):
  664. # Map keyword to chosen value
  665. subtest_kwargs = {
  666. kwarg: value for kwarg, value in zip(subtest_config_keys, values)
  667. }
  668. with self.subTest(**subtest_kwargs):
  669. test_fn(*test_args, **test_kwargs, **subtest_kwargs)
  670. dist.barrier()
  671. @classmethod
  672. def _run(cls, rank, test_name, file_name, pipe):
  673. self = cls(test_name)
  674. self.rank = rank
  675. self.file_name = file_name
  676. print(f"dist init r={self.rank}, world={self.world_size}")
  677. # Specify gloo backend to make 'init_process_group()' succeed,
  678. # Actual tests will be skipped if there is no enough GPUs.
  679. backend = "nccl" if torch.cuda.is_available() else "gloo"
  680. try:
  681. dist.init_process_group(
  682. init_method=self.init_method,
  683. backend=backend,
  684. world_size=int(self.world_size),
  685. rank=self.rank,
  686. )
  687. except RuntimeError as e:
  688. if "recompile" in e.args[0]:
  689. sys.exit(TEST_SKIPS["backend_unavailable"].exit_code)
  690. raise
  691. if torch.cuda.is_available() and torch.cuda.device_count():
  692. torch.cuda.set_device(self.rank % torch.cuda.device_count())
  693. # Execute barrier prior to running test to ensure that every process
  694. # has finished initialization and that the following test
  695. # immediately exiting due to a skip doesn't cause flakiness.
  696. dist.barrier()
  697. self.run_test(test_name, pipe)
  698. dist.barrier()
  699. dist.destroy_process_group()
  700. sys.exit(0)
  701. def _train_for_several_steps(
  702. self,
  703. model: nn.Module,
  704. num_steps: int,
  705. autocast: bool,
  706. lr: float = 0.01,
  707. fsdp_cpu_offload: Optional[CPUOffload] = None,
  708. save_model: bool = False,
  709. mixed_precision: Optional[MixedPrecision] = None,
  710. enable_sharded_grad_scaler: bool = False,
  711. use_pure_fp16: bool = False,
  712. ):
  713. cpu_offload_params = fsdp_cpu_offload and fsdp_cpu_offload.offload_params
  714. model_device = next(model.parameters()).device
  715. sharded_grad_scaler = ShardedGradScaler(enabled=enable_sharded_grad_scaler)
  716. # use SGD with momentum instead of Adam, since Adam is scale invariant
  717. # and this makes it bad for tests
  718. optim = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9)
  719. for _ in range(num_steps):
  720. optim.zero_grad()
  721. with torch.cuda.amp.autocast(enabled=autocast):
  722. # Inputs always cuda regardless of cpu offloading, or model.device
  723. input = model.module.get_input(torch.device("cuda"))
  724. if use_pure_fp16 or (mixed_precision and not isinstance(model, FSDP)):
  725. if isinstance(input, torch.Tensor):
  726. input = input.half()
  727. else:
  728. input = tuple(x.half() for x in input)
  729. output = model(*input)
  730. # Post-forward, if CPU offloading model param should be on CPU.
  731. if cpu_offload_params and isinstance(model, FSDP):
  732. for p in model.parameters():
  733. # Params should always be on CPU
  734. self.assertEqual(p.device, torch.device("cpu"))
  735. loss = model.module.get_loss(input, output).to(model_device)
  736. loss = sharded_grad_scaler.scale(loss)
  737. if not mixed_precision and not use_pure_fp16:
  738. assert (
  739. loss.dtype == torch.float32
  740. ), "loss data type should be float32, as the original \
  741. parameter data type is float32."
  742. else:
  743. if use_pure_fp16:
  744. self.assertEqual(loss.dtype, torch.float16)
  745. # FSDP loss is fp16, DDP AMP loss is fp32
  746. elif isinstance(model, FSDP):
  747. self.assertEqual(loss.dtype, mixed_precision.param_dtype)
  748. else:
  749. self.assertEqual(loss.dtype, torch.float32)
  750. model.module.run_backward(loss)
  751. # Post-backward, if CPU offloading model params should be on CPU.
  752. if cpu_offload_params and isinstance(model, FSDP):
  753. for p in model.parameters():
  754. # Params should always be on CPU
  755. self.assertEqual(p.device, torch.device("cpu"))
  756. # Unscale the gradients and step
  757. sharded_grad_scaler.step(optim)
  758. # Update the scale factor
  759. sharded_grad_scaler.update()
  760. # if save_model, simulate save + load.
  761. if save_model:
  762. state_dict = {k: v.clone() for k, v in model.state_dict().items()}
  763. # Zero params, if save/load state_dict did not work properly, this
  764. # would break the parity test with DDP.
  765. _zero_model(model)
  766. model.load_state_dict(state_dict)
  767. if isinstance(model, FSDP):
  768. model._assert_state(TrainingState.IDLE)
  769. return loss.detach()
  770. def _test_fsdp_parity(
  771. self,
  772. model_class: Type[FSDPTestModel],
  773. fsdp_init_mode: FSDPInitMode,
  774. cuda_init_mode: CUDAInitMode,
  775. ref_init_fn: Optional[Callable] = None,
  776. num_iters: int = 2,
  777. save_model: bool = True,
  778. cpu_offload: CPUOffload = CPUOffload(),
  779. backward_prefetch: Optional[BackwardPrefetch] = None,
  780. sharding_strategy: Optional[ShardingStrategy] = None,
  781. mixed_precision: Optional[MixedPrecision] = None,
  782. forward_prefetch: bool = False,
  783. use_orig_params: bool = False,
  784. enable_sharded_grad_scaler: bool = False,
  785. use_pure_fp16: bool = False,
  786. init_kwargs: Optional[Dict[str, Any]] = None,
  787. **fsdp_kwargs,
  788. ):
  789. """
  790. Tests FSDP training against a reference, which defaults to DDP but
  791. may be customized with ``ref_init_fn``.
  792. Args:
  793. model_class (Type[FSDPTestModel]): A model class that inherits from
  794. ``FSDPTestModel``, which defines the expected interface.
  795. fsdp_init_mode (FSDPInitMode): The mode to initialize the
  796. FSDP-wrapped model. This should not be ``NO_FSDP``.
  797. ref_init_fn (Optional[Callable]): A callable to invoke that wraps a
  798. non-wrapped model to construct the reference model, where this
  799. wrapper should provide data parallel semantics. If ``None``,
  800. then the callable defaults to the DDP constructor.
  801. """
  802. assert (
  803. fsdp_init_mode != FSDPInitMode.NO_FSDP
  804. ), "Expects an FSDP init mode that wraps with FSDP"
  805. if init_kwargs is None:
  806. init_kwargs = {}
  807. lr = 1e-2
  808. rank = self.process_group.rank()
  809. # Establish reference behavior with DDP
  810. model = model_class.init(
  811. self.process_group,
  812. FSDPInitMode.NO_FSDP,
  813. CUDAInitMode.CUDA_BEFORE,
  814. deterministic=True,
  815. **init_kwargs,
  816. )
  817. if ref_init_fn is None:
  818. ref_model = DDP(model, device_ids=[rank], output_device=rank)
  819. else:
  820. ref_model = ref_init_fn(model)
  821. if use_pure_fp16:
  822. ref_model = ref_model.half()
  823. ref_loss = self._train_for_several_steps(
  824. ref_model,
  825. num_iters,
  826. autocast=mixed_precision is not None,
  827. lr=lr,
  828. fsdp_cpu_offload=cpu_offload,
  829. mixed_precision=mixed_precision,
  830. enable_sharded_grad_scaler=enable_sharded_grad_scaler,
  831. use_pure_fp16=use_pure_fp16,
  832. )
  833. ddp_params = list(ref_model.parameters())
  834. # Check against FSDP behavior
  835. fsdp_kwargs.update(
  836. {
  837. "cpu_offload": cpu_offload,
  838. "backward_prefetch": backward_prefetch,
  839. "sharding_strategy": sharding_strategy,
  840. "mixed_precision": mixed_precision,
  841. "forward_prefetch": forward_prefetch,
  842. "use_orig_params": use_orig_params,
  843. }
  844. )
  845. try:
  846. fsdp_model = model_class.init(
  847. self.process_group,
  848. fsdp_init_mode,
  849. cuda_init_mode,
  850. fsdp_kwargs,
  851. deterministic=True,
  852. **init_kwargs,
  853. )
  854. except Exception as e:
  855. raise ValueError(f"Initializing {model_class} raised error {str(e)}") from e
  856. if not isinstance(fsdp_model, FSDP):
  857. # Enforce that we wrap with top-level FSDP since we are comparing
  858. # assuming a data parallel reference and some test models may not
  859. # do so in their `init()` method
  860. fsdp_model = FSDP(fsdp_model, self.process_group, **fsdp_kwargs)
  861. if use_pure_fp16:
  862. # Change the model parameter dtype after FSDP initialization
  863. fsdp_model = fsdp_model.half()
  864. if cuda_init_mode == CUDAInitMode.CUDA_AFTER:
  865. fsdp_model = fsdp_model.cuda()
  866. offload_params = cpu_offload is not None and cpu_offload.offload_params
  867. # Offloading parameters with `CUDA_AFTER` should raise an error during
  868. # lazy initialization due to the parameter devices not being CPU;
  869. # otherwise, all parameter devices should be CPU
  870. expects_device_error = (
  871. offload_params and cuda_init_mode == CUDAInitMode.CUDA_AFTER
  872. )
  873. expects_cpu_device = (
  874. offload_params and cuda_init_mode != CUDAInitMode.CUDA_AFTER
  875. )
  876. if expects_cpu_device:
  877. cpu_device = torch.device("cpu")
  878. for param in fsdp_model.parameters():
  879. self.assertEqual(param.device, cpu_device)
  880. context = (
  881. self.assertRaisesRegex(
  882. RuntimeError,
  883. "An FSDP-managed module with parameter CPU offloading enabled "
  884. "has parameters on cuda",
  885. )
  886. if expects_device_error
  887. else suppress()
  888. )
  889. with context:
  890. fsdp_loss = self._train_for_several_steps(
  891. fsdp_model,
  892. num_iters,
  893. autocast=False,
  894. lr=lr,
  895. fsdp_cpu_offload=cpu_offload,
  896. save_model=save_model,
  897. mixed_precision=mixed_precision,
  898. enable_sharded_grad_scaler=enable_sharded_grad_scaler,
  899. use_pure_fp16=use_pure_fp16,
  900. )
  901. # No need to check for parameter and loss parity if expecting an error
  902. if expects_device_error:
  903. return
  904. # Check parameter devices are CPU if offloading to CPU before calling
  905. # `get_full_params()`, which will cast the parameters to FP32
  906. if offload_params:
  907. for param in fsdp_model.parameters():
  908. self.assertEqual(param.device, cpu_device)
  909. fsdp_loss = fsdp_loss.cuda()
  910. fsdp_unsharded_params = get_full_params(fsdp_model)
  911. # Do not check dtype since the reference DDP loss may not be the same
  912. # dtype as the FSDP loss in the case of mixed precision
  913. torch.testing.assert_close(ref_loss, fsdp_loss, check_dtype=False)
  914. # Do not check for parameter parity if using mixed precision since (1)
  915. # the DDP parameters are in FP16 (from `half()`) while the FSDP
  916. # parameters are in FP32 (from `summon_full_params()`) and (2) DDP runs
  917. # the optimizer in FP16 while FSDP runs it in FP32
  918. # TODO: Disable checking the parameters for pure FP16 due to floating
  919. # point inaccuracy. Note that this means that the backward pass is not
  920. # checked: https://github.com/pytorch/pytorch/issues/90784
  921. if mixed_precision is None and not use_pure_fp16:
  922. self.assertEqual(
  923. ddp_params,
  924. fsdp_unsharded_params,
  925. exact_device=True,
  926. msg="FSDP did not match DDP",
  927. )
  928. class SkipModule(nn.Module):
  929. def __init__(self):
  930. super().__init__()
  931. self.lin = nn.Linear(10, 10, bias=False)
  932. def forward(self, x):
  933. return self.lin(x)
  934. class NestedLinear(nn.Module):
  935. def __init__(self, fsdp_wrap):
  936. super().__init__()
  937. if fsdp_wrap:
  938. self.nested_linear = wrap(nn.Linear(10, 10, bias=False).cuda())
  939. else:
  940. self.nested_linear = nn.Linear(10, 10, bias=False).cuda()
  941. def forward(self, x):
  942. return self.nested_linear(x)
  943. class SkipModel(nn.Module):
  944. def __init__(self, double_nest):
  945. super().__init__()
  946. self.linear = nn.Linear(10, 10, bias=False).cuda()
  947. self.linear_skip = SkipModule().cuda()
  948. self.nested_linear = wrap(NestedLinear(fsdp_wrap=double_nest))
  949. def forward(self, x):
  950. x = self.linear(x)
  951. x = self.linear_skip(x)
  952. x = self.nested_linear(x)
  953. return x