function.py 31 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727
  1. import torch
  2. import torch._C as _C
  3. from torch._C import _functions
  4. import torch._functorch as _functorch
  5. import torch.utils.hooks as hooks
  6. import functools
  7. import warnings
  8. from collections import OrderedDict
  9. from typing import Any, List, Optional, Tuple
  10. from torch._functorch.autograd_function import custom_function_call
  11. __all__ = ["FunctionCtx", "BackwardCFunction", "FunctionMeta", "Function", "once_differentiable", "traceable",
  12. "InplaceFunction", "NestedIOFunction"]
  13. # Formerly known as: _ContextMethodMixin
  14. class FunctionCtx:
  15. def save_for_backward(self, *tensors: torch.Tensor):
  16. r"""Saves given tensors for a future call to :func:`~Function.backward`.
  17. ``save_for_backward`` should be called at most once, only from inside the
  18. :func:`forward` method, and only with tensors.
  19. All tensors intended to be used in the backward pass should be saved
  20. with ``save_for_backward`` (as opposed to directly on ``ctx``) to prevent
  21. incorrect gradients and memory leaks, and enable the application of saved
  22. tensor hooks. See :class:`torch.autograd.graph.saved_tensors_hooks`.
  23. Note that if intermediary tensors, tensors that are neither inputs
  24. nor outputs of :func:`forward`, are saved for backward, your custom Function
  25. may not support double backward.
  26. Custom Functions that do not support double backward should decorate their
  27. :func:`backward` method with ``@once_differentiable`` so that performing
  28. double backward raises an error. If you'd like to support double backward,
  29. you can either recompute intermediaries based on the inputs during backward
  30. or return the intermediaries as the outputs of the custom Function. See the
  31. `double backward tutorial <https://pytorch.org/tutorials/intermediate/custom_function_double_backward_tutorial.html>`_
  32. for more details.
  33. In :func:`backward`, saved tensors can be accessed through the :attr:`saved_tensors`
  34. attribute. Before returning them to the user, a check is made to ensure
  35. they weren't used in any in-place operation that modified their content.
  36. Arguments can also be ``None``. This is a no-op.
  37. See :ref:`extending-autograd` for more details on how to use this method.
  38. Example::
  39. >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)
  40. >>> class Func(Function):
  41. >>> @staticmethod
  42. >>> def forward(ctx, x: torch.Tensor, y: torch.Tensor, z: int):
  43. >>> w = x * z
  44. >>> out = x * y + y * z + w * y
  45. >>> ctx.save_for_backward(x, y, w, out)
  46. >>> ctx.z = z # z is not a tensor
  47. >>> return out
  48. >>>
  49. >>> @staticmethod
  50. >>> @once_differentiable
  51. >>> def backward(ctx, grad_out):
  52. >>> x, y, w, out = ctx.saved_tensors
  53. >>> z = ctx.z
  54. >>> gx = grad_out * (y + y * z)
  55. >>> gy = grad_out * (x + z + w)
  56. >>> gz = None
  57. >>> return gx, gy, gz
  58. >>>
  59. >>> a = torch.tensor(1., requires_grad=True, dtype=torch.double)
  60. >>> b = torch.tensor(2., requires_grad=True, dtype=torch.double)
  61. >>> c = 4
  62. >>> d = Func.apply(a, b, c)
  63. """
  64. self.to_save = tensors
  65. def save_for_forward(self, *tensors: torch.Tensor):
  66. r"""Saves given tensors for a future call to :func:`~Function.jvp`.
  67. ``save_for_forward`` should be only called once, from inside the :func:`forward`
  68. method, and only be called with tensors.
  69. In :func:`jvp`, saved objects can be accessed through the :attr:`saved_tensors`
  70. attribute.
  71. Arguments can also be ``None``. This is a no-op.
  72. See :ref:`extending-autograd` for more details on how to use this method.
  73. Example::
  74. >>> # xdoctest: +SKIP
  75. >>> class Func(torch.autograd.Function):
  76. >>> @staticmethod
  77. >>> def forward(ctx, x: torch.Tensor, y: torch.Tensor, z: int):
  78. >>> ctx.save_for_backward(x, y)
  79. >>> ctx.save_for_forward(x, y)
  80. >>> ctx.z = z
  81. >>> return x * y * z
  82. >>>
  83. >>> @staticmethod
  84. >>> def jvp(ctx, x_t, y_t, _):
  85. >>> x, y = ctx.saved_tensors
  86. >>> z = ctx.z
  87. >>> return z * (y * x_t + x * y_t)
  88. >>>
  89. >>> @staticmethod
  90. >>> def vjp(ctx, grad_out):
  91. >>> x, y = ctx.saved_tensors
  92. >>> z = ctx.z
  93. >>> return z * grad_out * y, z * grad_out * x, None
  94. >>>
  95. >>> a = torch.tensor(1., requires_grad=True, dtype=torch.double)
  96. >>> t = torch.tensor(1., dtype=torch.double)
  97. >>> b = torch.tensor(2., requires_grad=True, dtype=torch.double)
  98. >>> c = 4
  99. >>>
  100. >>> with fwAD.dual_level():
  101. >>> a_dual = fwAD.make_dual(a, t)
  102. >>> d = Func.apply(a_dual, b, c)
  103. """
  104. for tensor in tensors:
  105. assert isinstance(tensor, torch.Tensor) or tensor is None, (
  106. "save_for_forward expects all arguments to be tensors; you should "
  107. "save non-tensors as attributes on ctx.")
  108. self.saved_for_forward = tensors
  109. def mark_dirty(self, *args: torch.Tensor):
  110. r"""Marks given tensors as modified in an in-place operation.
  111. **This should be called at most once, only from inside the**
  112. :func:`forward` **method, and all arguments should be inputs.**
  113. Every tensor that's been modified in-place in a call to :func:`forward`
  114. should be given to this function, to ensure correctness of our checks.
  115. It doesn't matter whether the function is called before or after
  116. modification.
  117. Examples::
  118. >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)
  119. >>> class Inplace(Function):
  120. >>> @staticmethod
  121. >>> def forward(ctx, x):
  122. >>> x_npy = x.numpy() # x_npy shares storage with x
  123. >>> x_npy += 1
  124. >>> ctx.mark_dirty(x)
  125. >>> return x
  126. >>>
  127. >>> @staticmethod
  128. >>> @once_differentiable
  129. >>> def backward(ctx, grad_output):
  130. >>> return grad_output
  131. >>>
  132. >>> a = torch.tensor(1., requires_grad=True, dtype=torch.double).clone()
  133. >>> b = a * a
  134. >>> Inplace.apply(a) # This would lead to wrong gradients!
  135. >>> # but the engine would not know unless we mark_dirty
  136. >>> # xdoctest: +SKIP
  137. >>> b.backward() # RuntimeError: one of the variables needed for gradient
  138. >>> # computation has been modified by an inplace operation
  139. """
  140. self.dirty_tensors = args
  141. def mark_shared_storage(self, *pairs):
  142. warnings.warn(
  143. 'mark_shared_storage is deprecated. '
  144. 'Tensors with shared storages are automatically tracked. Note '
  145. 'that calls to `set_()` are not tracked')
  146. def mark_non_differentiable(self, *args: torch.Tensor):
  147. r"""Marks outputs as non-differentiable.
  148. **This should be called at most once, only from inside the**
  149. :func:`forward` **method, and all arguments should be tensor outputs.**
  150. This will mark outputs as not requiring gradients, increasing the
  151. efficiency of backward computation. You still need to accept a gradient
  152. for each output in :meth:`~Function.backward`, but it's always going to
  153. be a zero tensor with the same shape as the shape of a corresponding
  154. output.
  155. This is used e.g. for indices returned from a sort. See example::
  156. >>> class Func(Function):
  157. >>> @staticmethod
  158. >>> def forward(ctx, x):
  159. >>> sorted, idx = x.sort()
  160. >>> ctx.mark_non_differentiable(idx)
  161. >>> ctx.save_for_backward(x, idx)
  162. >>> return sorted, idx
  163. >>>
  164. >>> @staticmethod
  165. >>> @once_differentiable
  166. >>> def backward(ctx, g1, g2): # still need to accept g2
  167. >>> x, idx = ctx.saved_tensors
  168. >>> grad_input = torch.zeros_like(x)
  169. >>> grad_input.index_add_(0, idx, g1)
  170. >>> return grad_input
  171. """
  172. self.non_differentiable = args
  173. def set_materialize_grads(self, value: bool):
  174. r"""Sets whether to materialize grad tensors. Default is ``True``.
  175. **This should be called only from inside the** :func:`forward` **method**
  176. If ``True``, undefined grad tensors will be expanded to tensors full of zeros
  177. prior to calling the :func:`backward` and :func:`jvp` methods.
  178. Example::
  179. >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)
  180. >>> class SimpleFunc(Function):
  181. >>> @staticmethod
  182. >>> def forward(ctx, x):
  183. >>> return x.clone(), x.clone()
  184. >>>
  185. >>> @staticmethod
  186. >>> @once_differentiable
  187. >>> def backward(ctx, g1, g2):
  188. >>> return g1 + g2 # No check for None necessary
  189. >>>
  190. >>> # We modify SimpleFunc to handle non-materialized grad outputs
  191. >>> class Func(Function):
  192. >>> @staticmethod
  193. >>> def forward(ctx, x):
  194. >>> ctx.set_materialize_grads(False)
  195. >>> ctx.save_for_backward(x)
  196. >>> return x.clone(), x.clone()
  197. >>>
  198. >>> @staticmethod
  199. >>> @once_differentiable
  200. >>> def backward(ctx, g1, g2):
  201. >>> x, = ctx.saved_tensors
  202. >>> grad_input = torch.zeros_like(x)
  203. >>> if g1 is not None: # We must check for None now
  204. >>> grad_input += g1
  205. >>> if g2 is not None:
  206. >>> grad_input += g2
  207. >>> return grad_input
  208. >>>
  209. >>> a = torch.tensor(1., requires_grad=True)
  210. >>> b, _ = Func.apply(a) # induces g2 to be undefined
  211. """
  212. self.materialize_grads = value
  213. # DO NOT USE: This is only defined to be able to load old serialized models
  214. _ContextMethodMixin = FunctionCtx
  215. class _HookMixin:
  216. @staticmethod
  217. def _register_hook(backward_hooks, hook):
  218. if backward_hooks is None:
  219. backward_hooks = OrderedDict()
  220. handle = hooks.RemovableHandle(backward_hooks)
  221. backward_hooks[handle.id] = hook
  222. return backward_hooks, handle
  223. class BackwardCFunction(_C._FunctionBase, FunctionCtx, _HookMixin):
  224. def apply(self, *args):
  225. # _forward_cls is defined by derived class
  226. # The user should define either backward or vjp but never both.
  227. backward_fn = self._forward_cls.backward # type: ignore[attr-defined]
  228. vjp_fn = self._forward_cls.vjp # type: ignore[attr-defined]
  229. if backward_fn is not Function.backward and vjp_fn is not Function.vjp:
  230. raise RuntimeError("Implementing both 'backward' and 'vjp' for a custom "
  231. "Function is not allowed. You should only implement one "
  232. "of them.")
  233. user_fn = vjp_fn if vjp_fn is not Function.vjp else backward_fn
  234. return user_fn(self, *args)
  235. def apply_jvp(self, *args):
  236. # _forward_cls is defined by derived class
  237. return self._forward_cls.jvp(self, *args) # type: ignore[attr-defined]
  238. class FunctionMeta(type):
  239. """Function metaclass.
  240. This metaclass sets up the following properties:
  241. _backward_cls: The Function class corresponding to the differentiated
  242. version of this function (which is generated on the fly by this
  243. metaclass).
  244. """
  245. def __init__(cls, name, bases, attrs):
  246. backward_fn = type(name + 'Backward', (BackwardCFunction,), {'_forward_cls': cls})
  247. cls._backward_cls = backward_fn
  248. super(FunctionMeta, cls).__init__(name, bases, attrs)
  249. class _SingleLevelFunction(_C._FunctionBase, FunctionCtx, _HookMixin, metaclass=FunctionMeta):
  250. @staticmethod
  251. def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any:
  252. r"""
  253. This function is to be overridden by all subclasses. There are two ways
  254. to define forward:
  255. Usage 1 (Combined forward and ctx)::
  256. @staticmethod
  257. def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any:
  258. pass
  259. - It must accept a context ctx as the first argument, followed by any
  260. number of arguments (tensors or other types).
  261. - See :ref:`combining-forward-context` for more details
  262. Usage 2 (Separate forward and ctx)::
  263. @staticmethod
  264. def forward(*args: Any, **kwargs: Any) -> Any:
  265. pass
  266. @staticmethod
  267. def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None:
  268. pass
  269. - The forward no longer accepts a ctx argument.
  270. - Instead, you must also override the :meth:`torch.autograd.Function.setup_context`
  271. staticmethod to handle setting up the ``ctx`` object.
  272. ``output`` is the output of the forward, ``inputs`` are a Tuple of inputs
  273. to the forward.
  274. - See :ref:`extending-autograd` for more details
  275. The context can be used to store arbitrary data that can be then
  276. retrieved during the backward pass. Tensors should not be stored
  277. directly on `ctx` (though this is not currently enforced for
  278. backward compatibility). Instead, tensors should be saved either with
  279. :func:`ctx.save_for_backward` if they are intended to be used in
  280. ``backward`` (equivalently, ``vjp``) or :func:`ctx.save_for_forward`
  281. if they are intended to be used for in ``jvp``.
  282. """
  283. raise NotImplementedError("You must implement the forward function for custom"
  284. " autograd.Function.")
  285. @staticmethod
  286. def setup_context(ctx: Any, inputs: Tuple[Any], output: Any) -> Any:
  287. r"""There are two ways to define the forward pass of an autograd.Function.
  288. Either:
  289. 1. Override forward with the signature forward(ctx, *args, **kwargs).
  290. ``setup_context`` is not overridden. Setting up the ctx for backward
  291. happens inside the ``forward``.
  292. 2. Override forward with the signature forward(*args, **kwargs) and
  293. override ``setup_context``. Setting up the ctx for backward happens
  294. inside ``setup_context`` (as opposed to inside the ``forward``)
  295. See :meth:`torch.autograd.Function.forward` and :ref:`extending-autograd` for more details.
  296. """
  297. raise NotImplementedError("setup_context is not implemented.")
  298. @staticmethod
  299. def backward(ctx: Any, *grad_outputs: Any) -> Any:
  300. r"""Defines a formula for differentiating the operation with backward mode
  301. automatic differentiation (alias to the vjp function).
  302. This function is to be overridden by all subclasses.
  303. It must accept a context :attr:`ctx` as the first argument, followed by
  304. as many outputs as the :func:`forward` returned (None will be passed in
  305. for non tensor outputs of the forward function),
  306. and it should return as many tensors, as there were inputs to
  307. :func:`forward`. Each argument is the gradient w.r.t the given output,
  308. and each returned value should be the gradient w.r.t. the
  309. corresponding input. If an input is not a Tensor or is a Tensor not
  310. requiring grads, you can just pass None as a gradient for that input.
  311. The context can be used to retrieve tensors saved during the forward
  312. pass. It also has an attribute :attr:`ctx.needs_input_grad` as a tuple
  313. of booleans representing whether each input needs gradient. E.g.,
  314. :func:`backward` will have ``ctx.needs_input_grad[0] = True`` if the
  315. first input to :func:`forward` needs gradient computated w.r.t. the
  316. output.
  317. """
  318. raise NotImplementedError("You must implement either the backward or vjp method for "
  319. "your custom autograd.Function to use it with backward "
  320. "mode AD.")
  321. # vjp and backward are alias of each other
  322. vjp = backward
  323. @staticmethod
  324. def jvp(ctx: Any, *grad_inputs: Any) -> Any:
  325. r"""Defines a formula for differentiating the operation with forward mode
  326. automatic differentiation.
  327. This function is to be overridden by all subclasses.
  328. It must accept a context :attr:`ctx` as the first argument, followed by
  329. as many inputs as the :func:`forward` got (None will be passed in
  330. for non tensor inputs of the forward function),
  331. and it should return as many tensors as there were outputs to
  332. :func:`forward`. Each argument is the gradient w.r.t the given input,
  333. and each returned value should be the gradient w.r.t. the
  334. corresponding output. If an output is not a Tensor or the function is not
  335. differentiable with respect to that output, you can just pass None as a
  336. gradient for that input.
  337. You can use the :attr:`ctx` object to pass any value from the forward to this
  338. functions.
  339. """
  340. raise NotImplementedError("You must implement the jvp function for custom "
  341. "autograd.Function to use it with forward mode AD.")
  342. class Function(_SingleLevelFunction):
  343. r"""Base class to create custom `autograd.Function`
  344. To create a custom `autograd.Function`, subclass this class and implement
  345. the :meth:`forward` and :meth:`backward` static methods. Then, to use your custom
  346. op in the forward pass, call the class method ``apply``. Do not call
  347. :meth:`forward` directly.
  348. To ensure correctness and best performance, make sure you are calling the
  349. correct methods on ``ctx`` and validating your backward function using
  350. :func:`torch.autograd.gradcheck`.
  351. See :ref:`extending-autograd` for more details on how to use this class.
  352. Examples::
  353. >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)
  354. >>> class Exp(Function):
  355. >>> @staticmethod
  356. >>> def forward(ctx, i):
  357. >>> result = i.exp()
  358. >>> ctx.save_for_backward(result)
  359. >>> return result
  360. >>>
  361. >>> @staticmethod
  362. >>> def backward(ctx, grad_output):
  363. >>> result, = ctx.saved_tensors
  364. >>> return grad_output * result
  365. >>>
  366. >>> # Use it by calling the apply method:
  367. >>> # xdoctest: +SKIP
  368. >>> output = Exp.apply(input)
  369. """
  370. def __init__(self, *args, **kwargs):
  371. cls = self.__class__
  372. warnings.warn(f"{cls} should not be instantiated. Methods on autograd functions"
  373. "are all static, so you should invoke them on the class itself. "
  374. "Instantiating an autograd function will raise an "
  375. "error in a future version of PyTorch.", DeprecationWarning)
  376. def __call__(self, *args, **kwargs):
  377. raise RuntimeError(
  378. "Legacy autograd function with non-static forward method is deprecated. "
  379. "Please use new-style autograd function with static forward method. "
  380. "(Example: https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function)")
  381. # for the tracer
  382. is_traceable = False
  383. """
  384. Bool that specifies if PyTorch should attempt to autogenerate
  385. :func:`torch.vmap` support for this autograd.Function. You may set this to
  386. True only if this autograd.Function's forward, backward, and jvp (if they
  387. exist) are written using PyTorch operations; otherwise, please override
  388. :meth:`torch.autograd.Function.vmap` to add support for :func:`torch.vmap`.
  389. Please see :ref:`func-autograd-function` for more details.
  390. """
  391. generate_vmap_rule = False
  392. @staticmethod
  393. def vmap(info, in_dims, *args):
  394. r"""Defines a rule for the behavior of this autograd.Function underneath
  395. :func:`torch.vmap`. For a :func:`torch.autograd.Function` to support
  396. :func:`torch.vmap`, you must either override this staticmethod, or set
  397. ``generate_vmap_rule`` to ``True`` (you may not do both).
  398. If you choose to override this staticmethod: it must accept
  399. - an ``info`` object as the first argument. ``info.batch_size``
  400. specifies the size of the dimension being vmapped over,
  401. while ``info.randomness`` is the randomness option passed to
  402. :func:`torch.vmap`.
  403. - an ``in_dims`` tuple as the second argument.
  404. For each arg in ``args``, ``in_dims`` has a corresponding
  405. ``Optional[int]``. It is ``None`` if the arg is not a Tensor or if
  406. the arg is not being vmapped over, otherwise, it is an integer
  407. specifying what dimension of the Tensor is being vmapped over.
  408. - ``*args``, which is the same as the args to :meth:`~Function.forward`.
  409. The return of the vmap staticmethod is a tuple of ``(output, out_dims)``.
  410. Similar to ``in_dims``, ``out_dims`` should be of the same structure as
  411. ``output`` and contain one ``out_dim`` per output that specifies if the
  412. output has the vmapped dimension and what index it is in.
  413. Please see :ref:`func-autograd-function` for more details.
  414. """
  415. raise NotImplementedError(
  416. "To use autograd.Function with vmap, you must either override the "
  417. "vmap staticmethod or set generate_vmap_rule=True.")
  418. @classmethod
  419. def apply(cls, *args, **kwargs):
  420. if not torch._C._are_functorch_transforms_active():
  421. # See NOTE: [functorch vjp and autograd interaction]
  422. args = _functorch.utils.unwrap_dead_wrappers(args)
  423. return super().apply(*args, **kwargs) # type: ignore[misc]
  424. if cls.setup_context == _SingleLevelFunction.setup_context:
  425. raise RuntimeError(
  426. 'In order to use an autograd.Function with functorch transforms '
  427. '(vmap, grad, jvp, jacrev, ...), it must override the setup_context '
  428. 'staticmethod. For more details, please see '
  429. 'https://pytorch.org/docs/master/notes/extending.func.html')
  430. return custom_function_call(cls, *args, **kwargs)
  431. def once_differentiable(fn):
  432. @functools.wraps(fn)
  433. def wrapper(ctx, *args):
  434. with torch.no_grad():
  435. outputs = fn(ctx, *args)
  436. if not torch.is_grad_enabled():
  437. return outputs
  438. # If any of the inputs have requires_grad=True, we force the outputs
  439. # to have requires_grad=True but point to a grad_fn which throws an
  440. # error message during (double) back-propagation.
  441. # XXX: this is only an approximation of requires_grad - there's no way
  442. # to figure out if fn didn't use ctx.saved_tensors and as a result
  443. # some Tensors might require grad, even if no args do.
  444. # Unfortunately, this leads to unexpected error messages ("no nodes
  445. # require computing gradients"), but I don't have a better idea.
  446. # These functions would raise an error in backward anyway.
  447. requires_grad = any(isinstance(arg, torch.Tensor) and arg.requires_grad
  448. for arg in args)
  449. if not requires_grad:
  450. return outputs
  451. if not isinstance(outputs, tuple):
  452. outputs = (outputs,)
  453. err_fn = _functions.DelayedError(
  454. b"trying to differentiate twice a function that was marked "
  455. b"with @once_differentiable", len(outputs))
  456. # Create aliases of each output that has requires_grad=True. We need
  457. # at least one of the inputs to err_fn to require grad so that the
  458. # output will have a grad_fn.
  459. def fake_requires_grad(var):
  460. if var is not None:
  461. var = var.detach()
  462. var.requires_grad = True
  463. return var
  464. return err_fn(*[fake_requires_grad(v) for v in outputs])
  465. return wrapper
  466. def traceable(fn_cls):
  467. r"""Marks Function as traceable for the JIT.
  468. Traceable functions have additional restrictions - they can't pass any
  469. data-dependent values to backward (e.g. Prod passes the output, which makes
  470. it non-traceable), and their backward should be implemented entirely in terms
  471. of operations on autograd Tensors in all cases.
  472. DON'T USE THIS DECORATOR. IT IS FOR INTERNAL USE ONLY AND SHOULD BE HANDLED WITH
  473. CARE (or can give incorrect results otherwise).
  474. """
  475. fn_cls.is_traceable = True
  476. return fn_cls
  477. class InplaceFunction(Function):
  478. def __init__(self, inplace=False):
  479. super().__init__()
  480. self.inplace = inplace
  481. def _nested_map(condition, fn, condition_msg=None):
  482. def _map(obj):
  483. if condition(obj):
  484. return fn(obj)
  485. elif obj is None:
  486. return None
  487. elif isinstance(obj, (list, tuple)):
  488. mapped = (_map(x) for x in obj)
  489. if hasattr(obj, '_fields'):
  490. # obj is namedtuple
  491. return type(obj)(*mapped)
  492. return type(obj)(mapped)
  493. elif isinstance(obj, dict):
  494. return {x : _map(obj[x]) for x in obj}
  495. else:
  496. raise ValueError("Auto nesting doesn't know how to process "
  497. "an input object of type " + torch.typename(obj) +
  498. (". Accepted types: " + condition_msg +
  499. ", or lists/tuples of them"
  500. if condition_msg else ""))
  501. return _map
  502. def _jit_unwrap_structured(obj):
  503. if hasattr(obj, "_jit_unwrap"):
  504. return obj._jit_unwrap()
  505. return obj
  506. def _iter_filter(condition, allow_unknown=False, condition_msg=None,
  507. conversion=None):
  508. def _iter(obj):
  509. if conversion is not None:
  510. obj = conversion(obj)
  511. if condition(obj):
  512. yield obj
  513. elif obj is None:
  514. return
  515. elif isinstance(obj, (list, tuple)):
  516. for o in obj:
  517. yield from _iter(o)
  518. elif isinstance(obj, dict):
  519. # We only accept primitive key types, so we needn't inspect them
  520. for o in obj.values():
  521. yield from _iter(o)
  522. elif allow_unknown:
  523. yield obj
  524. else:
  525. raise ValueError("Auto nesting doesn't know how to process "
  526. "an input object of type " + torch.typename(obj) +
  527. (". Accepted types: " + condition_msg +
  528. ", or lists/tuples of them"
  529. if condition_msg else ""))
  530. return _iter
  531. def _unflatten(input, proto):
  532. # unflatten a list or tuple input into a nested list/tuple structure
  533. # specified by proto
  534. def unflatten_helper(input, proto):
  535. res: List[Optional[torch.Tensor]] = []
  536. if hasattr(proto, "_jit_wrap"):
  537. return proto._jit_wrap(input)
  538. if not isinstance(proto, (list, tuple)):
  539. return input[0], input[1:]
  540. for e in proto:
  541. if e is None:
  542. res.append(e)
  543. else:
  544. res_e, input = unflatten_helper(input, e)
  545. res.append(res_e)
  546. return type(proto)(res), input
  547. return unflatten_helper(input, proto)[0]
  548. _iter_jit_values = _iter_filter(lambda o: o is None or isinstance(o, torch._C.Value),
  549. condition_msg="jit's Values or None")
  550. _iter_tensors = _iter_filter(lambda x: isinstance(x, torch.Tensor), condition_msg="Tensors",
  551. conversion=_jit_unwrap_structured)
  552. _iter_tensors_permissive = _iter_filter(lambda x: isinstance(x, torch.Tensor),
  553. allow_unknown=True,
  554. condition_msg="Tensors (permissive)")
  555. _iter_None_tensors = _iter_filter(lambda o: o is None or isinstance(o, torch.Tensor),
  556. condition_msg="Tensors or None")
  557. _map_tensor_data = _nested_map(lambda x: isinstance(x, torch.Tensor), lambda o: o.data,
  558. condition_msg="Tensors")
  559. class NestedIOFunction(Function):
  560. # The 'type: ignore' statements are needed here because these functions are declared as '@staticmethod' in the
  561. # superclass (Function) but are instance methods here, which mypy reports as incompatible.
  562. def _do_forward(self, *input):
  563. self._nested_input = input
  564. flat_input = tuple(_iter_tensors(input))
  565. flat_output = super()._do_forward(*flat_input) # type: ignore[misc]
  566. nested_output = self._nested_output
  567. nested_tensors = _unflatten(flat_output, self._nested_output)
  568. return nested_tensors
  569. def _do_backward(self, gradients, retain_variables):
  570. self.retain_variables = retain_variables
  571. result = super()._do_backward(gradients, retain_variables) # type: ignore[misc]
  572. if not retain_variables:
  573. del self._nested_output
  574. del self._to_save_nested
  575. return result
  576. def backward(self, *gradients: Any) -> Any: # type: ignore[override]
  577. nested_gradients = _unflatten(gradients, self._nested_output)
  578. result = self.backward_extended(*nested_gradients) # type: ignore[func-returns-value]
  579. return tuple(_iter_None_tensors(result))
  580. __call__ = _do_forward
  581. def forward(self, *args: Any) -> Any: # type: ignore[override]
  582. nested_tensors = _map_tensor_data(self._nested_input)
  583. result = self.forward_extended(*nested_tensors) # type: ignore[func-returns-value]
  584. del self._nested_input
  585. self._nested_output = result
  586. return tuple(_iter_tensors(result))
  587. def save_for_backward(self, *args: Any) -> None:
  588. self.to_save = tuple(_iter_tensors(args))
  589. self._to_save_nested = args
  590. @property
  591. def saved_tensors(self):
  592. flat_tensors = super().saved_tensors # type: ignore[misc]
  593. return _unflatten(flat_tensors, self._to_save_nested)
  594. def mark_dirty(self, *args: Any, **kwargs: Any) -> None:
  595. self.dirty_tensors = tuple(_iter_tensors((args, kwargs)))
  596. def mark_non_differentiable(self, *args: Any, **kwargs: Any) -> None:
  597. self.non_differentiable = tuple(_iter_tensors((args, kwargs)))
  598. def forward_extended(self, *input: Any) -> None:
  599. raise NotImplementedError
  600. def backward_extended(self, *grad_output: Any) -> None:
  601. raise NotImplementedError