portal.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232
  1. # -*- coding: utf-8 -*-
  2. # Copyright 2019 Kakao Brain
  3. #
  4. # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
  5. #
  6. # This source code is licensed under the BSD license found in the
  7. # LICENSE file in the root directory of this source tree.
  8. """Portal keeps a tensor in the pocket plane. The tensor becomes hidden to the
  9. autograd engine. The shared context of three functions (:class:`PortalBlue`,
  10. :class:`PortalOrange`, and :class:`PortalCopy`) out of the computation graph is
  11. one of the most important feature of :mod:`torchpipe.skip`.
  12. The metaphor is inspired by Portal™ from Valve.
  13. """
  14. from typing import List, Optional, Tuple
  15. import torch
  16. from torch import Tensor
  17. from ..copy import Context as CopyContext
  18. from ..copy import Copy
  19. from ..phony import get_phony
  20. from ..stream import AbstractStream, get_device
  21. __all__: List[str] = []
  22. class Portal:
  23. """A portal for a tensor."""
  24. def __init__(self, tensor: Optional[Tensor], tensor_life: int) -> None:
  25. self.put_tensor(tensor, tensor_life)
  26. self.grad: Optional[Tensor] = None
  27. def blue(self) -> Tensor:
  28. """Creates a :class:`PortalBlue` which hides the underlying tensor from
  29. the autograd engine.
  30. Join the returning phony to the main lane of the autograd graph to
  31. assure the correct backpropagation::
  32. PortalBlue --+
  33. |
  34. ---------- Join --
  35. """
  36. tensor = self.use_tensor()
  37. if tensor is None:
  38. return get_phony(torch.device("cpu"), requires_grad=False)
  39. return PortalBlue.apply(self, tensor)
  40. def orange(self, phony: Tensor) -> Optional[Tensor]:
  41. """Creates a :class:`PortalOrange` which retrieves the hidden tensor
  42. without losing ability of backpropagation.
  43. Give a phony forked from the main lane of an autograd graph::
  44. +-- PortalOrange --+
  45. | |
  46. -- Fork --------- f(a, b) --
  47. """
  48. self.check_tensor_life()
  49. if self.tensor is None:
  50. return self.use_tensor()
  51. return PortalOrange.apply(self, phony)
  52. def copy(self, prev_stream: AbstractStream, next_stream: AbstractStream, phony: Tensor,) -> Tensor:
  53. """Copies the hidden tensor by a :class:`PortalCopy`.
  54. Give a phony and use the returning phony to keep backpropagation::
  55. +-- PortalCopy --+
  56. | |
  57. -- Fork ---------- Join --
  58. """
  59. if self.tensor is None:
  60. return get_phony(torch.device("cpu"), requires_grad=False)
  61. return PortalCopy.apply(self, prev_stream, next_stream, phony)
  62. def check_tensor_life(self) -> None:
  63. if self.tensor_life <= 0:
  64. raise RuntimeError("tensor in portal has been removed")
  65. def put_tensor(self, tensor: Optional[Tensor], tensor_life: int) -> None:
  66. """Stores a tensor into this portal."""
  67. # [Life of Tensor through Portal]
  68. #
  69. # The tensor can be retrieved by use_tensor() up to 'tensor_life'
  70. # times. When the life becomes 0, the tensor will be deleted for
  71. # deallocation in CUDA memory.
  72. #
  73. # The below events participate in a tensor through a portal.
  74. # Note that [x] denotes the events which call use_tensor():
  75. #
  76. # 1. [x] blue()
  77. # 2. [ ] PortalBlue.forward
  78. # 3. [ ] copy()
  79. # 4. [ ] PortalCopy.forward
  80. # 5. [ ] orange()
  81. # 6. [x] PortalOrange.forward
  82. # - - - - - - - - - - - - - - - - - - - - - - - - - - -
  83. # 7. [ ] orange() (recomputed)
  84. # 8. [x] PortalOrange.forward (recomputed)
  85. # 9. [ ] PortalOrange.backward
  86. # 10. [ ] PortalCopy.backward
  87. # 11. [x] blue() (recomputed)
  88. # 12. [ ] PortalBlue.forward (recomputed)
  89. # 13. [ ] PortalBlue.backward
  90. #
  91. self.tensor_life = tensor_life
  92. if tensor_life > 0:
  93. self.tensor = tensor
  94. else:
  95. self.tensor = None
  96. def use_tensor(self) -> Optional[Tensor]:
  97. """Retrieves the underlying tensor and decreases the tensor life. When
  98. the life becomes 0, it the tensor will be removed.
  99. """
  100. self.check_tensor_life()
  101. tensor = self.tensor
  102. self.tensor_life -= 1
  103. if self.tensor_life <= 0:
  104. self.tensor = None
  105. return tensor
  106. def put_grad(self, grad: Tensor) -> None:
  107. """Stores a gradient into this portal."""
  108. self.grad = grad
  109. def use_grad(self) -> Tensor:
  110. """Retrieves and removes the underlying gradient. The gradient is
  111. always ephemeral.
  112. """
  113. if self.grad is None:
  114. raise RuntimeError("grad in portal has been removed or never set")
  115. grad = self.grad
  116. self.grad = None
  117. return grad
  118. # Common interface between :class:`PortalBlue`, :class:`PortalOrange`, and
  119. # :class:`PortalCopy`.
  120. class Context(CopyContext):
  121. portal: Portal
  122. class PortalBlue(torch.autograd.Function):
  123. """Hides a tensor from the autograd engine by a :class:`Portal`."""
  124. @staticmethod
  125. # type: ignore[override]
  126. def forward(
  127. ctx: Context,
  128. portal: Portal,
  129. # This tensor must be retrieved by portal.use_tensor().
  130. tensor: Tensor,
  131. ) -> Tensor:
  132. ctx.portal = portal
  133. phony = get_phony(tensor.device, requires_grad=False)
  134. return phony.detach()
  135. @staticmethod
  136. # type: ignore[override]
  137. def backward(ctx: Context, grad_phony: Tensor,) -> Tuple[None, Tensor]:
  138. # The paired PortalOrange should keep the gradient.
  139. grad = ctx.portal.use_grad()
  140. return None, grad
  141. class PortalOrange(torch.autograd.Function):
  142. """Retrieves the hidden tensor from a :class:`Portal`."""
  143. @staticmethod
  144. # type: ignore[override]
  145. def forward(ctx: Context, portal: Portal, phony: Tensor) -> Tensor:
  146. ctx.portal = portal
  147. tensor = portal.use_tensor()
  148. assert tensor is not None
  149. return tensor.detach()
  150. @staticmethod
  151. def backward(ctx: Context, grad: Tensor) -> Tuple[None, None]: # type: ignore[override]
  152. # The paired PortalBlue will use the gradient.
  153. ctx.portal.put_grad(grad)
  154. return None, None
  155. class PortalCopy(torch.autograd.Function):
  156. """Copies the hidden tensor in a :class:`Portal`. It replaces the hidden
  157. tensor with copied one.
  158. """
  159. @staticmethod
  160. # type: ignore[override]
  161. def forward(
  162. ctx: Context, portal: Portal, prev_stream: AbstractStream, next_stream: AbstractStream, phony: Tensor,
  163. ) -> Tensor:
  164. ctx.portal = portal
  165. assert portal.tensor is not None
  166. (portal.tensor,) = Copy.forward(ctx, prev_stream, next_stream, portal.tensor)
  167. phony = get_phony(get_device(next_stream), requires_grad=False)
  168. return phony.detach()
  169. @staticmethod
  170. # type: ignore[override]
  171. def backward(ctx: Context, grad_phony: Tensor,) -> Tuple[None, None, None, None]:
  172. portal = ctx.portal
  173. assert portal.grad is not None
  174. _, _, portal.grad = Copy.backward(ctx, portal.grad)
  175. return None, None, None, None