123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727 |
- import torch
- import torch._C as _C
- from torch._C import _functions
- import torch._functorch as _functorch
- import torch.utils.hooks as hooks
- import functools
- import warnings
- from collections import OrderedDict
- from typing import Any, List, Optional, Tuple
- from torch._functorch.autograd_function import custom_function_call
- __all__ = ["FunctionCtx", "BackwardCFunction", "FunctionMeta", "Function", "once_differentiable", "traceable",
- "InplaceFunction", "NestedIOFunction"]
- # Formerly known as: _ContextMethodMixin
- class FunctionCtx:
- def save_for_backward(self, *tensors: torch.Tensor):
- r"""Saves given tensors for a future call to :func:`~Function.backward`.
- ``save_for_backward`` should be called at most once, only from inside the
- :func:`forward` method, and only with tensors.
- All tensors intended to be used in the backward pass should be saved
- with ``save_for_backward`` (as opposed to directly on ``ctx``) to prevent
- incorrect gradients and memory leaks, and enable the application of saved
- tensor hooks. See :class:`torch.autograd.graph.saved_tensors_hooks`.
- Note that if intermediary tensors, tensors that are neither inputs
- nor outputs of :func:`forward`, are saved for backward, your custom Function
- may not support double backward.
- Custom Functions that do not support double backward should decorate their
- :func:`backward` method with ``@once_differentiable`` so that performing
- double backward raises an error. If you'd like to support double backward,
- you can either recompute intermediaries based on the inputs during backward
- or return the intermediaries as the outputs of the custom Function. See the
- `double backward tutorial <https://pytorch.org/tutorials/intermediate/custom_function_double_backward_tutorial.html>`_
- for more details.
- In :func:`backward`, saved tensors can be accessed through the :attr:`saved_tensors`
- attribute. Before returning them to the user, a check is made to ensure
- they weren't used in any in-place operation that modified their content.
- Arguments can also be ``None``. This is a no-op.
- See :ref:`extending-autograd` for more details on how to use this method.
- Example::
- >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)
- >>> class Func(Function):
- >>> @staticmethod
- >>> def forward(ctx, x: torch.Tensor, y: torch.Tensor, z: int):
- >>> w = x * z
- >>> out = x * y + y * z + w * y
- >>> ctx.save_for_backward(x, y, w, out)
- >>> ctx.z = z # z is not a tensor
- >>> return out
- >>>
- >>> @staticmethod
- >>> @once_differentiable
- >>> def backward(ctx, grad_out):
- >>> x, y, w, out = ctx.saved_tensors
- >>> z = ctx.z
- >>> gx = grad_out * (y + y * z)
- >>> gy = grad_out * (x + z + w)
- >>> gz = None
- >>> return gx, gy, gz
- >>>
- >>> a = torch.tensor(1., requires_grad=True, dtype=torch.double)
- >>> b = torch.tensor(2., requires_grad=True, dtype=torch.double)
- >>> c = 4
- >>> d = Func.apply(a, b, c)
- """
- self.to_save = tensors
- def save_for_forward(self, *tensors: torch.Tensor):
- r"""Saves given tensors for a future call to :func:`~Function.jvp`.
- ``save_for_forward`` should be only called once, from inside the :func:`forward`
- method, and only be called with tensors.
- In :func:`jvp`, saved objects can be accessed through the :attr:`saved_tensors`
- attribute.
- Arguments can also be ``None``. This is a no-op.
- See :ref:`extending-autograd` for more details on how to use this method.
- Example::
- >>> # xdoctest: +SKIP
- >>> class Func(torch.autograd.Function):
- >>> @staticmethod
- >>> def forward(ctx, x: torch.Tensor, y: torch.Tensor, z: int):
- >>> ctx.save_for_backward(x, y)
- >>> ctx.save_for_forward(x, y)
- >>> ctx.z = z
- >>> return x * y * z
- >>>
- >>> @staticmethod
- >>> def jvp(ctx, x_t, y_t, _):
- >>> x, y = ctx.saved_tensors
- >>> z = ctx.z
- >>> return z * (y * x_t + x * y_t)
- >>>
- >>> @staticmethod
- >>> def vjp(ctx, grad_out):
- >>> x, y = ctx.saved_tensors
- >>> z = ctx.z
- >>> return z * grad_out * y, z * grad_out * x, None
- >>>
- >>> a = torch.tensor(1., requires_grad=True, dtype=torch.double)
- >>> t = torch.tensor(1., dtype=torch.double)
- >>> b = torch.tensor(2., requires_grad=True, dtype=torch.double)
- >>> c = 4
- >>>
- >>> with fwAD.dual_level():
- >>> a_dual = fwAD.make_dual(a, t)
- >>> d = Func.apply(a_dual, b, c)
- """
- for tensor in tensors:
- assert isinstance(tensor, torch.Tensor) or tensor is None, (
- "save_for_forward expects all arguments to be tensors; you should "
- "save non-tensors as attributes on ctx.")
- self.saved_for_forward = tensors
- def mark_dirty(self, *args: torch.Tensor):
- r"""Marks given tensors as modified in an in-place operation.
- **This should be called at most once, only from inside the**
- :func:`forward` **method, and all arguments should be inputs.**
- Every tensor that's been modified in-place in a call to :func:`forward`
- should be given to this function, to ensure correctness of our checks.
- It doesn't matter whether the function is called before or after
- modification.
- Examples::
- >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)
- >>> class Inplace(Function):
- >>> @staticmethod
- >>> def forward(ctx, x):
- >>> x_npy = x.numpy() # x_npy shares storage with x
- >>> x_npy += 1
- >>> ctx.mark_dirty(x)
- >>> return x
- >>>
- >>> @staticmethod
- >>> @once_differentiable
- >>> def backward(ctx, grad_output):
- >>> return grad_output
- >>>
- >>> a = torch.tensor(1., requires_grad=True, dtype=torch.double).clone()
- >>> b = a * a
- >>> Inplace.apply(a) # This would lead to wrong gradients!
- >>> # but the engine would not know unless we mark_dirty
- >>> # xdoctest: +SKIP
- >>> b.backward() # RuntimeError: one of the variables needed for gradient
- >>> # computation has been modified by an inplace operation
- """
- self.dirty_tensors = args
- def mark_shared_storage(self, *pairs):
- warnings.warn(
- 'mark_shared_storage is deprecated. '
- 'Tensors with shared storages are automatically tracked. Note '
- 'that calls to `set_()` are not tracked')
- def mark_non_differentiable(self, *args: torch.Tensor):
- r"""Marks outputs as non-differentiable.
- **This should be called at most once, only from inside the**
- :func:`forward` **method, and all arguments should be tensor outputs.**
- This will mark outputs as not requiring gradients, increasing the
- efficiency of backward computation. You still need to accept a gradient
- for each output in :meth:`~Function.backward`, but it's always going to
- be a zero tensor with the same shape as the shape of a corresponding
- output.
- This is used e.g. for indices returned from a sort. See example::
- >>> class Func(Function):
- >>> @staticmethod
- >>> def forward(ctx, x):
- >>> sorted, idx = x.sort()
- >>> ctx.mark_non_differentiable(idx)
- >>> ctx.save_for_backward(x, idx)
- >>> return sorted, idx
- >>>
- >>> @staticmethod
- >>> @once_differentiable
- >>> def backward(ctx, g1, g2): # still need to accept g2
- >>> x, idx = ctx.saved_tensors
- >>> grad_input = torch.zeros_like(x)
- >>> grad_input.index_add_(0, idx, g1)
- >>> return grad_input
- """
- self.non_differentiable = args
- def set_materialize_grads(self, value: bool):
- r"""Sets whether to materialize grad tensors. Default is ``True``.
- **This should be called only from inside the** :func:`forward` **method**
- If ``True``, undefined grad tensors will be expanded to tensors full of zeros
- prior to calling the :func:`backward` and :func:`jvp` methods.
- Example::
- >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)
- >>> class SimpleFunc(Function):
- >>> @staticmethod
- >>> def forward(ctx, x):
- >>> return x.clone(), x.clone()
- >>>
- >>> @staticmethod
- >>> @once_differentiable
- >>> def backward(ctx, g1, g2):
- >>> return g1 + g2 # No check for None necessary
- >>>
- >>> # We modify SimpleFunc to handle non-materialized grad outputs
- >>> class Func(Function):
- >>> @staticmethod
- >>> def forward(ctx, x):
- >>> ctx.set_materialize_grads(False)
- >>> ctx.save_for_backward(x)
- >>> return x.clone(), x.clone()
- >>>
- >>> @staticmethod
- >>> @once_differentiable
- >>> def backward(ctx, g1, g2):
- >>> x, = ctx.saved_tensors
- >>> grad_input = torch.zeros_like(x)
- >>> if g1 is not None: # We must check for None now
- >>> grad_input += g1
- >>> if g2 is not None:
- >>> grad_input += g2
- >>> return grad_input
- >>>
- >>> a = torch.tensor(1., requires_grad=True)
- >>> b, _ = Func.apply(a) # induces g2 to be undefined
- """
- self.materialize_grads = value
- # DO NOT USE: This is only defined to be able to load old serialized models
- _ContextMethodMixin = FunctionCtx
- class _HookMixin:
- @staticmethod
- def _register_hook(backward_hooks, hook):
- if backward_hooks is None:
- backward_hooks = OrderedDict()
- handle = hooks.RemovableHandle(backward_hooks)
- backward_hooks[handle.id] = hook
- return backward_hooks, handle
- class BackwardCFunction(_C._FunctionBase, FunctionCtx, _HookMixin):
- def apply(self, *args):
- # _forward_cls is defined by derived class
- # The user should define either backward or vjp but never both.
- backward_fn = self._forward_cls.backward # type: ignore[attr-defined]
- vjp_fn = self._forward_cls.vjp # type: ignore[attr-defined]
- if backward_fn is not Function.backward and vjp_fn is not Function.vjp:
- raise RuntimeError("Implementing both 'backward' and 'vjp' for a custom "
- "Function is not allowed. You should only implement one "
- "of them.")
- user_fn = vjp_fn if vjp_fn is not Function.vjp else backward_fn
- return user_fn(self, *args)
- def apply_jvp(self, *args):
- # _forward_cls is defined by derived class
- return self._forward_cls.jvp(self, *args) # type: ignore[attr-defined]
- class FunctionMeta(type):
- """Function metaclass.
- This metaclass sets up the following properties:
- _backward_cls: The Function class corresponding to the differentiated
- version of this function (which is generated on the fly by this
- metaclass).
- """
- def __init__(cls, name, bases, attrs):
- backward_fn = type(name + 'Backward', (BackwardCFunction,), {'_forward_cls': cls})
- cls._backward_cls = backward_fn
- super(FunctionMeta, cls).__init__(name, bases, attrs)
- class _SingleLevelFunction(_C._FunctionBase, FunctionCtx, _HookMixin, metaclass=FunctionMeta):
- @staticmethod
- def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any:
- r"""
- This function is to be overridden by all subclasses. There are two ways
- to define forward:
- Usage 1 (Combined forward and ctx)::
- @staticmethod
- def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any:
- pass
- - It must accept a context ctx as the first argument, followed by any
- number of arguments (tensors or other types).
- - See :ref:`combining-forward-context` for more details
- Usage 2 (Separate forward and ctx)::
- @staticmethod
- def forward(*args: Any, **kwargs: Any) -> Any:
- pass
- @staticmethod
- def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None:
- pass
- - The forward no longer accepts a ctx argument.
- - Instead, you must also override the :meth:`torch.autograd.Function.setup_context`
- staticmethod to handle setting up the ``ctx`` object.
- ``output`` is the output of the forward, ``inputs`` are a Tuple of inputs
- to the forward.
- - See :ref:`extending-autograd` for more details
- The context can be used to store arbitrary data that can be then
- retrieved during the backward pass. Tensors should not be stored
- directly on `ctx` (though this is not currently enforced for
- backward compatibility). Instead, tensors should be saved either with
- :func:`ctx.save_for_backward` if they are intended to be used in
- ``backward`` (equivalently, ``vjp``) or :func:`ctx.save_for_forward`
- if they are intended to be used for in ``jvp``.
- """
- raise NotImplementedError("You must implement the forward function for custom"
- " autograd.Function.")
- @staticmethod
- def setup_context(ctx: Any, inputs: Tuple[Any], output: Any) -> Any:
- r"""There are two ways to define the forward pass of an autograd.Function.
- Either:
- 1. Override forward with the signature forward(ctx, *args, **kwargs).
- ``setup_context`` is not overridden. Setting up the ctx for backward
- happens inside the ``forward``.
- 2. Override forward with the signature forward(*args, **kwargs) and
- override ``setup_context``. Setting up the ctx for backward happens
- inside ``setup_context`` (as opposed to inside the ``forward``)
- See :meth:`torch.autograd.Function.forward` and :ref:`extending-autograd` for more details.
- """
- raise NotImplementedError("setup_context is not implemented.")
- @staticmethod
- def backward(ctx: Any, *grad_outputs: Any) -> Any:
- r"""Defines a formula for differentiating the operation with backward mode
- automatic differentiation (alias to the vjp function).
- This function is to be overridden by all subclasses.
- It must accept a context :attr:`ctx` as the first argument, followed by
- as many outputs as the :func:`forward` returned (None will be passed in
- for non tensor outputs of the forward function),
- and it should return as many tensors, as there were inputs to
- :func:`forward`. Each argument is the gradient w.r.t the given output,
- and each returned value should be the gradient w.r.t. the
- corresponding input. If an input is not a Tensor or is a Tensor not
- requiring grads, you can just pass None as a gradient for that input.
- The context can be used to retrieve tensors saved during the forward
- pass. It also has an attribute :attr:`ctx.needs_input_grad` as a tuple
- of booleans representing whether each input needs gradient. E.g.,
- :func:`backward` will have ``ctx.needs_input_grad[0] = True`` if the
- first input to :func:`forward` needs gradient computated w.r.t. the
- output.
- """
- raise NotImplementedError("You must implement either the backward or vjp method for "
- "your custom autograd.Function to use it with backward "
- "mode AD.")
- # vjp and backward are alias of each other
- vjp = backward
- @staticmethod
- def jvp(ctx: Any, *grad_inputs: Any) -> Any:
- r"""Defines a formula for differentiating the operation with forward mode
- automatic differentiation.
- This function is to be overridden by all subclasses.
- It must accept a context :attr:`ctx` as the first argument, followed by
- as many inputs as the :func:`forward` got (None will be passed in
- for non tensor inputs of the forward function),
- and it should return as many tensors as there were outputs to
- :func:`forward`. Each argument is the gradient w.r.t the given input,
- and each returned value should be the gradient w.r.t. the
- corresponding output. If an output is not a Tensor or the function is not
- differentiable with respect to that output, you can just pass None as a
- gradient for that input.
- You can use the :attr:`ctx` object to pass any value from the forward to this
- functions.
- """
- raise NotImplementedError("You must implement the jvp function for custom "
- "autograd.Function to use it with forward mode AD.")
- class Function(_SingleLevelFunction):
- r"""Base class to create custom `autograd.Function`
- To create a custom `autograd.Function`, subclass this class and implement
- the :meth:`forward` and :meth:`backward` static methods. Then, to use your custom
- op in the forward pass, call the class method ``apply``. Do not call
- :meth:`forward` directly.
- To ensure correctness and best performance, make sure you are calling the
- correct methods on ``ctx`` and validating your backward function using
- :func:`torch.autograd.gradcheck`.
- See :ref:`extending-autograd` for more details on how to use this class.
- Examples::
- >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)
- >>> class Exp(Function):
- >>> @staticmethod
- >>> def forward(ctx, i):
- >>> result = i.exp()
- >>> ctx.save_for_backward(result)
- >>> return result
- >>>
- >>> @staticmethod
- >>> def backward(ctx, grad_output):
- >>> result, = ctx.saved_tensors
- >>> return grad_output * result
- >>>
- >>> # Use it by calling the apply method:
- >>> # xdoctest: +SKIP
- >>> output = Exp.apply(input)
- """
- def __init__(self, *args, **kwargs):
- cls = self.__class__
- warnings.warn(f"{cls} should not be instantiated. Methods on autograd functions"
- "are all static, so you should invoke them on the class itself. "
- "Instantiating an autograd function will raise an "
- "error in a future version of PyTorch.", DeprecationWarning)
- def __call__(self, *args, **kwargs):
- raise RuntimeError(
- "Legacy autograd function with non-static forward method is deprecated. "
- "Please use new-style autograd function with static forward method. "
- "(Example: https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function)")
- # for the tracer
- is_traceable = False
- """
- Bool that specifies if PyTorch should attempt to autogenerate
- :func:`torch.vmap` support for this autograd.Function. You may set this to
- True only if this autograd.Function's forward, backward, and jvp (if they
- exist) are written using PyTorch operations; otherwise, please override
- :meth:`torch.autograd.Function.vmap` to add support for :func:`torch.vmap`.
- Please see :ref:`func-autograd-function` for more details.
- """
- generate_vmap_rule = False
- @staticmethod
- def vmap(info, in_dims, *args):
- r"""Defines a rule for the behavior of this autograd.Function underneath
- :func:`torch.vmap`. For a :func:`torch.autograd.Function` to support
- :func:`torch.vmap`, you must either override this staticmethod, or set
- ``generate_vmap_rule`` to ``True`` (you may not do both).
- If you choose to override this staticmethod: it must accept
- - an ``info`` object as the first argument. ``info.batch_size``
- specifies the size of the dimension being vmapped over,
- while ``info.randomness`` is the randomness option passed to
- :func:`torch.vmap`.
- - an ``in_dims`` tuple as the second argument.
- For each arg in ``args``, ``in_dims`` has a corresponding
- ``Optional[int]``. It is ``None`` if the arg is not a Tensor or if
- the arg is not being vmapped over, otherwise, it is an integer
- specifying what dimension of the Tensor is being vmapped over.
- - ``*args``, which is the same as the args to :meth:`~Function.forward`.
- The return of the vmap staticmethod is a tuple of ``(output, out_dims)``.
- Similar to ``in_dims``, ``out_dims`` should be of the same structure as
- ``output`` and contain one ``out_dim`` per output that specifies if the
- output has the vmapped dimension and what index it is in.
- Please see :ref:`func-autograd-function` for more details.
- """
- raise NotImplementedError(
- "To use autograd.Function with vmap, you must either override the "
- "vmap staticmethod or set generate_vmap_rule=True.")
- @classmethod
- def apply(cls, *args, **kwargs):
- if not torch._C._are_functorch_transforms_active():
- # See NOTE: [functorch vjp and autograd interaction]
- args = _functorch.utils.unwrap_dead_wrappers(args)
- return super().apply(*args, **kwargs) # type: ignore[misc]
- if cls.setup_context == _SingleLevelFunction.setup_context:
- raise RuntimeError(
- 'In order to use an autograd.Function with functorch transforms '
- '(vmap, grad, jvp, jacrev, ...), it must override the setup_context '
- 'staticmethod. For more details, please see '
- 'https://pytorch.org/docs/master/notes/extending.func.html')
- return custom_function_call(cls, *args, **kwargs)
- def once_differentiable(fn):
- @functools.wraps(fn)
- def wrapper(ctx, *args):
- with torch.no_grad():
- outputs = fn(ctx, *args)
- if not torch.is_grad_enabled():
- return outputs
- # If any of the inputs have requires_grad=True, we force the outputs
- # to have requires_grad=True but point to a grad_fn which throws an
- # error message during (double) back-propagation.
- # XXX: this is only an approximation of requires_grad - there's no way
- # to figure out if fn didn't use ctx.saved_tensors and as a result
- # some Tensors might require grad, even if no args do.
- # Unfortunately, this leads to unexpected error messages ("no nodes
- # require computing gradients"), but I don't have a better idea.
- # These functions would raise an error in backward anyway.
- requires_grad = any(isinstance(arg, torch.Tensor) and arg.requires_grad
- for arg in args)
- if not requires_grad:
- return outputs
- if not isinstance(outputs, tuple):
- outputs = (outputs,)
- err_fn = _functions.DelayedError(
- b"trying to differentiate twice a function that was marked "
- b"with @once_differentiable", len(outputs))
- # Create aliases of each output that has requires_grad=True. We need
- # at least one of the inputs to err_fn to require grad so that the
- # output will have a grad_fn.
- def fake_requires_grad(var):
- if var is not None:
- var = var.detach()
- var.requires_grad = True
- return var
- return err_fn(*[fake_requires_grad(v) for v in outputs])
- return wrapper
- def traceable(fn_cls):
- r"""Marks Function as traceable for the JIT.
- Traceable functions have additional restrictions - they can't pass any
- data-dependent values to backward (e.g. Prod passes the output, which makes
- it non-traceable), and their backward should be implemented entirely in terms
- of operations on autograd Tensors in all cases.
- DON'T USE THIS DECORATOR. IT IS FOR INTERNAL USE ONLY AND SHOULD BE HANDLED WITH
- CARE (or can give incorrect results otherwise).
- """
- fn_cls.is_traceable = True
- return fn_cls
- class InplaceFunction(Function):
- def __init__(self, inplace=False):
- super().__init__()
- self.inplace = inplace
- def _nested_map(condition, fn, condition_msg=None):
- def _map(obj):
- if condition(obj):
- return fn(obj)
- elif obj is None:
- return None
- elif isinstance(obj, (list, tuple)):
- mapped = (_map(x) for x in obj)
- if hasattr(obj, '_fields'):
- # obj is namedtuple
- return type(obj)(*mapped)
- return type(obj)(mapped)
- elif isinstance(obj, dict):
- return {x : _map(obj[x]) for x in obj}
- else:
- raise ValueError("Auto nesting doesn't know how to process "
- "an input object of type " + torch.typename(obj) +
- (". Accepted types: " + condition_msg +
- ", or lists/tuples of them"
- if condition_msg else ""))
- return _map
- def _jit_unwrap_structured(obj):
- if hasattr(obj, "_jit_unwrap"):
- return obj._jit_unwrap()
- return obj
- def _iter_filter(condition, allow_unknown=False, condition_msg=None,
- conversion=None):
- def _iter(obj):
- if conversion is not None:
- obj = conversion(obj)
- if condition(obj):
- yield obj
- elif obj is None:
- return
- elif isinstance(obj, (list, tuple)):
- for o in obj:
- yield from _iter(o)
- elif isinstance(obj, dict):
- # We only accept primitive key types, so we needn't inspect them
- for o in obj.values():
- yield from _iter(o)
- elif allow_unknown:
- yield obj
- else:
- raise ValueError("Auto nesting doesn't know how to process "
- "an input object of type " + torch.typename(obj) +
- (". Accepted types: " + condition_msg +
- ", or lists/tuples of them"
- if condition_msg else ""))
- return _iter
- def _unflatten(input, proto):
- # unflatten a list or tuple input into a nested list/tuple structure
- # specified by proto
- def unflatten_helper(input, proto):
- res: List[Optional[torch.Tensor]] = []
- if hasattr(proto, "_jit_wrap"):
- return proto._jit_wrap(input)
- if not isinstance(proto, (list, tuple)):
- return input[0], input[1:]
- for e in proto:
- if e is None:
- res.append(e)
- else:
- res_e, input = unflatten_helper(input, e)
- res.append(res_e)
- return type(proto)(res), input
- return unflatten_helper(input, proto)[0]
- _iter_jit_values = _iter_filter(lambda o: o is None or isinstance(o, torch._C.Value),
- condition_msg="jit's Values or None")
- _iter_tensors = _iter_filter(lambda x: isinstance(x, torch.Tensor), condition_msg="Tensors",
- conversion=_jit_unwrap_structured)
- _iter_tensors_permissive = _iter_filter(lambda x: isinstance(x, torch.Tensor),
- allow_unknown=True,
- condition_msg="Tensors (permissive)")
- _iter_None_tensors = _iter_filter(lambda o: o is None or isinstance(o, torch.Tensor),
- condition_msg="Tensors or None")
- _map_tensor_data = _nested_map(lambda x: isinstance(x, torch.Tensor), lambda o: o.data,
- condition_msg="Tensors")
- class NestedIOFunction(Function):
- # The 'type: ignore' statements are needed here because these functions are declared as '@staticmethod' in the
- # superclass (Function) but are instance methods here, which mypy reports as incompatible.
- def _do_forward(self, *input):
- self._nested_input = input
- flat_input = tuple(_iter_tensors(input))
- flat_output = super()._do_forward(*flat_input) # type: ignore[misc]
- nested_output = self._nested_output
- nested_tensors = _unflatten(flat_output, self._nested_output)
- return nested_tensors
- def _do_backward(self, gradients, retain_variables):
- self.retain_variables = retain_variables
- result = super()._do_backward(gradients, retain_variables) # type: ignore[misc]
- if not retain_variables:
- del self._nested_output
- del self._to_save_nested
- return result
- def backward(self, *gradients: Any) -> Any: # type: ignore[override]
- nested_gradients = _unflatten(gradients, self._nested_output)
- result = self.backward_extended(*nested_gradients) # type: ignore[func-returns-value]
- return tuple(_iter_None_tensors(result))
- __call__ = _do_forward
- def forward(self, *args: Any) -> Any: # type: ignore[override]
- nested_tensors = _map_tensor_data(self._nested_input)
- result = self.forward_extended(*nested_tensors) # type: ignore[func-returns-value]
- del self._nested_input
- self._nested_output = result
- return tuple(_iter_tensors(result))
- def save_for_backward(self, *args: Any) -> None:
- self.to_save = tuple(_iter_tensors(args))
- self._to_save_nested = args
- @property
- def saved_tensors(self):
- flat_tensors = super().saved_tensors # type: ignore[misc]
- return _unflatten(flat_tensors, self._to_save_nested)
- def mark_dirty(self, *args: Any, **kwargs: Any) -> None:
- self.dirty_tensors = tuple(_iter_tensors((args, kwargs)))
- def mark_non_differentiable(self, *args: Any, **kwargs: Any) -> None:
- self.non_differentiable = tuple(_iter_tensors((args, kwargs)))
- def forward_extended(self, *input: Any) -> None:
- raise NotImplementedError
- def backward_extended(self, *grad_output: Any) -> None:
- raise NotImplementedError
|