__init__.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369
  1. """
  2. ``torch.autograd`` provides classes and functions implementing automatic
  3. differentiation of arbitrary scalar valued functions. It requires minimal
  4. changes to the existing code - you only need to declare :class:`Tensor` s
  5. for which gradients should be computed with the ``requires_grad=True`` keyword.
  6. As of now, we only support autograd for floating point :class:`Tensor` types (
  7. half, float, double and bfloat16) and complex :class:`Tensor` types (cfloat, cdouble).
  8. """
  9. import torch
  10. import warnings
  11. from torch.types import _TensorOrTensors, _size
  12. from typing import Any, Callable, List, Optional, Sequence, Tuple, Union, cast
  13. from .variable import Variable
  14. from .function import Function, NestedIOFunction
  15. from .gradcheck import gradcheck, gradgradcheck
  16. from .grad_mode import (
  17. no_grad, enable_grad, set_grad_enabled, inference_mode, set_multithreading_enabled, _force_original_view_tracking,
  18. _unsafe_preserve_version_counter
  19. )
  20. from .anomaly_mode import detect_anomaly, set_detect_anomaly
  21. from ..overrides import has_torch_function, handle_torch_function, is_tensor_like
  22. from . import functional
  23. from . import forward_ad
  24. from . import graph
  25. from .. import _vmap_internals
  26. __all__ = ['Variable', 'Function', 'backward', 'grad_mode']
  27. _OptionalTensor = Optional[torch.Tensor]
  28. _ShapeorNestedShape = Union[_size, Sequence[_size], torch.Tensor]
  29. def _calculate_shape(output: torch.Tensor, grad: torch.Tensor,
  30. is_grads_batched: bool) -> Tuple[_ShapeorNestedShape, _ShapeorNestedShape]:
  31. # is_same_size ensures that both tensors are either nested or non nested
  32. if output.is_nested:
  33. if is_grads_batched:
  34. raise RuntimeError("Batched grads are not supported with Nested Tensor.")
  35. out_shape = output._nested_tensor_size()
  36. grad_shape = grad._nested_tensor_size()
  37. return out_shape, grad_shape
  38. reg_out_shape = output.shape
  39. reg_grad_shape = grad.shape if not is_grads_batched else grad.shape[1:]
  40. return reg_out_shape, reg_grad_shape
  41. def _make_grads(outputs: Sequence[torch.Tensor], grads: Sequence[_OptionalTensor],
  42. is_grads_batched: bool) -> Tuple[_OptionalTensor, ...]:
  43. new_grads: List[_OptionalTensor] = []
  44. for out, grad in zip(outputs, grads):
  45. if isinstance(grad, torch.Tensor):
  46. first_grad = grad if not is_grads_batched else grad[0]
  47. if not torch.is_same_size(out, first_grad):
  48. out_shape, grad_shape = _calculate_shape(out, first_grad, is_grads_batched)
  49. if is_grads_batched:
  50. raise RuntimeError("If `is_grads_batched=True`, we interpret the first "
  51. "dimension of each grad_output as the batch dimension. "
  52. "The sizes of the remaining dimensions are expected to match "
  53. "the shape of corresponding output, but a mismatch "
  54. "was detected: grad_output["
  55. + str(grads.index(grad)) + "] has a shape of "
  56. + str(grad_shape) + " and output["
  57. + str(outputs.index(out)) + "] has a shape of "
  58. + str(out_shape) + ". "
  59. "If you only want some tensors in `grad_output` to be considered "
  60. "batched, consider using vmap.")
  61. else:
  62. raise RuntimeError("Mismatch in shape: grad_output["
  63. + str(grads.index(grad)) + "] has a shape of "
  64. + str(grad_shape) + " and output["
  65. + str(outputs.index(out)) + "] has a shape of "
  66. + str(out_shape) + ".")
  67. if out.dtype.is_complex != grad.dtype.is_complex:
  68. raise RuntimeError("For complex Tensors, both grad_output and output"
  69. " are required to have the same dtype."
  70. " Mismatch in dtype: grad_output["
  71. + str(grads.index(grad)) + "] has a dtype of "
  72. + str(grad.dtype) + " and output["
  73. + str(outputs.index(out)) + "] has a dtype of "
  74. + str(out.dtype) + ".")
  75. new_grads.append(grad)
  76. elif grad is None:
  77. if out.requires_grad:
  78. if out.numel() != 1:
  79. raise RuntimeError("grad can be implicitly created only for scalar outputs")
  80. new_grads.append(torch.ones_like(out, memory_format=torch.preserve_format))
  81. else:
  82. new_grads.append(None)
  83. else:
  84. raise TypeError("gradients can be either Tensors or None, but got " +
  85. type(grad).__name__)
  86. return tuple(new_grads)
  87. def _tensor_or_tensors_to_tuple(tensors: Optional[_TensorOrTensors], length: int) -> Tuple[_OptionalTensor, ...]:
  88. if tensors is None:
  89. return (None, ) * length
  90. if isinstance(tensors, torch.Tensor):
  91. return (tensors, )
  92. return tuple(tensors)
  93. def backward(
  94. tensors: _TensorOrTensors,
  95. grad_tensors: Optional[_TensorOrTensors] = None,
  96. retain_graph: Optional[bool] = None,
  97. create_graph: bool = False,
  98. grad_variables: Optional[_TensorOrTensors] = None,
  99. inputs: Optional[_TensorOrTensors] = None,
  100. ) -> None:
  101. r"""Computes the sum of gradients of given tensors with respect to graph
  102. leaves.
  103. The graph is differentiated using the chain rule. If any of ``tensors``
  104. are non-scalar (i.e. their data has more than one element) and require
  105. gradient, then the Jacobian-vector product would be computed, in this
  106. case the function additionally requires specifying ``grad_tensors``.
  107. It should be a sequence of matching length, that contains the "vector"
  108. in the Jacobian-vector product, usually the gradient of the differentiated
  109. function w.r.t. corresponding tensors (``None`` is an acceptable value for
  110. all tensors that don't need gradient tensors).
  111. This function accumulates gradients in the leaves - you might need to zero
  112. ``.grad`` attributes or set them to ``None`` before calling it.
  113. See :ref:`Default gradient layouts<default-grad-layouts>`
  114. for details on the memory layout of accumulated gradients.
  115. .. note::
  116. Using this method with ``create_graph=True`` will create a reference cycle
  117. between the parameter and its gradient which can cause a memory leak.
  118. We recommend using ``autograd.grad`` when creating the graph to avoid this.
  119. If you have to use this function, make sure to reset the ``.grad`` fields of your
  120. parameters to ``None`` after use to break the cycle and avoid the leak.
  121. .. note::
  122. If you run any forward ops, create ``grad_tensors``, and/or call ``backward``
  123. in a user-specified CUDA stream context, see
  124. :ref:`Stream semantics of backward passes<bwd-cuda-stream-semantics>`.
  125. .. note::
  126. When ``inputs`` are provided and a given input is not a leaf,
  127. the current implementation will call its grad_fn (even though it is not strictly needed to get this gradients).
  128. It is an implementation detail on which the user should not rely.
  129. See https://github.com/pytorch/pytorch/pull/60521#issuecomment-867061780 for more details.
  130. Args:
  131. tensors (Sequence[Tensor] or Tensor): Tensors of which the derivative will be
  132. computed.
  133. grad_tensors (Sequence[Tensor or None] or Tensor, optional): The "vector" in
  134. the Jacobian-vector product, usually gradients w.r.t. each element of
  135. corresponding tensors. None values can be specified for scalar Tensors or
  136. ones that don't require grad. If a None value would be acceptable for all
  137. grad_tensors, then this argument is optional.
  138. retain_graph (bool, optional): If ``False``, the graph used to compute the grad
  139. will be freed. Note that in nearly all cases setting this option to ``True``
  140. is not needed and often can be worked around in a much more efficient
  141. way. Defaults to the value of ``create_graph``.
  142. create_graph (bool, optional): If ``True``, graph of the derivative will
  143. be constructed, allowing to compute higher order derivative products.
  144. Defaults to ``False``.
  145. inputs (Sequence[Tensor] or Tensor, optional): Inputs w.r.t. which the gradient
  146. be will accumulated into ``.grad``. All other Tensors will be ignored. If
  147. not provided, the gradient is accumulated into all the leaf Tensors that
  148. were used to compute the attr::tensors.
  149. """
  150. if torch._C._are_functorch_transforms_active():
  151. raise RuntimeError(
  152. "backward() called inside a functorch transform. This is not "
  153. "supported, please use functorch.grad or functorch.vjp instead "
  154. "or call backward() outside of functorch transforms.")
  155. if grad_variables is not None:
  156. warnings.warn("'grad_variables' is deprecated. Use 'grad_tensors' instead.")
  157. if grad_tensors is None:
  158. grad_tensors = grad_variables
  159. else:
  160. raise RuntimeError("'grad_tensors' and 'grad_variables' (deprecated) "
  161. "arguments both passed to backward(). Please only "
  162. "use 'grad_tensors'.")
  163. if inputs is not None and len(inputs) == 0:
  164. raise RuntimeError("'inputs' argument to backward() cannot be empty.")
  165. tensors = (tensors,) if isinstance(tensors, torch.Tensor) else tuple(tensors)
  166. inputs = (inputs,) if isinstance(inputs, torch.Tensor) else \
  167. tuple(inputs) if inputs is not None else tuple()
  168. grad_tensors_ = _tensor_or_tensors_to_tuple(grad_tensors, len(tensors))
  169. grad_tensors_ = _make_grads(tensors, grad_tensors_, is_grads_batched=False)
  170. if retain_graph is None:
  171. retain_graph = create_graph
  172. # The reason we repeat same the comment below is that
  173. # some Python versions print out the first line of a multi-line function
  174. # calls in the traceback and some print out the last line
  175. Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
  176. tensors, grad_tensors_, retain_graph, create_graph, inputs,
  177. allow_unreachable=True, accumulate_grad=True) # Calls into the C++ engine to run the backward pass
  178. def grad(
  179. outputs: _TensorOrTensors,
  180. inputs: _TensorOrTensors,
  181. grad_outputs: Optional[_TensorOrTensors] = None,
  182. retain_graph: Optional[bool] = None,
  183. create_graph: bool = False,
  184. only_inputs: bool = True,
  185. allow_unused: bool = False,
  186. is_grads_batched: bool = False
  187. ) -> Tuple[torch.Tensor, ...]:
  188. r"""Computes and returns the sum of gradients of outputs with respect to
  189. the inputs.
  190. ``grad_outputs`` should be a sequence of length matching ``output``
  191. containing the "vector" in vector-Jacobian product, usually the pre-computed
  192. gradients w.r.t. each of the outputs. If an output doesn't require_grad,
  193. then the gradient can be ``None``).
  194. .. note::
  195. If you run any forward ops, create ``grad_outputs``, and/or call ``grad``
  196. in a user-specified CUDA stream context, see
  197. :ref:`Stream semantics of backward passes<bwd-cuda-stream-semantics>`.
  198. .. note::
  199. ``only_inputs`` argument is deprecated and is ignored now (defaults to ``True``).
  200. To accumulate gradient for other parts of the graph, please use
  201. ``torch.autograd.backward``.
  202. Args:
  203. outputs (sequence of Tensor): outputs of the differentiated function.
  204. inputs (sequence of Tensor): Inputs w.r.t. which the gradient will be
  205. returned (and not accumulated into ``.grad``).
  206. grad_outputs (sequence of Tensor): The "vector" in the vector-Jacobian product.
  207. Usually gradients w.r.t. each output. None values can be specified for scalar
  208. Tensors or ones that don't require grad. If a None value would be acceptable
  209. for all grad_tensors, then this argument is optional. Default: None.
  210. retain_graph (bool, optional): If ``False``, the graph used to compute the grad
  211. will be freed. Note that in nearly all cases setting this option to ``True``
  212. is not needed and often can be worked around in a much more efficient
  213. way. Defaults to the value of ``create_graph``.
  214. create_graph (bool, optional): If ``True``, graph of the derivative will
  215. be constructed, allowing to compute higher order derivative products.
  216. Default: ``False``.
  217. allow_unused (bool, optional): If ``False``, specifying inputs that were not
  218. used when computing outputs (and therefore their grad is always zero)
  219. is an error. Defaults to ``False``.
  220. is_grads_batched (bool, optional): If ``True``, the first dimension of each
  221. tensor in ``grad_outputs`` will be interpreted as the batch dimension.
  222. Instead of computing a single vector-Jacobian product, we compute a
  223. batch of vector-Jacobian products for each "vector" in the batch.
  224. We use the vmap prototype feature as the backend to vectorize calls
  225. to the autograd engine so that this computation can be performed in a
  226. single call. This should lead to performance improvements when compared
  227. to manually looping and performing backward multiple times. Note that
  228. due to this feature being experimental, there may be performance
  229. cliffs. Please use ``torch._C._debug_only_display_vmap_fallback_warnings(True)``
  230. to show any performance warnings and file an issue on github if warnings exist
  231. for your use case. Defaults to ``False``.
  232. """
  233. t_outputs = cast(Tuple[torch.Tensor, ...], (outputs,) if is_tensor_like(outputs) else tuple(outputs))
  234. t_inputs = cast(Tuple[torch.Tensor, ...], (inputs,) if is_tensor_like(inputs) else tuple(inputs))
  235. overridable_args = t_outputs + t_inputs
  236. if has_torch_function(overridable_args):
  237. return handle_torch_function(
  238. grad,
  239. overridable_args,
  240. t_outputs,
  241. t_inputs,
  242. grad_outputs=grad_outputs,
  243. retain_graph=retain_graph,
  244. create_graph=create_graph,
  245. only_inputs=only_inputs,
  246. allow_unused=allow_unused,
  247. is_grads_batched=is_grads_batched,
  248. )
  249. if not only_inputs:
  250. warnings.warn("only_inputs argument is deprecated and is ignored now "
  251. "(defaults to True). To accumulate gradient for other "
  252. "parts of the graph, please use torch.autograd.backward.")
  253. grad_outputs_ = _tensor_or_tensors_to_tuple(grad_outputs, len(t_outputs))
  254. grad_outputs_ = _make_grads(t_outputs, grad_outputs_, is_grads_batched=is_grads_batched)
  255. if retain_graph is None:
  256. retain_graph = create_graph
  257. # The reason we repeat same the comment several times below is because
  258. # some Python versions print out the first line of multi-line function
  259. # calls in the traceback and some print out the last line
  260. if is_grads_batched:
  261. def vjp(gO):
  262. return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
  263. t_outputs, gO, retain_graph, create_graph, t_inputs,
  264. allow_unused, accumulate_grad=False) # Calls into the C++ engine to run the backward pass
  265. return _vmap_internals._vmap(vjp, 0, 0, allow_none_pass_through=True)(grad_outputs_)
  266. else:
  267. return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
  268. t_outputs, grad_outputs_, retain_graph, create_graph, t_inputs,
  269. allow_unused, accumulate_grad=False) # Calls into the C++ engine to run the backward pass
  270. # This function applies in case of gradient checkpointing for memory
  271. # optimization. Currently, gradient checkpointing is supported only if the
  272. # execution engine is invoked through torch.autograd.backward() and its
  273. # inputs argument is not passed. It is not supported for torch.autograd.grad().
  274. # This is because if inputs are specified, the gradient won't be calculated for
  275. # anything else e.g. model parameters like weights, bias etc.
  276. #
  277. # This function returns whether the checkpointing is valid i.e. torch.autograd.backward
  278. # or not i.e. torch.autograd.grad. The implementation works by maintaining a thread
  279. # local variable in torch/csrc/autograd/engine.cpp which looks at the NodeTask
  280. # in the stack and before a NodeTask is executed in evaluate_function, it
  281. # checks for whether reentrant backwards is imperative or not.
  282. # See https://github.com/pytorch/pytorch/pull/4594 for more discussion/context
  283. def _is_checkpoint_valid():
  284. return Variable._execution_engine.is_checkpoint_valid()
  285. def variable(*args, **kwargs):
  286. raise RuntimeError("torch.autograd.variable(...) is deprecated, use torch.tensor(...) instead")
  287. # Monkey patching variable.Variable to fix FX codegen. FX generates a call by roughly doing
  288. # f"{fn.__module__}.{fn.__name__}(...). This yields torch.autograd.variable.Variable(...) in the
  289. # output of an FX graph. Unfortunately the module name torch.autograd.variable is shadowed by the
  290. # deprecated function - variable(...).
  291. variable.Variable = Variable # type: ignore[attr-defined]
  292. if not torch._C._autograd_init():
  293. raise RuntimeError("autograd initialization failed")
  294. # Import all native method/classes
  295. from torch._C._autograd import (
  296. _add_metadata_json,
  297. _disable_profiler,
  298. _disable_profiler_legacy,
  299. _enable_profiler,
  300. _enable_profiler_legacy,
  301. _enable_record_function,
  302. _kineto_step,
  303. _KinetoEvent,
  304. _pop_saved_tensors_default_hooks,
  305. _prepare_profiler,
  306. _profiler_enabled,
  307. _ProfilerResult,
  308. _push_saved_tensors_default_hooks,
  309. _record_function_with_args_enter,
  310. _record_function_with_args_exit,
  311. _set_empty_test_observer,
  312. _supported_activities,
  313. DeviceType,
  314. kineto_available,
  315. ProfilerEvent,
  316. SavedTensor,
  317. )
  318. from torch._C._profiler import ProfilerActivity, ProfilerConfig, ProfilerState
  319. from . import profiler
  320. def _register_py_tensor_class_for_device(device, cls):
  321. if not isinstance(cls, type):
  322. raise RuntimeError("cls isn't a typeinfo object")
  323. torch._C._register_py_class_for_device(device, cls)