123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774 |
- #!/usr/bin/python3
- import collections
- import io
- import sys
- import types
- from typing import (
- Any,
- Callable,
- Dict,
- Iterator,
- List,
- Mapping,
- Optional,
- Set,
- Tuple,
- Type,
- TypeVar,
- Union,
- )
- import torch
- import torch.distributed.rpc as rpc
- from torch import Tensor, device, dtype, nn
- from torch.distributed.nn.jit import instantiator
- from torch.distributed import _remote_device
- from torch.distributed.rpc.internal import _internal_rpc_pickler
- from torch.nn import Module
- from torch.nn.parameter import Parameter
- from torch.utils.hooks import RemovableHandle
- __all__ = ["RemoteModule"]
- _grad_t = Union[Tuple[Tensor, ...], Tensor]
- # See https://mypy.readthedocs.io/en/latest/generics.html#generic-methods-and-generic-self for the use
- # of `T` to annotate `self`. Many methods of `Module` return `self` and we want those return values to be
- # the type of the subclass, not the looser type of `Module`.
- T = TypeVar("T", bound="Module")
- _NON_SCRIPTABLE_REMOTE_MODULE_MODULE = (
- instantiator.instantiate_non_scriptable_remote_module_template()
- )
- _REMOTE_MODULE_PICKLED_ATTRIBUTES = (
- "on",
- "device",
- "is_device_map_set",
- "is_scriptable",
- "generated_methods",
- "module_rref",
- )
- _SerializedRemoteModule = collections.namedtuple("_SerializedRemoteModule", _REMOTE_MODULE_PICKLED_ATTRIBUTES) # type: ignore[misc]
- # These attributes are mostly from RemoteModule's parent class and are intentionally not pickled.
- # A new attribute of RemoteModule should be either in _REMOTE_MODULE_PICKLED_ATTRIBUTES
- # or _REMOTE_MODULE_ATTRIBUTES_IGNORE_FOR_PICKLING.
- # Otherwise, it will not be pickled.
- _REMOTE_MODULE_ATTRIBUTES_IGNORE_FOR_PICKLING = (
- "training",
- "_parameters",
- "_buffers",
- "_non_persistent_buffers_set",
- "_backward_hooks",
- "_backward_pre_hooks",
- "_is_full_backward_hook",
- "_forward_hooks",
- "_forward_hooks_with_kwargs",
- "_forward_pre_hooks",
- "_forward_pre_hooks_with_kwargs",
- "_state_dict_hooks",
- "_state_dict_pre_hooks",
- "_load_state_dict_pre_hooks",
- "_load_state_dict_post_hooks",
- "_state_dict_pre_hooks",
- "_modules",
- # The two attributes below are generated methods, not available at pickling time.
- "forward_async",
- "forward",
- )
- # RPC handler.
- def _instantiate_template(module_interface_cls, enable_moving_cpu_tensors_to_cuda):
- instantiator.instantiate_scriptable_remote_module_template(
- module_interface_cls, enable_moving_cpu_tensors_to_cuda
- )
- def _create_module(module_cls, args, kwargs, device):
- module = module_cls(*args, **kwargs)
- if not isinstance(module, nn.Module):
- raise ValueError(
- "Expect `module_cls(*args, **kwargs)` returns an instance of <class nn.Module>, "
- f"but it returns an instance of {type(module)}."
- )
- module.to(device)
- return module
- def _create_module_with_interface(
- module_cls, args, kwargs, device, module_interface_cls
- ):
- module = _create_module(module_cls, args, kwargs, device)
- if module_interface_cls is not None:
- module = torch.jit.script(module)
- return rpc.RRef(module, module_interface_cls)
- def _param_rrefs(module_rref, recurse) -> List[rpc.RRef[Parameter]]:
- ret: List[rpc.RRef[Parameter]] = []
- for param in module_rref.local_value().parameters(recurse):
- ret.append(rpc.RRef(param))
- return ret
- def _raise_not_supported(name: str) -> None:
- raise ValueError("Method ``{}`` not supported for RemoteModule".format(name))
- class _RemoteModule(nn.Module):
- def __new__(cls, *args, **kwargs):
- # Use __new__ for logging purposes.
- torch._C._log_api_usage_once("torch.distributed.nn.api.remote_module")
- return super(_RemoteModule, cls).__new__(cls)
- def __init__(
- self,
- remote_device: str,
- module_cls: Type[nn.Module],
- args: Tuple = None,
- kwargs: Dict[str, Any] = None,
- _module_interface_cls: Any = None,
- ):
- """
- A RemoteModule instance can only be created after RPC initialization.
- It creates a user-specified module on a specified remote node.
- It behaves like a regular ``nn.Module`` except that the ``forward`` method is
- executed on the remote node.
- It takes care of autograd recording to ensure the backward pass propagates
- gradients back to the corresponding remote module.
- It can be shared across processors using `RPC framework <https://pytorch.org/docs/stable/rpc.html>`__,
- without incurring any overheads of copying the actual module,
- which is equivalent to an :class:`~torch.distributed.rpc.RRef`
- pointing to the remote module.
- The arguments of ``forward_async`` and ``forward`` are the same as
- the ``forward`` method of the module returned by the ``module_cls``.
- Apart from ``forward_async`` and ``forward``, no other methods are supported from nn.Module for now.
- Particularly, to create a hybrid model, typically the local modules should be
- created outside of remote modules, rather than as submodules of any remote module (by calling ``add_module``).
- Hybrid Example:
- >>> class HybridModel(nn.Module):
- >>> def __init__(self):
- >>> nn.Module.__init__(self)
- >>> self.remote_embedding = RemoteModule(...)
- >>> self.local_linear = nn.Linear(...)
- For example, if ``module_cls`` returns an instance of ``nn.Linear``,
- that has ``forward`` method signature, ``def forward(input: Tensor) -> Tensor:``,
- the generated ``RemoteModule`` will have 2 methods in signature of
- ``def forward(input: Tensor) -> Tensor:`` and
- ``def forward_async(input: Tensor) -> Future[Tensor]:``.
- .. note::
- If the remote module is placed on a cuda device,
- any input CPU tensors will be automatically moved to the same cuda device,
- and GPU tensors are returned over the wire according to the device map of the remote worker on TensorPipe RPC backend.
- Args:
- remote_device (str): Device on the destination worker where we'd like to place this module.
- The device can be a local device or a remote device specified by one of the following remote
- formats:
- 1. "rank:<rank>/<device>" (ex: "rank:0/cuda:0").
- 2. "<worker_name>/<device>" (ex: "trainer0/cuda:0").
- In addition, the device field can be optional and the default value is "cpu".
- module_cls (nn.Module): For example,
- >>> class MyModule(nn.Module):
- >>> def forward(input):
- >>> return input + 1
- >>>
- >>> module_cls = MyModule
- args (Sequence, optional): args to be passed to ``module_cls``.
- kwargs (Dict, optional): kwargs to be passed to ``module_cls``.
- _module_interface_cls (type, optional): The TorchScript interface type for the module
- to be created. The type object should be decorated by @torch.jit.interface.
- If not provided, the generated RemoteModule is not torchscript-able.
- Warning, this is an experimental API and susceptible to frequent changes.
- Returns:
- A remote module instance which wraps the :class:`~nn.Module` created by the
- user-provided ``module_cls``, it has a blocking ``forward`` method and an
- asynchronous ``forward_async`` method that returns a future of the ``forward`` call
- on the user-provided module on the remote side.
- Example::
- Run the following code in two different processes:
- >>> # xdoctest: +SKIP("distributed")
- >>> # On worker 0:
- >>> import torch
- >>> import torch.distributed.rpc as rpc
- >>> from torch import nn, Tensor
- >>> from torch.distributed.nn.api.remote_module import RemoteModule
- >>>
- >>> rpc.init_rpc("worker0", rank=0, world_size=2)
- >>> remote_linear_module = RemoteModule(
- >>> "worker1/cpu", nn.Linear, args=(20, 30),
- >>> )
- >>> input = torch.randn(128, 20)
- >>> ret_fut = remote_linear_module.forward_async(input)
- >>> ret = ret_fut.wait()
- >>> rpc.shutdown()
- >>> # On worker 1:
- >>> import torch
- >>> import torch.distributed.rpc as rpc
- >>>
- >>> rpc.init_rpc("worker1", rank=1, world_size=2)
- >>> rpc.shutdown()
- """
- super().__init__()
- enable_moving_cpu_tensors_to_cuda = self._prepare_init(remote_device)
- # Default arguments preperation.
- args = args if args is not None else ()
- kwargs = kwargs if kwargs is not None else {}
- if _module_interface_cls is not None:
- # Users reply on this field to know if this generated RemoteModule is TorchScript-able.
- self.is_scriptable = True
- # Instantiate template on remote side.
- fut = rpc.rpc_async(
- self.on,
- _instantiate_template,
- (_module_interface_cls, enable_moving_cpu_tensors_to_cuda),
- )
- self._init_template(
- _module_interface_cls, enable_moving_cpu_tensors_to_cuda
- )
- # Instantiate template on remote side.
- fut = rpc.rpc_async(
- self.on,
- _instantiate_template,
- (_module_interface_cls, enable_moving_cpu_tensors_to_cuda),
- )
- # Create the module on the remote side.
- fut.wait() # Ensure remote_module_cls is available on remote side.
- # TODO: We need to change this to rpc.remote, and make it async (see the else branch below).
- # For that we need to be able to apply _module_interface_cls to the RRef returned by rpc.remote
- # See https://github.com/pytorch/pytorch/issues/58098 for more context.
- self.module_rref = rpc.rpc_sync(
- self.on,
- _create_module_with_interface,
- (module_cls, args, kwargs, self.device, _module_interface_cls),
- )
- else:
- self.is_scriptable = False
- self.generated_methods = (
- _NON_SCRIPTABLE_REMOTE_MODULE_MODULE._generated_methods
- )
- # Create the module on the remote side.
- self.module_rref = rpc.remote(
- self.on,
- _create_module,
- (module_cls, args, kwargs, self.device),
- )
- self._install_generated_methods()
- self._check_attribute_picklability()
- def remote_parameters(self, recurse: bool = True) -> List[rpc.RRef[Parameter]]:
- """
- Returns a list of :class:`~torch.distributed.rpc.RRef` pointing to the
- remote module's parameters. This can typically be used in conjuction
- with :class:`~torch.distributed.optim.DistributedOptimizer`.
- Args:
- recurse (bool): if True, then returns parameters of the remote
- module and all submodules of the remote module. Otherwise,
- returns only parameters that are direct members of the
- remote module.
- Returns:
- A list of :class:`~torch.distributed.rpc.RRef` (``List[RRef[nn.Parameter]]``)
- to remote module's parameters.
- """
- return rpc.rpc_sync(self.on, _param_rrefs, args=(self.module_rref, recurse))
- def get_module_rref(self) -> rpc.RRef[nn.Module]:
- """
- Returns an :class:`~torch.distributed.rpc.RRef` (``RRef[nn.Module]``)
- pointing to the remote module.
- """
- return self.module_rref
- @torch.jit.export
- def __getstate__(self):
- raise RuntimeError(
- "Cannot pickle RemoteModule in python pickler. RemoteModule can only be pickled when using RPC"
- )
- @torch.jit.export
- def __setstate__(self, state):
- raise RuntimeError(
- "Cannot unpickle RemoteModule in python pickler. RemoteModule can only be unpickled when using RPC"
- )
- def register_buffer(
- self, name: str, tensor: Optional[Tensor], persistent: bool = True
- ) -> None:
- _raise_not_supported(self.register_buffer.__name__)
- def register_parameter(self, name: str, param: Optional[Parameter]) -> None:
- _raise_not_supported(self.register_parameter.__name__)
- def add_module(self, name: str, module: Optional[Module]) -> None:
- _raise_not_supported(self.add_module.__name__)
- def apply(self: T, fn: Callable[[Module], None]) -> T: # type: ignore[return]
- _raise_not_supported(self.apply.__name__)
- def cuda(self: T, device: Optional[Union[int, device]] = None) -> T: # type: ignore[return]
- _raise_not_supported(self.cuda.__name__)
- def ipu(self: T, device: Optional[Union[int, device]] = None) -> T: # type: ignore[return]
- _raise_not_supported(self.ipu.__name__)
- def xpu(self: T, device: Optional[Union[int, device]] = None) -> T: # type: ignore[return]
- _raise_not_supported(self.xpu.__name__)
- def cpu(self: T) -> T: # type: ignore[return]
- _raise_not_supported(self.cpu.__name__)
- def type(self: T, dst_type: Union[dtype, str]) -> T: # type: ignore[return]
- _raise_not_supported(self.type.__name__)
- def float(self: T) -> T: # type: ignore[return]
- _raise_not_supported(self.float.__name__)
- def double(self: T) -> T: # type: ignore[return]
- _raise_not_supported(self.double.__name__)
- def half(self: T) -> T: # type: ignore[return]
- _raise_not_supported(self.half.__name__)
- def bfloat16(self: T) -> T: # type: ignore[return]
- _raise_not_supported(self.bfloat16.__name__)
- def to(self, *args, **kwargs) -> T: # type: ignore[return]
- _raise_not_supported(self.to.__name__)
- def register_backward_hook( # type: ignore[return]
- self, hook: Callable[[Module, _grad_t, _grad_t], Union[None, _grad_t]]
- ) -> RemovableHandle:
- _raise_not_supported(self.register_backward_hook.__name__)
- def register_forward_pre_hook( # type: ignore[return]
- self,
- hook: Union[
- Callable[[T, Tuple[Any, ...]], Optional[Any]],
- Callable[[T, Tuple[Any, ...], Dict[str, Any]], Optional[Tuple[Any, Dict[str, Any]]]],
- ],
- prepend: bool = False,
- with_kwargs: bool = False,
- ) -> RemovableHandle:
- _raise_not_supported(self.register_forward_pre_hook.__name__)
- def register_forward_hook( # type: ignore[return]
- self,
- hook: Union[
- Callable[[T, Tuple[Any, ...], Any], Optional[Any]],
- Callable[[T, Tuple[Any, ...], Dict[str, Any], Any], Optional[Any]],
- ],
- prepend: bool = False,
- with_kwargs: bool = False,
- ) -> RemovableHandle:
- _raise_not_supported(self.register_forward_hook.__name__)
- def state_dict(self, *args, **kwargs):
- _raise_not_supported(self.state_dict.__name__)
- def load_state_dict(
- self,
- state_dict: Mapping[str, Any],
- strict: bool = True,
- ):
- _raise_not_supported(self.load_state_dict.__name__)
- def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
- raise ValueError(
- "Method ``parameters`` not supported for RemoteModule. Please use ``remote_parameters`` instead."
- )
- def named_parameters( # type: ignore[return]
- self,
- prefix: str = "",
- recurse: bool = True,
- remove_duplicate: bool = True
- ) -> Iterator[Tuple[str, Parameter]]:
- _raise_not_supported(self.named_parameters.__name__)
- def buffers(self, recurse: bool = True) -> Iterator[Tensor]: # type: ignore[return]
- _raise_not_supported(self.buffers.__name__)
- def named_buffers( # type: ignore[return]
- self,
- prefix: str = "",
- recurse: bool = True,
- remove_duplicate: bool = True
- ) -> Iterator[Tuple[str, Tensor]]:
- _raise_not_supported(self.named_buffers.__name__)
- def children(self) -> Iterator[Module]: # type: ignore[return]
- _raise_not_supported(self.children.__name__)
- def named_children(self) -> Iterator[Tuple[str, Module]]: # type: ignore[return]
- _raise_not_supported(self.named_children.__name__)
- def modules(self) -> Iterator[Module]: # type: ignore[return]
- _raise_not_supported(self.modules.__name__)
- def named_modules(
- self,
- memo: Optional[Set[Module]] = None,
- prefix: str = "",
- remove_duplicate: bool = True,
- ):
- _raise_not_supported(self.named_modules.__name__)
- def train(self: T, mode: bool = True) -> T:
- return self.module_rref.rpc_sync().train() # type: ignore[operator, union-attr]
- def eval(self: T) -> T:
- return self.module_rref.rpc_sync().eval() # type: ignore[operator, union-attr]
- def requires_grad_(self: T, requires_grad: bool = True) -> T: # type: ignore[return]
- _raise_not_supported(self.requires_grad_.__name__)
- def zero_grad(self, set_to_none: bool = True) -> None:
- _raise_not_supported(self.zero_grad.__name__)
- def share_memory(self: T) -> T: # type: ignore[return]
- _raise_not_supported(self.share_memory.__name__)
- def extra_repr(self) -> str: # type: ignore[return]
- _raise_not_supported(self.extra_repr.__name__)
- def _prepare_init(self, remote_device_str: str) -> bool:
- """
- Prepares the initializaiton and returns whether to enable automatically moving CPU tensors to CUDA devices.
- """
- # Sanity check.
- assert rpc._is_current_rpc_agent_set(), "RemoteModule only works in RPC."
- remote_device = _remote_device(remote_device_str)
- self.on = remote_device.worker_name() if remote_device.worker_name() is not None else remote_device.rank()
- self.device = str(remote_device.device())
- agent = rpc._get_current_rpc_agent()
- # If the device map of the remote worker is set,
- # then enable moving any input CPU tensors to the same cuda device.
- self.is_device_map_set = bool(
- agent._get_device_map(agent.get_worker_info(self.on)) # type: ignore[arg-type]
- )
- # ``enable_moving_cpu_tensors_to_cuda`` is less strict than ``is_device_map_set``:
- # If ``enable_moving_cpu_tensors_to_cuda`` is true, but the device map is not set,
- # then any CPU tensors can still be moved to a cuda device to run forward,
- # but the output must be moved back to CPU before being sent over the wire.
- enable_moving_cpu_tensors_to_cuda = torch.device(self.device).type == "cuda"
- return enable_moving_cpu_tensors_to_cuda
- def _init_template(self, module_interface_cls, enable_moving_cpu_tensors_to_cuda):
- """
- Instantiates template on local side.
- """
- generated_module = instantiator.instantiate_scriptable_remote_module_template(
- module_interface_cls, enable_moving_cpu_tensors_to_cuda
- )
- self.generated_methods = generated_module._generated_methods
- def _check_attribute_picklability(self):
- """
- Checks if all the attribute has explicitly defined whether to be pickled (i.e., picklability).
- """
- for k in self.__dict__.keys():
- if (
- k not in _REMOTE_MODULE_PICKLED_ATTRIBUTES
- and k not in _REMOTE_MODULE_ATTRIBUTES_IGNORE_FOR_PICKLING
- ):
- raise AttributeError(
- "Attribute {} must be either in ``_REMOTE_MODULE_PICKLED_ATTRIBUTES`` or "
- "``_REMOTE_MODULE_ATTRIBUTES_IGNORE_FOR_PICKLING``.".format(k)
- )
- def _install_generated_methods(self):
- for method in self.generated_methods:
- method_name = method.__name__
- method = torch.jit.export(method)
- setattr(self, method_name, types.MethodType(method, self))
- @staticmethod
- def init_from_module_rref(
- remote_device: str,
- module_rref: rpc.RRef[nn.Module],
- _module_interface_cls: Any = None,
- ):
- """
- Besides the constructor, a RemoteModule instance can also be initialized given a module RRef.
- This alternate initiailization method can be particularly useful if we want to create multiple
- RemoteModule instances that share the same underlying module and reduce memory consumption.
- Moreover, this also provides a workaround for passing script RemoteModule over RPC,
- which is not supported. The recommended way is as follows:
- 1. the sender creates a RemoteModule;
- 2. the sender sends its ``module_rref`` over RPC;
- 3. the receiver calls this method to initialize another RemoteModule using the same ``module_rref``.
- Example::
- Run the following code in two different processes:
- >>> # xdoctest: +SKIP("distributed")
- >>> # On worker 0:
- >>> import torch
- >>> import torch.distributed.rpc as rpc
- >>> from torch import nn, Tensor
- >>> from torch.distributed.nn.api.remote_module import RemoteModule
- >>>
- >>> rpc.init_rpc("worker0", rank=0, world_size=2)
- >>> remote_module = RemoteModule(
- >>> "worker1/cpu", nn.Linear, args=(20, 30),
- >>> )
- >>>
- >>> remote_module1 = rpc.rpc_sync(
- >>> "worker1/cpu",
- >>> RemoteModule.init_from_module_rref,
- >>> ("worker1/cpu", remote_module1.get_module_rref()),
- >>> )
- >>> rpc.shutdown()
- >>> # On worker 1:
- >>> import torch
- >>> import torch.distributed.rpc as rpc
- >>>
- >>> rpc.init_rpc("worker1", rank=1, world_size=2)
- >>> rpc.shutdown()
- Args:
- remote_device (str): Device on the destination worker where we'd like to place this module.
- The device can be a local device or a remote device specified by one of the following remote
- formats:
- 1. "rank:<rank>/<device>" (ex: "rank:0/cuda:0").
- 2. "<worker_name>/<device>" (ex: "trainer0/cuda:0").
- In addition, the device field can be optional and the default value is "cpu".
- module_rref (RRef[nn.Module]): The module reference shared by both the caller and
- the created remote module.
- _module_interface_cls (type, optional): The TorchScript interface type for the module
- to be created. The type object should be decorated by @torch.jit.interface.
- If not provided, the generated RemoteModule is not torchscript-able.
- Warning, this is an experimental API and susceptible to frequent changes.
- Returns:
- A remote module instance which wraps the :class:`~nn.Module` created by the
- user-provided ``module_rref``, it has a blocking ``forward`` method and an
- asynchronous ``forward_async`` method that returns a future of the ``forward`` call
- on the user-provided module on the remote side.
- """
- # NOTE: if a new attribute is added to this class, also need to add it
- # to ``_REMOTE_MODULE_PICKLED_ATTRIBUTES`` for pickling/unpickling.
- remote_module = object.__new__(RemoteModule)
- enable_moving_cpu_tensors_to_cuda = remote_module._prepare_init(remote_device)
- if _module_interface_cls is not None:
- # Users reply on this field to know if this generated RemoteModule is TorchScript-able.
- remote_module.is_scriptable = True
- remote_module._init_template(
- _module_interface_cls, enable_moving_cpu_tensors_to_cuda
- )
- else:
- remote_module.is_scriptable = False
- remote_module.generated_methods = (
- _NON_SCRIPTABLE_REMOTE_MODULE_MODULE._generated_methods
- )
- remote_module.module_rref = module_rref
- remote_module._install_generated_methods()
- remote_module._check_attribute_picklability()
- return remote_module
- class RemoteModule(_RemoteModule):
- """
- A RemoteModule instance can only be created after RPC initialization.
- It creates a user-specified module on a specified remote node.
- It behaves like a regular ``nn.Module`` except that the ``forward`` method is
- executed on the remote node.
- It takes care of autograd recording to ensure the backward pass propagates
- gradients back to the corresponding remote module.
- It generates two methods ``forward_async`` and ``forward`` based on the
- signature of the ``forward`` method of ``module_cls``. ``forward_async``
- runs asynchronously and returns a Future. The arguments of ``forward_async``
- and ``forward`` are the same as the ``forward`` method of the module
- returned by the ``module_cls``.
- For example, if ``module_cls`` returns an instance of ``nn.Linear``,
- that has ``forward`` method signature: ``def forward(input: Tensor) -> Tensor:``,
- the generated ``RemoteModule`` will have 2 methods with the signatures:
- | ``def forward(input: Tensor) -> Tensor:``
- | ``def forward_async(input: Tensor) -> Future[Tensor]:``
- Args:
- remote_device (str): Device on the destination worker where we'd like to place this module.
- The format should be "<workername>/<device>", where the device field can be parsed as torch.device type.
- E.g., "trainer0/cpu", "trainer0", "ps0/cuda:0".
- In addition, the device field can be optional and the default value is "cpu".
- module_cls (nn.Module): Class for the module to be created remotely. For example,
- >>> class MyModule(nn.Module):
- >>> def forward(input):
- >>> return input + 1
- >>>
- >>> module_cls = MyModule
- args (Sequence, optional): args to be passed to ``module_cls``.
- kwargs (Dict, optional): kwargs to be passed to ``module_cls``.
- Returns:
- A remote module instance which wraps the :class:`~nn.Module` created by the
- user-provided ``module_cls``, it has a blocking ``forward`` method and an
- asynchronous ``forward_async`` method that returns a future of the ``forward`` call
- on the user-provided module on the remote side.
- Example::
- Run the following code in two different processes:
- >>> # xdoctest: +SKIP("distributed")
- >>> # On worker 0:
- >>> import torch
- >>> import torch.distributed.rpc as rpc
- >>> from torch import nn, Tensor
- >>> from torch.distributed.nn.api.remote_module import RemoteModule
- >>>
- >>> rpc.init_rpc("worker0", rank=0, world_size=2)
- >>> remote_linear_module = RemoteModule(
- >>> "worker1/cpu", nn.Linear, args=(20, 30),
- >>> )
- >>> input = torch.randn(128, 20)
- >>> ret_fut = remote_linear_module.forward_async(input)
- >>> ret = ret_fut.wait()
- >>> rpc.shutdown()
- >>> # On worker 1:
- >>> import torch
- >>> import torch.distributed.rpc as rpc
- >>>
- >>> rpc.init_rpc("worker1", rank=1, world_size=2)
- >>> rpc.shutdown()
- Furthermore, a more practical example that is combined with
- `DistributedDataParallel <https://pytorch.org/docs/stable/nn.html#torch.nn.parallel.DistributedDataParallel>`__ (DDP)
- can be found in this `tutorial <https://pytorch.org/tutorials/advanced/rpc_ddp_tutorial.html>`__.
- """
- def __init__(
- self,
- remote_device: str,
- module_cls: Type[nn.Module],
- args: Tuple = None,
- kwargs: Dict[str, Any] = None,
- ):
- super().__init__(remote_device, module_cls, args, kwargs)
- def _remote_module_receiver(
- *remote_module_pickled_attrs,
- ):
- """
- Deserializes a RemoteModule.
- """
- serialized_remote_module = _SerializedRemoteModule._make(
- remote_module_pickled_attrs
- )
- m = object.__new__(RemoteModule)
- m.__dict__.update(serialized_remote_module._asdict())
- # Unpickling the attribute `module_rref` must invoke RRef's `_deserialize()` method.
- m.module_rref = rpc.PyRRef._deserialize(m.module_rref)
- # Install generated methods when unpickled.
- for method in m.generated_methods:
- method_name = method.__name__
- method = torch.jit.export(method)
- setattr(m, method_name, types.MethodType(method, m))
- return m
- def _remote_module_reducer(remote_module):
- """
- Serializes a RemoteModule.
- """
- pickled_attrs = {}
- for k, v in remote_module.__dict__.items():
- # Pickling the attribute `module_rref` must invoke RRef's `_serialize()` method.
- if k == "module_rref":
- pickled_attrs[k] = v._serialize()
- elif k in _REMOTE_MODULE_PICKLED_ATTRIBUTES:
- pickled_attrs[k] = v
- # Check if unpickled attributes are all in _REMOTE_MODULE_ATTRIBUTES_IGNORE_FOR_PICKLING.
- elif k not in _REMOTE_MODULE_ATTRIBUTES_IGNORE_FOR_PICKLING:
- print(
- "The new attribute ``{}`` of RemoteModule is ignored during RPC pickling. "
- "To pickle this attribute, please add it to ``_REMOTE_MODULE_PICKLED_ATTRIBUTES``. "
- "Otherwise, please explicitly add it to ``_REMOTE_MODULE_ATTRIBUTES_IGNORE_FOR_PICKLING``.".format(
- k
- ),
- file=sys.stderr,
- )
- return (
- _remote_module_receiver,
- tuple(pickled_attrs.values()),
- )
- def _recursive_script_module_receiver(
- recursive_script_module_serialized,
- ):
- """
- Deserializes a RecursiveScirptModule that does not contain a script RemoteModule.
- """
- f = io.BytesIO(recursive_script_module_serialized)
- m = torch.jit.load(f)
- return m
- def _recursive_script_module_reducer(recursive_script_module):
- """
- Serializes a RecursiveScirptModule that does not contain a script RemoteModule,
- and raises an error otherwise.
- """
- if hasattr(recursive_script_module._c, "module_rref"):
- raise RuntimeError(
- "Passing a script RemoteModule over RPC is not supported. Please create a RemoteModule in the sender, "
- "send the `module_rref` to the receiver, and create a new instance on the receiver end by passing this `module_rref`."
- )
- f = io.BytesIO()
- torch.jit.save(recursive_script_module, f)
- return (_recursive_script_module_receiver, (f.getvalue(),))
- _internal_rpc_pickler._register_reducer(RemoteModule, _remote_module_reducer)
- _internal_rpc_pickler._register_reducer(
- torch.jit.RecursiveScriptModule, _recursive_script_module_reducer
- )
|