_async.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. """Async API
  2. This module contains the API for parallelism in TorchScript, notably:
  3. * torch.jit.fork
  4. * torch.jit.wait
  5. This is not intended to be imported directly; please use the exposed
  6. functionalities in `torch.jit`.
  7. """
  8. import torch
  9. from torch.utils import set_module
  10. from torch.jit._builtins import _register_builtin
  11. from torch._jit_internal import Future
  12. set_module(Future, "torch.jit")
  13. def fork(func, *args, **kwargs):
  14. r"""
  15. Creates an asynchronous task executing `func` and a reference to the value
  16. of the result of this execution. `fork` will return immediately,
  17. so the return value of `func` may not have been computed yet. To force completion
  18. of the task and access the return value invoke `torch.jit.wait` on the Future. `fork` invoked
  19. with a `func` which returns `T` is typed as `torch.jit.Future[T]`. `fork` calls can be arbitrarily
  20. nested, and may be invoked with positional and keyword arguments.
  21. Asynchronous execution will only occur when run in TorchScript. If run in pure python,
  22. `fork` will not execute in parallel. `fork` will also not execute in parallel when invoked
  23. while tracing, however the `fork` and `wait` calls will be captured in the exported IR Graph.
  24. .. warning::
  25. `fork` tasks will execute non-deterministically. We recommend only spawning
  26. parallel fork tasks for pure functions that do not modify their inputs,
  27. module attributes, or global state.
  28. Args:
  29. func (callable or torch.nn.Module): A Python function or `torch.nn.Module`
  30. that will be invoked. If executed in TorchScript, it will execute asynchronously,
  31. otherwise it will not. Traced invocations of fork will be captured in the IR.
  32. ``*args``, ``**kwargs``: arguments to invoke `func` with.
  33. Returns:
  34. `torch.jit.Future[T]`: a reference to the execution of `func`. The value `T`
  35. can only be accessed by forcing completion of `func` through `torch.jit.wait`.
  36. Example (fork a free function):
  37. .. code-block:: python
  38. import torch
  39. from torch import Tensor
  40. def foo(a : Tensor, b : int) -> Tensor:
  41. return a + b
  42. def bar(a):
  43. fut : torch.jit.Future[Tensor] = torch.jit.fork(foo, a, b=2)
  44. return torch.jit.wait(fut)
  45. script_bar = torch.jit.script(bar)
  46. input = torch.tensor(2)
  47. # only the scripted version executes asynchronously
  48. assert script_bar(input) == bar(input)
  49. # trace is not run asynchronously, but fork is captured in IR
  50. graph = torch.jit.trace(bar, (input,)).graph
  51. assert "fork" in str(graph)
  52. Example (fork a module method):
  53. .. code-block:: python
  54. import torch
  55. from torch import Tensor
  56. class AddMod(torch.nn.Module):
  57. def forward(self, a: Tensor, b : int):
  58. return a + b
  59. class Mod(torch.nn.Module):
  60. def __init__(self):
  61. super(self).__init__()
  62. self.mod = AddMod()
  63. def forward(self, input):
  64. fut = torch.jit.fork(self.mod, a, b=2)
  65. return torch.jit.wait(fut)
  66. input = torch.tensor(2)
  67. mod = Mod()
  68. assert mod(input) == torch.jit.script(mod).forward(input)
  69. """
  70. return torch._C.fork(func, *args, **kwargs)
  71. def wait(future):
  72. r"""
  73. Forces completion of a `torch.jit.Future[T]` asynchronous task, returning the
  74. result of the task. See :func:`~fork` for docs and examples.
  75. Args:
  76. future (torch.jit.Future[T]): an asynchronous task reference, created through `torch.jit.fork`
  77. Returns:
  78. `T`: the return value of the the completed task
  79. """
  80. return torch._C.wait(future)
  81. _register_builtin(wait, "aten::wait")