functional.py 50 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028
  1. import torch
  2. from typing import Tuple, List
  3. from . import forward_ad as fwAD
  4. from torch._vmap_internals import _vmap
  5. __all__ = ["vjp", "jvp", "jacobian", "hessian", "hvp", "vhp"]
  6. # Utility functions
  7. def _as_tuple_nocheck(x):
  8. if isinstance(x, tuple):
  9. return x
  10. elif isinstance(x, list):
  11. return tuple(x)
  12. else:
  13. return x,
  14. def _as_tuple(inp, arg_name=None, fn_name=None):
  15. # Ensures that inp is a tuple of Tensors
  16. # Returns whether or not the original inp was a tuple and the tupled version of the input
  17. if arg_name is None and fn_name is None:
  18. return _as_tuple_nocheck(inp)
  19. is_inp_tuple = True
  20. if not isinstance(inp, tuple):
  21. inp = (inp,)
  22. is_inp_tuple = False
  23. for i, el in enumerate(inp):
  24. if not isinstance(el, torch.Tensor):
  25. if is_inp_tuple:
  26. raise TypeError("The {} given to {} must be either a Tensor or a tuple of Tensors but the"
  27. " value at index {} has type {}.".format(arg_name, fn_name, i, type(el)))
  28. else:
  29. raise TypeError("The {} given to {} must be either a Tensor or a tuple of Tensors but the"
  30. " given {} has type {}.".format(arg_name, fn_name, arg_name, type(el)))
  31. return is_inp_tuple, inp
  32. def _tuple_postprocess(res, to_unpack):
  33. # Unpacks a potentially nested tuple of Tensors
  34. # to_unpack should be a single boolean or a tuple of two booleans.
  35. # It is used to:
  36. # - invert _as_tuple when res should match the inp given to _as_tuple
  37. # - optionally remove nesting of two tuples created by multiple calls to _as_tuple
  38. if isinstance(to_unpack, tuple):
  39. assert len(to_unpack) == 2
  40. if not to_unpack[1]:
  41. res = tuple(el[0] for el in res)
  42. if not to_unpack[0]:
  43. res = res[0]
  44. else:
  45. if not to_unpack:
  46. res = res[0]
  47. return res
  48. def _grad_preprocess(inputs, create_graph, need_graph):
  49. # Preprocess the inputs to make sure they require gradient
  50. # inputs is a tuple of Tensors to preprocess
  51. # create_graph specifies if the user wants gradients to flow back to the Tensors in inputs
  52. # need_graph specifies if we internally want gradients to flow back to the Tensors in res
  53. # Note that we *always* create a new Tensor object to be able to see the difference between
  54. # inputs given as arguments and the same Tensors automatically captured by the user function.
  55. # Check this issue for more details on how that can happen: https://github.com/pytorch/pytorch/issues/32576
  56. res = []
  57. for inp in inputs:
  58. if create_graph and inp.requires_grad:
  59. # Create at least a new Tensor object in a differentiable way
  60. if not inp.is_sparse:
  61. # Use .view_as() to get a shallow copy
  62. res.append(inp.view_as(inp))
  63. else:
  64. # We cannot use view for sparse Tensors so we clone
  65. res.append(inp.clone())
  66. else:
  67. res.append(inp.detach().requires_grad_(need_graph))
  68. return tuple(res)
  69. def _grad_postprocess(inputs, create_graph):
  70. # Postprocess the generated Tensors to avoid returning Tensors with history when the user did not
  71. # request it.
  72. if isinstance(inputs[0], torch.Tensor):
  73. if not create_graph:
  74. return tuple(inp.detach() for inp in inputs)
  75. else:
  76. return inputs
  77. else:
  78. return tuple(_grad_postprocess(inp, create_graph) for inp in inputs)
  79. def _validate_v(v, other, is_other_tuple):
  80. # This assumes that other is the correct shape, and v should match
  81. # Both are assumed to be tuples of Tensors
  82. if len(other) != len(v):
  83. if is_other_tuple:
  84. raise RuntimeError("v is a tuple of invalid length: should be {} but got {}.".format(len(other), len(v)))
  85. else:
  86. raise RuntimeError("The given v should contain a single Tensor.")
  87. for idx, (el_v, el_other) in enumerate(zip(v, other)):
  88. if el_v.size() != el_other.size():
  89. prepend = ""
  90. if is_other_tuple:
  91. prepend = "Entry {} in ".format(idx)
  92. raise RuntimeError("{}v has invalid size: should be {} but got {}.".format(
  93. prepend, el_other.size(), el_v.size()))
  94. def _check_requires_grad(inputs, input_type, strict):
  95. # Used to make all the necessary checks to raise nice errors in strict mode.
  96. if not strict:
  97. return
  98. if input_type not in ["outputs", "grad_inputs", "jacobian", "hessian"]:
  99. raise RuntimeError("Invalid input_type to _check_requires_grad")
  100. for i, inp in enumerate(inputs):
  101. if inp is None:
  102. # This can only be reached for grad_inputs.
  103. raise RuntimeError("The output of the user-provided function is independent of input {}."
  104. " This is not allowed in strict mode.".format(i))
  105. if not inp.requires_grad:
  106. if input_type == "hessian":
  107. raise RuntimeError("The hessian of the user-provided function with respect to input {}"
  108. " is independent of the input. This is not allowed in strict mode."
  109. " You should ensure that your function is thrice differentiable and that"
  110. " the hessian depends on the inputs.".format(i))
  111. elif input_type == "jacobian":
  112. raise RuntimeError("While computing the hessian, found that the jacobian of the user-provided"
  113. " function with respect to input {} is independent of the input. This is not"
  114. " allowed in strict mode. You should ensure that your function is twice"
  115. " differentiable and that the jacobian depends on the inputs (this would be"
  116. " violated by a linear function for example).".format(i))
  117. elif input_type == "grad_inputs":
  118. raise RuntimeError("The gradient with respect to input {} is independent of the inputs of the"
  119. " user-provided function. This is not allowed in strict mode.".format(i))
  120. else:
  121. raise RuntimeError("Output {} of the user-provided function does not require gradients."
  122. " The outputs must be computed in a differentiable manner from the input"
  123. " when running in strict mode.".format(i))
  124. def _autograd_grad(outputs, inputs, grad_outputs=None, create_graph=False, retain_graph=None, is_grads_batched=False):
  125. # Version of autograd.grad that accepts `None` in outputs and do not compute gradients for them.
  126. # This has the extra constraint that inputs has to be a tuple
  127. assert isinstance(outputs, tuple)
  128. if grad_outputs is None:
  129. grad_outputs = (None,) * len(outputs)
  130. assert isinstance(grad_outputs, tuple)
  131. assert len(outputs) == len(grad_outputs)
  132. new_outputs: Tuple[torch.Tensor, ...] = tuple()
  133. new_grad_outputs: Tuple[torch.Tensor, ...] = tuple()
  134. for out, grad_out in zip(outputs, grad_outputs):
  135. if out is not None and out.requires_grad:
  136. new_outputs += (out,)
  137. new_grad_outputs += (grad_out,)
  138. if len(new_outputs) == 0:
  139. # No differentiable output, we don't need to call the autograd engine
  140. return (None,) * len(inputs)
  141. else:
  142. return torch.autograd.grad(new_outputs, inputs, new_grad_outputs, allow_unused=True,
  143. create_graph=create_graph, retain_graph=retain_graph,
  144. is_grads_batched=is_grads_batched)
  145. def _fill_in_zeros(grads, refs, strict, create_graph, stage):
  146. # Used to detect None in the grads and depending on the flags, either replace them
  147. # with Tensors full of 0s of the appropriate size based on the refs or raise an error.
  148. # strict and create graph allow us to detect when it is appropriate to raise an error
  149. # stage gives us information of which backward call we consider to give good error message
  150. if stage not in ["back", "back_trick", "double_back", "double_back_trick"]:
  151. raise RuntimeError("Invalid stage argument '{}' to _fill_in_zeros".format(stage))
  152. res: Tuple[torch.Tensor, ...] = tuple()
  153. for i, grads_i in enumerate(grads):
  154. if grads_i is None:
  155. if strict:
  156. if stage == "back":
  157. raise RuntimeError("The output of the user-provided function is independent of "
  158. "input {}. This is not allowed in strict mode.".format(i))
  159. elif stage == "back_trick":
  160. raise RuntimeError("The gradient with respect to the input is independent of entry {}"
  161. " in the grad_outputs when using the double backward trick to compute"
  162. " forward mode gradients. This is not allowed in strict mode.".format(i))
  163. elif stage == "double_back":
  164. raise RuntimeError("The jacobian of the user-provided function is independent of "
  165. "input {}. This is not allowed in strict mode.".format(i))
  166. else:
  167. raise RuntimeError("The hessian of the user-provided function is independent of "
  168. "entry {} in the grad_jacobian. This is not allowed in strict "
  169. "mode as it prevents from using the double backward trick to "
  170. "replace forward mode AD.".format(i))
  171. grads_i = torch.zeros_like(refs[i])
  172. else:
  173. if strict and create_graph and not grads_i.requires_grad:
  174. if "double" not in stage:
  175. raise RuntimeError("The jacobian of the user-provided function is independent of "
  176. "input {}. This is not allowed in strict mode when create_graph=True.".format(i))
  177. else:
  178. raise RuntimeError("The hessian of the user-provided function is independent of "
  179. "input {}. This is not allowed in strict mode when create_graph=True.".format(i))
  180. res += (grads_i,)
  181. return res
  182. # Public API
  183. def vjp(func, inputs, v=None, create_graph=False, strict=False):
  184. r"""Function that computes the dot product between a vector ``v`` and the
  185. Jacobian of the given function at the point given by the inputs.
  186. Args:
  187. func (function): a Python function that takes Tensor inputs and returns
  188. a tuple of Tensors or a Tensor.
  189. inputs (tuple of Tensors or Tensor): inputs to the function ``func``.
  190. v (tuple of Tensors or Tensor): The vector for which the vector
  191. Jacobian product is computed. Must be the same size as the output
  192. of ``func``. This argument is optional when the output of ``func``
  193. contains a single element and (if it is not provided) will be set
  194. as a Tensor containing a single ``1``.
  195. create_graph (bool, optional): If ``True``, both the output and result
  196. will be computed in a differentiable way. Note that when ``strict``
  197. is ``False``, the result can not require gradients or be
  198. disconnected from the inputs. Defaults to ``False``.
  199. strict (bool, optional): If ``True``, an error will be raised when we
  200. detect that there exists an input such that all the outputs are
  201. independent of it. If ``False``, we return a Tensor of zeros as the
  202. vjp for said inputs, which is the expected mathematical value.
  203. Defaults to ``False``.
  204. Returns:
  205. output (tuple): tuple with:
  206. func_output (tuple of Tensors or Tensor): output of ``func(inputs)``
  207. vjp (tuple of Tensors or Tensor): result of the dot product with
  208. the same shape as the inputs.
  209. Example:
  210. >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)
  211. >>> def exp_reducer(x):
  212. ... return x.exp().sum(dim=1)
  213. >>> inputs = torch.rand(4, 4)
  214. >>> v = torch.ones(4)
  215. >>> # xdoctest: +IGNORE_WANT("non-deterministic")
  216. >>> vjp(exp_reducer, inputs, v)
  217. (tensor([5.7817, 7.2458, 5.7830, 6.7782]),
  218. tensor([[1.4458, 1.3962, 1.3042, 1.6354],
  219. [2.1288, 1.0652, 1.5483, 2.5035],
  220. [2.2046, 1.1292, 1.1432, 1.3059],
  221. [1.3225, 1.6652, 1.7753, 2.0152]]))
  222. >>> vjp(exp_reducer, inputs, v, create_graph=True)
  223. (tensor([5.7817, 7.2458, 5.7830, 6.7782], grad_fn=<SumBackward1>),
  224. tensor([[1.4458, 1.3962, 1.3042, 1.6354],
  225. [2.1288, 1.0652, 1.5483, 2.5035],
  226. [2.2046, 1.1292, 1.1432, 1.3059],
  227. [1.3225, 1.6652, 1.7753, 2.0152]], grad_fn=<MulBackward0>))
  228. >>> def adder(x, y):
  229. ... return 2 * x + 3 * y
  230. >>> inputs = (torch.rand(2), torch.rand(2))
  231. >>> v = torch.ones(2)
  232. >>> vjp(adder, inputs, v)
  233. (tensor([2.4225, 2.3340]),
  234. (tensor([2., 2.]), tensor([3., 3.])))
  235. """
  236. with torch.enable_grad():
  237. is_inputs_tuple, inputs = _as_tuple(inputs, "inputs", "vjp")
  238. inputs = _grad_preprocess(inputs, create_graph=create_graph, need_graph=True)
  239. outputs = func(*inputs)
  240. is_outputs_tuple, outputs = _as_tuple(outputs, "outputs of the user-provided function", "vjp")
  241. _check_requires_grad(outputs, "outputs", strict=strict)
  242. if v is not None:
  243. _, v = _as_tuple(v, "v", "vjp")
  244. v = _grad_preprocess(v, create_graph=create_graph, need_graph=False)
  245. _validate_v(v, outputs, is_outputs_tuple)
  246. else:
  247. if len(outputs) != 1 or outputs[0].nelement() != 1:
  248. raise RuntimeError("The vector v can only be None if the "
  249. "user-provided function returns "
  250. "a single Tensor with a single element.")
  251. enable_grad = True if create_graph else torch.is_grad_enabled()
  252. with torch.set_grad_enabled(enable_grad):
  253. grad_res = _autograd_grad(outputs, inputs, v, create_graph=create_graph)
  254. vjp = _fill_in_zeros(grad_res, inputs, strict, create_graph, "back")
  255. # Cleanup objects and return them to the user
  256. outputs = _grad_postprocess(outputs, create_graph)
  257. vjp = _grad_postprocess(vjp, create_graph)
  258. return _tuple_postprocess(outputs, is_outputs_tuple), _tuple_postprocess(vjp, is_inputs_tuple)
  259. def jvp(func, inputs, v=None, create_graph=False, strict=False):
  260. r"""Function that computes the dot product between the Jacobian of
  261. the given function at the point given by the inputs and a vector ``v``.
  262. Args:
  263. func (function): a Python function that takes Tensor inputs and returns
  264. a tuple of Tensors or a Tensor.
  265. inputs (tuple of Tensors or Tensor): inputs to the function ``func``.
  266. v (tuple of Tensors or Tensor): The vector for which the Jacobian
  267. vector product is computed. Must be the same size as the input of
  268. ``func``. This argument is optional when the input to ``func``
  269. contains a single element and (if it is not provided) will be set
  270. as a Tensor containing a single ``1``.
  271. create_graph (bool, optional): If ``True``, both the output and result
  272. will be computed in a differentiable way. Note that when ``strict``
  273. is ``False``, the result can not require gradients or be
  274. disconnected from the inputs. Defaults to ``False``.
  275. strict (bool, optional): If ``True``, an error will be raised when we
  276. detect that there exists an input such that all the outputs are
  277. independent of it. If ``False``, we return a Tensor of zeros as the
  278. jvp for said inputs, which is the expected mathematical value.
  279. Defaults to ``False``.
  280. Returns:
  281. output (tuple): tuple with:
  282. func_output (tuple of Tensors or Tensor): output of ``func(inputs)``
  283. jvp (tuple of Tensors or Tensor): result of the dot product with
  284. the same shape as the output.
  285. Note:
  286. ``autograd.functional.jvp`` computes the jvp by using the backward of
  287. the backward (sometimes called the double backwards trick). This is not
  288. the most performant way of computing the jvp. Please consider using
  289. :func:`torch.func.jvp` or the
  290. :ref:`low-level forward-mode AD API <forward-mode-ad>` instead.
  291. Example:
  292. >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)
  293. >>> def exp_reducer(x):
  294. ... return x.exp().sum(dim=1)
  295. >>> inputs = torch.rand(4, 4)
  296. >>> v = torch.ones(4, 4)
  297. >>> # xdoctest: +IGNORE_WANT("non-deterministic")
  298. >>> jvp(exp_reducer, inputs, v)
  299. (tensor([6.3090, 4.6742, 7.9114, 8.2106]),
  300. tensor([6.3090, 4.6742, 7.9114, 8.2106]))
  301. >>> jvp(exp_reducer, inputs, v, create_graph=True)
  302. (tensor([6.3090, 4.6742, 7.9114, 8.2106], grad_fn=<SumBackward1>),
  303. tensor([6.3090, 4.6742, 7.9114, 8.2106], grad_fn=<SqueezeBackward1>))
  304. >>> def adder(x, y):
  305. ... return 2 * x + 3 * y
  306. >>> inputs = (torch.rand(2), torch.rand(2))
  307. >>> v = (torch.ones(2), torch.ones(2))
  308. >>> jvp(adder, inputs, v)
  309. (tensor([2.2399, 2.5005]),
  310. tensor([5., 5.]))
  311. """
  312. with torch.enable_grad():
  313. is_inputs_tuple, inputs = _as_tuple(inputs, "inputs", "jvp")
  314. inputs = _grad_preprocess(inputs, create_graph=create_graph, need_graph=True)
  315. if v is not None:
  316. _, v = _as_tuple(v, "v", "jvp")
  317. v = _grad_preprocess(v, create_graph=create_graph, need_graph=False)
  318. _validate_v(v, inputs, is_inputs_tuple)
  319. else:
  320. if len(inputs) != 1 or inputs[0].nelement() != 1:
  321. raise RuntimeError("The vector v can only be None if the input to "
  322. "the user-provided function is a single Tensor "
  323. "with a single element.")
  324. outputs = func(*inputs)
  325. is_outputs_tuple, outputs = _as_tuple(outputs, "outputs of the user-provided function", "jvp")
  326. _check_requires_grad(outputs, "outputs", strict=strict)
  327. # The backward is linear so the value of grad_outputs is not important as
  328. # it won't appear in the double backward graph. We only need to ensure that
  329. # it does not contain inf or nan.
  330. grad_outputs = tuple(torch.zeros_like(out, requires_grad=True) for out in outputs)
  331. grad_inputs = _autograd_grad(outputs, inputs, grad_outputs, create_graph=True)
  332. _check_requires_grad(grad_inputs, "grad_inputs", strict=strict)
  333. if create_graph:
  334. with torch.enable_grad():
  335. grad_res = _autograd_grad(grad_inputs, grad_outputs, v, create_graph=create_graph)
  336. jvp = _fill_in_zeros(grad_res, outputs, strict, create_graph, "back_trick")
  337. else:
  338. grad_res = _autograd_grad(grad_inputs, grad_outputs, v, create_graph=create_graph)
  339. jvp = _fill_in_zeros(grad_res, outputs, strict, create_graph, "back_trick")
  340. # Cleanup objects and return them to the user
  341. outputs = _grad_postprocess(outputs, create_graph)
  342. jvp = _grad_postprocess(jvp, create_graph)
  343. return _tuple_postprocess(outputs, is_outputs_tuple), _tuple_postprocess(jvp, is_outputs_tuple)
  344. def _construct_standard_basis_for(tensors: Tuple[torch.Tensor, ...], tensor_numels: Tuple[int, ...]) -> Tuple[torch.Tensor, ...]:
  345. # This function:
  346. # - constructs a N=sum(tensor_numels) standard basis. i.e. an NxN identity matrix.
  347. # - Splits the identity matrix into chunks with each chunk size determined by `tensor_numels`.
  348. # - Each chunk corresponds to one tensor. The chunk has the same dtype and
  349. # device as the tensor
  350. #
  351. # For example, with tensor_numels = [1, 2, 1], this function returns:
  352. # ( tensor([[1], tensor([[0, 0], tensor([[0],
  353. # [0], [1, 0], [0],
  354. # [0], [0, 1], [0],
  355. # [0]]) , [0, 0]]) , [1]]) )
  356. #
  357. # Precondition: tensor_numels == tuple(tensor.numel() for tensor in tensors)
  358. # Precondition: tensors always has at least one element.
  359. #
  360. # See NOTE: [Computing jacobian with vmap and grad for multiple tensors]
  361. # for context behind this function. All the pre-conditions are guarded for
  362. # in torch.autograd.functional.jacobian.
  363. assert len(tensors) == len(tensor_numels)
  364. assert len(tensors) > 0
  365. total_numel = sum(tensor_numels)
  366. chunks = tuple(tensor.new_zeros(total_numel, tensor_numel)
  367. for tensor, tensor_numel in zip(tensors, tensor_numels))
  368. diag_start_idx = 0
  369. for chunk, numel in zip(chunks, tensor_numels):
  370. chunk.diagonal(diag_start_idx).fill_(1)
  371. diag_start_idx -= numel
  372. return chunks
  373. def _jacfwd(func, inputs, strict=False, vectorize=False):
  374. if strict:
  375. raise RuntimeError('torch.autograd.functional.jacobian: `strict=True` '
  376. 'and `strategy="forward-mode"` are not supported together (yet). '
  377. 'Please either set `strict=False` or '
  378. '`strategy="reverse-mode"`.')
  379. is_inputs_tuple, inputs = _as_tuple(inputs, "inputs", "jacobian")
  380. output_info = []
  381. if vectorize:
  382. # See NOTE: [Computing jacobian with vmap and grad for multiple outputs]
  383. input_numels = tuple(input.numel() for input in inputs)
  384. # Step 1: Prepare tangents
  385. tangents = _construct_standard_basis_for(inputs, input_numels)
  386. # Step 2: Compute vmap over computation with dual tensors
  387. def jvp(tangents):
  388. with fwAD.dual_level():
  389. dual_inputs = tuple(
  390. fwAD.make_dual(input, tangent.view_as(input)) for input, tangent in zip(inputs, tangents))
  391. _is_outputs_tuple, dual_outputs = _as_tuple(func(*dual_inputs), "outputs")
  392. output_info.append(_is_outputs_tuple)
  393. jv = []
  394. primal_outs = []
  395. for dual_out in dual_outputs:
  396. primal, tangent = fwAD.unpack_dual(dual_out)
  397. primal_outs.append(primal)
  398. if tangent is not None:
  399. jv.append(tangent)
  400. else:
  401. jv.append(torch.zeros_like(primal))
  402. output_info.append(primal_outs)
  403. return tuple(jv)
  404. outputs_before_split = _vmap(jvp)(tangents)
  405. is_outputs_tuple, outputs = output_info
  406. # Step 3: for each of the output tangents, split along dim 0
  407. jacobian_input_output = []
  408. for jac, output_i in zip(outputs_before_split, outputs):
  409. jacobian_output_i_output = []
  410. for jac, input_j in zip(jac.split(input_numels, dim=0), inputs):
  411. # We need to transpose the Jacobian because in forward AD, the
  412. # batch dimension represents that of the inputs
  413. jacobian_input_i_output_j = jac.permute(*range(1, jac.ndim), 0) \
  414. .reshape(tuple([*output_i.shape, *input_j.shape])) # noqa: C409
  415. jacobian_output_i_output.append(jacobian_input_i_output_j)
  416. jacobian_input_output.append(jacobian_output_i_output)
  417. # Omit [Step 4] because everything is already transposed w/ forward AD
  418. return _tuple_postprocess(jacobian_input_output, (is_outputs_tuple, is_inputs_tuple))
  419. else:
  420. raise NotImplementedError("Computing Jacobian using forward-AD or forward-over-reverse Hessian is"
  421. "only implemented for `vectorize=True`.")
  422. def jacobian(func, inputs, create_graph=False, strict=False, vectorize=False, strategy="reverse-mode"):
  423. r"""Function that computes the Jacobian of a given function.
  424. Args:
  425. func (function): a Python function that takes Tensor inputs and returns
  426. a tuple of Tensors or a Tensor.
  427. inputs (tuple of Tensors or Tensor): inputs to the function ``func``.
  428. create_graph (bool, optional): If ``True``, the Jacobian will be
  429. computed in a differentiable manner. Note that when ``strict`` is
  430. ``False``, the result can not require gradients or be disconnected
  431. from the inputs. Defaults to ``False``.
  432. strict (bool, optional): If ``True``, an error will be raised when we
  433. detect that there exists an input such that all the outputs are
  434. independent of it. If ``False``, we return a Tensor of zeros as the
  435. jacobian for said inputs, which is the expected mathematical value.
  436. Defaults to ``False``.
  437. vectorize (bool, optional): This feature is experimental.
  438. Please consider using :func:`torch.func.jacrev` or
  439. :func:`torch.func.jacfwd` instead if you are looking for something
  440. less experimental and more performant.
  441. When computing the jacobian, usually we invoke
  442. ``autograd.grad`` once per row of the jacobian. If this flag is
  443. ``True``, we perform only a single ``autograd.grad`` call with
  444. ``batched_grad=True`` which uses the vmap prototype feature.
  445. Though this should lead to performance improvements in many cases,
  446. because this feature is still experimental, there may be performance
  447. cliffs. See :func:`torch.autograd.grad`'s ``batched_grad`` parameter for
  448. more information.
  449. strategy (str, optional): Set to ``"forward-mode"`` or ``"reverse-mode"`` to
  450. determine whether the Jacobian will be computed with forward or reverse
  451. mode AD. Currently, ``"forward-mode"`` requires ``vectorized=True``.
  452. Defaults to ``"reverse-mode"``. If ``func`` has more outputs than
  453. inputs, ``"forward-mode"`` tends to be more performant. Otherwise,
  454. prefer to use ``"reverse-mode"``.
  455. Returns:
  456. Jacobian (Tensor or nested tuple of Tensors): if there is a single
  457. input and output, this will be a single Tensor containing the
  458. Jacobian for the linearized inputs and output. If one of the two is
  459. a tuple, then the Jacobian will be a tuple of Tensors. If both of
  460. them are tuples, then the Jacobian will be a tuple of tuple of
  461. Tensors where ``Jacobian[i][j]`` will contain the Jacobian of the
  462. ``i``\th output and ``j``\th input and will have as size the
  463. concatenation of the sizes of the corresponding output and the
  464. corresponding input and will have same dtype and device as the
  465. corresponding input. If strategy is ``forward-mode``, the dtype will be
  466. that of the output; otherwise, the input.
  467. Example:
  468. >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)
  469. >>> def exp_reducer(x):
  470. ... return x.exp().sum(dim=1)
  471. >>> inputs = torch.rand(2, 2)
  472. >>> # xdoctest: +IGNORE_WANT("non-deterministic")
  473. >>> jacobian(exp_reducer, inputs)
  474. tensor([[[1.4917, 2.4352],
  475. [0.0000, 0.0000]],
  476. [[0.0000, 0.0000],
  477. [2.4369, 2.3799]]])
  478. >>> jacobian(exp_reducer, inputs, create_graph=True)
  479. tensor([[[1.4917, 2.4352],
  480. [0.0000, 0.0000]],
  481. [[0.0000, 0.0000],
  482. [2.4369, 2.3799]]], grad_fn=<ViewBackward>)
  483. >>> def exp_adder(x, y):
  484. ... return 2 * x.exp() + 3 * y
  485. >>> inputs = (torch.rand(2), torch.rand(2))
  486. >>> jacobian(exp_adder, inputs)
  487. (tensor([[2.8052, 0.0000],
  488. [0.0000, 3.3963]]),
  489. tensor([[3., 0.],
  490. [0., 3.]]))
  491. """
  492. assert strategy in ("forward-mode", "reverse-mode"), (
  493. 'Expected strategy to be either "forward-mode" or "reverse-mode". Hint: If your '
  494. 'function has more outputs than inputs, "forward-mode" tends to be more performant. '
  495. 'Otherwise, prefer to use "reverse-mode".')
  496. if strategy == "forward-mode":
  497. if create_graph:
  498. raise NotImplementedError('torch.autograd.functional.jacobian: `create_graph=True` '
  499. 'and `strategy="forward-mode"` are not supported together (yet). '
  500. 'Please either set `create_graph=False` or '
  501. '`strategy="reverse-mode"`.')
  502. return _jacfwd(func, inputs, strict, vectorize)
  503. with torch.enable_grad():
  504. is_inputs_tuple, inputs = _as_tuple(inputs, "inputs", "jacobian")
  505. inputs = _grad_preprocess(inputs, create_graph=create_graph, need_graph=True)
  506. outputs = func(*inputs)
  507. is_outputs_tuple, outputs = _as_tuple(outputs,
  508. "outputs of the user-provided function",
  509. "jacobian")
  510. _check_requires_grad(outputs, "outputs", strict=strict)
  511. if vectorize:
  512. if strict:
  513. raise RuntimeError('torch.autograd.functional.jacobian: `strict=True` '
  514. 'and `vectorized=True` are not supported together. '
  515. 'Please either set `strict=False` or '
  516. '`vectorize=False`.')
  517. # NOTE: [Computing jacobian with vmap and grad for multiple outputs]
  518. #
  519. # Let's consider f(x) = (x**2, x.sum()) and let x = torch.randn(3).
  520. # It turns out we can compute the jacobian of this function with a single
  521. # call to autograd.grad by using vmap over the correct grad_outputs.
  522. #
  523. # Firstly, one way to compute the jacobian is to stack x**2 and x.sum()
  524. # into a 4D vector. E.g., use g(x) = torch.stack([x**2, x.sum()])
  525. #
  526. # To get the first row of the jacobian, we call
  527. # >>> autograd.grad(g(x), x, grad_outputs=torch.tensor([1, 0, 0, 0]))
  528. # To get the 2nd row of the jacobian, we call
  529. # >>> autograd.grad(g(x), x, grad_outputs=torch.tensor([0, 1, 0, 0]))
  530. # and so on.
  531. #
  532. # Using vmap, we can vectorize all 4 of these computations into one by
  533. # passing the standard basis for R^4 as the grad_output.
  534. # vmap(partial(autograd.grad, g(x), x))(torch.eye(4)).
  535. #
  536. # Now, how do we compute the jacobian *without stacking the output*?
  537. # We can just split the standard basis across the outputs. So to
  538. # compute the jacobian of f(x), we'd use
  539. # >>> autograd.grad(f(x), x, grad_outputs=_construct_standard_basis_for(...))
  540. # The grad_outputs looks like the following:
  541. # ( torch.tensor([[1, 0, 0],
  542. # [0, 1, 0],
  543. # [0, 0, 1],
  544. # [0, 0, 0]]),
  545. # torch.tensor([[0],
  546. # [0],
  547. # [0],
  548. # [1]]) )
  549. #
  550. # But we're not done yet!
  551. # >>> vmap(partial(autograd.grad(f(x), x, grad_outputs=...)))
  552. # returns a Tensor of shape [4, 3]. We have to remember to split the
  553. # jacobian of shape [4, 3] into two:
  554. # - one of shape [3, 3] for the first output
  555. # - one of shape [ 3] for the second output
  556. # Step 1: Construct grad_outputs by splitting the standard basis
  557. output_numels = tuple(output.numel() for output in outputs)
  558. grad_outputs = _construct_standard_basis_for(outputs, output_numels)
  559. flat_outputs = tuple(output.reshape(-1) for output in outputs)
  560. # Step 2: Call vmap + autograd.grad
  561. def vjp(grad_output):
  562. vj = list(_autograd_grad(flat_outputs, inputs, grad_output, create_graph=create_graph, is_grads_batched=True))
  563. for el_idx, vj_el in enumerate(vj):
  564. if vj_el is not None:
  565. continue
  566. vj[el_idx] = torch.zeros_like(inputs[el_idx]).expand((sum(output_numels),) + inputs[el_idx].shape)
  567. return tuple(vj)
  568. jacobians_of_flat_output = vjp(grad_outputs)
  569. # Step 3: The returned jacobian is one big tensor per input. In this step,
  570. # we split each Tensor by output.
  571. jacobian_input_output = []
  572. for jac, input_i in zip(jacobians_of_flat_output, inputs):
  573. jacobian_input_i_output = []
  574. for jac, output_j in zip(jac.split(output_numels, dim=0), outputs):
  575. jacobian_input_i_output_j = jac.view(output_j.shape + input_i.shape)
  576. jacobian_input_i_output.append(jacobian_input_i_output_j)
  577. jacobian_input_output.append(jacobian_input_i_output)
  578. # Step 4: Right now, `jacobian` is a List[List[Tensor]].
  579. # The outer List corresponds to the number of inputs,
  580. # the inner List corresponds to the number of outputs.
  581. # We need to exchange the order of these and convert to tuples
  582. # before returning.
  583. jacobian_output_input = tuple(zip(*jacobian_input_output))
  584. jacobian_output_input = _grad_postprocess(jacobian_output_input, create_graph)
  585. return _tuple_postprocess(jacobian_output_input, (is_outputs_tuple, is_inputs_tuple))
  586. jacobian: Tuple[torch.Tensor, ...] = tuple()
  587. for i, out in enumerate(outputs):
  588. # mypy complains that expression and variable have different types due to the empty list
  589. jac_i: Tuple[List[torch.Tensor]] = tuple([] for _ in range(len(inputs))) # type: ignore[assignment]
  590. for j in range(out.nelement()):
  591. vj = _autograd_grad((out.reshape(-1)[j],), inputs,
  592. retain_graph=True, create_graph=create_graph)
  593. for el_idx, (jac_i_el, vj_el, inp_el) in enumerate(zip(jac_i, vj, inputs)):
  594. if vj_el is not None:
  595. if strict and create_graph and not vj_el.requires_grad:
  596. msg = ("The jacobian of the user-provided function is "
  597. "independent of input {}. This is not allowed in "
  598. "strict mode when create_graph=True.".format(i))
  599. raise RuntimeError(msg)
  600. jac_i_el.append(vj_el)
  601. else:
  602. if strict:
  603. msg = ("Output {} of the user-provided function is "
  604. "independent of input {}. This is not allowed in "
  605. "strict mode.".format(i, el_idx))
  606. raise RuntimeError(msg)
  607. jac_i_el.append(torch.zeros_like(inp_el))
  608. jacobian += (tuple(torch.stack(jac_i_el, dim=0).view(out.size() # type: ignore[operator]
  609. + inputs[el_idx].size()) for (el_idx, jac_i_el) in enumerate(jac_i)), )
  610. jacobian = _grad_postprocess(jacobian, create_graph)
  611. return _tuple_postprocess(jacobian, (is_outputs_tuple, is_inputs_tuple))
  612. def hessian(func, inputs, create_graph=False, strict=False, vectorize=False, outer_jacobian_strategy="reverse-mode"):
  613. r"""Function that computes the Hessian of a given scalar function.
  614. Args:
  615. func (function): a Python function that takes Tensor inputs and returns
  616. a Tensor with a single element.
  617. inputs (tuple of Tensors or Tensor): inputs to the function ``func``.
  618. create_graph (bool, optional): If ``True``, the Hessian will be computed in
  619. a differentiable manner. Note that when ``strict`` is ``False``, the result can not
  620. require gradients or be disconnected from the inputs.
  621. Defaults to ``False``.
  622. strict (bool, optional): If ``True``, an error will be raised when we detect that there exists an input
  623. such that all the outputs are independent of it. If ``False``, we return a Tensor of zeros as the
  624. hessian for said inputs, which is the expected mathematical value.
  625. Defaults to ``False``.
  626. vectorize (bool, optional): This feature is experimental.
  627. Please consider using :func:`torch.func.hessian`
  628. instead if you are looking for something less experimental and more performant.
  629. When computing the hessian, usually we invoke
  630. ``autograd.grad`` once per row of the hessian. If this flag is
  631. ``True``, we use the vmap prototype feature as the backend to
  632. vectorize calls to ``autograd.grad`` so we only invoke it once
  633. instead of once per row. This should lead to performance
  634. improvements in many use cases, however, due to this feature
  635. being incomplete, there may be performance cliffs. Please
  636. use `torch._C._debug_only_display_vmap_fallback_warnings(True)`
  637. to show any performance warnings and file us issues if
  638. warnings exist for your use case. Defaults to ``False``.
  639. outer_jacobian_strategy (str, optional): The Hessian is computed by
  640. computing the Jacobian of a Jacobian. The inner Jacobian is always
  641. computed in reverse-mode AD. Setting strategy to ``"forward-mode"``
  642. or ``"reverse-mode"`` determines whether the outer Jacobian will be
  643. computed with forward or reverse mode AD. Currently, computing the outer
  644. Jacobian in ``"forward-mode"`` requires ``vectorized=True``. Defaults
  645. to ``"reverse-mode"``.
  646. Returns:
  647. Hessian (Tensor or a tuple of tuple of Tensors): if there is a single input,
  648. this will be a single Tensor containing the Hessian for the input.
  649. If it is a tuple, then the Hessian will be a tuple of tuples where
  650. ``Hessian[i][j]`` will contain the Hessian of the ``i``\th input
  651. and ``j``\th input with size the sum of the size of the ``i``\th input plus
  652. the size of the ``j``\th input. ``Hessian[i][j]`` will have the same
  653. dtype and device as the corresponding ``i``\th input.
  654. Example:
  655. >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)
  656. >>> def pow_reducer(x):
  657. ... return x.pow(3).sum()
  658. >>> inputs = torch.rand(2, 2)
  659. >>> # xdoctest: +IGNORE_WANT("non-deterministic")
  660. >>> hessian(pow_reducer, inputs)
  661. tensor([[[[5.2265, 0.0000],
  662. [0.0000, 0.0000]],
  663. [[0.0000, 4.8221],
  664. [0.0000, 0.0000]]],
  665. [[[0.0000, 0.0000],
  666. [1.9456, 0.0000]],
  667. [[0.0000, 0.0000],
  668. [0.0000, 3.2550]]]])
  669. >>> hessian(pow_reducer, inputs, create_graph=True)
  670. tensor([[[[5.2265, 0.0000],
  671. [0.0000, 0.0000]],
  672. [[0.0000, 4.8221],
  673. [0.0000, 0.0000]]],
  674. [[[0.0000, 0.0000],
  675. [1.9456, 0.0000]],
  676. [[0.0000, 0.0000],
  677. [0.0000, 3.2550]]]], grad_fn=<ViewBackward>)
  678. >>> def pow_adder_reducer(x, y):
  679. ... return (2 * x.pow(2) + 3 * y.pow(2)).sum()
  680. >>> inputs = (torch.rand(2), torch.rand(2))
  681. >>> hessian(pow_adder_reducer, inputs)
  682. ((tensor([[4., 0.],
  683. [0., 4.]]),
  684. tensor([[0., 0.],
  685. [0., 0.]])),
  686. (tensor([[0., 0.],
  687. [0., 0.]]),
  688. tensor([[6., 0.],
  689. [0., 6.]])))
  690. """
  691. is_inputs_tuple, inputs = _as_tuple(inputs, "inputs", "hessian")
  692. assert outer_jacobian_strategy in ("forward-mode", "reverse-mode"), (
  693. 'Expected strategy to be either "forward-mode" or "reverse-mode".')
  694. def ensure_single_output_function(*inp):
  695. out = func(*inp)
  696. is_out_tuple, t_out = _as_tuple(out, "outputs of the user-provided function", "hessian")
  697. _check_requires_grad(t_out, "outputs", strict=strict)
  698. if is_out_tuple or not isinstance(out, torch.Tensor):
  699. raise RuntimeError("The function given to hessian should return a single Tensor")
  700. if out.nelement() != 1:
  701. raise RuntimeError("The Tensor returned by the function given to hessian should contain a single element")
  702. return out.squeeze()
  703. def jac_func(*inp):
  704. if outer_jacobian_strategy == "forward-mode":
  705. # _grad_preprocess requires create_graph=True and input to require_grad
  706. # or else the input will be detached
  707. inp = tuple(t.requires_grad_(True) for t in inp)
  708. jac = jacobian(ensure_single_output_function, inp, create_graph=True)
  709. _check_requires_grad(jac, "jacobian", strict=strict)
  710. return jac
  711. res = jacobian(jac_func, inputs, create_graph=create_graph, strict=strict, vectorize=vectorize,
  712. strategy=outer_jacobian_strategy)
  713. return _tuple_postprocess(res, (is_inputs_tuple, is_inputs_tuple))
  714. def vhp(func, inputs, v=None, create_graph=False, strict=False):
  715. r"""Function that computes the dot product between a vector ``v`` and the
  716. Hessian of a given scalar function at the point given by the inputs.
  717. Args:
  718. func (function): a Python function that takes Tensor inputs and returns
  719. a Tensor with a single element.
  720. inputs (tuple of Tensors or Tensor): inputs to the function ``func``.
  721. v (tuple of Tensors or Tensor): The vector for which the vector Hessian
  722. product is computed. Must be the same size as the input of
  723. ``func``. This argument is optional when ``func``'s input contains
  724. a single element and (if it is not provided) will be set as a
  725. Tensor containing a single ``1``.
  726. create_graph (bool, optional): If ``True``, both the output and result
  727. will be computed in a differentiable way. Note that when ``strict``
  728. is ``False``, the result can not require gradients or be
  729. disconnected from the inputs.
  730. Defaults to ``False``.
  731. strict (bool, optional): If ``True``, an error will be raised when we
  732. detect that there exists an input such that all the outputs are
  733. independent of it. If ``False``, we return a Tensor of zeros as the
  734. vhp for said inputs, which is the expected mathematical value.
  735. Defaults to ``False``.
  736. Returns:
  737. output (tuple): tuple with:
  738. func_output (tuple of Tensors or Tensor): output of ``func(inputs)``
  739. vhp (tuple of Tensors or Tensor): result of the dot product with the
  740. same shape as the inputs.
  741. Example:
  742. >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)
  743. >>> def pow_reducer(x):
  744. ... return x.pow(3).sum()
  745. >>> inputs = torch.rand(2, 2)
  746. >>> v = torch.ones(2, 2)
  747. >>> # xdoctest: +IGNORE_WANT("non-deterministic")
  748. >>> vhp(pow_reducer, inputs, v)
  749. (tensor(0.5591),
  750. tensor([[1.0689, 1.2431],
  751. [3.0989, 4.4456]]))
  752. >>> vhp(pow_reducer, inputs, v, create_graph=True)
  753. (tensor(0.5591, grad_fn=<SumBackward0>),
  754. tensor([[1.0689, 1.2431],
  755. [3.0989, 4.4456]], grad_fn=<MulBackward0>))
  756. >>> def pow_adder_reducer(x, y):
  757. ... return (2 * x.pow(2) + 3 * y.pow(2)).sum()
  758. >>> inputs = (torch.rand(2), torch.rand(2))
  759. >>> v = (torch.zeros(2), torch.ones(2))
  760. >>> vhp(pow_adder_reducer, inputs, v)
  761. (tensor(4.8053),
  762. (tensor([0., 0.]),
  763. tensor([6., 6.])))
  764. """
  765. with torch.enable_grad():
  766. is_inputs_tuple, inputs = _as_tuple(inputs, "inputs", "vhp")
  767. inputs = _grad_preprocess(inputs, create_graph=create_graph, need_graph=True)
  768. if v is not None:
  769. _, v = _as_tuple(v, "v", "vhp")
  770. v = _grad_preprocess(v, create_graph=create_graph, need_graph=False)
  771. _validate_v(v, inputs, is_inputs_tuple)
  772. else:
  773. if len(inputs) != 1 or inputs[0].nelement() != 1:
  774. raise RuntimeError("The vector v can only be None if the input to the user-provided function "
  775. "is a single Tensor with a single element.")
  776. outputs = func(*inputs)
  777. is_outputs_tuple, outputs = _as_tuple(outputs, "outputs of the user-provided function", "vhp")
  778. _check_requires_grad(outputs, "outputs", strict=strict)
  779. if is_outputs_tuple or not isinstance(outputs[0], torch.Tensor):
  780. raise RuntimeError("The function given to vhp should return a single Tensor")
  781. if outputs[0].nelement() != 1:
  782. raise RuntimeError("The Tensor returned by the function given to vhp should contain a single element")
  783. jac = _autograd_grad(outputs, inputs, create_graph=True)
  784. _check_requires_grad(jac, "jacobian", strict=strict)
  785. enable_grad = True if create_graph else torch.is_grad_enabled()
  786. with torch.set_grad_enabled(enable_grad):
  787. grad_res = _autograd_grad(jac, inputs, v, create_graph=create_graph)
  788. vhp = _fill_in_zeros(grad_res, inputs, strict, create_graph, "double_back")
  789. outputs = _grad_postprocess(outputs, create_graph)
  790. vhp = _grad_postprocess(vhp, create_graph)
  791. return _tuple_postprocess(outputs, is_outputs_tuple), _tuple_postprocess(vhp, is_inputs_tuple)
  792. def hvp(func, inputs, v=None, create_graph=False, strict=False):
  793. r"""Function that computes the dot product between the Hessian of a given scalar
  794. function and a vector ``v`` at the point given by the inputs.
  795. Args:
  796. func (function): a Python function that takes Tensor inputs and returns
  797. a Tensor with a single element.
  798. inputs (tuple of Tensors or Tensor): inputs to the function ``func``.
  799. v (tuple of Tensors or Tensor): The vector for which the Hessian vector
  800. product is computed. Must be the same size as the input of
  801. ``func``. This argument is optional when ``func``'s input contains
  802. a single element and (if it is not provided) will be set as a
  803. Tensor containing a single ``1``.
  804. create_graph (bool, optional): If ``True``, both the output and result will be
  805. computed in a differentiable way. Note that when ``strict`` is
  806. ``False``, the result can not require gradients or be disconnected
  807. from the inputs. Defaults to ``False``.
  808. strict (bool, optional): If ``True``, an error will be raised when we
  809. detect that there exists an input such that all the outputs are
  810. independent of it. If ``False``, we return a Tensor of zeros as the
  811. hvp for said inputs, which is the expected mathematical value.
  812. Defaults to ``False``.
  813. Returns:
  814. output (tuple): tuple with:
  815. func_output (tuple of Tensors or Tensor): output of ``func(inputs)``
  816. hvp (tuple of Tensors or Tensor): result of the dot product with
  817. the same shape as the inputs.
  818. Example:
  819. >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)
  820. >>> def pow_reducer(x):
  821. ... return x.pow(3).sum()
  822. >>> inputs = torch.rand(2, 2)
  823. >>> v = torch.ones(2, 2)
  824. >>> # xdoctest: +IGNORE_WANT("non-deterministic")
  825. >>> hvp(pow_reducer, inputs, v)
  826. (tensor(0.1448),
  827. tensor([[2.0239, 1.6456],
  828. [2.4988, 1.4310]]))
  829. >>> hvp(pow_reducer, inputs, v, create_graph=True)
  830. (tensor(0.1448, grad_fn=<SumBackward0>),
  831. tensor([[2.0239, 1.6456],
  832. [2.4988, 1.4310]], grad_fn=<MulBackward0>))
  833. >>> def pow_adder_reducer(x, y):
  834. ... return (2 * x.pow(2) + 3 * y.pow(2)).sum()
  835. >>> inputs = (torch.rand(2), torch.rand(2))
  836. >>> v = (torch.zeros(2), torch.ones(2))
  837. >>> hvp(pow_adder_reducer, inputs, v)
  838. (tensor(2.3030),
  839. (tensor([0., 0.]),
  840. tensor([6., 6.])))
  841. Note:
  842. This function is significantly slower than `vhp` due to backward mode AD constraints.
  843. If your functions is twice continuously differentiable, then hvp = vhp.t(). So if you
  844. know that your function satisfies this condition, you should use vhp instead that is
  845. much faster with the current implementation.
  846. """
  847. with torch.enable_grad():
  848. is_inputs_tuple, inputs = _as_tuple(inputs, "inputs", "hvp")
  849. inputs = _grad_preprocess(inputs, create_graph=create_graph, need_graph=True)
  850. if v is not None:
  851. _, v = _as_tuple(v, "v", "hvp")
  852. v = _grad_preprocess(v, create_graph=create_graph, need_graph=False)
  853. _validate_v(v, inputs, is_inputs_tuple)
  854. else:
  855. if len(inputs) != 1 or inputs[0].nelement() != 1:
  856. raise RuntimeError("The vector v can only be None if the input to the user-provided function "
  857. "is a single Tensor with a single element.")
  858. outputs = func(*inputs)
  859. is_outputs_tuple, outputs = _as_tuple(outputs, "outputs of the user-provided function", "hvp")
  860. _check_requires_grad(outputs, "outputs", strict=strict)
  861. if is_outputs_tuple or not isinstance(outputs[0], torch.Tensor):
  862. raise RuntimeError("The function given to hvp should return a single Tensor")
  863. if outputs[0].nelement() != 1:
  864. raise RuntimeError("The Tensor returned by the function given to hvp should contain a single element")
  865. jac = _autograd_grad(outputs, inputs, create_graph=True)
  866. _check_requires_grad(jac, "jacobian", strict=strict)
  867. grad_jac = tuple(torch.zeros_like(inp, requires_grad=True) for inp in inputs)
  868. double_back = _autograd_grad(jac, inputs, grad_jac, create_graph=True)
  869. _check_requires_grad(jac, "hessian", strict=strict)
  870. enable_grad = True if create_graph else torch.is_grad_enabled()
  871. with torch.set_grad_enabled(enable_grad):
  872. grad_res = _autograd_grad(double_back, grad_jac, v, create_graph=create_graph)
  873. hvp = _fill_in_zeros(grad_res, inputs, strict, create_graph, "double_back_trick")
  874. outputs = _grad_postprocess(outputs, create_graph)
  875. hvp = _grad_postprocess(hvp, create_graph)
  876. return _tuple_postprocess(outputs, is_outputs_tuple), _tuple_postprocess(hvp, is_inputs_tuple)