eager_transforms.py 72 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729
  1. # Copyright (c) Facebook, Inc. and its affiliates.
  2. # All rights reserved.
  3. #
  4. # This source code is licensed under the BSD-style license found in the
  5. # LICENSE file in the root directory of this source tree.
  6. from typing import Callable, Union, Tuple, List, Any, Optional
  7. import torch
  8. from functools import partial, wraps
  9. import contextlib
  10. from torch.utils._pytree import tree_flatten, tree_unflatten, tree_map, tree_map_only
  11. from torch.fx.experimental import const_fold
  12. from torch.fx.experimental.proxy_tensor import make_fx
  13. from .pytree_hacks import tree_map_, treespec_pprint
  14. import torch.autograd.forward_ad as fwAD
  15. from .vmap import vmap, doesnt_support_saved_tensors_hooks, get_chunk_sizes
  16. from torch._C._functorch import (
  17. _wrap_for_grad,
  18. _unwrap_for_grad,
  19. _grad_increment_nesting,
  20. _grad_decrement_nesting,
  21. _jvp_increment_nesting,
  22. _jvp_decrement_nesting,
  23. _wrap_functional_tensor,
  24. _unwrap_functional_tensor,
  25. _func_decrement_nesting,
  26. _func_increment_nesting,
  27. _assert_wrapped_functional,
  28. _propagate_functional_input_mutation,
  29. set_inplace_requires_grad_allowed,
  30. get_inplace_requires_grad_allowed
  31. )
  32. from torch._functorch.utils import exposed_in
  33. argnums_t = Union[int, Tuple[int, ...]]
  34. @contextlib.contextmanager
  35. def enable_inplace_requires_grad(enabled=True):
  36. prev_state = get_inplace_requires_grad_allowed()
  37. set_inplace_requires_grad_allowed(enabled)
  38. try:
  39. yield
  40. finally:
  41. set_inplace_requires_grad_allowed(prev_state)
  42. def _create_differentiable(inps, level=None):
  43. def create_differentiable(x):
  44. if isinstance(x, torch.Tensor):
  45. with enable_inplace_requires_grad():
  46. return x.requires_grad_()
  47. raise ValueError(f'Thing passed to transform API must be Tensor, '
  48. f'got {type(x)}')
  49. return tree_map(create_differentiable, inps)
  50. def _undo_create_differentiable(inps, level=None):
  51. def unwrap_tensors(x):
  52. if isinstance(x, torch.Tensor):
  53. return _unwrap_for_grad(x, level)
  54. # TODO: Remove the following hack for namedtuples
  55. if isinstance(x, tuple):
  56. return tree_map(unwrap_tensors, tuple(x))
  57. raise RuntimeError(f"Expected tensors, got unsupported type {type(x)}")
  58. return tree_map(unwrap_tensors, inps)
  59. def _is_differentiable(maybe_tensor):
  60. if not isinstance(maybe_tensor, torch.Tensor):
  61. return False
  62. return maybe_tensor.requires_grad
  63. def _any_differentiable(tensor_or_tuple_of_tensors):
  64. flat_args, _ = tree_unflatten(tensor_or_tuple_of_tensors)
  65. return any(tuple(map(_is_differentiable, flat_args)))
  66. def _wrap_tensor_for_grad(maybe_tensor, level):
  67. if not isinstance(maybe_tensor, torch.Tensor):
  68. return maybe_tensor
  69. return _wrap_for_grad(maybe_tensor, level)
  70. def _wrap_all_tensors(tensor_pytree, level):
  71. return tree_map(partial(_wrap_tensor_for_grad, level=level), tensor_pytree)
  72. def _as_tuple(val):
  73. if isinstance(val, tuple):
  74. return val
  75. return (val,)
  76. # Version of autograd.grad that handles outputs that don't depend on inputs
  77. def _autograd_grad(outputs, inputs, grad_outputs=None, retain_graph=False, create_graph=True):
  78. if grad_outputs is None:
  79. diff_outputs = tuple(out for out in outputs if out.requires_grad)
  80. else:
  81. result = tuple((out, go) for out, go in zip(outputs, grad_outputs) if out.requires_grad)
  82. if len(result) == 0:
  83. diff_outputs, grad_outputs = (), ()
  84. else:
  85. diff_outputs, grad_outputs = zip(*result)
  86. if len(diff_outputs) == 0:
  87. return tuple(torch.zeros_like(inp) for inp in inputs)
  88. grad_inputs = torch.autograd.grad(diff_outputs, inputs, grad_outputs,
  89. retain_graph=retain_graph,
  90. create_graph=create_graph,
  91. allow_unused=True)
  92. grad_inputs = tuple(torch.zeros_like(inp) if gi is None else gi
  93. for gi, inp in zip(grad_inputs, inputs))
  94. return grad_inputs
  95. # NOTE [grad and vjp interaction with no_grad]
  96. #
  97. # def f(x):
  98. # with torch.no_grad():
  99. # c = x ** 2
  100. # return x - c
  101. #
  102. # The thing to consider is if enable_grad is on/off before grad gets called.
  103. #
  104. # Case 1: enable_grad is on.
  105. # grad(f)(x)
  106. # In this case, `grad` should respect the inner torch.no_grad.
  107. #
  108. # Case 2: enable_grad is off
  109. # with torch.no_grad():
  110. # grad(f)(x)
  111. # In this case, `grad` should respect the inner torch.no_grad, but not the
  112. # outer one. This is because `grad` is a "function transform": its result
  113. # should not depend on the result of a context manager outside of `f`.
  114. #
  115. # This gives us the following desired behavior:
  116. # - (nested) grad transforms must obey torch.no_grad inside them
  117. # - (nested) grad transforms should not obey torch.no_grad outside them
  118. #
  119. # To achieve this behavior, upon entering grad/vjp:
  120. # - we save the current ("previous") is_grad_enabled (*)
  121. # - we unconditionally enable grad.
  122. #
  123. # Inside DynamicLayerBackFallback, when we're temporarily popping `grad` layer
  124. # off the stack:
  125. # - if grad_mode is disabled, then we do nothing. (there is a torch.no_grad
  126. # active, all subsequent grad transforms must obey it).
  127. # - if grad_mode is enabled, and the previous is_grad_enabled (*) is False,
  128. # then we temporarily restore the previous `is_grad_enabled`. This is
  129. # because we're crossing the boundary from a `grad` outside the
  130. # no_grad to a `grad` inside the no_grad.
  131. #
  132. # NB: vjp has some interesting behavior because the vjp's callable can be called
  133. # under a different grad_mode than the forward computation...
  134. #
  135. # NB: forward-mode AD: forward-mode AD doesn't respect torch.no_grad, but
  136. # it respects c10::AutoFwGradMode. We've implemented the same logic for
  137. # our jvp transform (it will have special handling if FwGradMode is disabled).
  138. # How do we increment and decrement the nesting? I don't think we can.
  139. @exposed_in("torch.func")
  140. def vjp(func: Callable, *primals, has_aux: bool = False):
  141. """
  142. Standing for the vector-Jacobian product, returns a tuple containing the
  143. results of ``func`` applied to ``primals`` and a function that, when
  144. given ``cotangents``, computes the reverse-mode Jacobian of ``func`` with
  145. respect to ``primals`` times ``cotangents``.
  146. Args:
  147. func (Callable): A Python function that takes one or more arguments. Must
  148. return one or more Tensors.
  149. primals (Tensors): Positional arguments to ``func`` that must all be
  150. Tensors. The returned function will also be computing the
  151. derivative with respect to these arguments
  152. has_aux (bool): Flag indicating that ``func`` returns a
  153. ``(output, aux)`` tuple where the first element is the output of
  154. the function to be differentiated and the second element is
  155. other auxiliary objects that will not be differentiated.
  156. Default: False.
  157. Returns:
  158. Returns a ``(output, vjp_fn)`` tuple containing the output of ``func``
  159. applied to ``primals`` and a function that computes the vjp of
  160. ``func`` with respect to all ``primals`` using the cotangents passed
  161. to the returned function. If ``has_aux is True``, then instead returns a
  162. ``(output, vjp_fn, aux)`` tuple.
  163. The returned ``vjp_fn`` function will return a tuple of each VJP.
  164. When used in simple cases, :func:`vjp` behaves the same as :func:`grad`
  165. >>> x = torch.randn([5])
  166. >>> f = lambda x: x.sin().sum()
  167. >>> (_, vjpfunc) = torch.func.vjp(f, x)
  168. >>> grad = vjpfunc(torch.tensor(1.))[0]
  169. >>> assert torch.allclose(grad, torch.func.grad(f)(x))
  170. However, :func:`vjp` can support functions with multiple outputs by
  171. passing in the cotangents for each of the outputs
  172. >>> x = torch.randn([5])
  173. >>> f = lambda x: (x.sin(), x.cos())
  174. >>> (_, vjpfunc) = torch.func.vjp(f, x)
  175. >>> vjps = vjpfunc((torch.ones([5]), torch.ones([5])))
  176. >>> assert torch.allclose(vjps[0], x.cos() + -x.sin())
  177. :func:`vjp` can even support outputs being Python structs
  178. >>> x = torch.randn([5])
  179. >>> f = lambda x: {'first': x.sin(), 'second': x.cos()}
  180. >>> (_, vjpfunc) = torch.func.vjp(f, x)
  181. >>> cotangents = {'first': torch.ones([5]), 'second': torch.ones([5])}
  182. >>> vjps = vjpfunc(cotangents)
  183. >>> assert torch.allclose(vjps[0], x.cos() + -x.sin())
  184. The function returned by :func:`vjp` will compute the partials with
  185. respect to each of the ``primals``
  186. >>> x, y = torch.randn([5, 4]), torch.randn([4, 5])
  187. >>> (_, vjpfunc) = torch.func.vjp(torch.matmul, x, y)
  188. >>> cotangents = torch.randn([5, 5])
  189. >>> vjps = vjpfunc(cotangents)
  190. >>> assert len(vjps) == 2
  191. >>> assert torch.allclose(vjps[0], torch.matmul(cotangents, y.transpose(0, 1)))
  192. >>> assert torch.allclose(vjps[1], torch.matmul(x.transpose(0, 1), cotangents))
  193. ``primals`` are the positional arguments for ``f``. All kwargs use their
  194. default value
  195. >>> x = torch.randn([5])
  196. >>> def f(x, scale=4.):
  197. >>> return x * scale
  198. >>>
  199. >>> (_, vjpfunc) = torch.func.vjp(f, x)
  200. >>> vjps = vjpfunc(torch.ones_like(x))
  201. >>> assert torch.allclose(vjps[0], torch.full(x.shape, 4.))
  202. .. note::
  203. Using PyTorch ``torch.no_grad`` together with ``vjp``.
  204. Case 1: Using ``torch.no_grad`` inside a function:
  205. >>> def f(x):
  206. >>> with torch.no_grad():
  207. >>> c = x ** 2
  208. >>> return x - c
  209. In this case, ``vjp(f)(x)`` will respect the inner ``torch.no_grad``.
  210. Case 2: Using ``vjp`` inside ``torch.no_grad`` context manager:
  211. >>> # xdoctest: +SKIP(failing)
  212. >>> with torch.no_grad():
  213. >>> vjp(f)(x)
  214. In this case, ``vjp`` will respect the inner ``torch.no_grad``, but not the
  215. outer one. This is because ``vjp`` is a "function transform": its result
  216. should not depend on the result of a context manager outside of ``f``.
  217. """
  218. return _vjp_with_argnums(func, *primals, has_aux=has_aux)
  219. @doesnt_support_saved_tensors_hooks
  220. def _vjp_with_argnums(func: Callable, *primals, argnums: Optional[argnums_t] = None, has_aux: bool = False):
  221. # This is the same function as vjp but also accepts an argnums argument
  222. # All args are the same as vjp except for the added argument
  223. # argnums (Optional[int or tuple[int]]): Optional, specifies the argument(s) to compute gradients with respect to.
  224. # If None, computes the gradients with respect to all inputs (used for vjp). Default: None
  225. #
  226. # WARN: Users should NOT call this function directly and should just be calling vjp.
  227. # It is only separated so that inputs passed to jacrev but not differentiated get the correct wrappers.
  228. #
  229. # NOTE: All error messages are produced as if vjp was being called, even if this was called by jacrev
  230. #
  231. # Returns the same two elements as :func:`vjp` but the function returned, vjp_fn, returns a tuple of VJPs
  232. # for only the primal elements given by argnums.
  233. level = _grad_increment_nesting()
  234. try:
  235. # See NOTE [grad and vjp interaction with no_grad]
  236. with torch.enable_grad():
  237. primals = _wrap_all_tensors(primals, level)
  238. if argnums is None:
  239. diff_primals = _create_differentiable(primals, level)
  240. else:
  241. diff_primals = _slice_argnums(primals, argnums, as_tuple=False)
  242. tree_map_(partial(_create_differentiable, level=level), diff_primals)
  243. primals_out = func(*primals)
  244. if has_aux:
  245. if not (isinstance(primals_out, tuple) and len(primals_out) == 2):
  246. raise RuntimeError(
  247. "vjp(f, *primals): output of function f should be a tuple: (output, aux) "
  248. "if has_aux is True"
  249. )
  250. primals_out, aux = primals_out
  251. aux = _undo_create_differentiable(aux, level)
  252. flat_primals_out, primals_out_spec = tree_flatten(primals_out)
  253. assert_non_empty_tensor_output(flat_primals_out, 'vjp(f, *primals)')
  254. flat_diff_primals, primals_spec = tree_flatten(diff_primals)
  255. results = _undo_create_differentiable(primals_out, level)
  256. for primal_out in flat_primals_out:
  257. assert isinstance(primal_out, torch.Tensor)
  258. if primal_out.is_floating_point() or primal_out.is_complex():
  259. continue
  260. raise RuntimeError("vjp(f, ...): All outputs of f must be "
  261. "floating-point or complex Tensors, got Tensor "
  262. f"with dtype {primal_out.dtype}")
  263. def wrapper(cotangents, retain_graph=True, create_graph=None):
  264. if create_graph is None:
  265. create_graph = torch.is_grad_enabled()
  266. flat_cotangents, cotangents_spec = tree_flatten(cotangents)
  267. if primals_out_spec != cotangents_spec:
  268. raise RuntimeError(
  269. f'Expected pytree structure of cotangents to be the same '
  270. f'as pytree structure of outputs to the function. '
  271. f'cotangents: {treespec_pprint(cotangents_spec)}, '
  272. f'primal output: {treespec_pprint(primals_out_spec)}')
  273. result = _autograd_grad(flat_primals_out, flat_diff_primals, flat_cotangents,
  274. retain_graph=retain_graph, create_graph=create_graph)
  275. return tree_unflatten(result, primals_spec)
  276. finally:
  277. _grad_decrement_nesting()
  278. if has_aux:
  279. return results, wrapper, aux
  280. else:
  281. return results, wrapper
  282. def _safe_zero_index(x):
  283. assert len(x) == 1
  284. return x[0]
  285. # jacrev and jacfwd don't support complex functions
  286. # Helper function to throw appropriate error.
  287. def error_if_complex(func_name, args, is_input):
  288. flat_args, _ = tree_flatten(args)
  289. for idx, arg in enumerate(flat_args):
  290. if arg.dtype.is_complex:
  291. input_or_output = ("inputs" if is_input else "outputs")
  292. err_msg = (f"{func_name}: Expected all {input_or_output} "
  293. f"to be real but received complex tensor at flattened input idx: {idx}")
  294. raise RuntimeError(err_msg)
  295. @exposed_in("torch.func")
  296. def jacrev(func: Callable, argnums: Union[int, Tuple[int]] = 0, *, has_aux=False,
  297. chunk_size: Optional[int] = None,
  298. _preallocate_and_copy=False):
  299. """
  300. Computes the Jacobian of ``func`` with respect to the arg(s) at index
  301. ``argnum`` using reverse mode autodiff
  302. .. note::
  303. Using :attr:`chunk_size=1` is equivalent to computing the jacobian
  304. row-by-row with a for-loop i.e. the constraints of :func:`vmap` are
  305. not applicable.
  306. Args:
  307. func (function): A Python function that takes one or more arguments,
  308. one of which must be a Tensor, and returns one or more Tensors
  309. argnums (int or Tuple[int]): Optional, integer or tuple of integers,
  310. saying which arguments to get the Jacobian with respect to.
  311. Default: 0.
  312. has_aux (bool): Flag indicating that ``func`` returns a
  313. ``(output, aux)`` tuple where the first element is the output of
  314. the function to be differentiated and the second element is
  315. auxiliary objects that will not be differentiated.
  316. Default: False.
  317. chunk_size (None or int): If None (default), use the maximum chunk size
  318. (equivalent to doing a single vmap over vjp to compute the jacobian).
  319. If 1, then compute the jacobian row-by-row with a for-loop.
  320. If not None, then compute the jacobian :attr:`chunk_size` rows at a time
  321. (equivalent to doing multiple vmap over vjp). If you run into memory issues computing
  322. the jacobian, please try to specify a non-None chunk_size.
  323. Returns:
  324. Returns a function that takes in the same inputs as ``func`` and
  325. returns the Jacobian of ``func`` with respect to the arg(s) at
  326. ``argnums``. If ``has_aux is True``, then the returned function
  327. instead returns a ``(jacobian, aux)`` tuple where ``jacobian``
  328. is the Jacobian and ``aux`` is auxiliary objects returned by ``func``.
  329. A basic usage with a pointwise, unary operation will give a diagonal array
  330. as the Jacobian
  331. >>> from torch.func import jacrev
  332. >>> x = torch.randn(5)
  333. >>> jacobian = jacrev(torch.sin)(x)
  334. >>> expected = torch.diag(torch.cos(x))
  335. >>> assert torch.allclose(jacobian, expected)
  336. If you would like to compute the output of the function as well as the
  337. jacobian of the function, use the ``has_aux`` flag to return the output
  338. as an auxiliary object:
  339. >>> from torch.func import jacrev
  340. >>> x = torch.randn(5)
  341. >>>
  342. >>> def f(x):
  343. >>> return x.sin()
  344. >>>
  345. >>> def g(x):
  346. >>> result = f(x)
  347. >>> return result, result
  348. >>>
  349. >>> jacobian_f, f_x = jacrev(g, has_aux=True)(x)
  350. >>> assert torch.allclose(f_x, f(x))
  351. :func:`jacrev` can be composed with vmap to produce batched
  352. Jacobians:
  353. >>> from torch.func import jacrev, vmap
  354. >>> x = torch.randn(64, 5)
  355. >>> jacobian = vmap(jacrev(torch.sin))(x)
  356. >>> assert jacobian.shape == (64, 5, 5)
  357. Additionally, :func:`jacrev` can be composed with itself to produce
  358. Hessians
  359. >>> from torch.func import jacrev
  360. >>> def f(x):
  361. >>> return x.sin().sum()
  362. >>>
  363. >>> x = torch.randn(5)
  364. >>> hessian = jacrev(jacrev(f))(x)
  365. >>> assert torch.allclose(hessian, torch.diag(-x.sin()))
  366. By default, :func:`jacrev` computes the Jacobian with respect to the first
  367. input. However, it can compute the Jacboian with respect to a different
  368. argument by using ``argnums``:
  369. >>> from torch.func import jacrev
  370. >>> def f(x, y):
  371. >>> return x + y ** 2
  372. >>>
  373. >>> x, y = torch.randn(5), torch.randn(5)
  374. >>> jacobian = jacrev(f, argnums=1)(x, y)
  375. >>> expected = torch.diag(2 * y)
  376. >>> assert torch.allclose(jacobian, expected)
  377. Additionally, passing a tuple to ``argnums`` will compute the Jacobian
  378. with respect to multiple arguments
  379. >>> from torch.func import jacrev
  380. >>> def f(x, y):
  381. >>> return x + y ** 2
  382. >>>
  383. >>> x, y = torch.randn(5), torch.randn(5)
  384. >>> jacobian = jacrev(f, argnums=(0, 1))(x, y)
  385. >>> expectedX = torch.diag(torch.ones_like(x))
  386. >>> expectedY = torch.diag(2 * y)
  387. >>> assert torch.allclose(jacobian[0], expectedX)
  388. >>> assert torch.allclose(jacobian[1], expectedY)
  389. .. note::
  390. Using PyTorch ``torch.no_grad`` together with ``jacrev``.
  391. Case 1: Using ``torch.no_grad`` inside a function:
  392. >>> def f(x):
  393. >>> with torch.no_grad():
  394. >>> c = x ** 2
  395. >>> return x - c
  396. In this case, ``jacrev(f)(x)`` will respect the inner ``torch.no_grad``.
  397. Case 2: Using ``jacrev`` inside ``torch.no_grad`` context manager:
  398. >>> with torch.no_grad():
  399. >>> jacrev(f)(x)
  400. In this case, ``jacrev`` will respect the inner ``torch.no_grad``, but not the
  401. outer one. This is because ``jacrev`` is a "function transform": its result
  402. should not depend on the result of a context manager outside of ``f``.
  403. """
  404. if not (chunk_size is None or chunk_size > 0):
  405. raise ValueError("jacrev: `chunk_size` should be greater than 0.")
  406. @wraps(func)
  407. def wrapper_fn(*args):
  408. error_if_complex("jacrev", args, is_input=True)
  409. vjp_out = _vjp_with_argnums(func, *args, argnums=argnums, has_aux=has_aux)
  410. if has_aux:
  411. output, vjp_fn, aux = vjp_out
  412. else:
  413. output, vjp_fn = vjp_out
  414. # See NOTE: [Computing jacobian with vmap and vjp for multiple outputs]
  415. flat_output, output_spec = tree_flatten(output)
  416. error_if_complex("jacrev", flat_output, is_input=False)
  417. # NB: vjp already checks that all outputs are tensors
  418. # Step 1: Construct grad_outputs by splitting the standard basis
  419. flat_output_numels = tuple(out.numel() for out in flat_output)
  420. primals = _slice_argnums(args, argnums)
  421. flat_primals, primals_spec = tree_flatten(primals)
  422. def compute_jacobian_stacked():
  423. # Helper function to compute chunked Jacobian
  424. # The intermediate chunked calculation are only
  425. # scoped at this function level.
  426. chunked_results = []
  427. for flat_basis_chunk in _chunked_standard_basis_for_(flat_output,
  428. flat_output_numels,
  429. chunk_size=chunk_size):
  430. if chunk_size == 1:
  431. # sanity check.
  432. for t in flat_basis_chunk:
  433. assert t.size(0) == 1
  434. flat_basis_chunk = tree_map(lambda t: torch.squeeze(t, 0), flat_basis_chunk)
  435. basis = tree_unflatten(flat_basis_chunk, output_spec)
  436. if chunk_size == 1:
  437. # Behaviour with `chunk_size=1` is same as `for-loop`
  438. # i.e. user shouldn't deal with the limitations of vmap.
  439. chunked_result = vjp_fn(basis)
  440. else: # chunk_size is None or chunk_size != 1
  441. chunked_result = vmap(vjp_fn)(basis)
  442. flat_results, _ = tree_flatten(chunked_result)
  443. if chunk_size == 1:
  444. flat_results = tree_map(lambda t: torch.unsqueeze(t, 0), flat_results)
  445. chunked_results.append(flat_results)
  446. if len(chunked_results) == 1:
  447. # Short-circuit if we used a single chunk
  448. return chunked_results[0]
  449. # Concatenate chunks.
  450. flat_results = []
  451. # Iterate and concat the jacobians of different
  452. # inputs.
  453. for idx in range(len(flat_primals)):
  454. r = tuple(map(lambda r_: r_[idx], chunked_results))
  455. flat_results.append(torch.cat(r, 0))
  456. return flat_results
  457. def compute_jacobian_preallocate_and_copy():
  458. # Helper function to compute chunked Jacobian
  459. # The intermediate chunked calculation are only
  460. # scoped at this function level.
  461. out_vec_size = sum(flat_output_numels)
  462. # Don't pre-allocate if we have a single chunk.
  463. if not (chunk_size is None or chunk_size >= out_vec_size):
  464. stacked_results = [primal.new_zeros(out_vec_size, *primal.shape) for primal in flat_primals]
  465. for idx, flat_basis_chunk in enumerate(_chunked_standard_basis_for_(flat_output,
  466. flat_output_numels,
  467. chunk_size=chunk_size)):
  468. if chunk_size == 1:
  469. # sanity check.
  470. for t in flat_basis_chunk:
  471. assert t.size(0) == 1
  472. flat_basis_chunk = list(map(lambda t: torch.squeeze(t, 0), flat_basis_chunk))
  473. basis = tree_unflatten(flat_basis_chunk, output_spec)
  474. if chunk_size == 1:
  475. # Behaviour with `chunk_size=1` is same as `for-loop`
  476. # i.e. user shouldn't deal with the limitations of vmap.
  477. chunked_result = vjp_fn(basis)
  478. else: # chunk_size is None or chunk_size != 1
  479. chunked_result = vmap(vjp_fn)(basis)
  480. flat_results, _ = tree_flatten(chunked_result)
  481. # Short-circuit if we have a single chunk.
  482. if chunk_size is None or chunk_size >= out_vec_size:
  483. if chunk_size == 1: # and out_vec_size == 1
  484. # Since we squeezed the output dim
  485. flat_results = tree_map(lambda t: torch.unsqueeze(t, 0), flat_results)
  486. return flat_results
  487. for r, sr in zip(flat_results, stacked_results):
  488. sr[idx * chunk_size: (idx + 1) * chunk_size].copy_(r)
  489. return stacked_results
  490. if _preallocate_and_copy:
  491. flat_jacobians_per_input = compute_jacobian_preallocate_and_copy()
  492. else:
  493. flat_jacobians_per_input = compute_jacobian_stacked()
  494. # Step 2: The returned jacobian is one big tensor per input. In this step,
  495. # we split each Tensor by output.
  496. flat_jacobians_per_input = [result.split(flat_output_numels, dim=0) for result in flat_jacobians_per_input]
  497. flat_input_flat_output = [
  498. tuple(split.view(out.shape + primal.shape)
  499. for split, out in zip(splits, flat_output))
  500. for splits, primal in zip(flat_jacobians_per_input, flat_primals)
  501. ]
  502. # Step 3: Right now, `jacobian` is a List[List[Tensor]].
  503. # The outer List corresponds to the number of primals,
  504. # the inner List corresponds to the number of outputs.
  505. # We need to:
  506. # a. Exchange the order of the outer List and inner List
  507. # b. tree_unflatten the inner Lists (which correspond to the primals)
  508. # c. handle the argnums=int case
  509. # d. tree_unflatten the outer List (which corresponds to the outputs)
  510. flat_output_flat_input = tuple(zip(*flat_input_flat_output))
  511. flat_output_input = tuple(tree_unflatten(flat_input, primals_spec)
  512. for flat_input in flat_output_flat_input)
  513. if isinstance(argnums, int):
  514. flat_output_input = tuple(_safe_zero_index(flat_input)
  515. for flat_input in flat_output_input)
  516. output_input = tree_unflatten(flat_output_input, output_spec)
  517. if has_aux:
  518. return output_input, aux
  519. return output_input
  520. return wrapper_fn
  521. # NOTE: [Computing jacobian with vmap and vjp for multiple outputs]
  522. #
  523. # Let's consider f(x) = (x**2, x.sum()) and let x = torch.randn(3).
  524. # It turns out we can compute the jacobian of this function with a single
  525. # call to autograd.grad by using vmap over the correct grad_outputs.
  526. #
  527. # Firstly, one way to compute the jacobian is to stack x**2 and x.sum()
  528. # into a 4D vector. E.g., use g(x) = torch.stack([x**2, x.sum()])
  529. #
  530. # To get the first row of the jacobian, we call
  531. # >>> autograd.grad(g(x), x, grad_outputs=torch.tensor([1, 0, 0, 0]))
  532. # To get the 2nd row of the jacobian, we call
  533. # >>> autograd.grad(g(x), x, grad_outputs=torch.tensor([0, 1, 0, 0]))
  534. # and so on.
  535. #
  536. # Using vmap, we can vectorize all 4 of these computations into one by
  537. # passing the standard basis for R^4 as the grad_output.
  538. # vmap(partial(autograd.grad, g(x), x))(torch.eye(4)).
  539. #
  540. # Now, how do we compute the jacobian *without stacking the output*?
  541. # We can just split the standard basis across the outputs. So to
  542. # compute the jacobian of f(x), we'd use
  543. # >>> autograd.grad(f(x), x, grad_outputs=_construct_standard_basis_for(...))
  544. # The grad_outputs looks like the following:
  545. # ( torch.tensor([[1, 0, 0],
  546. # [0, 1, 0],
  547. # [0, 0, 1],
  548. # [0, 0, 0]]),
  549. # torch.tensor([[0],
  550. # [0],
  551. # [0],
  552. # [1]]) )
  553. #
  554. # But we're not done yet!
  555. # >>> vmap(partial(autograd.grad(f(x), x, grad_outputs=...)))
  556. # returns a Tensor of shape [4, 3]. We have to remember to split the
  557. # jacobian of shape [4, 3] into two:
  558. # - one of shape [3, 3] for the first output
  559. # - one of shape [ 3] for the second output
  560. def _chunked_standard_basis_for_(tensors, tensor_numels, chunk_size=None):
  561. # This function:
  562. # - constructs a N=sum(tensor_numels) standard basis. i.e. an NxN identity matrix.
  563. # - Splits the identity matrix into chunks with each chunk size determined by `tensor_numels`.
  564. # - Each chunk corresponds to one tensor. The chunk has the same dtype and
  565. # device as the tensor
  566. #
  567. # For example, with tensor_numels = [1, 2, 1], this function returns:
  568. # ( tensor([[1], tensor([[0, 0], tensor([[0],
  569. # [0], [1, 0], [0],
  570. # [0], [0, 1], [0],
  571. # [0]]) , [0, 0]]) , [1]]) )
  572. #
  573. # Precondition: tensor_numels == tuple(tensor.numel() for tensor in tensors)
  574. # Precondition: tensors always has at least one element.
  575. #
  576. # See NOTE: [Computing jacobian with vmap and grad for multiple tensors]
  577. # for context behind this function.
  578. # NOTE: Argument `chunk_size` is used to generate chunked basis instead of
  579. # one huge basis matrix. `chunk_size` dictates the maximum size of the
  580. # basis matrix along dim=0.
  581. assert len(tensors) == len(tensor_numels)
  582. assert len(tensors) > 0
  583. assert chunk_size is None or chunk_size > 0
  584. total_numel = sum(tensor_numels)
  585. if chunk_size and chunk_size < total_numel:
  586. chunk_numels = get_chunk_sizes(total_numel, chunk_size)
  587. else: # chunk_size is None or chunk_size >= total_numel
  588. chunk_size = total_numel
  589. chunk_numels = [total_numel]
  590. diag_start_indices = (0, *torch.tensor(tensor_numels).cumsum(dim=0)[:-1].neg().unbind())
  591. for chunk_idx, total_numel in enumerate(chunk_numels):
  592. chunks = tuple(tensor.new_zeros(total_numel, tensor_numel)
  593. for tensor, tensor_numel in zip(tensors, tensor_numels))
  594. for chunk, diag_start_idx in zip(chunks, diag_start_indices):
  595. chunk.diagonal(diag_start_idx + chunk_idx * chunk_size).fill_(1)
  596. chunks = tuple(chunk.view(total_numel, *tensor.shape)
  597. for chunk, tensor in zip(chunks, tensors))
  598. yield chunks
  599. def _construct_standard_basis_for(tensors, tensor_numels):
  600. for basis in _chunked_standard_basis_for_(tensors, tensor_numels, chunk_size=None):
  601. return basis
  602. def _validate_and_wrap_argnum(argnum, num_args):
  603. if not isinstance(argnum, int):
  604. raise RuntimeError(f'argnum must be int, got: {type(argnum)}')
  605. if argnum >= 0 and argnum < num_args:
  606. return argnum
  607. if argnum < 0 and argnum >= -num_args:
  608. return argnum + num_args
  609. raise RuntimeError(f'Got argnum={argnum}, but only {num_args} positional inputs')
  610. def _check_unique_non_empty(argnums):
  611. if isinstance(argnums, tuple):
  612. if len(argnums) == 0:
  613. raise RuntimeError("argnums must be non-empty")
  614. if len(set(argnums)) != len(argnums):
  615. raise RuntimeError(f"argnums elements must be unique, got {argnums}")
  616. def _replace_args(old_args, new_args, argnums):
  617. if isinstance(argnums, int):
  618. if len(new_args) != 1:
  619. raise RuntimeError(f'new_args should be of size 1, was of size {len(new_args)}')
  620. return tuple(new_args[0] if i == argnums else old_args[i] for i in range(len(old_args)))
  621. if isinstance(argnums, tuple):
  622. if len(new_args) != len(argnums):
  623. raise RuntimeError(
  624. "new_args should have the same size as argnums. "
  625. f"Argnums size {len(argnums)}, new_args size {len(new_args)}")
  626. def get_right_elem(i):
  627. return new_args[argnums.index(i)] if i in argnums else old_args[i]
  628. return tuple(get_right_elem(i) for i in range(len(old_args)))
  629. raise RuntimeError(f'argnums must be int or Tuple[int, ...], got: {type(argnums)}')
  630. def _validate_and_wrap_argnums(argnums, num_args):
  631. if isinstance(argnums, int):
  632. return _validate_and_wrap_argnum(argnums, num_args)
  633. if isinstance(argnums, tuple):
  634. return tuple(_validate_and_wrap_argnum(argnum, num_args) for argnum in argnums)
  635. raise AssertionError("Should never get here")
  636. def _slice_argnums(args, argnums, as_tuple=True):
  637. if not isinstance(argnums, int) and not isinstance(argnums, tuple):
  638. raise RuntimeError(f'argnums must be int or Tuple[int, ...], got: {type(argnums)}')
  639. argnums = _validate_and_wrap_argnums(argnums, len(args))
  640. _check_unique_non_empty(argnums)
  641. if isinstance(argnums, int):
  642. if as_tuple:
  643. return (args[argnums],)
  644. else:
  645. return args[argnums]
  646. return tuple(args[i] for i in argnums)
  647. JVP_NESTING = 0
  648. @contextlib.contextmanager
  649. def noop():
  650. yield
  651. def assert_flat_tuple_of_tensors(elts: Any, api: str, argname: str) -> None:
  652. if not isinstance(elts, tuple):
  653. raise RuntimeError(
  654. f'{api}: Expected {argname} to be a tuple of Tensors, got {type(elts)}')
  655. for elt in elts:
  656. if isinstance(elt, torch.Tensor):
  657. continue
  658. raise RuntimeError(
  659. f'{api}: Expected {argname} to be a tuple of Tensors, got '
  660. f'a tuple with an element of type {type(elt)}')
  661. if len(elts) == 0:
  662. raise RuntimeError(
  663. f'{api}: Expected {argname} to be a non-empty tuple of Tensors.')
  664. def assert_non_empty_tensor_output(output: List[Any], api: str) -> None:
  665. if output == [None] or len(output) < 1:
  666. raise RuntimeError(
  667. f'{api}: Expected f to be a function that has non-empty output (got output = {output})'
  668. )
  669. for o in output:
  670. if not isinstance(o, torch.Tensor):
  671. raise RuntimeError(
  672. f'{api}: expected f(*primals) to return only tensors'
  673. f', got unsupported type {type(o)}'
  674. )
  675. def assert_output_is_tensor_or_tensors(output: Any, api: str) -> None:
  676. if isinstance(output, torch.Tensor):
  677. return
  678. if not isinstance(output, tuple):
  679. raise RuntimeError(
  680. f'{api}: Expected output of f to be a Tensor or Tensors, got '
  681. f'{type(output)}')
  682. if len(output) == 0:
  683. raise RuntimeError(
  684. f'{api}: Expected output of f to be a non-empty tuple of Tensors.')
  685. for out in output:
  686. if isinstance(out, torch.Tensor):
  687. continue
  688. raise RuntimeError(
  689. f'{api}: Expected output of f to be a Tensor or Tensors, got '
  690. f'{type(out)} as an output')
  691. def assert_non_empty_list_of_tensors(output: List[torch.Tensor], api: str, argname: str) -> None:
  692. if len(output) == 0:
  693. raise RuntimeError(
  694. f'{api}: Expected {argname} to contain at least one Tensor.')
  695. for out in output:
  696. if isinstance(out, torch.Tensor):
  697. continue
  698. raise RuntimeError(
  699. f'{api}: Expected {argname} to only contain Tensors, got '
  700. f'{type(out)}')
  701. jvp_str = 'jvp(f, primals, tangents)'
  702. def safe_unpack_dual(dual, strict):
  703. if not isinstance(dual, torch.Tensor):
  704. raise RuntimeError(
  705. f'{jvp_str}: expected f(*args) to return only tensors'
  706. f', got unsupported type {type(dual)}'
  707. )
  708. primal, tangent = fwAD.unpack_dual(dual)
  709. if tangent is None:
  710. if strict:
  711. raise RuntimeError(
  712. 'jvp(f, primals, tangents, strict=True): '
  713. 'The output of f is independent of '
  714. 'the inputs. This is not allowed with strict=True.')
  715. tangent = torch.zeros_like(primal)
  716. return primal, tangent
  717. @exposed_in("torch.func")
  718. def jvp(func: Callable, primals: Any, tangents: Any, *, strict: bool = False, has_aux: bool = False):
  719. """
  720. Standing for the Jacobian-vector product, returns a tuple containing
  721. the output of `func(*primals)` and the "Jacobian of ``func`` evaluated at
  722. ``primals``" times ``tangents``. This is also known as forward-mode autodiff.
  723. Args:
  724. func (function): A Python function that takes one or more arguments,
  725. one of which must be a Tensor, and returns one or more Tensors
  726. primals (Tensors): Positional arguments to ``func`` that must all be
  727. Tensors. The returned function will also be computing the
  728. derivative with respect to these arguments
  729. tangents (Tensors): The "vector" for which Jacobian-vector-product is
  730. computed. Must be the same structure and sizes as the inputs to
  731. ``func``.
  732. has_aux (bool): Flag indicating that ``func`` returns a
  733. ``(output, aux)`` tuple where the first element is the output of
  734. the function to be differentiated and the second element is
  735. other auxiliary objects that will not be differentiated.
  736. Default: False.
  737. Returns:
  738. Returns a ``(output, jvp_out)`` tuple containing the output of ``func``
  739. evaluated at ``primals`` and the Jacobian-vector product.
  740. If ``has_aux is True``, then instead returns a ``(output, jvp_out, aux)`` tuple.
  741. .. note::
  742. You may see this API error out with "forward-mode AD not implemented
  743. for operator X". If so, please file a bug report and we will prioritize it.
  744. jvp is useful when you wish to compute gradients of a function R^1 -> R^N
  745. >>> from torch.func import jvp
  746. >>> x = torch.randn([])
  747. >>> f = lambda x: x * torch.tensor([1., 2., 3])
  748. >>> value, grad = jvp(f, (x,), (torch.tensor(1.),))
  749. >>> assert torch.allclose(value, f(x))
  750. >>> assert torch.allclose(grad, torch.tensor([1., 2, 3]))
  751. :func:`jvp` can support functions with multiple inputs by passing in the
  752. tangents for each of the inputs
  753. >>> from torch.func import jvp
  754. >>> x = torch.randn(5)
  755. >>> y = torch.randn(5)
  756. >>> f = lambda x, y: (x * y)
  757. >>> _, output = jvp(f, (x, y), (torch.ones(5), torch.ones(5)))
  758. >>> assert torch.allclose(output, x + y)
  759. """
  760. return _jvp_with_argnums(func, primals, tangents, argnums=None, strict=strict, has_aux=has_aux)
  761. @doesnt_support_saved_tensors_hooks
  762. def _jvp_with_argnums(func: Callable, primals: Any, tangents: Any, argnums: Optional[argnums_t], *,
  763. strict: bool = False, has_aux: bool):
  764. # This is the same function as jvp but also accepts an argnums argument
  765. # Most args are the same as jvp except for the added argument
  766. # argnums (Optional[int or tuple[int]]): Optional, specifies the argument(s) to compute gradients with respect to.
  767. # If None, computes the gradients with respect to all inputs (used for jvp). Default: None
  768. # Because of this, tangents must be of length argnums and matches up to the corresponding primal whose index is
  769. # given by argnums
  770. #
  771. # WARN: Users should NOT call this function directly and should just be calling jvp.
  772. # It is only separated so that inputs passed to jacfwd but not differentiated get the correct wrappers.
  773. #
  774. # NOTE: All error messages are produced as if jvp was being called, even if this was called by jacfwd
  775. #
  776. # Returns the same two elements as :func:`jvp` but the returned tuple, ``jvp_out``, only has JVPs with respect to
  777. # the primals given by argnums
  778. if not isinstance(primals, tuple):
  779. raise RuntimeError(
  780. f'{jvp_str}: Expected primals to be a tuple. '
  781. f'E.g. it should be valid to call f(*primals).')
  782. diff_args = primals if argnums is None else _slice_argnums(primals, argnums)
  783. flat_primals, primals_spec = tree_flatten(diff_args)
  784. flat_tangents, tangents_spec = tree_flatten(tangents)
  785. if primals_spec != tangents_spec:
  786. raise RuntimeError(
  787. f'{jvp_str}: Expected primals and tangents to have the same python '
  788. f'structure. For example, if primals is a tuple of 3 tensors, '
  789. f'tangents also must be. Got primals with structure {primals_spec} '
  790. f'and tangents with structure {tangents_spec}')
  791. assert_non_empty_list_of_tensors(flat_primals, jvp_str, 'primals')
  792. assert_non_empty_list_of_tensors(flat_tangents, jvp_str, 'tangents')
  793. level = _jvp_increment_nesting()
  794. try:
  795. global JVP_NESTING
  796. JVP_NESTING += 1
  797. with fwAD._set_fwd_grad_enabled(True):
  798. ctx = fwAD.dual_level if JVP_NESTING == 1 else noop
  799. with ctx():
  800. flat_duals = tuple(fwAD.make_dual(p, t)
  801. for p, t in zip(flat_primals, flat_tangents))
  802. duals = tree_unflatten(flat_duals, primals_spec)
  803. if argnums is not None:
  804. primals = _wrap_all_tensors(primals, level)
  805. duals = _replace_args(primals, duals, argnums)
  806. result_duals = func(*duals)
  807. if has_aux:
  808. if not (isinstance(result_duals, tuple) and len(result_duals) == 2):
  809. raise RuntimeError(
  810. f"{jvp_str}: output of function f should be a tuple: (output, aux) "
  811. "if has_aux is True"
  812. )
  813. result_duals, aux = result_duals
  814. aux = _undo_create_differentiable(aux, level)
  815. result_duals, spec = tree_flatten(result_duals)
  816. assert_non_empty_tensor_output(result_duals, jvp_str)
  817. primals_out, tangents_out = \
  818. zip(*[safe_unpack_dual(dual, strict) for dual in result_duals])
  819. primals_out = tree_map(
  820. partial(_undo_create_differentiable, level=level), primals_out)
  821. tangents_out = tree_map(
  822. partial(_undo_create_differentiable, level=level), tangents_out)
  823. primals_out_unflatten = tree_unflatten(primals_out, spec)
  824. tangents_out_unflatten = tree_unflatten(tangents_out, spec)
  825. if has_aux:
  826. return primals_out_unflatten, tangents_out_unflatten, aux
  827. return primals_out_unflatten, tangents_out_unflatten
  828. finally:
  829. _jvp_decrement_nesting()
  830. JVP_NESTING -= 1
  831. def safe_unflatten(tensor, dim, shape):
  832. if len(shape) == 0:
  833. assert tensor.shape[dim] == 1
  834. return tensor.squeeze(dim)
  835. return tensor.unflatten(dim, shape)
  836. @exposed_in("torch.func")
  837. def jacfwd(func: Callable, argnums: argnums_t = 0, has_aux: bool = False, *, randomness: str = "error"):
  838. """
  839. Computes the Jacobian of ``func`` with respect to the arg(s) at index
  840. ``argnum`` using forward-mode autodiff
  841. Args:
  842. func (function): A Python function that takes one or more arguments,
  843. one of which must be a Tensor, and returns one or more Tensors
  844. argnums (int or Tuple[int]): Optional, integer or tuple of integers,
  845. saying which arguments to get the Jacobian with respect to.
  846. Default: 0.
  847. has_aux (bool): Flag indicating that ``func`` returns a
  848. ``(output, aux)`` tuple where the first element is the output of
  849. the function to be differentiated and the second element is
  850. auxiliary objects that will not be differentiated.
  851. Default: False.
  852. randomness(str): Flag indicating what type of randomness to use.
  853. See :func:`vmap` for more detail. Allowed: "different", "same", "error".
  854. Default: "error"
  855. Returns:
  856. Returns a function that takes in the same inputs as ``func`` and
  857. returns the Jacobian of ``func`` with respect to the arg(s) at
  858. ``argnums``. If ``has_aux is True``, then the returned function
  859. instead returns a ``(jacobian, aux)`` tuple where ``jacobian``
  860. is the Jacobian and ``aux`` is auxiliary objects returned by ``func``.
  861. .. note::
  862. You may see this API error out with "forward-mode AD not implemented
  863. for operator X". If so, please file a bug report and we will prioritize it.
  864. An alternative is to use :func:`jacrev`, which has better operator coverage.
  865. A basic usage with a pointwise, unary operation will give a diagonal array
  866. as the Jacobian
  867. >>> from torch.func import jacfwd
  868. >>> x = torch.randn(5)
  869. >>> jacobian = jacfwd(torch.sin)(x)
  870. >>> expected = torch.diag(torch.cos(x))
  871. >>> assert torch.allclose(jacobian, expected)
  872. :func:`jacfwd` can be composed with vmap to produce batched
  873. Jacobians:
  874. >>> from torch.func import jacfwd, vmap
  875. >>> x = torch.randn(64, 5)
  876. >>> jacobian = vmap(jacfwd(torch.sin))(x)
  877. >>> assert jacobian.shape == (64, 5, 5)
  878. If you would like to compute the output of the function as well as the
  879. jacobian of the function, use the ``has_aux`` flag to return the output
  880. as an auxiliary object:
  881. >>> from torch.func import jacfwd
  882. >>> x = torch.randn(5)
  883. >>>
  884. >>> def f(x):
  885. >>> return x.sin()
  886. >>>
  887. >>> def g(x):
  888. >>> result = f(x)
  889. >>> return result, result
  890. >>>
  891. >>> jacobian_f, f_x = jacfwd(g, has_aux=True)(x)
  892. >>> assert torch.allclose(f_x, f(x))
  893. Additionally, :func:`jacrev` can be composed with itself or :func:`jacrev`
  894. to produce Hessians
  895. >>> from torch.func import jacfwd, jacrev
  896. >>> def f(x):
  897. >>> return x.sin().sum()
  898. >>>
  899. >>> x = torch.randn(5)
  900. >>> hessian = jacfwd(jacrev(f))(x)
  901. >>> assert torch.allclose(hessian, torch.diag(-x.sin()))
  902. By default, :func:`jacfwd` computes the Jacobian with respect to the first
  903. input. However, it can compute the Jacboian with respect to a different
  904. argument by using ``argnums``:
  905. >>> from torch.func import jacfwd
  906. >>> def f(x, y):
  907. >>> return x + y ** 2
  908. >>>
  909. >>> x, y = torch.randn(5), torch.randn(5)
  910. >>> jacobian = jacfwd(f, argnums=1)(x, y)
  911. >>> expected = torch.diag(2 * y)
  912. >>> assert torch.allclose(jacobian, expected)
  913. Additionally, passing a tuple to ``argnums`` will compute the Jacobian
  914. with respect to multiple arguments
  915. >>> from torch.func import jacfwd
  916. >>> def f(x, y):
  917. >>> return x + y ** 2
  918. >>>
  919. >>> x, y = torch.randn(5), torch.randn(5)
  920. >>> jacobian = jacfwd(f, argnums=(0, 1))(x, y)
  921. >>> expectedX = torch.diag(torch.ones_like(x))
  922. >>> expectedY = torch.diag(2 * y)
  923. >>> assert torch.allclose(jacobian[0], expectedX)
  924. >>> assert torch.allclose(jacobian[1], expectedY)
  925. """
  926. @wraps(func)
  927. def wrapper_fn(*args):
  928. error_if_complex("jacfwd", args, is_input=True)
  929. primals = args if argnums is None else _slice_argnums(args, argnums)
  930. flat_primals, primals_spec = tree_flatten(primals)
  931. flat_primals_numels = tuple(p.numel() for p in flat_primals)
  932. flat_basis = _construct_standard_basis_for(flat_primals, flat_primals_numels)
  933. basis = tree_unflatten(flat_basis, primals_spec)
  934. def push_jvp(basis):
  935. output = _jvp_with_argnums(func, args, basis, argnums=argnums, has_aux=has_aux)
  936. # output[0] is the output of `func(*args)`
  937. error_if_complex("jacfwd", output[0], is_input=False)
  938. if has_aux:
  939. _, jvp_out, aux = output
  940. return jvp_out, aux
  941. _, jvp_out = output
  942. return jvp_out
  943. results = vmap(push_jvp, randomness=randomness)(basis)
  944. if has_aux:
  945. results, aux = results
  946. # aux is in the standard basis format, e.g. NxN matrix
  947. # We need to fetch the first element as original `func` output
  948. flat_aux, aux_spec = tree_flatten(aux)
  949. flat_aux = [value[0] for value in flat_aux]
  950. aux = tree_unflatten(flat_aux, aux_spec)
  951. jac_outs, spec = tree_flatten(results)
  952. # Most probably below output check can never raise an error
  953. # as jvp should test the output before
  954. # assert_non_empty_output(jac_outs, 'jacfwd(f, ...)(*args)')
  955. jac_outs_ins = tuple(
  956. tuple(
  957. safe_unflatten(jac_out_in, -1, primal.shape)
  958. for primal, jac_out_in in
  959. zip(flat_primals, jac_out.movedim(0, -1).split(flat_primals_numels, dim=-1))
  960. )
  961. for jac_out in jac_outs
  962. )
  963. jac_outs_ins = tuple(tree_unflatten(jac_ins, primals_spec) for jac_ins in jac_outs_ins)
  964. if isinstance(argnums, int):
  965. jac_outs_ins = tuple(jac_ins[0] for jac_ins in jac_outs_ins)
  966. if has_aux:
  967. return tree_unflatten(jac_outs_ins, spec), aux
  968. return tree_unflatten(jac_outs_ins, spec)
  969. return wrapper_fn
  970. @exposed_in("torch.func")
  971. def hessian(func, argnums=0):
  972. """
  973. Computes the Hessian of ``func`` with respect to the arg(s) at index
  974. ``argnum`` via a forward-over-reverse strategy.
  975. The forward-over-reverse strategy (composing ``jacfwd(jacrev(func))``) is
  976. a good default for good performance. It is possible to compute Hessians
  977. through other compositions of :func:`jacfwd` and :func:`jacrev` like
  978. ``jacfwd(jacfwd(func))`` or ``jacrev(jacrev(func))``.
  979. Args:
  980. func (function): A Python function that takes one or more arguments,
  981. one of which must be a Tensor, and returns one or more Tensors
  982. argnums (int or Tuple[int]): Optional, integer or tuple of integers,
  983. saying which arguments to get the Hessian with respect to.
  984. Default: 0.
  985. Returns:
  986. Returns a function that takes in the same inputs as ``func`` and
  987. returns the Hessian of ``func`` with respect to the arg(s) at
  988. ``argnums``.
  989. .. note::
  990. You may see this API error out with "forward-mode AD not implemented
  991. for operator X". If so, please file a bug report and we will prioritize it.
  992. An alternative is to use ``jacrev(jacrev(func))``, which has better
  993. operator coverage.
  994. A basic usage with a R^N -> R^1 function gives a N x N Hessian:
  995. >>> from torch.func import hessian
  996. >>> def f(x):
  997. >>> return x.sin().sum()
  998. >>>
  999. >>> x = torch.randn(5)
  1000. >>> hess = hessian(f)(x) # equivalent to jacfwd(jacrev(f))(x)
  1001. >>> assert torch.allclose(hess, torch.diag(-x.sin()))
  1002. """
  1003. return jacfwd(jacrev(func, argnums), argnums)
  1004. @exposed_in("torch.func")
  1005. def grad_and_value(func: Callable, argnums: argnums_t = 0, has_aux: bool = False) -> Callable:
  1006. """
  1007. Returns a function to compute a tuple of the gradient and primal, or
  1008. forward, computation.
  1009. Args:
  1010. func (Callable): A Python function that takes one or more arguments.
  1011. Must return a single-element Tensor. If specified ``has_aux``
  1012. equals ``True``, function can return a tuple of single-element
  1013. Tensor and other auxiliary objects: ``(output, aux)``.
  1014. argnums (int or Tuple[int]): Specifies arguments to compute gradients
  1015. with respect to. ``argnums`` can be single integer or tuple of
  1016. integers. Default: 0.
  1017. has_aux (bool): Flag indicating that ``func`` returns a tensor and
  1018. other auxiliary objects: ``(output, aux)``. Default: False.
  1019. Returns:
  1020. Function to compute a tuple of gradients with respect to its inputs
  1021. and the forward computation. By default, the output of the function is
  1022. a tuple of the gradient tensor(s) with respect to the first argument
  1023. and the primal computation. If specified ``has_aux`` equals
  1024. ``True``, tuple of gradients and tuple of the forward computation with
  1025. output auxiliary objects is returned. If ``argnums`` is a tuple of
  1026. integers, a tuple of a tuple of the output gradients with respect to
  1027. each ``argnums`` value and the forward computation is returned.
  1028. See :func:`grad` for examples
  1029. """
  1030. @doesnt_support_saved_tensors_hooks
  1031. @wraps(func)
  1032. def wrapper(*args, **kwargs):
  1033. level = _grad_increment_nesting()
  1034. try:
  1035. output, aux, grad_input = None, None, None
  1036. # See NOTE [grad and vjp interaction with no_grad]
  1037. with torch.enable_grad():
  1038. args = _wrap_all_tensors(args, level)
  1039. kwargs = _wrap_all_tensors(kwargs, level)
  1040. diff_args = _slice_argnums(args, argnums, as_tuple=False)
  1041. tree_map_(partial(_create_differentiable, level=level), diff_args)
  1042. output = func(*args, **kwargs)
  1043. if has_aux:
  1044. if not (isinstance(output, tuple) and len(output) == 2):
  1045. raise RuntimeError(
  1046. "grad_and_value(f)(*args): output of function f should be a tuple: (output, aux) "
  1047. "if has_aux is True"
  1048. )
  1049. output, aux = output
  1050. if not isinstance(output, torch.Tensor):
  1051. raise RuntimeError('grad_and_value(f)(*args): Expected f(*args) '
  1052. f'to return a Tensor, got {type(output)}')
  1053. if output.dim() != 0:
  1054. raise RuntimeError('grad_and_value(f)(*args): Expected f(*args) '
  1055. 'to return a scalar Tensor, got tensor with '
  1056. f'{output.dim()} dims. Maybe you wanted to '
  1057. 'use the vjp or jacrev APIs instead?')
  1058. flat_diff_args, spec = tree_flatten(diff_args)
  1059. # NB: need create_graph so that backward pass isn't run in no_grad mode
  1060. flat_outputs = _as_tuple(output)
  1061. flat_grad_input = _autograd_grad(flat_outputs, flat_diff_args, create_graph=True)
  1062. grad_input = tree_unflatten(flat_grad_input, spec)
  1063. grad_input = _undo_create_differentiable(grad_input, level)
  1064. output = _undo_create_differentiable(output, level)
  1065. if aux is not None:
  1066. aux = _undo_create_differentiable(aux, level)
  1067. if has_aux:
  1068. return grad_input, (output, aux)
  1069. return grad_input, output
  1070. finally:
  1071. _grad_decrement_nesting()
  1072. return wrapper
  1073. @exposed_in("torch.func")
  1074. def grad(func: Callable, argnums: argnums_t = 0, has_aux: bool = False) -> Callable:
  1075. """``grad`` operator helps computing gradients of ``func`` with respect to the
  1076. input(s) specified by ``argnums``. This operator can be nested to
  1077. compute higher-order gradients.
  1078. Args:
  1079. func (Callable): A Python function that takes one or more arguments.
  1080. Must return a single-element Tensor. If specified ``has_aux`` equals ``True``,
  1081. function can return a tuple of single-element Tensor and other auxiliary objects:
  1082. ``(output, aux)``.
  1083. argnums (int or Tuple[int]): Specifies arguments to compute gradients with respect to.
  1084. ``argnums`` can be single integer or tuple of integers. Default: 0.
  1085. has_aux (bool): Flag indicating that ``func`` returns a tensor and other
  1086. auxiliary objects: ``(output, aux)``. Default: False.
  1087. Returns:
  1088. Function to compute gradients with respect to its inputs. By default, the output of
  1089. the function is the gradient tensor(s) with respect to the first argument.
  1090. If specified ``has_aux`` equals ``True``, tuple of gradients and output auxiliary objects
  1091. is returned. If ``argnums`` is a tuple of integers, a tuple of output gradients with
  1092. respect to each ``argnums`` value is returned.
  1093. Example of using ``grad``:
  1094. >>> # xdoctest: +SKIP
  1095. >>> from torch.func import grad
  1096. >>> x = torch.randn([])
  1097. >>> cos_x = grad(lambda x: torch.sin(x))(x)
  1098. >>> assert torch.allclose(cos_x, x.cos())
  1099. >>>
  1100. >>> # Second-order gradients
  1101. >>> neg_sin_x = grad(grad(lambda x: torch.sin(x)))(x)
  1102. >>> assert torch.allclose(neg_sin_x, -x.sin())
  1103. When composed with ``vmap``, ``grad`` can be used to compute per-sample-gradients:
  1104. >>> # xdoctest: +SKIP
  1105. >>> from torch.func import grad, vmap
  1106. >>> batch_size, feature_size = 3, 5
  1107. >>>
  1108. >>> def model(weights, feature_vec):
  1109. >>> # Very simple linear model with activation
  1110. >>> assert feature_vec.dim() == 1
  1111. >>> return feature_vec.dot(weights).relu()
  1112. >>>
  1113. >>> def compute_loss(weights, example, target):
  1114. >>> y = model(weights, example)
  1115. >>> return ((y - target) ** 2).mean() # MSELoss
  1116. >>>
  1117. >>> weights = torch.randn(feature_size, requires_grad=True)
  1118. >>> examples = torch.randn(batch_size, feature_size)
  1119. >>> targets = torch.randn(batch_size)
  1120. >>> inputs = (weights, examples, targets)
  1121. >>> grad_weight_per_example = vmap(grad(compute_loss), in_dims=(None, 0, 0))(*inputs)
  1122. Example of using ``grad`` with ``has_aux`` and ``argnums``:
  1123. >>> # xdoctest: +SKIP
  1124. >>> from torch.func import grad
  1125. >>> def my_loss_func(y, y_pred):
  1126. >>> loss_per_sample = (0.5 * y_pred - y) ** 2
  1127. >>> loss = loss_per_sample.mean()
  1128. >>> return loss, (y_pred, loss_per_sample)
  1129. >>>
  1130. >>> fn = grad(my_loss_func, argnums=(0, 1), has_aux=True)
  1131. >>> y_true = torch.rand(4)
  1132. >>> y_preds = torch.rand(4, requires_grad=True)
  1133. >>> out = fn(y_true, y_preds)
  1134. >>> # > output is ((grads w.r.t y_true, grads w.r.t y_preds), (y_pred, loss_per_sample))
  1135. .. note::
  1136. Using PyTorch ``torch.no_grad`` together with ``grad``.
  1137. Case 1: Using ``torch.no_grad`` inside a function:
  1138. >>> # xdoctest: +SKIP
  1139. >>> def f(x):
  1140. >>> with torch.no_grad():
  1141. >>> c = x ** 2
  1142. >>> return x - c
  1143. In this case, ``grad(f)(x)`` will respect the inner ``torch.no_grad``.
  1144. Case 2: Using ``grad`` inside ``torch.no_grad`` context manager:
  1145. >>> # xdoctest: +SKIP
  1146. >>> with torch.no_grad():
  1147. >>> grad(f)(x)
  1148. In this case, ``grad`` will respect the inner ``torch.no_grad``, but not the
  1149. outer one. This is because ``grad`` is a "function transform": its result
  1150. should not depend on the result of a context manager outside of ``f``.
  1151. """
  1152. @wraps(func)
  1153. def wrapper(*args, **kwargs):
  1154. results = grad_and_value(func, argnums, has_aux=has_aux)(*args, **kwargs)
  1155. if has_aux:
  1156. grad, (_, aux) = results
  1157. return grad, aux
  1158. grad, _ = results
  1159. return grad
  1160. return wrapper
  1161. def _maybe_wrap_functional_tensor(maybe_tensor, level):
  1162. if not isinstance(maybe_tensor, torch.Tensor):
  1163. return maybe_tensor
  1164. wrapped = _wrap_functional_tensor(maybe_tensor, level)
  1165. _assert_wrapped_functional(maybe_tensor, wrapped)
  1166. return wrapped
  1167. def _wrap_all_tensors_to_functional(tensor_pytree, level):
  1168. return tree_map(partial(_maybe_wrap_functional_tensor, level=level), tensor_pytree)
  1169. def _maybe_unwrap_functional_tensor(maybe_tensor, *, reapply_views: bool):
  1170. if not isinstance(maybe_tensor, torch.Tensor):
  1171. return maybe_tensor
  1172. if not torch._is_functional_tensor(maybe_tensor):
  1173. # If it's not a functional tensor, just return it.
  1174. # This can happen if we functionalize a fn that returns a global,
  1175. # which was never wrapped properly.
  1176. return maybe_tensor
  1177. # Sync any pending updates on the output tensor
  1178. torch._sync(maybe_tensor)
  1179. return _unwrap_functional_tensor(maybe_tensor, reapply_views)
  1180. def _unwrap_all_tensors_from_functional(tensor_pytree, *, reapply_views: bool):
  1181. return tree_map(lambda t: _maybe_unwrap_functional_tensor(t, reapply_views=reapply_views), tensor_pytree)
  1182. @exposed_in("torch.func")
  1183. def functionalize(func: Callable, *, remove: str = 'mutations') -> Callable:
  1184. """
  1185. functionalize is a transform that can be used to remove (intermediate)
  1186. mutations and aliasing from a function, while preserving the function's
  1187. semantics.
  1188. ``functionalize(func)`` returns a new function with the same semantics
  1189. as ``func``, but with all intermediate mutations removed.
  1190. Every inplace operation performed on an intermediate tensor:
  1191. ``intermediate.foo_()``
  1192. gets replaced by its out-of-place equivalent:
  1193. ``intermediate_updated = intermediate.foo()``.
  1194. functionalize is useful for shipping a pytorch program off to
  1195. backends or compilers that aren't able to easily represent
  1196. mutations or aliasing operators.
  1197. Args:
  1198. func (Callable): A Python function that takes one or more arguments.
  1199. remove (str): An optional string argument, that takes on either
  1200. the value 'mutations' or 'mutations_and_views'.
  1201. If 'mutations' is passed in then all mutating operators
  1202. will be replaced with their non-mutating equivalents.
  1203. If 'mutations_and_views' is passed in, then additionally, all aliasing
  1204. operators will be replaced with their non-aliasing equivalents.
  1205. Default: 'mutations'.
  1206. Returns:
  1207. Returns a new "functionalized" function. It takes the same inputs as
  1208. ``func``, and has the same behavior, but any mutations
  1209. (and optionally aliasing) performed on intermeidate tensors
  1210. in the function will be removed.
  1211. functionalize will also remove mutations (and views) that were performed on function inputs.
  1212. However to preserve semantics, functionalize will "fix up" the mutations after
  1213. the transform has finished running, by detecting if any tensor inputs "should have"
  1214. been mutated, and copying the new data back to the inputs if necessary.
  1215. Example::
  1216. >>> # xdoctest: +SKIP
  1217. >>> import torch
  1218. >>> from torch.fx.experimental.proxy_tensor import make_fx
  1219. >>> from torch.func import functionalize
  1220. >>>
  1221. >>> # A function that uses mutations and views, but only on intermediate tensors.
  1222. >>> def f(a):
  1223. ... b = a + 1
  1224. ... c = b.view(-1)
  1225. ... c.add_(1)
  1226. ... return b
  1227. ...
  1228. >>> inpt = torch.randn(2)
  1229. >>>
  1230. >>> out1 = f(inpt)
  1231. >>> out2 = functionalize(f)(inpt)
  1232. >>>
  1233. >>> # semantics are the same (outputs are equivalent)
  1234. >>> print(torch.allclose(out1, out2))
  1235. True
  1236. >>>
  1237. >>> f_traced = make_fx(f)(inpt)
  1238. >>> f_no_mutations_traced = make_fx(functionalize(f))(inpt)
  1239. >>> f_no_mutations_and_views_traced = make_fx(functionalize(f, remove='mutations_and_views'))(inpt)
  1240. >>>
  1241. >>> print(f_traced.code)
  1242. def forward(self, a_1):
  1243. add = torch.ops.aten.add(a_1, 1); a_1 = None
  1244. view = torch.ops.aten.view(add, [-1])
  1245. add_ = torch.ops.aten.add_(view, 1); view = None
  1246. return add
  1247. >>> print(f_no_mutations_traced.code)
  1248. def forward(self, a_1):
  1249. add = torch.ops.aten.add(a_1, 1); a_1 = None
  1250. view = torch.ops.aten.view(add, [-1]); add = None
  1251. add_1 = torch.ops.aten.add(view, 1); view = None
  1252. view_1 = torch.ops.aten.view(add_1, [2]); add_1 = None
  1253. return view_1
  1254. >>> print(f_no_mutations_and_views_traced.code)
  1255. def forward(self, a_1):
  1256. add = torch.ops.aten.add(a_1, 1); a_1 = None
  1257. view_copy = torch.ops.aten.view_copy(add, [-1]); add = None
  1258. add_1 = torch.ops.aten.add(view_copy, 1); view_copy = None
  1259. view_copy_1 = torch.ops.aten.view_copy(add_1, [2]); add_1 = None
  1260. return view_copy_1
  1261. >>> # A function that mutates its input tensor
  1262. >>> def f(a):
  1263. ... b = a.view(-1)
  1264. ... b.add_(1)
  1265. ... return a
  1266. ...
  1267. >>> f_no_mutations_and_views_traced = make_fx(functionalize(f, remove='mutations_and_views'))(inpt)
  1268. >>> #
  1269. >>> # All mutations and views have been removed,
  1270. >>> # but there is an extra copy_ in the graph to correctly apply the mutation to the input
  1271. >>> # after the function has completed.
  1272. >>> print(f_no_mutations_and_views_traced.code)
  1273. def forward(self, a_1):
  1274. view_copy = torch.ops.aten.view_copy(a_1, [-1])
  1275. add = torch.ops.aten.add(view_copy, 1); view_copy = None
  1276. view_copy_1 = torch.ops.aten.view_copy(add, [2]); add = None
  1277. copy_ = torch.ops.aten.copy_(a_1, view_copy_1); a_1 = None
  1278. return view_copy_1
  1279. There are a few "failure modes" for functionalize that are worth calling out:
  1280. (1) Like other torch.func transforms, `functionalize()` doesn't work with functions
  1281. that directly use `.backward()`. The same is true for torch.autograd.grad.
  1282. If you want to use autograd, you can compute gradients directly
  1283. with `functionalize(grad(f))`.
  1284. (2) Like other torch.func transforms, `functionalize()` doesn't work with global state.
  1285. If you call `functionalize(f)` on a function that takes views / mutations of
  1286. non-local state, functionalization will simply no-op and pass the view/mutation
  1287. calls directly to the backend.
  1288. One way to work around this is is to ensure that any non-local state creation
  1289. is wrapped into a larger function, which you then call functionalize on.
  1290. (3) `resize_()` has some limitations: functionalize will only work on programs
  1291. that use resize_()` as long as the tensor being resized is not a view.
  1292. (4) `as_strided()` has some limitations: functionalize will not work on
  1293. `as_strided()` calls that result in tensors with overlapping memory.
  1294. Finally, a helpful mental model for understanding functionalization is that
  1295. most user pytorch programs are writting with the public torch API.
  1296. When executed, torch operators are generally decomposed into
  1297. our internal C++ "ATen" API.
  1298. The logic for functionalization happens entirely at the level of ATen.
  1299. Functionalization knows how to take every aliasing operator in ATen,
  1300. and map it to its non-aliasing equivalent
  1301. (e.g. ``tensor.view({-1})`` -> ``at::view_copy(tensor, {-1})``),
  1302. and how to take every mutating operator in ATen,
  1303. and map it to its non-mutating equivalent
  1304. (e.g. ``tensor.add_(1)`` -> ``at::add(tensor, -1)``),
  1305. while tracking aliases and mutations out-of-line to know when to fix things up.
  1306. Information about which ATen operators are aliasing or mutating all comes from
  1307. https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml.
  1308. """
  1309. if remove == 'mutations':
  1310. reapply_views = True
  1311. elif remove == 'mutations_and_views':
  1312. reapply_views = False
  1313. else:
  1314. raise RuntimeError(
  1315. f"functionalize(f, remove='mutations'): received invalid argument for remove={remove}."
  1316. " Valid options are:\n"
  1317. " remove='mutations': all inplace and out= operators will be removed from the program, and replaced"
  1318. " with their out-of-place equivalents.\n"
  1319. " remove='mutations_and_views': In addition to the above, all aliasing operators {view} will be"
  1320. " replaced with their non-aliasing counterparts, {view}_copy.\n"
  1321. )
  1322. @doesnt_support_saved_tensors_hooks
  1323. @wraps(func)
  1324. def wrapped(*args, **kwargs):
  1325. try:
  1326. func_level = _func_increment_nesting(reapply_views)
  1327. func_args = _wrap_all_tensors_to_functional(args, func_level)
  1328. func_kwargs = _wrap_all_tensors_to_functional(kwargs, func_level)
  1329. flattened_unwrapped_args, _ = tree_flatten(args)
  1330. flattened_wrapped_args, _ = tree_flatten(func_args)
  1331. flattened_unwrapped_kwargs, _ = tree_flatten(kwargs)
  1332. flattened_wrapped_kwargs, _ = tree_flatten(func_kwargs)
  1333. func_outputs = func(*func_args, **func_kwargs)
  1334. outputs = _unwrap_all_tensors_from_functional(func_outputs, reapply_views=reapply_views)
  1335. flat_outputs, func_out_spec = tree_flatten(outputs)
  1336. for a in flattened_wrapped_args + flattened_wrapped_kwargs:
  1337. if isinstance(a, torch.Tensor):
  1338. # Call sync_() on the inputs, to ensure that any pending mutations have been applied.
  1339. torch._sync(a)
  1340. # And if any mutations were applied to the inputs, we need to propagate them back to the user.
  1341. for unwrapped, wrapped in zip(flattened_unwrapped_args, flattened_wrapped_args):
  1342. if isinstance(unwrapped, torch.Tensor) and isinstance(wrapped, torch.Tensor):
  1343. _propagate_functional_input_mutation(unwrapped, wrapped)
  1344. for unwrapped, wrapped in zip(flattened_unwrapped_kwargs, flattened_wrapped_kwargs):
  1345. if isinstance(unwrapped, torch.Tensor) and isinstance(wrapped, torch.Tensor):
  1346. _propagate_functional_input_mutation(unwrapped, wrapped)
  1347. return outputs
  1348. finally:
  1349. _func_decrement_nesting()
  1350. return wrapped
  1351. @exposed_in("torch.func")
  1352. def linearize(func: Callable, *primals) -> Tuple[Any, Callable]:
  1353. '''
  1354. Returns the value of ``func`` at ``primals`` and linear approximation
  1355. at ``primals``.
  1356. Args:
  1357. func (Callable): A Python function that takes one or more arguments.
  1358. primals (Tensors): Positional arguments to ``func`` that must all be
  1359. Tensors. These are the values at which the function is linearly approximated.
  1360. Returns:
  1361. Returns a ``(output, jvp_fn)`` tuple containing the output of ``func``
  1362. applied to ``primals`` and a function that computes the jvp of
  1363. ``func`` evaluated at ``primals``.
  1364. linearize is useful if jvp is to be computed multiple times at ``primals``. However,
  1365. to achieve this, linearize saves intermediate computation and has higher memory requrements
  1366. than directly applying `jvp`. So, if all the ``tangents`` are known, it maybe more efficient
  1367. to compute vmap(jvp) instead of using linearize.
  1368. .. note::
  1369. linearize evaluates ``func`` twice. Please file an issue for an implementation
  1370. with a single evaluation.
  1371. Example::
  1372. >>> import torch
  1373. >>> from torch.func import linearize
  1374. >>> def fn(x):
  1375. ... return x.sin()
  1376. ...
  1377. >>> output, jvp_fn = linearize(fn, torch.zeros(3, 3))
  1378. >>> jvp_fn(torch.ones(3, 3))
  1379. tensor([[1., 1., 1.],
  1380. [1., 1., 1.],
  1381. [1., 1., 1.]])
  1382. >>>
  1383. '''
  1384. # Note: We evaluate `fn` twice.
  1385. # Once for returning the output and other while
  1386. # tracing the graph.
  1387. # If this becomes a bottle-neck, we should update
  1388. # make_fx such that it also returns the output.
  1389. output = func(*primals)
  1390. _, output_spec = tree_flatten(output)
  1391. flat_primals, primals_argspec = tree_flatten(primals)
  1392. # tangents for tracing
  1393. flat_tangents = tuple(p.new_empty(()).expand_as(p) for p in flat_primals)
  1394. # function to trace
  1395. def trace_fn(flat_tangents):
  1396. with fwAD.dual_level():
  1397. flat_duals = tuple(fwAD.make_dual(p, t) for p, t in zip(flat_primals, flat_tangents))
  1398. duals = tree_unflatten(flat_duals, primals_argspec)
  1399. output = func(*duals)
  1400. tangents = tree_map_only(torch.Tensor, lambda t: fwAD.unpack_dual(t)[1], output)
  1401. return tangents
  1402. jvp_graph = make_fx(trace_fn)(flat_tangents)
  1403. const_folded_jvp_graph = const_fold.split_const_subgraphs(jvp_graph)
  1404. # Hold only the meta-data regarding the primals.
  1405. flat_primals_shape = tuple(p.shape for p in flat_primals)
  1406. flat_primals_device = tuple(p.device for p in flat_primals)
  1407. flat_primals_dtype = tuple(p.dtype for p in flat_primals)
  1408. def forward_ad_checks(flat_tangents):
  1409. for idx, t in enumerate(flat_tangents):
  1410. if t.shape != flat_primals_shape[idx]:
  1411. msg = (f"tangent:{idx} with shape {t.shape} in flattened "
  1412. f"pytree doesn't match the shape {flat_primals_shape[idx]} "
  1413. "of the corresponding primal.")
  1414. raise RuntimeError(msg)
  1415. if t.device != flat_primals_device[idx]:
  1416. msg = (f"tangent:{idx} with device {t.device} in flattened "
  1417. f"pytree doesn't match the device {flat_primals_device[idx]} "
  1418. "of the corresponding primal.")
  1419. raise RuntimeError(msg)
  1420. if t.dtype != flat_primals_dtype[idx]:
  1421. msg = (f"tangent:{idx} with dtype {t.dtype} in flattened "
  1422. f"pytree doesn't match the dtype {flat_primals_dtype[idx]} "
  1423. "of the corresponding primal.")
  1424. raise RuntimeError(msg)
  1425. # jvp_fn : callable to return
  1426. # It takes care of checking the argspec of tangents,
  1427. # calling the folded fx graph and unflattening fx graph output
  1428. def jvp_fn(*tangents):
  1429. flat_tangents, tangent_argspec = tree_flatten(tangents)
  1430. if tangent_argspec != primals_argspec:
  1431. raise RuntimeError(f"Expected the tangents {tangent_argspec} to have "
  1432. f"the same argspec as the primals {primals_argspec}")
  1433. forward_ad_checks(flat_tangents)
  1434. flat_output = const_folded_jvp_graph(*flat_tangents)
  1435. # const folded graph can return flat output,
  1436. # so transform output.
  1437. return tree_unflatten(flat_output, output_spec)
  1438. return output, jvp_fn