123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294 |
- import inspect
- from typing import Callable, Dict, List, Optional, Tuple
- import torch
- import torch._decomp
- from torch import Tensor
- decomposition_table = torch._decomp.decomposition_table
- decomposition_table_for_jvp: Dict[torch._ops.OpOverload, Callable] = {}
- register_decomposition = torch._decomp.register_decomposition
- aten = torch.ops.aten
- # NOTE: [forward-mode AD decompositions mechanism]
- #
- # The mechanism is in VariableType,
- # IF any inputs have forward grad
- # AND there is no forward AD formula implemented
- # AND the functions is actually differentiable
- # run the decomposition
- # See run_jit_decomposition_with_args_for_jvp
- # We currently use python decompositions that we torchscript.
- #
- # Note that we would be building the backward graph at the decomposed level
- # too, but that is OK, because we would've errored out otherwise anyway.
- #
- # TODO: The mechanism we are using to register decompositions doesn't
- # seem to be exclusively used for jvp. So open question here is whether
- # torch/csrc/jit/runtime/decomposition_registry.cpp is being used for other things.
- # If that is the case, we may go down the decomposition path unexpectedly
- # (and possibly produce an unintelligible error) vs erroring out earlier and
- # printing that the forward AD formula is not implemented.
- #
- # The solution to this may be to have a explicitly white list control when
- # to enable the decomposition.
- def maybe_register_decomposition(op):
- def decorator(f):
- try:
- return register_decomposition(op)(f)
- except Exception:
- return f
- return decorator
- # Functions where we need a special decomposition for jvp but there's another version that
- # should be used more generally (ex. for jvp we need to recompute the mean and variance for
- # the backwards of a normalization function. Without jvp, it should used the saved value)
- decomposition_table_for_jvp = {}
- def register_decomposition_for_jvp(fn):
- return register_decomposition(fn, registry=decomposition_table_for_jvp)
- def _register_jit_decomposition_for_jvp(decomp, use_python=False):
- if decomp in decomposition_table_for_jvp:
- decomposition_table_used = decomposition_table_for_jvp
- elif decomp in decomposition_table:
- decomposition_table_used = decomposition_table
- else:
- raise RuntimeError(f"could not find decomposition for {decomp}")
- decomp_fn = decomposition_table_used[decomp]
- if use_python:
- decomp_fn = torch.jit.ignore(decomp_fn)
- sig = inspect.signature(decomp_fn)
- # Create a string wrapping the function from the signature
- # example output:
- # def wrapped_decomp(x: torch.Tensor, y: int, z: int):
- # return decomp_fn(x, y, z)
- # Thanks copilot!
- def get_function_def(sig):
- param_def = [f"{param_str}" for param_str in sig.parameters.values()]
- param_use = [f"{param_str}" for param_str in sig.parameters.keys()]
- return f"def wrapped_decomp({', '.join(param_def)}):\n return decomp_fn({', '.join(param_use)})\n"
- f_str = get_function_def(sig)
- graph = torch.jit.CompilationUnit(f_str).wrapped_decomp.graph
- else:
- graph = torch.jit.script(decomp_fn).graph
- torch.jit._register_decomposition(decomp, graph)
- # The only decompositions here are temporary or hacks for the purposes of jvp
- # TODO: do these also belong here?
- @maybe_register_decomposition(aten.trace.default)
- def trace(self: Tensor) -> Tensor:
- return torch.sum(torch.diag(self))
- @maybe_register_decomposition(aten.log_sigmoid_forward.default)
- def log_sigmoid_forward(self: Tensor) -> Tuple[Tensor, Tensor]:
- min = torch.minimum(self.new_zeros(()), self)
- z = torch.exp(-torch.abs(self))
- if self.is_cuda:
- buffer = self.new_zeros((0,))
- else:
- buffer = z
- return min - torch.log1p(z), buffer
- def recompute_mean_var(
- input: Tensor, rstd: Tensor, inner_dim_indices: List[int], keepdim: bool
- ):
- # for most norm decompositions, it will be the same as the core version except for here.
- # We recompute the mean and variance so that they track gradients through input
- mean = torch.mean(input, dim=inner_dim_indices, keepdim=keepdim)
- var = torch.var(input, dim=inner_dim_indices, unbiased=False, keepdim=keepdim)
- eps = torch.pow(1 / rstd, 2) - var # this makes me so sad inside
- eps = eps.detach()
- rstd = 1 / torch.sqrt(var + eps)
- return mean, rstd
- @register_decomposition_for_jvp(aten.native_layer_norm_backward)
- def native_layer_norm_backward(
- grad_out: Tensor,
- input: Tensor,
- normalized_shape: List[int],
- mean: Tensor,
- rstd: Tensor,
- weight: Optional[Tensor],
- bias: Optional[Tensor],
- output_mask: List[bool],
- ) -> Tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor]]:
- input_shape = input.shape
- input_ndim = input.dim()
- axis = input_ndim - len(normalized_shape)
- inner_dims = input_shape[axis:]
- outer_dims = input_shape[:axis]
- inner_dim_indices = list(range(axis, input_ndim))
- outer_dim_indices = list(range(0, axis))
- N = 1
- for i in inner_dims:
- N *= i
- M = 1
- for i in outer_dims:
- M *= i
- if M <= 0 or N <= 0:
- return (
- input.new_zeros(input_shape),
- input.new_zeros(input_shape[axis:]),
- input.new_zeros(input_shape[axis:]),
- )
- mean_, rstd_ = recompute_mean_var(input, rstd, inner_dim_indices, keepdim=True)
- x_hat = (input - mean_) * rstd_
- if weight is not None:
- grad_x_hat = grad_out * weight
- else:
- grad_x_hat = grad_out
- a = grad_x_hat * N
- b = torch.sum(grad_x_hat, inner_dim_indices, True)
- c1 = torch.mul(grad_x_hat, x_hat)
- c2 = torch.sum(c1, inner_dim_indices, True)
- c3 = torch.mul(x_hat, c2)
- inner = a - b - c3
- if output_mask[0]:
- d_input: Optional[Tensor] = (rstd_ / N) * inner
- else:
- d_input = torch.zeros_like(input) # should be None but doesn't work with vjp
- if output_mask[1] and weight is not None:
- if len(outer_dim_indices) > 0:
- d_weight: Optional[Tensor] = torch.sum(
- grad_out * x_hat, outer_dim_indices, False
- )
- else:
- d_weight = grad_out * x_hat
- elif weight is not None:
- d_weight = torch.zeros_like(weight) # should be None but doesn't work with vjp
- else:
- d_weight = torch.zeros(()) # should be None but doesn't work with vjp
- if output_mask[2] and bias is not None:
- if len(outer_dim_indices) > 0:
- d_bias: Optional[Tensor] = torch.sum(grad_out, outer_dim_indices, False)
- else:
- d_bias = grad_out.clone()
- elif bias is not None:
- d_bias = torch.zeros_like(bias) # should be None but doesn't work with vjp
- else:
- d_bias = torch.zeros(()) # should be None but doesn't work with vjp
- return (d_input, d_weight, d_bias)
- def prod(x: List[int]):
- r = 1
- for i in x:
- r *= i
- return r
- @register_decomposition_for_jvp(aten.native_batch_norm_backward)
- def native_batch_norm_backward(
- grad_out: Tensor,
- input: Tensor,
- weight: Optional[Tensor],
- running_mean: Optional[Tensor],
- running_var: Optional[Tensor],
- save_mean: Optional[Tensor],
- save_invstd: Optional[Tensor],
- train: bool,
- eps: float,
- output_mask: List[bool],
- ) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
- input_shape = input.shape
- input_rank = input.dim()
- assert input_rank >= 2, "rank of the input must be at least 2"
- axis = 1
- num_features = prod(input_shape) / input_shape[axis] # type: ignore[arg-type]
- mean = save_mean
- invstd = save_invstd
- if train:
- assert (
- save_mean is not None and save_invstd is not None
- ), "when train=True, save_mean and save_invstd are required"
- reduciton_dims = [0] + list(range(2, input.dim()))
- assert invstd is not None # for typing
- mean, invstd = recompute_mean_var(input, invstd, reduciton_dims, keepdim=False)
- else:
- assert running_mean is not None and running_var is not None
- mean = running_mean
- invstd = torch.rsqrt(running_var + eps)
- assert invstd is not None and mean is not None
- broadcast_mask = [1] * input_rank
- broadcast_mask[axis] = input_shape[axis]
- reduction_axes: List[int] = []
- for i in range(input_rank):
- if i != axis:
- reduction_axes.append(i)
- mean = torch.reshape(mean, broadcast_mask)
- norm = 1.0 / num_features
- grad_output_sum = torch.sum(grad_out, reduction_axes)
- dot_p = torch.sum(grad_out * (input - mean), reduction_axes)
- grad_mean = torch.reshape(grad_output_sum * norm, broadcast_mask)
- proj_scale = torch.reshape(torch.mul(dot_p * norm, invstd * invstd), broadcast_mask)
- if weight is None:
- grad_scale = torch.reshape(invstd, broadcast_mask) * 1.0
- else:
- grad_scale = torch.reshape(invstd * weight, broadcast_mask)
- if train:
- proj = (input - mean) * proj_scale
- grad_input = ((grad_out - proj) - grad_mean) * grad_scale
- else:
- grad_input = grad_out * grad_scale
- if output_mask[1]:
- grad_weight = dot_p * invstd
- elif weight is not None:
- grad_weight = torch.zeros_like(
- weight
- ) # should be None but doesn't work with vjp
- else:
- grad_weight = torch.zeros(()) # should be None but doesn't work with vjp
- if output_mask[2]:
- grad_bias = grad_output_sum
- else:
- grad_bias = torch.zeros_like(
- grad_output_sum
- ) # should be None but doesn't work with vjp
- return (grad_input, grad_weight, grad_bias)
- _register_jit_decomposition_for_jvp(torch.ops.aten.trace.default, use_python=True)
- _register_jit_decomposition_for_jvp(torch.ops.aten.nll_loss_backward.default)
- _register_jit_decomposition_for_jvp(torch.ops.aten.nll_loss2d_backward.default)
- _register_jit_decomposition_for_jvp(torch.ops.aten._log_softmax_backward_data.default)
- _register_jit_decomposition_for_jvp(torch.ops.aten._softmax_backward_data.default)
- _register_jit_decomposition_for_jvp(torch.ops.aten.log_sigmoid_forward.default)
- _register_jit_decomposition_for_jvp(torch.ops.aten.native_layer_norm_backward.default)
- _register_jit_decomposition_for_jvp(torch.ops.aten.native_batch_norm_backward.default)
- _register_jit_decomposition_for_jvp(torch.ops.aten.cudnn_batch_norm_backward.default)
|