1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729 |
- # Copyright (c) Facebook, Inc. and its affiliates.
- # All rights reserved.
- #
- # This source code is licensed under the BSD-style license found in the
- # LICENSE file in the root directory of this source tree.
- from typing import Callable, Union, Tuple, List, Any, Optional
- import torch
- from functools import partial, wraps
- import contextlib
- from torch.utils._pytree import tree_flatten, tree_unflatten, tree_map, tree_map_only
- from torch.fx.experimental import const_fold
- from torch.fx.experimental.proxy_tensor import make_fx
- from .pytree_hacks import tree_map_, treespec_pprint
- import torch.autograd.forward_ad as fwAD
- from .vmap import vmap, doesnt_support_saved_tensors_hooks, get_chunk_sizes
- from torch._C._functorch import (
- _wrap_for_grad,
- _unwrap_for_grad,
- _grad_increment_nesting,
- _grad_decrement_nesting,
- _jvp_increment_nesting,
- _jvp_decrement_nesting,
- _wrap_functional_tensor,
- _unwrap_functional_tensor,
- _func_decrement_nesting,
- _func_increment_nesting,
- _assert_wrapped_functional,
- _propagate_functional_input_mutation,
- set_inplace_requires_grad_allowed,
- get_inplace_requires_grad_allowed
- )
- from torch._functorch.utils import exposed_in
- argnums_t = Union[int, Tuple[int, ...]]
- @contextlib.contextmanager
- def enable_inplace_requires_grad(enabled=True):
- prev_state = get_inplace_requires_grad_allowed()
- set_inplace_requires_grad_allowed(enabled)
- try:
- yield
- finally:
- set_inplace_requires_grad_allowed(prev_state)
- def _create_differentiable(inps, level=None):
- def create_differentiable(x):
- if isinstance(x, torch.Tensor):
- with enable_inplace_requires_grad():
- return x.requires_grad_()
- raise ValueError(f'Thing passed to transform API must be Tensor, '
- f'got {type(x)}')
- return tree_map(create_differentiable, inps)
- def _undo_create_differentiable(inps, level=None):
- def unwrap_tensors(x):
- if isinstance(x, torch.Tensor):
- return _unwrap_for_grad(x, level)
- # TODO: Remove the following hack for namedtuples
- if isinstance(x, tuple):
- return tree_map(unwrap_tensors, tuple(x))
- raise RuntimeError(f"Expected tensors, got unsupported type {type(x)}")
- return tree_map(unwrap_tensors, inps)
- def _is_differentiable(maybe_tensor):
- if not isinstance(maybe_tensor, torch.Tensor):
- return False
- return maybe_tensor.requires_grad
- def _any_differentiable(tensor_or_tuple_of_tensors):
- flat_args, _ = tree_unflatten(tensor_or_tuple_of_tensors)
- return any(tuple(map(_is_differentiable, flat_args)))
- def _wrap_tensor_for_grad(maybe_tensor, level):
- if not isinstance(maybe_tensor, torch.Tensor):
- return maybe_tensor
- return _wrap_for_grad(maybe_tensor, level)
- def _wrap_all_tensors(tensor_pytree, level):
- return tree_map(partial(_wrap_tensor_for_grad, level=level), tensor_pytree)
- def _as_tuple(val):
- if isinstance(val, tuple):
- return val
- return (val,)
- # Version of autograd.grad that handles outputs that don't depend on inputs
- def _autograd_grad(outputs, inputs, grad_outputs=None, retain_graph=False, create_graph=True):
- if grad_outputs is None:
- diff_outputs = tuple(out for out in outputs if out.requires_grad)
- else:
- result = tuple((out, go) for out, go in zip(outputs, grad_outputs) if out.requires_grad)
- if len(result) == 0:
- diff_outputs, grad_outputs = (), ()
- else:
- diff_outputs, grad_outputs = zip(*result)
- if len(diff_outputs) == 0:
- return tuple(torch.zeros_like(inp) for inp in inputs)
- grad_inputs = torch.autograd.grad(diff_outputs, inputs, grad_outputs,
- retain_graph=retain_graph,
- create_graph=create_graph,
- allow_unused=True)
- grad_inputs = tuple(torch.zeros_like(inp) if gi is None else gi
- for gi, inp in zip(grad_inputs, inputs))
- return grad_inputs
- # NOTE [grad and vjp interaction with no_grad]
- #
- # def f(x):
- # with torch.no_grad():
- # c = x ** 2
- # return x - c
- #
- # The thing to consider is if enable_grad is on/off before grad gets called.
- #
- # Case 1: enable_grad is on.
- # grad(f)(x)
- # In this case, `grad` should respect the inner torch.no_grad.
- #
- # Case 2: enable_grad is off
- # with torch.no_grad():
- # grad(f)(x)
- # In this case, `grad` should respect the inner torch.no_grad, but not the
- # outer one. This is because `grad` is a "function transform": its result
- # should not depend on the result of a context manager outside of `f`.
- #
- # This gives us the following desired behavior:
- # - (nested) grad transforms must obey torch.no_grad inside them
- # - (nested) grad transforms should not obey torch.no_grad outside them
- #
- # To achieve this behavior, upon entering grad/vjp:
- # - we save the current ("previous") is_grad_enabled (*)
- # - we unconditionally enable grad.
- #
- # Inside DynamicLayerBackFallback, when we're temporarily popping `grad` layer
- # off the stack:
- # - if grad_mode is disabled, then we do nothing. (there is a torch.no_grad
- # active, all subsequent grad transforms must obey it).
- # - if grad_mode is enabled, and the previous is_grad_enabled (*) is False,
- # then we temporarily restore the previous `is_grad_enabled`. This is
- # because we're crossing the boundary from a `grad` outside the
- # no_grad to a `grad` inside the no_grad.
- #
- # NB: vjp has some interesting behavior because the vjp's callable can be called
- # under a different grad_mode than the forward computation...
- #
- # NB: forward-mode AD: forward-mode AD doesn't respect torch.no_grad, but
- # it respects c10::AutoFwGradMode. We've implemented the same logic for
- # our jvp transform (it will have special handling if FwGradMode is disabled).
- # How do we increment and decrement the nesting? I don't think we can.
- @exposed_in("torch.func")
- def vjp(func: Callable, *primals, has_aux: bool = False):
- """
- Standing for the vector-Jacobian product, returns a tuple containing the
- results of ``func`` applied to ``primals`` and a function that, when
- given ``cotangents``, computes the reverse-mode Jacobian of ``func`` with
- respect to ``primals`` times ``cotangents``.
- Args:
- func (Callable): A Python function that takes one or more arguments. Must
- return one or more Tensors.
- primals (Tensors): Positional arguments to ``func`` that must all be
- Tensors. The returned function will also be computing the
- derivative with respect to these arguments
- has_aux (bool): Flag indicating that ``func`` returns a
- ``(output, aux)`` tuple where the first element is the output of
- the function to be differentiated and the second element is
- other auxiliary objects that will not be differentiated.
- Default: False.
- Returns:
- Returns a ``(output, vjp_fn)`` tuple containing the output of ``func``
- applied to ``primals`` and a function that computes the vjp of
- ``func`` with respect to all ``primals`` using the cotangents passed
- to the returned function. If ``has_aux is True``, then instead returns a
- ``(output, vjp_fn, aux)`` tuple.
- The returned ``vjp_fn`` function will return a tuple of each VJP.
- When used in simple cases, :func:`vjp` behaves the same as :func:`grad`
- >>> x = torch.randn([5])
- >>> f = lambda x: x.sin().sum()
- >>> (_, vjpfunc) = torch.func.vjp(f, x)
- >>> grad = vjpfunc(torch.tensor(1.))[0]
- >>> assert torch.allclose(grad, torch.func.grad(f)(x))
- However, :func:`vjp` can support functions with multiple outputs by
- passing in the cotangents for each of the outputs
- >>> x = torch.randn([5])
- >>> f = lambda x: (x.sin(), x.cos())
- >>> (_, vjpfunc) = torch.func.vjp(f, x)
- >>> vjps = vjpfunc((torch.ones([5]), torch.ones([5])))
- >>> assert torch.allclose(vjps[0], x.cos() + -x.sin())
- :func:`vjp` can even support outputs being Python structs
- >>> x = torch.randn([5])
- >>> f = lambda x: {'first': x.sin(), 'second': x.cos()}
- >>> (_, vjpfunc) = torch.func.vjp(f, x)
- >>> cotangents = {'first': torch.ones([5]), 'second': torch.ones([5])}
- >>> vjps = vjpfunc(cotangents)
- >>> assert torch.allclose(vjps[0], x.cos() + -x.sin())
- The function returned by :func:`vjp` will compute the partials with
- respect to each of the ``primals``
- >>> x, y = torch.randn([5, 4]), torch.randn([4, 5])
- >>> (_, vjpfunc) = torch.func.vjp(torch.matmul, x, y)
- >>> cotangents = torch.randn([5, 5])
- >>> vjps = vjpfunc(cotangents)
- >>> assert len(vjps) == 2
- >>> assert torch.allclose(vjps[0], torch.matmul(cotangents, y.transpose(0, 1)))
- >>> assert torch.allclose(vjps[1], torch.matmul(x.transpose(0, 1), cotangents))
- ``primals`` are the positional arguments for ``f``. All kwargs use their
- default value
- >>> x = torch.randn([5])
- >>> def f(x, scale=4.):
- >>> return x * scale
- >>>
- >>> (_, vjpfunc) = torch.func.vjp(f, x)
- >>> vjps = vjpfunc(torch.ones_like(x))
- >>> assert torch.allclose(vjps[0], torch.full(x.shape, 4.))
- .. note::
- Using PyTorch ``torch.no_grad`` together with ``vjp``.
- Case 1: Using ``torch.no_grad`` inside a function:
- >>> def f(x):
- >>> with torch.no_grad():
- >>> c = x ** 2
- >>> return x - c
- In this case, ``vjp(f)(x)`` will respect the inner ``torch.no_grad``.
- Case 2: Using ``vjp`` inside ``torch.no_grad`` context manager:
- >>> # xdoctest: +SKIP(failing)
- >>> with torch.no_grad():
- >>> vjp(f)(x)
- In this case, ``vjp`` will respect the inner ``torch.no_grad``, but not the
- outer one. This is because ``vjp`` is a "function transform": its result
- should not depend on the result of a context manager outside of ``f``.
- """
- return _vjp_with_argnums(func, *primals, has_aux=has_aux)
- @doesnt_support_saved_tensors_hooks
- def _vjp_with_argnums(func: Callable, *primals, argnums: Optional[argnums_t] = None, has_aux: bool = False):
- # This is the same function as vjp but also accepts an argnums argument
- # All args are the same as vjp except for the added argument
- # argnums (Optional[int or tuple[int]]): Optional, specifies the argument(s) to compute gradients with respect to.
- # If None, computes the gradients with respect to all inputs (used for vjp). Default: None
- #
- # WARN: Users should NOT call this function directly and should just be calling vjp.
- # It is only separated so that inputs passed to jacrev but not differentiated get the correct wrappers.
- #
- # NOTE: All error messages are produced as if vjp was being called, even if this was called by jacrev
- #
- # Returns the same two elements as :func:`vjp` but the function returned, vjp_fn, returns a tuple of VJPs
- # for only the primal elements given by argnums.
- level = _grad_increment_nesting()
- try:
- # See NOTE [grad and vjp interaction with no_grad]
- with torch.enable_grad():
- primals = _wrap_all_tensors(primals, level)
- if argnums is None:
- diff_primals = _create_differentiable(primals, level)
- else:
- diff_primals = _slice_argnums(primals, argnums, as_tuple=False)
- tree_map_(partial(_create_differentiable, level=level), diff_primals)
- primals_out = func(*primals)
- if has_aux:
- if not (isinstance(primals_out, tuple) and len(primals_out) == 2):
- raise RuntimeError(
- "vjp(f, *primals): output of function f should be a tuple: (output, aux) "
- "if has_aux is True"
- )
- primals_out, aux = primals_out
- aux = _undo_create_differentiable(aux, level)
- flat_primals_out, primals_out_spec = tree_flatten(primals_out)
- assert_non_empty_tensor_output(flat_primals_out, 'vjp(f, *primals)')
- flat_diff_primals, primals_spec = tree_flatten(diff_primals)
- results = _undo_create_differentiable(primals_out, level)
- for primal_out in flat_primals_out:
- assert isinstance(primal_out, torch.Tensor)
- if primal_out.is_floating_point() or primal_out.is_complex():
- continue
- raise RuntimeError("vjp(f, ...): All outputs of f must be "
- "floating-point or complex Tensors, got Tensor "
- f"with dtype {primal_out.dtype}")
- def wrapper(cotangents, retain_graph=True, create_graph=None):
- if create_graph is None:
- create_graph = torch.is_grad_enabled()
- flat_cotangents, cotangents_spec = tree_flatten(cotangents)
- if primals_out_spec != cotangents_spec:
- raise RuntimeError(
- f'Expected pytree structure of cotangents to be the same '
- f'as pytree structure of outputs to the function. '
- f'cotangents: {treespec_pprint(cotangents_spec)}, '
- f'primal output: {treespec_pprint(primals_out_spec)}')
- result = _autograd_grad(flat_primals_out, flat_diff_primals, flat_cotangents,
- retain_graph=retain_graph, create_graph=create_graph)
- return tree_unflatten(result, primals_spec)
- finally:
- _grad_decrement_nesting()
- if has_aux:
- return results, wrapper, aux
- else:
- return results, wrapper
- def _safe_zero_index(x):
- assert len(x) == 1
- return x[0]
- # jacrev and jacfwd don't support complex functions
- # Helper function to throw appropriate error.
- def error_if_complex(func_name, args, is_input):
- flat_args, _ = tree_flatten(args)
- for idx, arg in enumerate(flat_args):
- if arg.dtype.is_complex:
- input_or_output = ("inputs" if is_input else "outputs")
- err_msg = (f"{func_name}: Expected all {input_or_output} "
- f"to be real but received complex tensor at flattened input idx: {idx}")
- raise RuntimeError(err_msg)
- @exposed_in("torch.func")
- def jacrev(func: Callable, argnums: Union[int, Tuple[int]] = 0, *, has_aux=False,
- chunk_size: Optional[int] = None,
- _preallocate_and_copy=False):
- """
- Computes the Jacobian of ``func`` with respect to the arg(s) at index
- ``argnum`` using reverse mode autodiff
- .. note::
- Using :attr:`chunk_size=1` is equivalent to computing the jacobian
- row-by-row with a for-loop i.e. the constraints of :func:`vmap` are
- not applicable.
- Args:
- func (function): A Python function that takes one or more arguments,
- one of which must be a Tensor, and returns one or more Tensors
- argnums (int or Tuple[int]): Optional, integer or tuple of integers,
- saying which arguments to get the Jacobian with respect to.
- Default: 0.
- has_aux (bool): Flag indicating that ``func`` returns a
- ``(output, aux)`` tuple where the first element is the output of
- the function to be differentiated and the second element is
- auxiliary objects that will not be differentiated.
- Default: False.
- chunk_size (None or int): If None (default), use the maximum chunk size
- (equivalent to doing a single vmap over vjp to compute the jacobian).
- If 1, then compute the jacobian row-by-row with a for-loop.
- If not None, then compute the jacobian :attr:`chunk_size` rows at a time
- (equivalent to doing multiple vmap over vjp). If you run into memory issues computing
- the jacobian, please try to specify a non-None chunk_size.
- Returns:
- Returns a function that takes in the same inputs as ``func`` and
- returns the Jacobian of ``func`` with respect to the arg(s) at
- ``argnums``. If ``has_aux is True``, then the returned function
- instead returns a ``(jacobian, aux)`` tuple where ``jacobian``
- is the Jacobian and ``aux`` is auxiliary objects returned by ``func``.
- A basic usage with a pointwise, unary operation will give a diagonal array
- as the Jacobian
- >>> from torch.func import jacrev
- >>> x = torch.randn(5)
- >>> jacobian = jacrev(torch.sin)(x)
- >>> expected = torch.diag(torch.cos(x))
- >>> assert torch.allclose(jacobian, expected)
- If you would like to compute the output of the function as well as the
- jacobian of the function, use the ``has_aux`` flag to return the output
- as an auxiliary object:
- >>> from torch.func import jacrev
- >>> x = torch.randn(5)
- >>>
- >>> def f(x):
- >>> return x.sin()
- >>>
- >>> def g(x):
- >>> result = f(x)
- >>> return result, result
- >>>
- >>> jacobian_f, f_x = jacrev(g, has_aux=True)(x)
- >>> assert torch.allclose(f_x, f(x))
- :func:`jacrev` can be composed with vmap to produce batched
- Jacobians:
- >>> from torch.func import jacrev, vmap
- >>> x = torch.randn(64, 5)
- >>> jacobian = vmap(jacrev(torch.sin))(x)
- >>> assert jacobian.shape == (64, 5, 5)
- Additionally, :func:`jacrev` can be composed with itself to produce
- Hessians
- >>> from torch.func import jacrev
- >>> def f(x):
- >>> return x.sin().sum()
- >>>
- >>> x = torch.randn(5)
- >>> hessian = jacrev(jacrev(f))(x)
- >>> assert torch.allclose(hessian, torch.diag(-x.sin()))
- By default, :func:`jacrev` computes the Jacobian with respect to the first
- input. However, it can compute the Jacboian with respect to a different
- argument by using ``argnums``:
- >>> from torch.func import jacrev
- >>> def f(x, y):
- >>> return x + y ** 2
- >>>
- >>> x, y = torch.randn(5), torch.randn(5)
- >>> jacobian = jacrev(f, argnums=1)(x, y)
- >>> expected = torch.diag(2 * y)
- >>> assert torch.allclose(jacobian, expected)
- Additionally, passing a tuple to ``argnums`` will compute the Jacobian
- with respect to multiple arguments
- >>> from torch.func import jacrev
- >>> def f(x, y):
- >>> return x + y ** 2
- >>>
- >>> x, y = torch.randn(5), torch.randn(5)
- >>> jacobian = jacrev(f, argnums=(0, 1))(x, y)
- >>> expectedX = torch.diag(torch.ones_like(x))
- >>> expectedY = torch.diag(2 * y)
- >>> assert torch.allclose(jacobian[0], expectedX)
- >>> assert torch.allclose(jacobian[1], expectedY)
- .. note::
- Using PyTorch ``torch.no_grad`` together with ``jacrev``.
- Case 1: Using ``torch.no_grad`` inside a function:
- >>> def f(x):
- >>> with torch.no_grad():
- >>> c = x ** 2
- >>> return x - c
- In this case, ``jacrev(f)(x)`` will respect the inner ``torch.no_grad``.
- Case 2: Using ``jacrev`` inside ``torch.no_grad`` context manager:
- >>> with torch.no_grad():
- >>> jacrev(f)(x)
- In this case, ``jacrev`` will respect the inner ``torch.no_grad``, but not the
- outer one. This is because ``jacrev`` is a "function transform": its result
- should not depend on the result of a context manager outside of ``f``.
- """
- if not (chunk_size is None or chunk_size > 0):
- raise ValueError("jacrev: `chunk_size` should be greater than 0.")
- @wraps(func)
- def wrapper_fn(*args):
- error_if_complex("jacrev", args, is_input=True)
- vjp_out = _vjp_with_argnums(func, *args, argnums=argnums, has_aux=has_aux)
- if has_aux:
- output, vjp_fn, aux = vjp_out
- else:
- output, vjp_fn = vjp_out
- # See NOTE: [Computing jacobian with vmap and vjp for multiple outputs]
- flat_output, output_spec = tree_flatten(output)
- error_if_complex("jacrev", flat_output, is_input=False)
- # NB: vjp already checks that all outputs are tensors
- # Step 1: Construct grad_outputs by splitting the standard basis
- flat_output_numels = tuple(out.numel() for out in flat_output)
- primals = _slice_argnums(args, argnums)
- flat_primals, primals_spec = tree_flatten(primals)
- def compute_jacobian_stacked():
- # Helper function to compute chunked Jacobian
- # The intermediate chunked calculation are only
- # scoped at this function level.
- chunked_results = []
- for flat_basis_chunk in _chunked_standard_basis_for_(flat_output,
- flat_output_numels,
- chunk_size=chunk_size):
- if chunk_size == 1:
- # sanity check.
- for t in flat_basis_chunk:
- assert t.size(0) == 1
- flat_basis_chunk = tree_map(lambda t: torch.squeeze(t, 0), flat_basis_chunk)
- basis = tree_unflatten(flat_basis_chunk, output_spec)
- if chunk_size == 1:
- # Behaviour with `chunk_size=1` is same as `for-loop`
- # i.e. user shouldn't deal with the limitations of vmap.
- chunked_result = vjp_fn(basis)
- else: # chunk_size is None or chunk_size != 1
- chunked_result = vmap(vjp_fn)(basis)
- flat_results, _ = tree_flatten(chunked_result)
- if chunk_size == 1:
- flat_results = tree_map(lambda t: torch.unsqueeze(t, 0), flat_results)
- chunked_results.append(flat_results)
- if len(chunked_results) == 1:
- # Short-circuit if we used a single chunk
- return chunked_results[0]
- # Concatenate chunks.
- flat_results = []
- # Iterate and concat the jacobians of different
- # inputs.
- for idx in range(len(flat_primals)):
- r = tuple(map(lambda r_: r_[idx], chunked_results))
- flat_results.append(torch.cat(r, 0))
- return flat_results
- def compute_jacobian_preallocate_and_copy():
- # Helper function to compute chunked Jacobian
- # The intermediate chunked calculation are only
- # scoped at this function level.
- out_vec_size = sum(flat_output_numels)
- # Don't pre-allocate if we have a single chunk.
- if not (chunk_size is None or chunk_size >= out_vec_size):
- stacked_results = [primal.new_zeros(out_vec_size, *primal.shape) for primal in flat_primals]
- for idx, flat_basis_chunk in enumerate(_chunked_standard_basis_for_(flat_output,
- flat_output_numels,
- chunk_size=chunk_size)):
- if chunk_size == 1:
- # sanity check.
- for t in flat_basis_chunk:
- assert t.size(0) == 1
- flat_basis_chunk = list(map(lambda t: torch.squeeze(t, 0), flat_basis_chunk))
- basis = tree_unflatten(flat_basis_chunk, output_spec)
- if chunk_size == 1:
- # Behaviour with `chunk_size=1` is same as `for-loop`
- # i.e. user shouldn't deal with the limitations of vmap.
- chunked_result = vjp_fn(basis)
- else: # chunk_size is None or chunk_size != 1
- chunked_result = vmap(vjp_fn)(basis)
- flat_results, _ = tree_flatten(chunked_result)
- # Short-circuit if we have a single chunk.
- if chunk_size is None or chunk_size >= out_vec_size:
- if chunk_size == 1: # and out_vec_size == 1
- # Since we squeezed the output dim
- flat_results = tree_map(lambda t: torch.unsqueeze(t, 0), flat_results)
- return flat_results
- for r, sr in zip(flat_results, stacked_results):
- sr[idx * chunk_size: (idx + 1) * chunk_size].copy_(r)
- return stacked_results
- if _preallocate_and_copy:
- flat_jacobians_per_input = compute_jacobian_preallocate_and_copy()
- else:
- flat_jacobians_per_input = compute_jacobian_stacked()
- # Step 2: The returned jacobian is one big tensor per input. In this step,
- # we split each Tensor by output.
- flat_jacobians_per_input = [result.split(flat_output_numels, dim=0) for result in flat_jacobians_per_input]
- flat_input_flat_output = [
- tuple(split.view(out.shape + primal.shape)
- for split, out in zip(splits, flat_output))
- for splits, primal in zip(flat_jacobians_per_input, flat_primals)
- ]
- # Step 3: Right now, `jacobian` is a List[List[Tensor]].
- # The outer List corresponds to the number of primals,
- # the inner List corresponds to the number of outputs.
- # We need to:
- # a. Exchange the order of the outer List and inner List
- # b. tree_unflatten the inner Lists (which correspond to the primals)
- # c. handle the argnums=int case
- # d. tree_unflatten the outer List (which corresponds to the outputs)
- flat_output_flat_input = tuple(zip(*flat_input_flat_output))
- flat_output_input = tuple(tree_unflatten(flat_input, primals_spec)
- for flat_input in flat_output_flat_input)
- if isinstance(argnums, int):
- flat_output_input = tuple(_safe_zero_index(flat_input)
- for flat_input in flat_output_input)
- output_input = tree_unflatten(flat_output_input, output_spec)
- if has_aux:
- return output_input, aux
- return output_input
- return wrapper_fn
- # NOTE: [Computing jacobian with vmap and vjp for multiple outputs]
- #
- # Let's consider f(x) = (x**2, x.sum()) and let x = torch.randn(3).
- # It turns out we can compute the jacobian of this function with a single
- # call to autograd.grad by using vmap over the correct grad_outputs.
- #
- # Firstly, one way to compute the jacobian is to stack x**2 and x.sum()
- # into a 4D vector. E.g., use g(x) = torch.stack([x**2, x.sum()])
- #
- # To get the first row of the jacobian, we call
- # >>> autograd.grad(g(x), x, grad_outputs=torch.tensor([1, 0, 0, 0]))
- # To get the 2nd row of the jacobian, we call
- # >>> autograd.grad(g(x), x, grad_outputs=torch.tensor([0, 1, 0, 0]))
- # and so on.
- #
- # Using vmap, we can vectorize all 4 of these computations into one by
- # passing the standard basis for R^4 as the grad_output.
- # vmap(partial(autograd.grad, g(x), x))(torch.eye(4)).
- #
- # Now, how do we compute the jacobian *without stacking the output*?
- # We can just split the standard basis across the outputs. So to
- # compute the jacobian of f(x), we'd use
- # >>> autograd.grad(f(x), x, grad_outputs=_construct_standard_basis_for(...))
- # The grad_outputs looks like the following:
- # ( torch.tensor([[1, 0, 0],
- # [0, 1, 0],
- # [0, 0, 1],
- # [0, 0, 0]]),
- # torch.tensor([[0],
- # [0],
- # [0],
- # [1]]) )
- #
- # But we're not done yet!
- # >>> vmap(partial(autograd.grad(f(x), x, grad_outputs=...)))
- # returns a Tensor of shape [4, 3]. We have to remember to split the
- # jacobian of shape [4, 3] into two:
- # - one of shape [3, 3] for the first output
- # - one of shape [ 3] for the second output
- def _chunked_standard_basis_for_(tensors, tensor_numels, chunk_size=None):
- # This function:
- # - constructs a N=sum(tensor_numels) standard basis. i.e. an NxN identity matrix.
- # - Splits the identity matrix into chunks with each chunk size determined by `tensor_numels`.
- # - Each chunk corresponds to one tensor. The chunk has the same dtype and
- # device as the tensor
- #
- # For example, with tensor_numels = [1, 2, 1], this function returns:
- # ( tensor([[1], tensor([[0, 0], tensor([[0],
- # [0], [1, 0], [0],
- # [0], [0, 1], [0],
- # [0]]) , [0, 0]]) , [1]]) )
- #
- # Precondition: tensor_numels == tuple(tensor.numel() for tensor in tensors)
- # Precondition: tensors always has at least one element.
- #
- # See NOTE: [Computing jacobian with vmap and grad for multiple tensors]
- # for context behind this function.
- # NOTE: Argument `chunk_size` is used to generate chunked basis instead of
- # one huge basis matrix. `chunk_size` dictates the maximum size of the
- # basis matrix along dim=0.
- assert len(tensors) == len(tensor_numels)
- assert len(tensors) > 0
- assert chunk_size is None or chunk_size > 0
- total_numel = sum(tensor_numels)
- if chunk_size and chunk_size < total_numel:
- chunk_numels = get_chunk_sizes(total_numel, chunk_size)
- else: # chunk_size is None or chunk_size >= total_numel
- chunk_size = total_numel
- chunk_numels = [total_numel]
- diag_start_indices = (0, *torch.tensor(tensor_numels).cumsum(dim=0)[:-1].neg().unbind())
- for chunk_idx, total_numel in enumerate(chunk_numels):
- chunks = tuple(tensor.new_zeros(total_numel, tensor_numel)
- for tensor, tensor_numel in zip(tensors, tensor_numels))
- for chunk, diag_start_idx in zip(chunks, diag_start_indices):
- chunk.diagonal(diag_start_idx + chunk_idx * chunk_size).fill_(1)
- chunks = tuple(chunk.view(total_numel, *tensor.shape)
- for chunk, tensor in zip(chunks, tensors))
- yield chunks
- def _construct_standard_basis_for(tensors, tensor_numels):
- for basis in _chunked_standard_basis_for_(tensors, tensor_numels, chunk_size=None):
- return basis
- def _validate_and_wrap_argnum(argnum, num_args):
- if not isinstance(argnum, int):
- raise RuntimeError(f'argnum must be int, got: {type(argnum)}')
- if argnum >= 0 and argnum < num_args:
- return argnum
- if argnum < 0 and argnum >= -num_args:
- return argnum + num_args
- raise RuntimeError(f'Got argnum={argnum}, but only {num_args} positional inputs')
- def _check_unique_non_empty(argnums):
- if isinstance(argnums, tuple):
- if len(argnums) == 0:
- raise RuntimeError("argnums must be non-empty")
- if len(set(argnums)) != len(argnums):
- raise RuntimeError(f"argnums elements must be unique, got {argnums}")
- def _replace_args(old_args, new_args, argnums):
- if isinstance(argnums, int):
- if len(new_args) != 1:
- raise RuntimeError(f'new_args should be of size 1, was of size {len(new_args)}')
- return tuple(new_args[0] if i == argnums else old_args[i] for i in range(len(old_args)))
- if isinstance(argnums, tuple):
- if len(new_args) != len(argnums):
- raise RuntimeError(
- "new_args should have the same size as argnums. "
- f"Argnums size {len(argnums)}, new_args size {len(new_args)}")
- def get_right_elem(i):
- return new_args[argnums.index(i)] if i in argnums else old_args[i]
- return tuple(get_right_elem(i) for i in range(len(old_args)))
- raise RuntimeError(f'argnums must be int or Tuple[int, ...], got: {type(argnums)}')
- def _validate_and_wrap_argnums(argnums, num_args):
- if isinstance(argnums, int):
- return _validate_and_wrap_argnum(argnums, num_args)
- if isinstance(argnums, tuple):
- return tuple(_validate_and_wrap_argnum(argnum, num_args) for argnum in argnums)
- raise AssertionError("Should never get here")
- def _slice_argnums(args, argnums, as_tuple=True):
- if not isinstance(argnums, int) and not isinstance(argnums, tuple):
- raise RuntimeError(f'argnums must be int or Tuple[int, ...], got: {type(argnums)}')
- argnums = _validate_and_wrap_argnums(argnums, len(args))
- _check_unique_non_empty(argnums)
- if isinstance(argnums, int):
- if as_tuple:
- return (args[argnums],)
- else:
- return args[argnums]
- return tuple(args[i] for i in argnums)
- JVP_NESTING = 0
- @contextlib.contextmanager
- def noop():
- yield
- def assert_flat_tuple_of_tensors(elts: Any, api: str, argname: str) -> None:
- if not isinstance(elts, tuple):
- raise RuntimeError(
- f'{api}: Expected {argname} to be a tuple of Tensors, got {type(elts)}')
- for elt in elts:
- if isinstance(elt, torch.Tensor):
- continue
- raise RuntimeError(
- f'{api}: Expected {argname} to be a tuple of Tensors, got '
- f'a tuple with an element of type {type(elt)}')
- if len(elts) == 0:
- raise RuntimeError(
- f'{api}: Expected {argname} to be a non-empty tuple of Tensors.')
- def assert_non_empty_tensor_output(output: List[Any], api: str) -> None:
- if output == [None] or len(output) < 1:
- raise RuntimeError(
- f'{api}: Expected f to be a function that has non-empty output (got output = {output})'
- )
- for o in output:
- if not isinstance(o, torch.Tensor):
- raise RuntimeError(
- f'{api}: expected f(*primals) to return only tensors'
- f', got unsupported type {type(o)}'
- )
- def assert_output_is_tensor_or_tensors(output: Any, api: str) -> None:
- if isinstance(output, torch.Tensor):
- return
- if not isinstance(output, tuple):
- raise RuntimeError(
- f'{api}: Expected output of f to be a Tensor or Tensors, got '
- f'{type(output)}')
- if len(output) == 0:
- raise RuntimeError(
- f'{api}: Expected output of f to be a non-empty tuple of Tensors.')
- for out in output:
- if isinstance(out, torch.Tensor):
- continue
- raise RuntimeError(
- f'{api}: Expected output of f to be a Tensor or Tensors, got '
- f'{type(out)} as an output')
- def assert_non_empty_list_of_tensors(output: List[torch.Tensor], api: str, argname: str) -> None:
- if len(output) == 0:
- raise RuntimeError(
- f'{api}: Expected {argname} to contain at least one Tensor.')
- for out in output:
- if isinstance(out, torch.Tensor):
- continue
- raise RuntimeError(
- f'{api}: Expected {argname} to only contain Tensors, got '
- f'{type(out)}')
- jvp_str = 'jvp(f, primals, tangents)'
- def safe_unpack_dual(dual, strict):
- if not isinstance(dual, torch.Tensor):
- raise RuntimeError(
- f'{jvp_str}: expected f(*args) to return only tensors'
- f', got unsupported type {type(dual)}'
- )
- primal, tangent = fwAD.unpack_dual(dual)
- if tangent is None:
- if strict:
- raise RuntimeError(
- 'jvp(f, primals, tangents, strict=True): '
- 'The output of f is independent of '
- 'the inputs. This is not allowed with strict=True.')
- tangent = torch.zeros_like(primal)
- return primal, tangent
- @exposed_in("torch.func")
- def jvp(func: Callable, primals: Any, tangents: Any, *, strict: bool = False, has_aux: bool = False):
- """
- Standing for the Jacobian-vector product, returns a tuple containing
- the output of `func(*primals)` and the "Jacobian of ``func`` evaluated at
- ``primals``" times ``tangents``. This is also known as forward-mode autodiff.
- Args:
- func (function): A Python function that takes one or more arguments,
- one of which must be a Tensor, and returns one or more Tensors
- primals (Tensors): Positional arguments to ``func`` that must all be
- Tensors. The returned function will also be computing the
- derivative with respect to these arguments
- tangents (Tensors): The "vector" for which Jacobian-vector-product is
- computed. Must be the same structure and sizes as the inputs to
- ``func``.
- has_aux (bool): Flag indicating that ``func`` returns a
- ``(output, aux)`` tuple where the first element is the output of
- the function to be differentiated and the second element is
- other auxiliary objects that will not be differentiated.
- Default: False.
- Returns:
- Returns a ``(output, jvp_out)`` tuple containing the output of ``func``
- evaluated at ``primals`` and the Jacobian-vector product.
- If ``has_aux is True``, then instead returns a ``(output, jvp_out, aux)`` tuple.
- .. note::
- You may see this API error out with "forward-mode AD not implemented
- for operator X". If so, please file a bug report and we will prioritize it.
- jvp is useful when you wish to compute gradients of a function R^1 -> R^N
- >>> from torch.func import jvp
- >>> x = torch.randn([])
- >>> f = lambda x: x * torch.tensor([1., 2., 3])
- >>> value, grad = jvp(f, (x,), (torch.tensor(1.),))
- >>> assert torch.allclose(value, f(x))
- >>> assert torch.allclose(grad, torch.tensor([1., 2, 3]))
- :func:`jvp` can support functions with multiple inputs by passing in the
- tangents for each of the inputs
- >>> from torch.func import jvp
- >>> x = torch.randn(5)
- >>> y = torch.randn(5)
- >>> f = lambda x, y: (x * y)
- >>> _, output = jvp(f, (x, y), (torch.ones(5), torch.ones(5)))
- >>> assert torch.allclose(output, x + y)
- """
- return _jvp_with_argnums(func, primals, tangents, argnums=None, strict=strict, has_aux=has_aux)
- @doesnt_support_saved_tensors_hooks
- def _jvp_with_argnums(func: Callable, primals: Any, tangents: Any, argnums: Optional[argnums_t], *,
- strict: bool = False, has_aux: bool):
- # This is the same function as jvp but also accepts an argnums argument
- # Most args are the same as jvp except for the added argument
- # argnums (Optional[int or tuple[int]]): Optional, specifies the argument(s) to compute gradients with respect to.
- # If None, computes the gradients with respect to all inputs (used for jvp). Default: None
- # Because of this, tangents must be of length argnums and matches up to the corresponding primal whose index is
- # given by argnums
- #
- # WARN: Users should NOT call this function directly and should just be calling jvp.
- # It is only separated so that inputs passed to jacfwd but not differentiated get the correct wrappers.
- #
- # NOTE: All error messages are produced as if jvp was being called, even if this was called by jacfwd
- #
- # Returns the same two elements as :func:`jvp` but the returned tuple, ``jvp_out``, only has JVPs with respect to
- # the primals given by argnums
- if not isinstance(primals, tuple):
- raise RuntimeError(
- f'{jvp_str}: Expected primals to be a tuple. '
- f'E.g. it should be valid to call f(*primals).')
- diff_args = primals if argnums is None else _slice_argnums(primals, argnums)
- flat_primals, primals_spec = tree_flatten(diff_args)
- flat_tangents, tangents_spec = tree_flatten(tangents)
- if primals_spec != tangents_spec:
- raise RuntimeError(
- f'{jvp_str}: Expected primals and tangents to have the same python '
- f'structure. For example, if primals is a tuple of 3 tensors, '
- f'tangents also must be. Got primals with structure {primals_spec} '
- f'and tangents with structure {tangents_spec}')
- assert_non_empty_list_of_tensors(flat_primals, jvp_str, 'primals')
- assert_non_empty_list_of_tensors(flat_tangents, jvp_str, 'tangents')
- level = _jvp_increment_nesting()
- try:
- global JVP_NESTING
- JVP_NESTING += 1
- with fwAD._set_fwd_grad_enabled(True):
- ctx = fwAD.dual_level if JVP_NESTING == 1 else noop
- with ctx():
- flat_duals = tuple(fwAD.make_dual(p, t)
- for p, t in zip(flat_primals, flat_tangents))
- duals = tree_unflatten(flat_duals, primals_spec)
- if argnums is not None:
- primals = _wrap_all_tensors(primals, level)
- duals = _replace_args(primals, duals, argnums)
- result_duals = func(*duals)
- if has_aux:
- if not (isinstance(result_duals, tuple) and len(result_duals) == 2):
- raise RuntimeError(
- f"{jvp_str}: output of function f should be a tuple: (output, aux) "
- "if has_aux is True"
- )
- result_duals, aux = result_duals
- aux = _undo_create_differentiable(aux, level)
- result_duals, spec = tree_flatten(result_duals)
- assert_non_empty_tensor_output(result_duals, jvp_str)
- primals_out, tangents_out = \
- zip(*[safe_unpack_dual(dual, strict) for dual in result_duals])
- primals_out = tree_map(
- partial(_undo_create_differentiable, level=level), primals_out)
- tangents_out = tree_map(
- partial(_undo_create_differentiable, level=level), tangents_out)
- primals_out_unflatten = tree_unflatten(primals_out, spec)
- tangents_out_unflatten = tree_unflatten(tangents_out, spec)
- if has_aux:
- return primals_out_unflatten, tangents_out_unflatten, aux
- return primals_out_unflatten, tangents_out_unflatten
- finally:
- _jvp_decrement_nesting()
- JVP_NESTING -= 1
- def safe_unflatten(tensor, dim, shape):
- if len(shape) == 0:
- assert tensor.shape[dim] == 1
- return tensor.squeeze(dim)
- return tensor.unflatten(dim, shape)
- @exposed_in("torch.func")
- def jacfwd(func: Callable, argnums: argnums_t = 0, has_aux: bool = False, *, randomness: str = "error"):
- """
- Computes the Jacobian of ``func`` with respect to the arg(s) at index
- ``argnum`` using forward-mode autodiff
- Args:
- func (function): A Python function that takes one or more arguments,
- one of which must be a Tensor, and returns one or more Tensors
- argnums (int or Tuple[int]): Optional, integer or tuple of integers,
- saying which arguments to get the Jacobian with respect to.
- Default: 0.
- has_aux (bool): Flag indicating that ``func`` returns a
- ``(output, aux)`` tuple where the first element is the output of
- the function to be differentiated and the second element is
- auxiliary objects that will not be differentiated.
- Default: False.
- randomness(str): Flag indicating what type of randomness to use.
- See :func:`vmap` for more detail. Allowed: "different", "same", "error".
- Default: "error"
- Returns:
- Returns a function that takes in the same inputs as ``func`` and
- returns the Jacobian of ``func`` with respect to the arg(s) at
- ``argnums``. If ``has_aux is True``, then the returned function
- instead returns a ``(jacobian, aux)`` tuple where ``jacobian``
- is the Jacobian and ``aux`` is auxiliary objects returned by ``func``.
- .. note::
- You may see this API error out with "forward-mode AD not implemented
- for operator X". If so, please file a bug report and we will prioritize it.
- An alternative is to use :func:`jacrev`, which has better operator coverage.
- A basic usage with a pointwise, unary operation will give a diagonal array
- as the Jacobian
- >>> from torch.func import jacfwd
- >>> x = torch.randn(5)
- >>> jacobian = jacfwd(torch.sin)(x)
- >>> expected = torch.diag(torch.cos(x))
- >>> assert torch.allclose(jacobian, expected)
- :func:`jacfwd` can be composed with vmap to produce batched
- Jacobians:
- >>> from torch.func import jacfwd, vmap
- >>> x = torch.randn(64, 5)
- >>> jacobian = vmap(jacfwd(torch.sin))(x)
- >>> assert jacobian.shape == (64, 5, 5)
- If you would like to compute the output of the function as well as the
- jacobian of the function, use the ``has_aux`` flag to return the output
- as an auxiliary object:
- >>> from torch.func import jacfwd
- >>> x = torch.randn(5)
- >>>
- >>> def f(x):
- >>> return x.sin()
- >>>
- >>> def g(x):
- >>> result = f(x)
- >>> return result, result
- >>>
- >>> jacobian_f, f_x = jacfwd(g, has_aux=True)(x)
- >>> assert torch.allclose(f_x, f(x))
- Additionally, :func:`jacrev` can be composed with itself or :func:`jacrev`
- to produce Hessians
- >>> from torch.func import jacfwd, jacrev
- >>> def f(x):
- >>> return x.sin().sum()
- >>>
- >>> x = torch.randn(5)
- >>> hessian = jacfwd(jacrev(f))(x)
- >>> assert torch.allclose(hessian, torch.diag(-x.sin()))
- By default, :func:`jacfwd` computes the Jacobian with respect to the first
- input. However, it can compute the Jacboian with respect to a different
- argument by using ``argnums``:
- >>> from torch.func import jacfwd
- >>> def f(x, y):
- >>> return x + y ** 2
- >>>
- >>> x, y = torch.randn(5), torch.randn(5)
- >>> jacobian = jacfwd(f, argnums=1)(x, y)
- >>> expected = torch.diag(2 * y)
- >>> assert torch.allclose(jacobian, expected)
- Additionally, passing a tuple to ``argnums`` will compute the Jacobian
- with respect to multiple arguments
- >>> from torch.func import jacfwd
- >>> def f(x, y):
- >>> return x + y ** 2
- >>>
- >>> x, y = torch.randn(5), torch.randn(5)
- >>> jacobian = jacfwd(f, argnums=(0, 1))(x, y)
- >>> expectedX = torch.diag(torch.ones_like(x))
- >>> expectedY = torch.diag(2 * y)
- >>> assert torch.allclose(jacobian[0], expectedX)
- >>> assert torch.allclose(jacobian[1], expectedY)
- """
- @wraps(func)
- def wrapper_fn(*args):
- error_if_complex("jacfwd", args, is_input=True)
- primals = args if argnums is None else _slice_argnums(args, argnums)
- flat_primals, primals_spec = tree_flatten(primals)
- flat_primals_numels = tuple(p.numel() for p in flat_primals)
- flat_basis = _construct_standard_basis_for(flat_primals, flat_primals_numels)
- basis = tree_unflatten(flat_basis, primals_spec)
- def push_jvp(basis):
- output = _jvp_with_argnums(func, args, basis, argnums=argnums, has_aux=has_aux)
- # output[0] is the output of `func(*args)`
- error_if_complex("jacfwd", output[0], is_input=False)
- if has_aux:
- _, jvp_out, aux = output
- return jvp_out, aux
- _, jvp_out = output
- return jvp_out
- results = vmap(push_jvp, randomness=randomness)(basis)
- if has_aux:
- results, aux = results
- # aux is in the standard basis format, e.g. NxN matrix
- # We need to fetch the first element as original `func` output
- flat_aux, aux_spec = tree_flatten(aux)
- flat_aux = [value[0] for value in flat_aux]
- aux = tree_unflatten(flat_aux, aux_spec)
- jac_outs, spec = tree_flatten(results)
- # Most probably below output check can never raise an error
- # as jvp should test the output before
- # assert_non_empty_output(jac_outs, 'jacfwd(f, ...)(*args)')
- jac_outs_ins = tuple(
- tuple(
- safe_unflatten(jac_out_in, -1, primal.shape)
- for primal, jac_out_in in
- zip(flat_primals, jac_out.movedim(0, -1).split(flat_primals_numels, dim=-1))
- )
- for jac_out in jac_outs
- )
- jac_outs_ins = tuple(tree_unflatten(jac_ins, primals_spec) for jac_ins in jac_outs_ins)
- if isinstance(argnums, int):
- jac_outs_ins = tuple(jac_ins[0] for jac_ins in jac_outs_ins)
- if has_aux:
- return tree_unflatten(jac_outs_ins, spec), aux
- return tree_unflatten(jac_outs_ins, spec)
- return wrapper_fn
- @exposed_in("torch.func")
- def hessian(func, argnums=0):
- """
- Computes the Hessian of ``func`` with respect to the arg(s) at index
- ``argnum`` via a forward-over-reverse strategy.
- The forward-over-reverse strategy (composing ``jacfwd(jacrev(func))``) is
- a good default for good performance. It is possible to compute Hessians
- through other compositions of :func:`jacfwd` and :func:`jacrev` like
- ``jacfwd(jacfwd(func))`` or ``jacrev(jacrev(func))``.
- Args:
- func (function): A Python function that takes one or more arguments,
- one of which must be a Tensor, and returns one or more Tensors
- argnums (int or Tuple[int]): Optional, integer or tuple of integers,
- saying which arguments to get the Hessian with respect to.
- Default: 0.
- Returns:
- Returns a function that takes in the same inputs as ``func`` and
- returns the Hessian of ``func`` with respect to the arg(s) at
- ``argnums``.
- .. note::
- You may see this API error out with "forward-mode AD not implemented
- for operator X". If so, please file a bug report and we will prioritize it.
- An alternative is to use ``jacrev(jacrev(func))``, which has better
- operator coverage.
- A basic usage with a R^N -> R^1 function gives a N x N Hessian:
- >>> from torch.func import hessian
- >>> def f(x):
- >>> return x.sin().sum()
- >>>
- >>> x = torch.randn(5)
- >>> hess = hessian(f)(x) # equivalent to jacfwd(jacrev(f))(x)
- >>> assert torch.allclose(hess, torch.diag(-x.sin()))
- """
- return jacfwd(jacrev(func, argnums), argnums)
- @exposed_in("torch.func")
- def grad_and_value(func: Callable, argnums: argnums_t = 0, has_aux: bool = False) -> Callable:
- """
- Returns a function to compute a tuple of the gradient and primal, or
- forward, computation.
- Args:
- func (Callable): A Python function that takes one or more arguments.
- Must return a single-element Tensor. If specified ``has_aux``
- equals ``True``, function can return a tuple of single-element
- Tensor and other auxiliary objects: ``(output, aux)``.
- argnums (int or Tuple[int]): Specifies arguments to compute gradients
- with respect to. ``argnums`` can be single integer or tuple of
- integers. Default: 0.
- has_aux (bool): Flag indicating that ``func`` returns a tensor and
- other auxiliary objects: ``(output, aux)``. Default: False.
- Returns:
- Function to compute a tuple of gradients with respect to its inputs
- and the forward computation. By default, the output of the function is
- a tuple of the gradient tensor(s) with respect to the first argument
- and the primal computation. If specified ``has_aux`` equals
- ``True``, tuple of gradients and tuple of the forward computation with
- output auxiliary objects is returned. If ``argnums`` is a tuple of
- integers, a tuple of a tuple of the output gradients with respect to
- each ``argnums`` value and the forward computation is returned.
- See :func:`grad` for examples
- """
- @doesnt_support_saved_tensors_hooks
- @wraps(func)
- def wrapper(*args, **kwargs):
- level = _grad_increment_nesting()
- try:
- output, aux, grad_input = None, None, None
- # See NOTE [grad and vjp interaction with no_grad]
- with torch.enable_grad():
- args = _wrap_all_tensors(args, level)
- kwargs = _wrap_all_tensors(kwargs, level)
- diff_args = _slice_argnums(args, argnums, as_tuple=False)
- tree_map_(partial(_create_differentiable, level=level), diff_args)
- output = func(*args, **kwargs)
- if has_aux:
- if not (isinstance(output, tuple) and len(output) == 2):
- raise RuntimeError(
- "grad_and_value(f)(*args): output of function f should be a tuple: (output, aux) "
- "if has_aux is True"
- )
- output, aux = output
- if not isinstance(output, torch.Tensor):
- raise RuntimeError('grad_and_value(f)(*args): Expected f(*args) '
- f'to return a Tensor, got {type(output)}')
- if output.dim() != 0:
- raise RuntimeError('grad_and_value(f)(*args): Expected f(*args) '
- 'to return a scalar Tensor, got tensor with '
- f'{output.dim()} dims. Maybe you wanted to '
- 'use the vjp or jacrev APIs instead?')
- flat_diff_args, spec = tree_flatten(diff_args)
- # NB: need create_graph so that backward pass isn't run in no_grad mode
- flat_outputs = _as_tuple(output)
- flat_grad_input = _autograd_grad(flat_outputs, flat_diff_args, create_graph=True)
- grad_input = tree_unflatten(flat_grad_input, spec)
- grad_input = _undo_create_differentiable(grad_input, level)
- output = _undo_create_differentiable(output, level)
- if aux is not None:
- aux = _undo_create_differentiable(aux, level)
- if has_aux:
- return grad_input, (output, aux)
- return grad_input, output
- finally:
- _grad_decrement_nesting()
- return wrapper
- @exposed_in("torch.func")
- def grad(func: Callable, argnums: argnums_t = 0, has_aux: bool = False) -> Callable:
- """``grad`` operator helps computing gradients of ``func`` with respect to the
- input(s) specified by ``argnums``. This operator can be nested to
- compute higher-order gradients.
- Args:
- func (Callable): A Python function that takes one or more arguments.
- Must return a single-element Tensor. If specified ``has_aux`` equals ``True``,
- function can return a tuple of single-element Tensor and other auxiliary objects:
- ``(output, aux)``.
- argnums (int or Tuple[int]): Specifies arguments to compute gradients with respect to.
- ``argnums`` can be single integer or tuple of integers. Default: 0.
- has_aux (bool): Flag indicating that ``func`` returns a tensor and other
- auxiliary objects: ``(output, aux)``. Default: False.
- Returns:
- Function to compute gradients with respect to its inputs. By default, the output of
- the function is the gradient tensor(s) with respect to the first argument.
- If specified ``has_aux`` equals ``True``, tuple of gradients and output auxiliary objects
- is returned. If ``argnums`` is a tuple of integers, a tuple of output gradients with
- respect to each ``argnums`` value is returned.
- Example of using ``grad``:
- >>> # xdoctest: +SKIP
- >>> from torch.func import grad
- >>> x = torch.randn([])
- >>> cos_x = grad(lambda x: torch.sin(x))(x)
- >>> assert torch.allclose(cos_x, x.cos())
- >>>
- >>> # Second-order gradients
- >>> neg_sin_x = grad(grad(lambda x: torch.sin(x)))(x)
- >>> assert torch.allclose(neg_sin_x, -x.sin())
- When composed with ``vmap``, ``grad`` can be used to compute per-sample-gradients:
- >>> # xdoctest: +SKIP
- >>> from torch.func import grad, vmap
- >>> batch_size, feature_size = 3, 5
- >>>
- >>> def model(weights, feature_vec):
- >>> # Very simple linear model with activation
- >>> assert feature_vec.dim() == 1
- >>> return feature_vec.dot(weights).relu()
- >>>
- >>> def compute_loss(weights, example, target):
- >>> y = model(weights, example)
- >>> return ((y - target) ** 2).mean() # MSELoss
- >>>
- >>> weights = torch.randn(feature_size, requires_grad=True)
- >>> examples = torch.randn(batch_size, feature_size)
- >>> targets = torch.randn(batch_size)
- >>> inputs = (weights, examples, targets)
- >>> grad_weight_per_example = vmap(grad(compute_loss), in_dims=(None, 0, 0))(*inputs)
- Example of using ``grad`` with ``has_aux`` and ``argnums``:
- >>> # xdoctest: +SKIP
- >>> from torch.func import grad
- >>> def my_loss_func(y, y_pred):
- >>> loss_per_sample = (0.5 * y_pred - y) ** 2
- >>> loss = loss_per_sample.mean()
- >>> return loss, (y_pred, loss_per_sample)
- >>>
- >>> fn = grad(my_loss_func, argnums=(0, 1), has_aux=True)
- >>> y_true = torch.rand(4)
- >>> y_preds = torch.rand(4, requires_grad=True)
- >>> out = fn(y_true, y_preds)
- >>> # > output is ((grads w.r.t y_true, grads w.r.t y_preds), (y_pred, loss_per_sample))
- .. note::
- Using PyTorch ``torch.no_grad`` together with ``grad``.
- Case 1: Using ``torch.no_grad`` inside a function:
- >>> # xdoctest: +SKIP
- >>> def f(x):
- >>> with torch.no_grad():
- >>> c = x ** 2
- >>> return x - c
- In this case, ``grad(f)(x)`` will respect the inner ``torch.no_grad``.
- Case 2: Using ``grad`` inside ``torch.no_grad`` context manager:
- >>> # xdoctest: +SKIP
- >>> with torch.no_grad():
- >>> grad(f)(x)
- In this case, ``grad`` will respect the inner ``torch.no_grad``, but not the
- outer one. This is because ``grad`` is a "function transform": its result
- should not depend on the result of a context manager outside of ``f``.
- """
- @wraps(func)
- def wrapper(*args, **kwargs):
- results = grad_and_value(func, argnums, has_aux=has_aux)(*args, **kwargs)
- if has_aux:
- grad, (_, aux) = results
- return grad, aux
- grad, _ = results
- return grad
- return wrapper
- def _maybe_wrap_functional_tensor(maybe_tensor, level):
- if not isinstance(maybe_tensor, torch.Tensor):
- return maybe_tensor
- wrapped = _wrap_functional_tensor(maybe_tensor, level)
- _assert_wrapped_functional(maybe_tensor, wrapped)
- return wrapped
- def _wrap_all_tensors_to_functional(tensor_pytree, level):
- return tree_map(partial(_maybe_wrap_functional_tensor, level=level), tensor_pytree)
- def _maybe_unwrap_functional_tensor(maybe_tensor, *, reapply_views: bool):
- if not isinstance(maybe_tensor, torch.Tensor):
- return maybe_tensor
- if not torch._is_functional_tensor(maybe_tensor):
- # If it's not a functional tensor, just return it.
- # This can happen if we functionalize a fn that returns a global,
- # which was never wrapped properly.
- return maybe_tensor
- # Sync any pending updates on the output tensor
- torch._sync(maybe_tensor)
- return _unwrap_functional_tensor(maybe_tensor, reapply_views)
- def _unwrap_all_tensors_from_functional(tensor_pytree, *, reapply_views: bool):
- return tree_map(lambda t: _maybe_unwrap_functional_tensor(t, reapply_views=reapply_views), tensor_pytree)
- @exposed_in("torch.func")
- def functionalize(func: Callable, *, remove: str = 'mutations') -> Callable:
- """
- functionalize is a transform that can be used to remove (intermediate)
- mutations and aliasing from a function, while preserving the function's
- semantics.
- ``functionalize(func)`` returns a new function with the same semantics
- as ``func``, but with all intermediate mutations removed.
- Every inplace operation performed on an intermediate tensor:
- ``intermediate.foo_()``
- gets replaced by its out-of-place equivalent:
- ``intermediate_updated = intermediate.foo()``.
- functionalize is useful for shipping a pytorch program off to
- backends or compilers that aren't able to easily represent
- mutations or aliasing operators.
- Args:
- func (Callable): A Python function that takes one or more arguments.
- remove (str): An optional string argument, that takes on either
- the value 'mutations' or 'mutations_and_views'.
- If 'mutations' is passed in then all mutating operators
- will be replaced with their non-mutating equivalents.
- If 'mutations_and_views' is passed in, then additionally, all aliasing
- operators will be replaced with their non-aliasing equivalents.
- Default: 'mutations'.
- Returns:
- Returns a new "functionalized" function. It takes the same inputs as
- ``func``, and has the same behavior, but any mutations
- (and optionally aliasing) performed on intermeidate tensors
- in the function will be removed.
- functionalize will also remove mutations (and views) that were performed on function inputs.
- However to preserve semantics, functionalize will "fix up" the mutations after
- the transform has finished running, by detecting if any tensor inputs "should have"
- been mutated, and copying the new data back to the inputs if necessary.
- Example::
- >>> # xdoctest: +SKIP
- >>> import torch
- >>> from torch.fx.experimental.proxy_tensor import make_fx
- >>> from torch.func import functionalize
- >>>
- >>> # A function that uses mutations and views, but only on intermediate tensors.
- >>> def f(a):
- ... b = a + 1
- ... c = b.view(-1)
- ... c.add_(1)
- ... return b
- ...
- >>> inpt = torch.randn(2)
- >>>
- >>> out1 = f(inpt)
- >>> out2 = functionalize(f)(inpt)
- >>>
- >>> # semantics are the same (outputs are equivalent)
- >>> print(torch.allclose(out1, out2))
- True
- >>>
- >>> f_traced = make_fx(f)(inpt)
- >>> f_no_mutations_traced = make_fx(functionalize(f))(inpt)
- >>> f_no_mutations_and_views_traced = make_fx(functionalize(f, remove='mutations_and_views'))(inpt)
- >>>
- >>> print(f_traced.code)
- def forward(self, a_1):
- add = torch.ops.aten.add(a_1, 1); a_1 = None
- view = torch.ops.aten.view(add, [-1])
- add_ = torch.ops.aten.add_(view, 1); view = None
- return add
- >>> print(f_no_mutations_traced.code)
- def forward(self, a_1):
- add = torch.ops.aten.add(a_1, 1); a_1 = None
- view = torch.ops.aten.view(add, [-1]); add = None
- add_1 = torch.ops.aten.add(view, 1); view = None
- view_1 = torch.ops.aten.view(add_1, [2]); add_1 = None
- return view_1
- >>> print(f_no_mutations_and_views_traced.code)
- def forward(self, a_1):
- add = torch.ops.aten.add(a_1, 1); a_1 = None
- view_copy = torch.ops.aten.view_copy(add, [-1]); add = None
- add_1 = torch.ops.aten.add(view_copy, 1); view_copy = None
- view_copy_1 = torch.ops.aten.view_copy(add_1, [2]); add_1 = None
- return view_copy_1
- >>> # A function that mutates its input tensor
- >>> def f(a):
- ... b = a.view(-1)
- ... b.add_(1)
- ... return a
- ...
- >>> f_no_mutations_and_views_traced = make_fx(functionalize(f, remove='mutations_and_views'))(inpt)
- >>> #
- >>> # All mutations and views have been removed,
- >>> # but there is an extra copy_ in the graph to correctly apply the mutation to the input
- >>> # after the function has completed.
- >>> print(f_no_mutations_and_views_traced.code)
- def forward(self, a_1):
- view_copy = torch.ops.aten.view_copy(a_1, [-1])
- add = torch.ops.aten.add(view_copy, 1); view_copy = None
- view_copy_1 = torch.ops.aten.view_copy(add, [2]); add = None
- copy_ = torch.ops.aten.copy_(a_1, view_copy_1); a_1 = None
- return view_copy_1
- There are a few "failure modes" for functionalize that are worth calling out:
- (1) Like other torch.func transforms, `functionalize()` doesn't work with functions
- that directly use `.backward()`. The same is true for torch.autograd.grad.
- If you want to use autograd, you can compute gradients directly
- with `functionalize(grad(f))`.
- (2) Like other torch.func transforms, `functionalize()` doesn't work with global state.
- If you call `functionalize(f)` on a function that takes views / mutations of
- non-local state, functionalization will simply no-op and pass the view/mutation
- calls directly to the backend.
- One way to work around this is is to ensure that any non-local state creation
- is wrapped into a larger function, which you then call functionalize on.
- (3) `resize_()` has some limitations: functionalize will only work on programs
- that use resize_()` as long as the tensor being resized is not a view.
- (4) `as_strided()` has some limitations: functionalize will not work on
- `as_strided()` calls that result in tensors with overlapping memory.
- Finally, a helpful mental model for understanding functionalization is that
- most user pytorch programs are writting with the public torch API.
- When executed, torch operators are generally decomposed into
- our internal C++ "ATen" API.
- The logic for functionalization happens entirely at the level of ATen.
- Functionalization knows how to take every aliasing operator in ATen,
- and map it to its non-aliasing equivalent
- (e.g. ``tensor.view({-1})`` -> ``at::view_copy(tensor, {-1})``),
- and how to take every mutating operator in ATen,
- and map it to its non-mutating equivalent
- (e.g. ``tensor.add_(1)`` -> ``at::add(tensor, -1)``),
- while tracking aliases and mutations out-of-line to know when to fix things up.
- Information about which ATen operators are aliasing or mutating all comes from
- https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml.
- """
- if remove == 'mutations':
- reapply_views = True
- elif remove == 'mutations_and_views':
- reapply_views = False
- else:
- raise RuntimeError(
- f"functionalize(f, remove='mutations'): received invalid argument for remove={remove}."
- " Valid options are:\n"
- " remove='mutations': all inplace and out= operators will be removed from the program, and replaced"
- " with their out-of-place equivalents.\n"
- " remove='mutations_and_views': In addition to the above, all aliasing operators {view} will be"
- " replaced with their non-aliasing counterparts, {view}_copy.\n"
- )
- @doesnt_support_saved_tensors_hooks
- @wraps(func)
- def wrapped(*args, **kwargs):
- try:
- func_level = _func_increment_nesting(reapply_views)
- func_args = _wrap_all_tensors_to_functional(args, func_level)
- func_kwargs = _wrap_all_tensors_to_functional(kwargs, func_level)
- flattened_unwrapped_args, _ = tree_flatten(args)
- flattened_wrapped_args, _ = tree_flatten(func_args)
- flattened_unwrapped_kwargs, _ = tree_flatten(kwargs)
- flattened_wrapped_kwargs, _ = tree_flatten(func_kwargs)
- func_outputs = func(*func_args, **func_kwargs)
- outputs = _unwrap_all_tensors_from_functional(func_outputs, reapply_views=reapply_views)
- flat_outputs, func_out_spec = tree_flatten(outputs)
- for a in flattened_wrapped_args + flattened_wrapped_kwargs:
- if isinstance(a, torch.Tensor):
- # Call sync_() on the inputs, to ensure that any pending mutations have been applied.
- torch._sync(a)
- # And if any mutations were applied to the inputs, we need to propagate them back to the user.
- for unwrapped, wrapped in zip(flattened_unwrapped_args, flattened_wrapped_args):
- if isinstance(unwrapped, torch.Tensor) and isinstance(wrapped, torch.Tensor):
- _propagate_functional_input_mutation(unwrapped, wrapped)
- for unwrapped, wrapped in zip(flattened_unwrapped_kwargs, flattened_wrapped_kwargs):
- if isinstance(unwrapped, torch.Tensor) and isinstance(wrapped, torch.Tensor):
- _propagate_functional_input_mutation(unwrapped, wrapped)
- return outputs
- finally:
- _func_decrement_nesting()
- return wrapped
- @exposed_in("torch.func")
- def linearize(func: Callable, *primals) -> Tuple[Any, Callable]:
- '''
- Returns the value of ``func`` at ``primals`` and linear approximation
- at ``primals``.
- Args:
- func (Callable): A Python function that takes one or more arguments.
- primals (Tensors): Positional arguments to ``func`` that must all be
- Tensors. These are the values at which the function is linearly approximated.
- Returns:
- Returns a ``(output, jvp_fn)`` tuple containing the output of ``func``
- applied to ``primals`` and a function that computes the jvp of
- ``func`` evaluated at ``primals``.
- linearize is useful if jvp is to be computed multiple times at ``primals``. However,
- to achieve this, linearize saves intermediate computation and has higher memory requrements
- than directly applying `jvp`. So, if all the ``tangents`` are known, it maybe more efficient
- to compute vmap(jvp) instead of using linearize.
- .. note::
- linearize evaluates ``func`` twice. Please file an issue for an implementation
- with a single evaluation.
- Example::
- >>> import torch
- >>> from torch.func import linearize
- >>> def fn(x):
- ... return x.sin()
- ...
- >>> output, jvp_fn = linearize(fn, torch.zeros(3, 3))
- >>> jvp_fn(torch.ones(3, 3))
- tensor([[1., 1., 1.],
- [1., 1., 1.],
- [1., 1., 1.]])
- >>>
- '''
- # Note: We evaluate `fn` twice.
- # Once for returning the output and other while
- # tracing the graph.
- # If this becomes a bottle-neck, we should update
- # make_fx such that it also returns the output.
- output = func(*primals)
- _, output_spec = tree_flatten(output)
- flat_primals, primals_argspec = tree_flatten(primals)
- # tangents for tracing
- flat_tangents = tuple(p.new_empty(()).expand_as(p) for p in flat_primals)
- # function to trace
- def trace_fn(flat_tangents):
- with fwAD.dual_level():
- flat_duals = tuple(fwAD.make_dual(p, t) for p, t in zip(flat_primals, flat_tangents))
- duals = tree_unflatten(flat_duals, primals_argspec)
- output = func(*duals)
- tangents = tree_map_only(torch.Tensor, lambda t: fwAD.unpack_dual(t)[1], output)
- return tangents
- jvp_graph = make_fx(trace_fn)(flat_tangents)
- const_folded_jvp_graph = const_fold.split_const_subgraphs(jvp_graph)
- # Hold only the meta-data regarding the primals.
- flat_primals_shape = tuple(p.shape for p in flat_primals)
- flat_primals_device = tuple(p.device for p in flat_primals)
- flat_primals_dtype = tuple(p.dtype for p in flat_primals)
- def forward_ad_checks(flat_tangents):
- for idx, t in enumerate(flat_tangents):
- if t.shape != flat_primals_shape[idx]:
- msg = (f"tangent:{idx} with shape {t.shape} in flattened "
- f"pytree doesn't match the shape {flat_primals_shape[idx]} "
- "of the corresponding primal.")
- raise RuntimeError(msg)
- if t.device != flat_primals_device[idx]:
- msg = (f"tangent:{idx} with device {t.device} in flattened "
- f"pytree doesn't match the device {flat_primals_device[idx]} "
- "of the corresponding primal.")
- raise RuntimeError(msg)
- if t.dtype != flat_primals_dtype[idx]:
- msg = (f"tangent:{idx} with dtype {t.dtype} in flattened "
- f"pytree doesn't match the dtype {flat_primals_dtype[idx]} "
- "of the corresponding primal.")
- raise RuntimeError(msg)
- # jvp_fn : callable to return
- # It takes care of checking the argspec of tangents,
- # calling the folded fx graph and unflattening fx graph output
- def jvp_fn(*tangents):
- flat_tangents, tangent_argspec = tree_flatten(tangents)
- if tangent_argspec != primals_argspec:
- raise RuntimeError(f"Expected the tangents {tangent_argspec} to have "
- f"the same argspec as the primals {primals_argspec}")
- forward_ad_checks(flat_tangents)
- flat_output = const_folded_jvp_graph(*flat_tangents)
- # const folded graph can return flat output,
- # so transform output.
- return tree_unflatten(flat_output, output_spec)
- return output, jvp_fn
|