12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394 |
- from typing import List, Tuple
- import torch
- import torch.nn as nn
- from . import _ddp
- from .contract import _get_registry, contract
- @contract()
- def replicate(
- module: nn.Module, # NOTE: contract now supports single module only
- **kwargs,
- ) -> nn.Module:
- r"""Replicates a module
- Args:
- module (torch.nn.Module): module to replicate
- Example::
- >>> # xdoctest: +REQUIRES(module:torch._C._distributed_c10d)
- >>> module = nn.Linear(3, 3)
- >>> replicate(module)
- """
- _ReplicateState().mark_modules(module, **kwargs)
- return module
- def _can_compose(module: nn.Module) -> bool:
- r"""Check if module is composable for `replicate` API."""
- return "fully_shard" not in _get_registry(module)
- class _ReplicateState:
- def __init__(self) -> None:
- self.modules: List[nn.Module] = []
- self.has_initialized: bool = False
- self._param_list: nn.ParameterList = nn.ParameterList()
- self.kwargs: dict = {}
- def mark_modules(self, *modules: nn.Module, **kwargs) -> None:
- for module in modules:
- if not _can_compose(module):
- raise AssertionError(
- "Cannot apply `replicate()` on a Module already managed by `fully_shard`"
- )
- self.modules.append(module)
- replicate.state(module)._distributed_state = self
- replicate.state(module)._params_collected = False
- module.register_forward_pre_hook(self.forward_pre_hook)
- # TODO(@yhcharles): fix type error
- module.register_forward_hook(self.forward_post_hook) # type: ignore[arg-type]
- self.kwargs = kwargs
- def _recursive_collect_params(self, module: nn.Module) -> None:
- # skip if managed by other APIs
- if not _can_compose(module):
- return
- # skip if module parameters already collected
- if hasattr(replicate.state(module), "_params_collected"):
- if replicate.state(module)._params_collected:
- return
- replicate.state(module)._params_collected = True
- self._param_list.extend(
- param for param in module.parameters(recurse=False) if param.requires_grad
- )
- for child in module.children():
- self._recursive_collect_params(child)
- def init_helper(self) -> None:
- if self.has_initialized:
- return
- self.has_initialized = True
- for module in self.modules:
- self._recursive_collect_params(module)
- self._ddp = _ddp.DistributedDataParallel(self._param_list, **self.kwargs)
- def forward_pre_hook(
- self, module: nn.Module, input: Tuple[torch.Tensor, ...]
- ) -> None:
- self.init_helper()
- self._ddp.pre_forward()
- def forward_post_hook(
- self,
- module: nn.Module,
- input: Tuple[torch.Tensor],
- output: torch.Tensor,
- ) -> torch.Tensor:
- return self._ddp.post_forward(output)
|