123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634 |
- import torch
- from torch._ops import PyOperator
- from torch._C._functorch import TransformType
- from torch._functorch.utils import enable_single_level_autograd_function
- import torch.utils._pytree as pytree
- from torch._C._functorch import (
- _wrap_for_grad,
- _unwrap_for_grad,
- current_level,
- )
- from torch._functorch.vmap import (
- wrap_batched,
- unwrap_batched,
- vmap,
- restore_vmap,
- _add_batch_dim,
- )
- from torch._functorch.vmap import _broadcast_to_and_flatten
- from torch.autograd.forward_ad import _set_fwd_grad_enabled
- from typing import Any, NamedTuple, Tuple
- # autograd.Function technically runs before the regular PyTorch dispatcher.
- # This is how features like autocast and torch_dispatch (e.g. PythonTLSSnapshot)
- # work with it. One day we might decide to change this, but until then,
- # we need to give the illusion that autograd.Function runs before those things.
- #
- # We do this by using creating a custom PyOperator that only functorch
- # dispatches specially.
- class CustomFunctionPyOperator(PyOperator):
- def __init__(self):
- super().__init__('custom_function_call')
- def __call__(self, autograd_function, *args, **kwargs):
- # When custom_function_call is done dispatching through functorch,
- # it should just invoke the autograd.Function. This is consistent
- # with the autograd.Function behavior of being invoked before the
- # PyTorch dispatcher.
- #
- # This will lead us into trouble later down the line, but this is
- # pre-existing. There is an invariant that a function traced by
- # make_fx should have the same behavior when provided the same
- # Tensor. However, make_fx sees autograd.Function as a composite
- # (because autograd.Function happens before the Python dispatch key)
- # and only traces the forward pass.
- if torch._C._are_functorch_transforms_active():
- return super().__call__(autograd_function, *args, **kwargs)
- return autograd_function.apply(*args, **kwargs)
- # "custom_function_call"
- # This is the mechanism for an autograd.Function that works with functorch transforms.
- # It wraps an autograd.Function; interactions with functorch transforms are defined
- # via PyDispatcher and PyOperator rather than through the traditional PyTorch
- # dispatcher.
- custom_function_call = CustomFunctionPyOperator()
- # The grad rule for custom_function_call is to construct a new _SingleLevelFunction
- # (autograd.Function that only works with a single layer (level) of functorch) that:
- # - unwraps the inputs
- # - redispatches to custom_function_call
- # - wraps the outputs
- # and whose backward pass calls the original autograd.Function's backward.
- #
- # Why do we need to redispatch to custom_function_call?
- # -----------------------------------------------------
- # This is consistent with how ATen operators work with functorch's grad transform:
- # they always redispatch to the original operator.
- # Consider torch.sin, and let's say we do grad0(grad1(torch.sin))(x)
- #
- # grad1 will:
- # - set up the autograd graph
- # - unwrap the inputs
- # - redispatch to at::sin (*)
- # - rewrap the outputs on the return
- #
- # On the redispatch in (*), grad0 will:
- # - set up the autograd graph
- # - unwrap the inputs
- # - redispatch to at::sin
- # - rewrap the outputs on the return
- #
- # To "set up the autograd graph", we generate a _SingleLevelFunction
- # and apply it.
- @custom_function_call.py_impl(TransformType.Grad)
- @custom_function_call.py_impl(TransformType.Jvp)
- def custom_function_call_grad(interpreter, autograd_function, *operands):
- Generated = generate_single_level_function(interpreter, autograd_function)
- with enable_single_level_autograd_function():
- flat_out = Generated.apply(*operands)
- return flat_out
- def generate_single_level_function(interpreter, autograd_function):
- level = interpreter.level()
- def forward(*operands):
- unwrapped_operands = pytree.tree_map_only(
- torch.Tensor,
- lambda x: _unwrap_for_grad(x, level),
- operands)
- # Both enable_grad() and _set_fwd_grad_enabled() are necessary no matter
- # the transform. _SingleLevelFunction will turn off both fwd and bwd
- # gradient computation and we need to turn it back on here.
- with torch.enable_grad(), _set_fwd_grad_enabled(True), interpreter.lower():
- unwrapped_output = custom_function_call(autograd_function, *unwrapped_operands)
- # See NOTE [mark_dirty object identity check]
- def wrap_fn(output):
- return _wrap_for_grad(output, level)
- return wrap_outputs_maintaining_identity(
- unwrapped_output,
- unwrapped_operands,
- operands,
- wrap_fn)
- def setup_context(ctx, inputs, output):
- return autograd_function.setup_context(ctx, inputs, output)
- # backward is only used if the transform is TransformType.Grad
- def backward(ctx, *grads):
- result = autograd_function.backward(ctx, *grads)
- return result
- # jvp is only used if the transform is TransformType.Jvp
- def jvp(ctx, *tangents):
- result = autograd_function.jvp(ctx, *tangents)
- return result
- # This is the sequence of magic words to dynamically generate a Subclass with
- # a given name. A Tensor's .grad_fn field has a class name that is the original
- # autograd.Function's name + Backward, so we do this to generate some
- # meaningful name.
- name = f'{autograd_function.__name__}Generated'
- Generated = type(
- name,
- (torch.autograd.function._SingleLevelFunction,),
- {
- 'forward': staticmethod(forward),
- 'backward': staticmethod(backward),
- 'jvp': staticmethod(jvp),
- 'setup_context': staticmethod(setup_context),
- },
- )
- return Generated
- # wrap_outputs_maintaining_identity handles outputs from the vmap,
- # backward (vjp), and jvp staticmethod. The way it distinguishes
- # between the vmap case and the {backward, jvp} case is if the out_dims
- # are specified or not.
- #
- # NB: we cannot use out_dims=None as the deciding factor. This because
- # out_dims=None can still happen in the vmap staticmethod! What the
- # user is saying in that case is that their output does not have a
- # dimension that is being vmapped over, which is valid.
- NO_OUT_DIMS = "not specified"
- # NOTE [mark_dirty object identity check]
- # autograd.Function's ctx.mark_dirty expect a returned input
- # to have the same object identity as the input.
- # Mode-only functorch will greatly simplify this logic.
- def wrap_outputs_maintaining_identity(
- outputs, unwrapped_inputs, orig_inputs, wrap_fn, out_dims=NO_OUT_DIMS):
- flat_unwrapped_inputs, _ = pytree.tree_flatten(unwrapped_inputs)
- flat_orig_inputs, _ = pytree.tree_flatten(orig_inputs)
- unwrapped_input_to_orig_input = {
- id(unwrapped): orig
- for unwrapped, orig in zip(flat_unwrapped_inputs, flat_orig_inputs)
- }
- flat_outputs, spec = pytree.tree_flatten(outputs)
- result = []
- out_dims_specified = out_dims != NO_OUT_DIMS
- if out_dims_specified:
- flat_out_dims = _broadcast_to_and_flatten(out_dims, spec)
- # _broadcast_to_and_flatten returns None if it is unable to broadcast.
- # TODO: update following link from master to stable once that's out
- if flat_out_dims is None:
- raise RuntimeError(
- f"The autograd.Function's vmap staticmethod returned an "
- f"incompatible (output, out_dims) tuple. "
- f"Expected out_dims={out_dims} "
- f"to be compatible with the structure of `output`. "
- f"out_dims has structure {pytree.tree_flatten(out_dims)[1]} "
- f"but output has structure {spec}. "
- f"For more details, please see "
- f"https://pytorch.org/docs/master/notes/extending.func.html"
- )
- for i, output in enumerate(flat_outputs):
- if not isinstance(output, torch.Tensor):
- result.append(output)
- continue
- if id(output) in unwrapped_input_to_orig_input:
- result.append(unwrapped_input_to_orig_input[id(output)])
- continue
- if out_dims_specified:
- result.append(wrap_fn(output, flat_out_dims[i])) # type: ignore[index]
- else:
- result.append(wrap_fn(output))
- return pytree.tree_unflatten(result, spec)
- # NOTE: [functorch vjp and autograd interaction]
- # There's an edge case with the functorch vjp and autograd interaction
- # that will eventually be fixed by mode-only functorch.
- # The TL;DR is that there's no way to unwrap a dead GradTensorWrapper,
- # so we (the framework) need to do it manually. Regular PyTorch operators
- # automatically do so this is consisent.
- #
- # class MyExp(torch.autograd.Function):
- # @staticmethod
- # def forward(x):
- # return x.exp()
- #
- # @staticmethod
- # def setup_context(ctx, inputs, output):
- # y = output
- # ctx.save_for_backward(y)
- #
- # @staticmethod
- # def backward(gy):
- # y, = ctx.saved_tensors()
- # return MyMul.apply(gy, y)
- #
- # x = torch.randn([], requires_grad=True)
- # gy = torch.randn([], requires_grad=True)
- # _, vjp_fn = vjp(MySin.apply, x)
- # result = vjp_fn(gy)
- #
- # MyMul is an autograd.Function that is not shown here.
- # It saves a `y` for backward (since gy requires grad).
- #
- # in vjp_fn(gy), we get:
- # > MyMul.apply(gy, GradTensorWrapper(y, level=dead))
- # Because the y that is saved for backward by MyExp is a GradTensorWrapper
- # but is now dead since we are outside the vjp context.
- #
- # PyTorch dispatcher operations, upon seeing a dead GradTensorWrapper,
- # will automatically unwrap the GradTensorWrapper when applied.
- # But since autograd.Function technically sits above the regular PyTorch
- # dispatcher, it doesn't get this treatment. So we manually do
- # the unwrapping to be consistent with regular PyTorch dispatcher operations.
- class VmapInfo(NamedTuple):
- batch_size: int
- randomness: str
- def has_overriden_vmap_rule(autograd_function):
- return autograd_function.vmap is not torch.autograd.Function.vmap
- def validate_vmap_returns_tuple_of_two_elements(result):
- base_error_msg = (
- "Expected the vmap staticmethod to have two returns, an output "
- "and out_dims with pytree structure compatible with the output. "
- )
- if not isinstance(result, tuple):
- raise RuntimeError(base_error_msg + f"Got a {type(result)} instead")
- if not len(result) == 2:
- raise RuntimeError(base_error_msg + f"Got {len(result)} returns instead")
- @custom_function_call.py_impl(TransformType.Vmap)
- def custom_function_call_vmap(interpreter, autograd_function, *operands):
- if autograd_function.generate_vmap_rule:
- if has_overriden_vmap_rule(autograd_function):
- # TODO: Update link to stable once that's out
- # https://github.com/pytorch/pytorch/issues/92029
- raise RuntimeError(
- f"You tried to vmap over {autograd_function.__name__}, but "
- f"it has both generate_vmap_rule=True and an overriden vmap "
- f"staticmethod. Please set generate_vmap_rule=False or delete "
- f"the overriden vmap staticmethod to avoid ambiguity. "
- f"For more details, please see "
- f"https://pytorch.org/docs/master/notes/extending.func.html")
- return custom_function_call_vmap_generate_rule(interpreter, autograd_function, *operands)
- if not has_overriden_vmap_rule(autograd_function):
- # TODO: Update link to stable once that's out
- # https://github.com/pytorch/pytorch/issues/92029
- raise RuntimeError(
- f"You tried to vmap over {autograd_function.__name__}, but "
- f"it does not have vmap support. Please override and implement the "
- f"vmap staticmethod or set generate_vmap_rule=True. "
- f"For more details, please see "
- f"https://pytorch.org/docs/master/notes/extending.func.html")
- current_level = interpreter.level()
- info = VmapInfo(
- batch_size=interpreter.batch_size(),
- randomness=interpreter.randomness(),
- )
- unwrapped_operands, in_dims = unwrap_batched(operands, current_level)
- # If none of the tensors are batched at the current level, then we skip the
- # current level. This saves the user from needing to handle this case in
- # their vmap staticmethod (and is consistent with our C++ batching rule API)
- if pytree.tree_all(lambda dim: dim is None, in_dims):
- with interpreter.lower():
- return custom_function_call(autograd_function, *operands)
- with interpreter.lower():
- result = autograd_function.vmap(info, in_dims, *unwrapped_operands)
- validate_vmap_returns_tuple_of_two_elements(result)
- unwrapped_output, out_dims = result
- # See NOTE [mark_dirty object identity check]
- def wrap_fn(output, out_dim):
- return output if out_dim is None else _add_batch_dim(output, out_dim, current_level)
- return wrap_outputs_maintaining_identity(
- unwrapped_output,
- unwrapped_operands,
- operands,
- wrap_fn,
- out_dims=out_dims)
- def custom_function_call_vmap_generate_rule(interpreter, autograd_function, *operands):
- unwrapped_operands, in_dims = unwrap_batched(operands, interpreter.level())
- vmapped_function, get_out_dims = vmapify_autograd_function(
- autograd_function, in_dims, interpreter.batch_size(), interpreter.randomness())
- with interpreter.lower():
- output = custom_function_call(vmapped_function, *unwrapped_operands)
- out_dims = get_out_dims()
- return wrap_batched(output, out_dims, interpreter.level())
- @custom_function_call.py_impl(TransformType.Functionalize)
- def custom_function_call_functionalize(interpreter, autograd_function, generate_vmap_rule, *operands):
- raise RuntimeError("NYI: Functionalize rule for custom_function_call")
- def vmapify_autograd_function(autograd_function, in_dims, batch_size, randomness):
- # The following values are saved from the forward() and setup_context()
- # and used in backward().
- # Why do we save the values out here instead of on the ctx object?
- # - out_dims: There's no way to retrieve this from forward()
- # - input_shapes, saved_tensors_bdims: I'm a bit scared of nesting
- # vmap(vmap( but not completely sure if it is a problem. If we
- # assigned those fields to the ctx object, the worry is that they
- # get overwritten.
- out_dims = "not populated"
- input_shapes: Any = "not populated"
- saved_tensors_bdims: Any = "not populated"
- def forward(*operands):
- nonlocal out_dims
- outputs, out_dims = restore_vmap(
- autograd_function.forward, in_dims, batch_size, randomness)(*operands)
- return outputs
- def setup_context(ctx, inputs, outputs):
- input_shapes_ = None
- saved_tensors_bdims_ = None
- def inner(inputs, outputs):
- # wrapped_ctx.save_for_backward will:
- # - unwrap batchedtensors into (tensor, bdim)
- # - save_for_backward(*unwrapped_tensors)
- # - assign the bdims to wrapped_ctx._pt_saved_tensors_bdims
- wrapped_ctx = CtxCustomSave(ctx, current_level())
- autograd_function.setup_context(wrapped_ctx, inputs, outputs)
- # input_shapes are used for reductify later to reduce expanded gradients
- # to the correct shape.
- # See NOTE: [Why can't we rely on autograd to reduce expanded gradients?]
- # for more details
- nonlocal input_shapes_
- input_shapes_ = tuple(inp.shape if isinstance(inp, torch.Tensor) else None
- for inp in inputs)
- nonlocal saved_tensors_bdims_
- saved_tensors_bdims_ = wrapped_ctx._pt_saved_tensors_bdims
- # See NOTE: [Why do we need to run setup_context under a vmap?]
- restore_vmap(
- inner,
- (in_dims, out_dims),
- batch_size,
- randomness,
- )(inputs, outputs)
- nonlocal input_shapes
- input_shapes = input_shapes_
- nonlocal saved_tensors_bdims
- saved_tensors_bdims = saved_tensors_bdims_
- def jvp(ctx, *tangents):
- assert out_dims != "not populated"
- assert saved_tensors_bdims != "not populated"
- def jvp_no_context(saved_tensors, tangents):
- wrapped_ctx = CtxWithSavedTensors(ctx, saved_tensors)
- return autograd_function.jvp(wrapped_ctx, *tangents)
- tangent_in_dims = get_tangents_in_dims(in_dims, tangents)
- out_tangents, out_tangents_dims = restore_vmap(
- jvp_no_context, (saved_tensors_bdims, tangent_in_dims), batch_size, randomness)(
- ctx.saved_tensors, tangents)
- result = reductify(out_tangents, out_tangents_dims, out_dims, batch_size)
- return result
- def backward(ctx, *grad_outputs):
- assert out_dims != "not populated"
- assert input_shapes != "not populated"
- assert saved_tensors_bdims != "not populated"
- def backward_no_context(inputs):
- saved_tensors, grad_outputs = inputs
- wrapped_ctx = CtxWithSavedTensors(ctx, saved_tensors)
- return autograd_function.backward(wrapped_ctx, *grad_outputs)
- grad_ins, grad_ins_dims = restore_vmap(
- backward_no_context, ((saved_tensors_bdims, out_dims),), batch_size, randomness)(
- (ctx.saved_tensors, grad_outputs))
- result = reductify(grad_ins, grad_ins_dims, in_dims, batch_size, input_shapes)
- return result
- name = f'Vmapped{autograd_function.__name__}'
- Generated = type(
- name,
- (torch.autograd.Function,),
- {
- 'forward': staticmethod(forward),
- 'backward': staticmethod(backward),
- 'jvp': staticmethod(jvp),
- 'setup_context': staticmethod(setup_context),
- 'generate_vmap_rule': True
- }
- )
- def get_out_dims():
- assert out_dims != "not populated"
- return out_dims
- return Generated, get_out_dims
- # tangents might be None, so we need to replace
- # the corresponding in_dims with None.
- def get_tangents_in_dims(input_dims, tangents):
- flat_in_dims, spec = pytree.tree_flatten(input_dims)
- flat_tangents, _ = pytree.tree_flatten(tangents)
- result = [None if tangent is None else in_dim
- for in_dim, tangent in zip(flat_in_dims, flat_tangents)]
- return pytree.tree_unflatten(result, spec)
- # NOTE: [Why do we need to run setup_context under a vmap?]
- # Consider the following autograd.Function
- #
- # class Sum(torch.autograd.Function):
- # @staticmethod
- # def forward(x):
- # return x.sum()
- # @staticmethod
- # def setup_context(ctx, inputs, outputs):
- # ctx.x_shape = inputs[0]
- # @staticmethod
- # def backward(ctx, gy):
- # return gy.expand(ctx.x_shape)
- #
- # x = torch.randn(B, 4)
- # in_dims = 0
- # vmap(Sum.apply, in_dims)(x)
- #
- # Let’s assume for a moment that we didn’t vmap setup_context in VmappedSum:
- #
- # class VmappedSum(torch.autograd.Function):
- # @staticmethod
- # def forward(x):
- # return vmap(Sum.forward, in_dims)(x)
- #
- # @staticmethod
- # def setup_context(ctx, inputs, outputs):
- # Sum.setup_context(ctx, inputs, outputs)
- #
- # @staticmethod
- # def backward(ctx, gy):
- # def backward_no_context(gy):
- # return gy.expand(ctx.x_shape)
- #
- # dims = (0,)
- # gx = vmap(backward_no_context, dims)(gy)
- # return gx
- #
- # We end up saving [B, 4] as x_shape. In the backward, gy has shape [B],
- # and we’re doing:
- #
- # def backward_no_context(gy):
- # return gy.expand([B, 4])
- #
- # gx = vmap(backward_no_context, dims)(gy: “Tensor[B]”)
- #
- # This gives us the wrong result (gx has shape [B, B, 4], but it should
- # have shape [4]). Performing vmap over setup_context means the shape
- # saved has shape [4] and leads to a correct result shape for gx.
- # Wraps a ctx object. Forwards all attr accesses to the underlying object
- # except for the attrs in _pt_attrs
- class WrappedCtx:
- _pt_reserved_attrs: Tuple[str, ...] = ('_pt_reserved_attrs', '_pt_inner_ctx')
- def __init__(self, ctx):
- if not isinstance(ctx, WrappedCtx):
- reserved_attrs = type(self)._pt_reserved_attrs
- for name in reserved_attrs:
- if not hasattr(ctx, name):
- continue
- raise RuntimeError(
- f'PyTorch reserves the {reserved_attrs} field on ctx. '
- 'Please name your fields on ctx something else to avoid name '
- 'collision.')
- self._pt_inner_ctx = ctx
- def __getattr__(self, name):
- return getattr(self._pt_inner_ctx, name)
- def __setattr__(self, name, value):
- if name in type(self)._pt_reserved_attrs:
- self.__dict__[name] = value
- return
- return setattr(self._pt_inner_ctx, name, value)
- # Wraps ctx to create a new ctx object that overrides saved_tensors.
- class CtxWithSavedTensors(WrappedCtx):
- _pt_reserved_attrs = ('_pt_new_saved_tensors', *WrappedCtx._pt_reserved_attrs)
- def __init__(self, ctx, new_saved_tensors):
- super().__init__(ctx)
- self._pt_new_saved_tensors = new_saved_tensors
- @property
- def saved_tensors(self):
- return self._pt_new_saved_tensors
- class CtxCustomSave(WrappedCtx):
- _pt_reserved_attrs = ('_pt_saved_tensors_bdims', '_pt_current_level',
- *WrappedCtx._pt_reserved_attrs)
- def __init__(self, ctx, current_level):
- super().__init__(ctx)
- self._pt_saved_tensors_bdims = ()
- self._pt_current_level = current_level
- def save_for_backward(self, *tensors):
- unwrapped_tensors, bdims = unwrap_batched(tensors, self._pt_current_level)
- self._pt_inner_ctx.save_for_backward(*unwrapped_tensors)
- self._pt_saved_tensors_bdims = bdims
- def save_for_forward(self, *tensors):
- unwrapped_tensors, bdims = unwrap_batched(tensors, self._pt_current_level)
- self._pt_inner_ctx.save_for_forward(*unwrapped_tensors)
- self._pt_saved_tensors_bdims = bdims
- def reductify(grad_input, grad_input_bdim, input_bdim, batch_size,
- target_shape_without_bdim_to_reduce_to=None):
- if not isinstance(grad_input, tuple):
- grad_input = (grad_input,)
- if not isinstance(grad_input_bdim, tuple):
- grad_input_bdim = (grad_input_bdim,)
- if not isinstance(input_bdim, tuple):
- input_bdim = (input_bdim,)
- if target_shape_without_bdim_to_reduce_to is None:
- target_shape_without_bdim_to_reduce_to = len(grad_input) * (None,)
- result = tuple(
- reductify_leaf(gi, gi_bdim, i_bdim, batch_size, maybe_ishape)
- for gi, gi_bdim, i_bdim, maybe_ishape in
- zip(grad_input, grad_input_bdim, input_bdim, target_shape_without_bdim_to_reduce_to)
- )
- return result
- def reductify_leaf(grad_input, grad_input_bdim, input_bdim, batch_size,
- target_shape_without_bdim_to_reduce_to=None):
- if grad_input is None:
- return None
- if grad_input_bdim is None and input_bdim is None:
- return grad_input
- if grad_input_bdim is not None and input_bdim is None:
- return grad_input.sum(grad_input_bdim)
- # NOTE: [Why can't we rely on autograd to reduce expanded gradients?]
- # For reverse-mode AD,
- # given a grad_input and input, it is valid for the user to return a
- # grad_input that has a broadcasted shape when compared to the input.
- # In this situation, autograd automatically reduces the grad_input to
- # the shape of the input.
- #
- # However, when input_bdim is not None, we have problems.
- #
- # [example 1]
- # grad_input: Tensor[3, 4], input: Tensor[B, 4]
- # We can expand grad_input to Tensor[B, 3, 4], but that isn't broadcastable
- # from [B, 4].
- #
- # [example 2]
- # grad_input: Tensor[3, B, 4], input: Tensor[B, 4]
- # We can swizzle grad_input to Tensor[B, 3, 4], but that isn't broadcastable
- # from [B, 4].
- #
- # This means that we need to also reduce the grad_input to the shape of the
- # input. This behavior is controlled by the `target_shape_without_bdim_to_reduce_to` flag;
- # if not-None then we do the reducing manually, otherwise, we do not do a reduction.
- assert input_bdim is not None
- if grad_input_bdim is None:
- grad_input = grad_input.unsqueeze(input_bdim)
- new_shape = list(grad_input.shape)
- new_shape[input_bdim] = batch_size
- grad_input = grad_input.expand(new_shape)
- grad_input_bdim = input_bdim
- if target_shape_without_bdim_to_reduce_to is not None:
- return vmap(torch.Tensor.sum_to_size, in_dims=(grad_input_bdim, None), out_dims=input_bdim)(
- grad_input, target_shape_without_bdim_to_reduce_to)
- if input_bdim != grad_input_bdim:
- grad_input = grad_input.movedim(grad_input_bdim, input_bdim)
- return grad_input
|