1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374 |
- from functools import partial
- from . import functions
- from . import rpc_async
- import torch
- from .constants import UNSET_RPC_TIMEOUT
- from torch.futures import Future
- def _local_invoke(rref, func_name, args, kwargs):
- return getattr(rref.local_value(), func_name)(*args, **kwargs)
- @functions.async_execution
- def _local_invoke_async_execution(rref, func_name, args, kwargs):
- return getattr(rref.local_value(), func_name)(*args, **kwargs)
- def _invoke_rpc(rref, rpc_api, func_name, timeout, *args, **kwargs):
- def _rref_type_cont(rref_fut):
- rref_type = rref_fut.value()
- _invoke_func = _local_invoke
- # Bypass ScriptModules when checking for async function attribute.
- bypass_type = issubclass(rref_type, torch.jit.ScriptModule) or issubclass(
- rref_type, torch._C.ScriptModule
- )
- if not bypass_type:
- func = getattr(rref_type, func_name)
- if hasattr(func, "_wrapped_async_rpc_function"):
- _invoke_func = _local_invoke_async_execution
- return rpc_api(
- rref.owner(),
- _invoke_func,
- args=(rref, func_name, args, kwargs),
- timeout=timeout
- )
- rref_fut = rref._get_type(timeout=timeout, blocking=False)
- if rpc_api != rpc_async:
- rref_fut.wait()
- return _rref_type_cont(rref_fut)
- else:
- # A little explanation on this.
- # rpc_async returns a Future pointing to the return value of `func_name`, it returns a `Future[T]`
- # Calling _rref_type_cont from the `then` lambda causes Future wrapping. IOW, `then` returns a `Future[Future[T]]`
- # To address that, we return a Future that is completed with the result of the async call.
- result: Future = Future()
- def _wrap_rref_type_cont(fut):
- try:
- _rref_type_cont(fut).then(_complete_op)
- except BaseException as ex:
- result.set_exception(ex)
- def _complete_op(fut):
- try:
- result.set_result(fut.value())
- except BaseException as ex:
- result.set_exception(ex)
- rref_fut.then(lambda fut: _wrap_rref_type_cont(fut))
- return result
- # This class manages proxied RPC API calls for RRefs. It is entirely used from
- # C++ (see python_rpc_handler.cpp).
- class RRefProxy:
- def __init__(self, rref, rpc_api, timeout=UNSET_RPC_TIMEOUT):
- self.rref = rref
- self.rpc_api = rpc_api
- self.rpc_timeout = timeout
- def __getattr__(self, func_name):
- return partial(_invoke_rpc, self.rref, self.rpc_api, func_name, self.rpc_timeout)
|