forward_ad.py 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199
  1. import torch
  2. import os
  3. from .grad_mode import _DecoratorContextManager
  4. from collections import namedtuple
  5. from typing import Any
  6. __all__ = ["UnpackedDualTensor", "enter_dual_level", "exit_dual_level", "make_dual", "unpack_dual", "dual_level"]
  7. # Global variable used to make the python API simpler to use
  8. _current_level = -1
  9. def enter_dual_level():
  10. r"""Function that can be used to enter a new forward grad level.
  11. This level can be used to make and unpack dual Tensors to compute
  12. forward gradients.
  13. This function also updates the current level that is used by default
  14. by the other functions in this API.
  15. """
  16. global _current_level
  17. new_level = torch._C._enter_dual_level()
  18. if new_level != _current_level + 1:
  19. raise RuntimeError("Entering a new forward AD level but the current level "
  20. "is not valid. Make sure you did not modified it directly.")
  21. _current_level = new_level
  22. return new_level
  23. def exit_dual_level(*, level=None):
  24. r"""Function that can be used to exit a forward grad level.
  25. This function deletes all the gradients associated with this
  26. level. Only deleting the latest entered level is allowed.
  27. This function also updates the current level that is used by default
  28. by the other functions in this API.
  29. """
  30. global _current_level
  31. if level is None:
  32. level = _current_level
  33. if level != _current_level:
  34. raise RuntimeError("Trying to exit a forward AD level that was not the last one "
  35. "that was created. This is not supported.")
  36. torch._C._exit_dual_level(level=level)
  37. _current_level = level - 1
  38. def make_dual(tensor, tangent, *, level=None):
  39. r"""Associates a tensor value with a forward gradient, the tangent, to create a
  40. "dual tensor", which is used to compute forward AD gradients.
  41. The result is a new tensor aliased to :attr:`tensor` with :attr:`tangent` embedded
  42. as an attribute as-is if it has the same storage layout or copied otherwise.
  43. The tangent attribute can be recovered with :func:`unpack_dual`.
  44. This function is backward differentiable.
  45. Given a function `f` whose jacobian is `J`, it allows one to compute the Jacobian-vector product (`jvp`)
  46. between `J` and a given vector `v` as follows.
  47. Example::
  48. >>> # xdoctest: +SKIP("Undefined variables")
  49. >>> with dual_level():
  50. ... inp = make_dual(x, v)
  51. ... out = f(inp)
  52. ... y, jvp = unpack_dual(out)
  53. Please see the `forward-mode AD tutorial <https://pytorch.org/tutorials/intermediate/forward_ad_usage.html>`__
  54. for detailed steps on how to use this API.
  55. """
  56. # See NOTE: [forward-mode AD decompositions mechanism]
  57. #
  58. # Import from torch._decomp import decompositions_for_jvp to register
  59. # decompositions for jvp to the jit registry
  60. #
  61. # FIXME: We specify that __debug__ must be True because
  62. # if python is run with -OO or -O flags (i.e., __debug__ is False), we encounter the
  63. # following error:
  64. #
  65. # Return value was annotated as having type Tuple[NoneType, NoneType] but is actually of
  66. # type Tuple[Tensor, Tensor]:
  67. # File ".../torch/_decomp/__init__.py", line 1585
  68. # else:
  69. # buffer = z
  70. # return min - torch.log1p(z), buffer
  71. # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
  72. if os.environ.get("PYTORCH_JIT", "1") == "1" and __debug__:
  73. from torch._decomp import decompositions_for_jvp # noqa: F401
  74. if level is None:
  75. level = _current_level
  76. if level < 0:
  77. raise RuntimeError("Trying to create a dual Tensor for forward AD but no level "
  78. "exists, make sure to enter_dual_level() first.")
  79. if not (tensor.is_floating_point() or tensor.is_complex()):
  80. raise ValueError(f"Expected primal to be floating point or complex, but got: {tensor.dtype}")
  81. if not (tangent.is_floating_point() or tangent.is_complex()):
  82. raise ValueError(f"Expected tangent to be floating point or complex, but got: {tangent.dtype}")
  83. return torch._VF._make_dual(tensor, tangent, level=level)
  84. _UnpackedDualTensor = namedtuple('_UnpackedDualTensor', ['primal', 'tangent'])
  85. class UnpackedDualTensor(_UnpackedDualTensor):
  86. r"""Namedtuple returned by :func:`unpack_dual` containing the primal and tangent components of the dual tensor.
  87. See :func:`unpack_dual` for more details."""
  88. pass
  89. def unpack_dual(tensor, *, level=None):
  90. r"""Unpacks a "dual tensor" to get both its Tensor value and its forward AD gradient.
  91. The result is a namedtuple ``(primal, tangent)`` where ``primal`` is a view of
  92. :attr:`tensor`'s primal and ``tangent`` is :attr:`tensor`'s tangent as-is.
  93. Neither of these tensors can be dual tensor of level :attr:`level`.
  94. This function is backward differentiable.
  95. Example::
  96. >>> # xdoctest: +SKIP("Undefined variables")
  97. >>> with dual_level():
  98. ... inp = make_dual(x, x_t)
  99. ... out = f(inp)
  100. ... y, jvp = unpack_dual(out)
  101. ... jvp = unpack_dual(out).tangent
  102. Please see the `forward-mode AD tutorial <https://pytorch.org/tutorials/intermediate/forward_ad_usage.html>`__
  103. for detailed steps on how to use this API.
  104. """
  105. if level is None:
  106. level = _current_level
  107. if level < 0:
  108. return UnpackedDualTensor(tensor, None)
  109. primal, dual = torch._VF._unpack_dual(tensor, level=level)
  110. return UnpackedDualTensor(primal, dual)
  111. class dual_level(_DecoratorContextManager):
  112. r"""Context-manager that enables forward AD. All forward AD computation must
  113. be performed in a ``dual_level`` context.
  114. .. Note::
  115. The ``dual_level`` context appropriately enters and exit the dual level to
  116. controls the current forward AD level, which is used by default by the other
  117. functions in this API.
  118. We currently don't plan to support nested ``dual_level`` contexts, however, so
  119. only a single forward AD level is supported. To compute higher-order
  120. forward grads, one can use :func:`torch.func.jvp`.
  121. Example::
  122. >>> # xdoctest: +SKIP("Undefined variables")
  123. >>> x = torch.tensor([1])
  124. >>> x_t = torch.tensor([1])
  125. >>> with dual_level():
  126. ... inp = make_dual(x, x_t)
  127. ... # Do computations with inp
  128. ... out = your_fn(inp)
  129. ... _, grad = unpack_dual(out)
  130. >>> grad is None
  131. False
  132. >>> # After exiting the level, the grad is deleted
  133. >>> _, grad_after = unpack_dual(out)
  134. >>> grad is None
  135. True
  136. Please see the `forward-mode AD tutorial <https://pytorch.org/tutorials/intermediate/forward_ad_usage.html>`__
  137. for detailed steps on how to use this API.
  138. """
  139. def __enter__(self):
  140. return enter_dual_level()
  141. def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
  142. exit_dual_level()
  143. # Private helper functions
  144. _is_fwd_grad_enabled = torch._C._is_fwd_grad_enabled
  145. # Private helper function to enable or disable fwd grad.
  146. # If you're a user and want to use this, please file an issue to discuss the use case.
  147. class _set_fwd_grad_enabled(_DecoratorContextManager):
  148. def __init__(self, mode: bool) -> None:
  149. self.prev = _is_fwd_grad_enabled()
  150. torch._C._set_fwd_grad_enabled(mode)
  151. def __enter__(self) -> None:
  152. pass
  153. def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
  154. torch._C._set_fwd_grad_enabled(self.prev)