replicate.py 2.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
  1. from typing import List, Tuple
  2. import torch
  3. import torch.nn as nn
  4. from . import _ddp
  5. from .contract import _get_registry, contract
  6. @contract()
  7. def replicate(
  8. module: nn.Module, # NOTE: contract now supports single module only
  9. **kwargs,
  10. ) -> nn.Module:
  11. r"""Replicates a module
  12. Args:
  13. module (torch.nn.Module): module to replicate
  14. Example::
  15. >>> # xdoctest: +REQUIRES(module:torch._C._distributed_c10d)
  16. >>> module = nn.Linear(3, 3)
  17. >>> replicate(module)
  18. """
  19. _ReplicateState().mark_modules(module, **kwargs)
  20. return module
  21. def _can_compose(module: nn.Module) -> bool:
  22. r"""Check if module is composable for `replicate` API."""
  23. return "fully_shard" not in _get_registry(module)
  24. class _ReplicateState:
  25. def __init__(self) -> None:
  26. self.modules: List[nn.Module] = []
  27. self.has_initialized: bool = False
  28. self._param_list: nn.ParameterList = nn.ParameterList()
  29. self.kwargs: dict = {}
  30. def mark_modules(self, *modules: nn.Module, **kwargs) -> None:
  31. for module in modules:
  32. if not _can_compose(module):
  33. raise AssertionError(
  34. "Cannot apply `replicate()` on a Module already managed by `fully_shard`"
  35. )
  36. self.modules.append(module)
  37. replicate.state(module)._distributed_state = self
  38. replicate.state(module)._params_collected = False
  39. module.register_forward_pre_hook(self.forward_pre_hook)
  40. # TODO(@yhcharles): fix type error
  41. module.register_forward_hook(self.forward_post_hook) # type: ignore[arg-type]
  42. self.kwargs = kwargs
  43. def _recursive_collect_params(self, module: nn.Module) -> None:
  44. # skip if managed by other APIs
  45. if not _can_compose(module):
  46. return
  47. # skip if module parameters already collected
  48. if hasattr(replicate.state(module), "_params_collected"):
  49. if replicate.state(module)._params_collected:
  50. return
  51. replicate.state(module)._params_collected = True
  52. self._param_list.extend(
  53. param for param in module.parameters(recurse=False) if param.requires_grad
  54. )
  55. for child in module.children():
  56. self._recursive_collect_params(child)
  57. def init_helper(self) -> None:
  58. if self.has_initialized:
  59. return
  60. self.has_initialized = True
  61. for module in self.modules:
  62. self._recursive_collect_params(module)
  63. self._ddp = _ddp.DistributedDataParallel(self._param_list, **self.kwargs)
  64. def forward_pre_hook(
  65. self, module: nn.Module, input: Tuple[torch.Tensor, ...]
  66. ) -> None:
  67. self.init_helper()
  68. self._ddp.pre_forward()
  69. def forward_post_hook(
  70. self,
  71. module: nn.Module,
  72. input: Tuple[torch.Tensor],
  73. output: torch.Tensor,
  74. ) -> torch.Tensor:
  75. return self._ddp.post_forward(output)