autograd_function.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634
  1. import torch
  2. from torch._ops import PyOperator
  3. from torch._C._functorch import TransformType
  4. from torch._functorch.utils import enable_single_level_autograd_function
  5. import torch.utils._pytree as pytree
  6. from torch._C._functorch import (
  7. _wrap_for_grad,
  8. _unwrap_for_grad,
  9. current_level,
  10. )
  11. from torch._functorch.vmap import (
  12. wrap_batched,
  13. unwrap_batched,
  14. vmap,
  15. restore_vmap,
  16. _add_batch_dim,
  17. )
  18. from torch._functorch.vmap import _broadcast_to_and_flatten
  19. from torch.autograd.forward_ad import _set_fwd_grad_enabled
  20. from typing import Any, NamedTuple, Tuple
  21. # autograd.Function technically runs before the regular PyTorch dispatcher.
  22. # This is how features like autocast and torch_dispatch (e.g. PythonTLSSnapshot)
  23. # work with it. One day we might decide to change this, but until then,
  24. # we need to give the illusion that autograd.Function runs before those things.
  25. #
  26. # We do this by using creating a custom PyOperator that only functorch
  27. # dispatches specially.
  28. class CustomFunctionPyOperator(PyOperator):
  29. def __init__(self):
  30. super().__init__('custom_function_call')
  31. def __call__(self, autograd_function, *args, **kwargs):
  32. # When custom_function_call is done dispatching through functorch,
  33. # it should just invoke the autograd.Function. This is consistent
  34. # with the autograd.Function behavior of being invoked before the
  35. # PyTorch dispatcher.
  36. #
  37. # This will lead us into trouble later down the line, but this is
  38. # pre-existing. There is an invariant that a function traced by
  39. # make_fx should have the same behavior when provided the same
  40. # Tensor. However, make_fx sees autograd.Function as a composite
  41. # (because autograd.Function happens before the Python dispatch key)
  42. # and only traces the forward pass.
  43. if torch._C._are_functorch_transforms_active():
  44. return super().__call__(autograd_function, *args, **kwargs)
  45. return autograd_function.apply(*args, **kwargs)
  46. # "custom_function_call"
  47. # This is the mechanism for an autograd.Function that works with functorch transforms.
  48. # It wraps an autograd.Function; interactions with functorch transforms are defined
  49. # via PyDispatcher and PyOperator rather than through the traditional PyTorch
  50. # dispatcher.
  51. custom_function_call = CustomFunctionPyOperator()
  52. # The grad rule for custom_function_call is to construct a new _SingleLevelFunction
  53. # (autograd.Function that only works with a single layer (level) of functorch) that:
  54. # - unwraps the inputs
  55. # - redispatches to custom_function_call
  56. # - wraps the outputs
  57. # and whose backward pass calls the original autograd.Function's backward.
  58. #
  59. # Why do we need to redispatch to custom_function_call?
  60. # -----------------------------------------------------
  61. # This is consistent with how ATen operators work with functorch's grad transform:
  62. # they always redispatch to the original operator.
  63. # Consider torch.sin, and let's say we do grad0(grad1(torch.sin))(x)
  64. #
  65. # grad1 will:
  66. # - set up the autograd graph
  67. # - unwrap the inputs
  68. # - redispatch to at::sin (*)
  69. # - rewrap the outputs on the return
  70. #
  71. # On the redispatch in (*), grad0 will:
  72. # - set up the autograd graph
  73. # - unwrap the inputs
  74. # - redispatch to at::sin
  75. # - rewrap the outputs on the return
  76. #
  77. # To "set up the autograd graph", we generate a _SingleLevelFunction
  78. # and apply it.
  79. @custom_function_call.py_impl(TransformType.Grad)
  80. @custom_function_call.py_impl(TransformType.Jvp)
  81. def custom_function_call_grad(interpreter, autograd_function, *operands):
  82. Generated = generate_single_level_function(interpreter, autograd_function)
  83. with enable_single_level_autograd_function():
  84. flat_out = Generated.apply(*operands)
  85. return flat_out
  86. def generate_single_level_function(interpreter, autograd_function):
  87. level = interpreter.level()
  88. def forward(*operands):
  89. unwrapped_operands = pytree.tree_map_only(
  90. torch.Tensor,
  91. lambda x: _unwrap_for_grad(x, level),
  92. operands)
  93. # Both enable_grad() and _set_fwd_grad_enabled() are necessary no matter
  94. # the transform. _SingleLevelFunction will turn off both fwd and bwd
  95. # gradient computation and we need to turn it back on here.
  96. with torch.enable_grad(), _set_fwd_grad_enabled(True), interpreter.lower():
  97. unwrapped_output = custom_function_call(autograd_function, *unwrapped_operands)
  98. # See NOTE [mark_dirty object identity check]
  99. def wrap_fn(output):
  100. return _wrap_for_grad(output, level)
  101. return wrap_outputs_maintaining_identity(
  102. unwrapped_output,
  103. unwrapped_operands,
  104. operands,
  105. wrap_fn)
  106. def setup_context(ctx, inputs, output):
  107. return autograd_function.setup_context(ctx, inputs, output)
  108. # backward is only used if the transform is TransformType.Grad
  109. def backward(ctx, *grads):
  110. result = autograd_function.backward(ctx, *grads)
  111. return result
  112. # jvp is only used if the transform is TransformType.Jvp
  113. def jvp(ctx, *tangents):
  114. result = autograd_function.jvp(ctx, *tangents)
  115. return result
  116. # This is the sequence of magic words to dynamically generate a Subclass with
  117. # a given name. A Tensor's .grad_fn field has a class name that is the original
  118. # autograd.Function's name + Backward, so we do this to generate some
  119. # meaningful name.
  120. name = f'{autograd_function.__name__}Generated'
  121. Generated = type(
  122. name,
  123. (torch.autograd.function._SingleLevelFunction,),
  124. {
  125. 'forward': staticmethod(forward),
  126. 'backward': staticmethod(backward),
  127. 'jvp': staticmethod(jvp),
  128. 'setup_context': staticmethod(setup_context),
  129. },
  130. )
  131. return Generated
  132. # wrap_outputs_maintaining_identity handles outputs from the vmap,
  133. # backward (vjp), and jvp staticmethod. The way it distinguishes
  134. # between the vmap case and the {backward, jvp} case is if the out_dims
  135. # are specified or not.
  136. #
  137. # NB: we cannot use out_dims=None as the deciding factor. This because
  138. # out_dims=None can still happen in the vmap staticmethod! What the
  139. # user is saying in that case is that their output does not have a
  140. # dimension that is being vmapped over, which is valid.
  141. NO_OUT_DIMS = "not specified"
  142. # NOTE [mark_dirty object identity check]
  143. # autograd.Function's ctx.mark_dirty expect a returned input
  144. # to have the same object identity as the input.
  145. # Mode-only functorch will greatly simplify this logic.
  146. def wrap_outputs_maintaining_identity(
  147. outputs, unwrapped_inputs, orig_inputs, wrap_fn, out_dims=NO_OUT_DIMS):
  148. flat_unwrapped_inputs, _ = pytree.tree_flatten(unwrapped_inputs)
  149. flat_orig_inputs, _ = pytree.tree_flatten(orig_inputs)
  150. unwrapped_input_to_orig_input = {
  151. id(unwrapped): orig
  152. for unwrapped, orig in zip(flat_unwrapped_inputs, flat_orig_inputs)
  153. }
  154. flat_outputs, spec = pytree.tree_flatten(outputs)
  155. result = []
  156. out_dims_specified = out_dims != NO_OUT_DIMS
  157. if out_dims_specified:
  158. flat_out_dims = _broadcast_to_and_flatten(out_dims, spec)
  159. # _broadcast_to_and_flatten returns None if it is unable to broadcast.
  160. # TODO: update following link from master to stable once that's out
  161. if flat_out_dims is None:
  162. raise RuntimeError(
  163. f"The autograd.Function's vmap staticmethod returned an "
  164. f"incompatible (output, out_dims) tuple. "
  165. f"Expected out_dims={out_dims} "
  166. f"to be compatible with the structure of `output`. "
  167. f"out_dims has structure {pytree.tree_flatten(out_dims)[1]} "
  168. f"but output has structure {spec}. "
  169. f"For more details, please see "
  170. f"https://pytorch.org/docs/master/notes/extending.func.html"
  171. )
  172. for i, output in enumerate(flat_outputs):
  173. if not isinstance(output, torch.Tensor):
  174. result.append(output)
  175. continue
  176. if id(output) in unwrapped_input_to_orig_input:
  177. result.append(unwrapped_input_to_orig_input[id(output)])
  178. continue
  179. if out_dims_specified:
  180. result.append(wrap_fn(output, flat_out_dims[i])) # type: ignore[index]
  181. else:
  182. result.append(wrap_fn(output))
  183. return pytree.tree_unflatten(result, spec)
  184. # NOTE: [functorch vjp and autograd interaction]
  185. # There's an edge case with the functorch vjp and autograd interaction
  186. # that will eventually be fixed by mode-only functorch.
  187. # The TL;DR is that there's no way to unwrap a dead GradTensorWrapper,
  188. # so we (the framework) need to do it manually. Regular PyTorch operators
  189. # automatically do so this is consisent.
  190. #
  191. # class MyExp(torch.autograd.Function):
  192. # @staticmethod
  193. # def forward(x):
  194. # return x.exp()
  195. #
  196. # @staticmethod
  197. # def setup_context(ctx, inputs, output):
  198. # y = output
  199. # ctx.save_for_backward(y)
  200. #
  201. # @staticmethod
  202. # def backward(gy):
  203. # y, = ctx.saved_tensors()
  204. # return MyMul.apply(gy, y)
  205. #
  206. # x = torch.randn([], requires_grad=True)
  207. # gy = torch.randn([], requires_grad=True)
  208. # _, vjp_fn = vjp(MySin.apply, x)
  209. # result = vjp_fn(gy)
  210. #
  211. # MyMul is an autograd.Function that is not shown here.
  212. # It saves a `y` for backward (since gy requires grad).
  213. #
  214. # in vjp_fn(gy), we get:
  215. # > MyMul.apply(gy, GradTensorWrapper(y, level=dead))
  216. # Because the y that is saved for backward by MyExp is a GradTensorWrapper
  217. # but is now dead since we are outside the vjp context.
  218. #
  219. # PyTorch dispatcher operations, upon seeing a dead GradTensorWrapper,
  220. # will automatically unwrap the GradTensorWrapper when applied.
  221. # But since autograd.Function technically sits above the regular PyTorch
  222. # dispatcher, it doesn't get this treatment. So we manually do
  223. # the unwrapping to be consistent with regular PyTorch dispatcher operations.
  224. class VmapInfo(NamedTuple):
  225. batch_size: int
  226. randomness: str
  227. def has_overriden_vmap_rule(autograd_function):
  228. return autograd_function.vmap is not torch.autograd.Function.vmap
  229. def validate_vmap_returns_tuple_of_two_elements(result):
  230. base_error_msg = (
  231. "Expected the vmap staticmethod to have two returns, an output "
  232. "and out_dims with pytree structure compatible with the output. "
  233. )
  234. if not isinstance(result, tuple):
  235. raise RuntimeError(base_error_msg + f"Got a {type(result)} instead")
  236. if not len(result) == 2:
  237. raise RuntimeError(base_error_msg + f"Got {len(result)} returns instead")
  238. @custom_function_call.py_impl(TransformType.Vmap)
  239. def custom_function_call_vmap(interpreter, autograd_function, *operands):
  240. if autograd_function.generate_vmap_rule:
  241. if has_overriden_vmap_rule(autograd_function):
  242. # TODO: Update link to stable once that's out
  243. # https://github.com/pytorch/pytorch/issues/92029
  244. raise RuntimeError(
  245. f"You tried to vmap over {autograd_function.__name__}, but "
  246. f"it has both generate_vmap_rule=True and an overriden vmap "
  247. f"staticmethod. Please set generate_vmap_rule=False or delete "
  248. f"the overriden vmap staticmethod to avoid ambiguity. "
  249. f"For more details, please see "
  250. f"https://pytorch.org/docs/master/notes/extending.func.html")
  251. return custom_function_call_vmap_generate_rule(interpreter, autograd_function, *operands)
  252. if not has_overriden_vmap_rule(autograd_function):
  253. # TODO: Update link to stable once that's out
  254. # https://github.com/pytorch/pytorch/issues/92029
  255. raise RuntimeError(
  256. f"You tried to vmap over {autograd_function.__name__}, but "
  257. f"it does not have vmap support. Please override and implement the "
  258. f"vmap staticmethod or set generate_vmap_rule=True. "
  259. f"For more details, please see "
  260. f"https://pytorch.org/docs/master/notes/extending.func.html")
  261. current_level = interpreter.level()
  262. info = VmapInfo(
  263. batch_size=interpreter.batch_size(),
  264. randomness=interpreter.randomness(),
  265. )
  266. unwrapped_operands, in_dims = unwrap_batched(operands, current_level)
  267. # If none of the tensors are batched at the current level, then we skip the
  268. # current level. This saves the user from needing to handle this case in
  269. # their vmap staticmethod (and is consistent with our C++ batching rule API)
  270. if pytree.tree_all(lambda dim: dim is None, in_dims):
  271. with interpreter.lower():
  272. return custom_function_call(autograd_function, *operands)
  273. with interpreter.lower():
  274. result = autograd_function.vmap(info, in_dims, *unwrapped_operands)
  275. validate_vmap_returns_tuple_of_two_elements(result)
  276. unwrapped_output, out_dims = result
  277. # See NOTE [mark_dirty object identity check]
  278. def wrap_fn(output, out_dim):
  279. return output if out_dim is None else _add_batch_dim(output, out_dim, current_level)
  280. return wrap_outputs_maintaining_identity(
  281. unwrapped_output,
  282. unwrapped_operands,
  283. operands,
  284. wrap_fn,
  285. out_dims=out_dims)
  286. def custom_function_call_vmap_generate_rule(interpreter, autograd_function, *operands):
  287. unwrapped_operands, in_dims = unwrap_batched(operands, interpreter.level())
  288. vmapped_function, get_out_dims = vmapify_autograd_function(
  289. autograd_function, in_dims, interpreter.batch_size(), interpreter.randomness())
  290. with interpreter.lower():
  291. output = custom_function_call(vmapped_function, *unwrapped_operands)
  292. out_dims = get_out_dims()
  293. return wrap_batched(output, out_dims, interpreter.level())
  294. @custom_function_call.py_impl(TransformType.Functionalize)
  295. def custom_function_call_functionalize(interpreter, autograd_function, generate_vmap_rule, *operands):
  296. raise RuntimeError("NYI: Functionalize rule for custom_function_call")
  297. def vmapify_autograd_function(autograd_function, in_dims, batch_size, randomness):
  298. # The following values are saved from the forward() and setup_context()
  299. # and used in backward().
  300. # Why do we save the values out here instead of on the ctx object?
  301. # - out_dims: There's no way to retrieve this from forward()
  302. # - input_shapes, saved_tensors_bdims: I'm a bit scared of nesting
  303. # vmap(vmap( but not completely sure if it is a problem. If we
  304. # assigned those fields to the ctx object, the worry is that they
  305. # get overwritten.
  306. out_dims = "not populated"
  307. input_shapes: Any = "not populated"
  308. saved_tensors_bdims: Any = "not populated"
  309. def forward(*operands):
  310. nonlocal out_dims
  311. outputs, out_dims = restore_vmap(
  312. autograd_function.forward, in_dims, batch_size, randomness)(*operands)
  313. return outputs
  314. def setup_context(ctx, inputs, outputs):
  315. input_shapes_ = None
  316. saved_tensors_bdims_ = None
  317. def inner(inputs, outputs):
  318. # wrapped_ctx.save_for_backward will:
  319. # - unwrap batchedtensors into (tensor, bdim)
  320. # - save_for_backward(*unwrapped_tensors)
  321. # - assign the bdims to wrapped_ctx._pt_saved_tensors_bdims
  322. wrapped_ctx = CtxCustomSave(ctx, current_level())
  323. autograd_function.setup_context(wrapped_ctx, inputs, outputs)
  324. # input_shapes are used for reductify later to reduce expanded gradients
  325. # to the correct shape.
  326. # See NOTE: [Why can't we rely on autograd to reduce expanded gradients?]
  327. # for more details
  328. nonlocal input_shapes_
  329. input_shapes_ = tuple(inp.shape if isinstance(inp, torch.Tensor) else None
  330. for inp in inputs)
  331. nonlocal saved_tensors_bdims_
  332. saved_tensors_bdims_ = wrapped_ctx._pt_saved_tensors_bdims
  333. # See NOTE: [Why do we need to run setup_context under a vmap?]
  334. restore_vmap(
  335. inner,
  336. (in_dims, out_dims),
  337. batch_size,
  338. randomness,
  339. )(inputs, outputs)
  340. nonlocal input_shapes
  341. input_shapes = input_shapes_
  342. nonlocal saved_tensors_bdims
  343. saved_tensors_bdims = saved_tensors_bdims_
  344. def jvp(ctx, *tangents):
  345. assert out_dims != "not populated"
  346. assert saved_tensors_bdims != "not populated"
  347. def jvp_no_context(saved_tensors, tangents):
  348. wrapped_ctx = CtxWithSavedTensors(ctx, saved_tensors)
  349. return autograd_function.jvp(wrapped_ctx, *tangents)
  350. tangent_in_dims = get_tangents_in_dims(in_dims, tangents)
  351. out_tangents, out_tangents_dims = restore_vmap(
  352. jvp_no_context, (saved_tensors_bdims, tangent_in_dims), batch_size, randomness)(
  353. ctx.saved_tensors, tangents)
  354. result = reductify(out_tangents, out_tangents_dims, out_dims, batch_size)
  355. return result
  356. def backward(ctx, *grad_outputs):
  357. assert out_dims != "not populated"
  358. assert input_shapes != "not populated"
  359. assert saved_tensors_bdims != "not populated"
  360. def backward_no_context(inputs):
  361. saved_tensors, grad_outputs = inputs
  362. wrapped_ctx = CtxWithSavedTensors(ctx, saved_tensors)
  363. return autograd_function.backward(wrapped_ctx, *grad_outputs)
  364. grad_ins, grad_ins_dims = restore_vmap(
  365. backward_no_context, ((saved_tensors_bdims, out_dims),), batch_size, randomness)(
  366. (ctx.saved_tensors, grad_outputs))
  367. result = reductify(grad_ins, grad_ins_dims, in_dims, batch_size, input_shapes)
  368. return result
  369. name = f'Vmapped{autograd_function.__name__}'
  370. Generated = type(
  371. name,
  372. (torch.autograd.Function,),
  373. {
  374. 'forward': staticmethod(forward),
  375. 'backward': staticmethod(backward),
  376. 'jvp': staticmethod(jvp),
  377. 'setup_context': staticmethod(setup_context),
  378. 'generate_vmap_rule': True
  379. }
  380. )
  381. def get_out_dims():
  382. assert out_dims != "not populated"
  383. return out_dims
  384. return Generated, get_out_dims
  385. # tangents might be None, so we need to replace
  386. # the corresponding in_dims with None.
  387. def get_tangents_in_dims(input_dims, tangents):
  388. flat_in_dims, spec = pytree.tree_flatten(input_dims)
  389. flat_tangents, _ = pytree.tree_flatten(tangents)
  390. result = [None if tangent is None else in_dim
  391. for in_dim, tangent in zip(flat_in_dims, flat_tangents)]
  392. return pytree.tree_unflatten(result, spec)
  393. # NOTE: [Why do we need to run setup_context under a vmap?]
  394. # Consider the following autograd.Function
  395. #
  396. # class Sum(torch.autograd.Function):
  397. # @staticmethod
  398. # def forward(x):
  399. # return x.sum()
  400. # @staticmethod
  401. # def setup_context(ctx, inputs, outputs):
  402. # ctx.x_shape = inputs[0]
  403. # @staticmethod
  404. # def backward(ctx, gy):
  405. # return gy.expand(ctx.x_shape)
  406. #
  407. # x = torch.randn(B, 4)
  408. # in_dims = 0
  409. # vmap(Sum.apply, in_dims)(x)
  410. #
  411. # Let’s assume for a moment that we didn’t vmap setup_context in VmappedSum:
  412. #
  413. # class VmappedSum(torch.autograd.Function):
  414. # @staticmethod
  415. # def forward(x):
  416. # return vmap(Sum.forward, in_dims)(x)
  417. #
  418. # @staticmethod
  419. # def setup_context(ctx, inputs, outputs):
  420. # Sum.setup_context(ctx, inputs, outputs)
  421. #
  422. # @staticmethod
  423. # def backward(ctx, gy):
  424. # def backward_no_context(gy):
  425. # return gy.expand(ctx.x_shape)
  426. #
  427. # dims = (0,)
  428. # gx = vmap(backward_no_context, dims)(gy)
  429. # return gx
  430. #
  431. # We end up saving [B, 4] as x_shape. In the backward, gy has shape [B],
  432. # and we’re doing:
  433. #
  434. # def backward_no_context(gy):
  435. # return gy.expand([B, 4])
  436. #
  437. # gx = vmap(backward_no_context, dims)(gy: “Tensor[B]”)
  438. #
  439. # This gives us the wrong result (gx has shape [B, B, 4], but it should
  440. # have shape [4]). Performing vmap over setup_context means the shape
  441. # saved has shape [4] and leads to a correct result shape for gx.
  442. # Wraps a ctx object. Forwards all attr accesses to the underlying object
  443. # except for the attrs in _pt_attrs
  444. class WrappedCtx:
  445. _pt_reserved_attrs: Tuple[str, ...] = ('_pt_reserved_attrs', '_pt_inner_ctx')
  446. def __init__(self, ctx):
  447. if not isinstance(ctx, WrappedCtx):
  448. reserved_attrs = type(self)._pt_reserved_attrs
  449. for name in reserved_attrs:
  450. if not hasattr(ctx, name):
  451. continue
  452. raise RuntimeError(
  453. f'PyTorch reserves the {reserved_attrs} field on ctx. '
  454. 'Please name your fields on ctx something else to avoid name '
  455. 'collision.')
  456. self._pt_inner_ctx = ctx
  457. def __getattr__(self, name):
  458. return getattr(self._pt_inner_ctx, name)
  459. def __setattr__(self, name, value):
  460. if name in type(self)._pt_reserved_attrs:
  461. self.__dict__[name] = value
  462. return
  463. return setattr(self._pt_inner_ctx, name, value)
  464. # Wraps ctx to create a new ctx object that overrides saved_tensors.
  465. class CtxWithSavedTensors(WrappedCtx):
  466. _pt_reserved_attrs = ('_pt_new_saved_tensors', *WrappedCtx._pt_reserved_attrs)
  467. def __init__(self, ctx, new_saved_tensors):
  468. super().__init__(ctx)
  469. self._pt_new_saved_tensors = new_saved_tensors
  470. @property
  471. def saved_tensors(self):
  472. return self._pt_new_saved_tensors
  473. class CtxCustomSave(WrappedCtx):
  474. _pt_reserved_attrs = ('_pt_saved_tensors_bdims', '_pt_current_level',
  475. *WrappedCtx._pt_reserved_attrs)
  476. def __init__(self, ctx, current_level):
  477. super().__init__(ctx)
  478. self._pt_saved_tensors_bdims = ()
  479. self._pt_current_level = current_level
  480. def save_for_backward(self, *tensors):
  481. unwrapped_tensors, bdims = unwrap_batched(tensors, self._pt_current_level)
  482. self._pt_inner_ctx.save_for_backward(*unwrapped_tensors)
  483. self._pt_saved_tensors_bdims = bdims
  484. def save_for_forward(self, *tensors):
  485. unwrapped_tensors, bdims = unwrap_batched(tensors, self._pt_current_level)
  486. self._pt_inner_ctx.save_for_forward(*unwrapped_tensors)
  487. self._pt_saved_tensors_bdims = bdims
  488. def reductify(grad_input, grad_input_bdim, input_bdim, batch_size,
  489. target_shape_without_bdim_to_reduce_to=None):
  490. if not isinstance(grad_input, tuple):
  491. grad_input = (grad_input,)
  492. if not isinstance(grad_input_bdim, tuple):
  493. grad_input_bdim = (grad_input_bdim,)
  494. if not isinstance(input_bdim, tuple):
  495. input_bdim = (input_bdim,)
  496. if target_shape_without_bdim_to_reduce_to is None:
  497. target_shape_without_bdim_to_reduce_to = len(grad_input) * (None,)
  498. result = tuple(
  499. reductify_leaf(gi, gi_bdim, i_bdim, batch_size, maybe_ishape)
  500. for gi, gi_bdim, i_bdim, maybe_ishape in
  501. zip(grad_input, grad_input_bdim, input_bdim, target_shape_without_bdim_to_reduce_to)
  502. )
  503. return result
  504. def reductify_leaf(grad_input, grad_input_bdim, input_bdim, batch_size,
  505. target_shape_without_bdim_to_reduce_to=None):
  506. if grad_input is None:
  507. return None
  508. if grad_input_bdim is None and input_bdim is None:
  509. return grad_input
  510. if grad_input_bdim is not None and input_bdim is None:
  511. return grad_input.sum(grad_input_bdim)
  512. # NOTE: [Why can't we rely on autograd to reduce expanded gradients?]
  513. # For reverse-mode AD,
  514. # given a grad_input and input, it is valid for the user to return a
  515. # grad_input that has a broadcasted shape when compared to the input.
  516. # In this situation, autograd automatically reduces the grad_input to
  517. # the shape of the input.
  518. #
  519. # However, when input_bdim is not None, we have problems.
  520. #
  521. # [example 1]
  522. # grad_input: Tensor[3, 4], input: Tensor[B, 4]
  523. # We can expand grad_input to Tensor[B, 3, 4], but that isn't broadcastable
  524. # from [B, 4].
  525. #
  526. # [example 2]
  527. # grad_input: Tensor[3, B, 4], input: Tensor[B, 4]
  528. # We can swizzle grad_input to Tensor[B, 3, 4], but that isn't broadcastable
  529. # from [B, 4].
  530. #
  531. # This means that we need to also reduce the grad_input to the shape of the
  532. # input. This behavior is controlled by the `target_shape_without_bdim_to_reduce_to` flag;
  533. # if not-None then we do the reducing manually, otherwise, we do not do a reduction.
  534. assert input_bdim is not None
  535. if grad_input_bdim is None:
  536. grad_input = grad_input.unsqueeze(input_bdim)
  537. new_shape = list(grad_input.shape)
  538. new_shape[input_bdim] = batch_size
  539. grad_input = grad_input.expand(new_shape)
  540. grad_input_bdim = input_bdim
  541. if target_shape_without_bdim_to_reduce_to is not None:
  542. return vmap(torch.Tensor.sum_to_size, in_dims=(grad_input_bdim, None), out_dims=input_bdim)(
  543. grad_input, target_shape_without_bdim_to_reduce_to)
  544. if input_bdim != grad_input_bdim:
  545. grad_input = grad_input.movedim(grad_input_bdim, input_bdim)
  546. return grad_input