import inspect from typing import Callable, Dict, List, Optional, Tuple import torch import torch._decomp from torch import Tensor decomposition_table = torch._decomp.decomposition_table decomposition_table_for_jvp: Dict[torch._ops.OpOverload, Callable] = {} register_decomposition = torch._decomp.register_decomposition aten = torch.ops.aten # NOTE: [forward-mode AD decompositions mechanism] # # The mechanism is in VariableType, # IF any inputs have forward grad # AND there is no forward AD formula implemented # AND the functions is actually differentiable # run the decomposition # See run_jit_decomposition_with_args_for_jvp # We currently use python decompositions that we torchscript. # # Note that we would be building the backward graph at the decomposed level # too, but that is OK, because we would've errored out otherwise anyway. # # TODO: The mechanism we are using to register decompositions doesn't # seem to be exclusively used for jvp. So open question here is whether # torch/csrc/jit/runtime/decomposition_registry.cpp is being used for other things. # If that is the case, we may go down the decomposition path unexpectedly # (and possibly produce an unintelligible error) vs erroring out earlier and # printing that the forward AD formula is not implemented. # # The solution to this may be to have a explicitly white list control when # to enable the decomposition. def maybe_register_decomposition(op): def decorator(f): try: return register_decomposition(op)(f) except Exception: return f return decorator # Functions where we need a special decomposition for jvp but there's another version that # should be used more generally (ex. for jvp we need to recompute the mean and variance for # the backwards of a normalization function. Without jvp, it should used the saved value) decomposition_table_for_jvp = {} def register_decomposition_for_jvp(fn): return register_decomposition(fn, registry=decomposition_table_for_jvp) def _register_jit_decomposition_for_jvp(decomp, use_python=False): if decomp in decomposition_table_for_jvp: decomposition_table_used = decomposition_table_for_jvp elif decomp in decomposition_table: decomposition_table_used = decomposition_table else: raise RuntimeError(f"could not find decomposition for {decomp}") decomp_fn = decomposition_table_used[decomp] if use_python: decomp_fn = torch.jit.ignore(decomp_fn) sig = inspect.signature(decomp_fn) # Create a string wrapping the function from the signature # example output: # def wrapped_decomp(x: torch.Tensor, y: int, z: int): # return decomp_fn(x, y, z) # Thanks copilot! def get_function_def(sig): param_def = [f"{param_str}" for param_str in sig.parameters.values()] param_use = [f"{param_str}" for param_str in sig.parameters.keys()] return f"def wrapped_decomp({', '.join(param_def)}):\n return decomp_fn({', '.join(param_use)})\n" f_str = get_function_def(sig) graph = torch.jit.CompilationUnit(f_str).wrapped_decomp.graph else: graph = torch.jit.script(decomp_fn).graph torch.jit._register_decomposition(decomp, graph) # The only decompositions here are temporary or hacks for the purposes of jvp # TODO: do these also belong here? @maybe_register_decomposition(aten.trace.default) def trace(self: Tensor) -> Tensor: return torch.sum(torch.diag(self)) @maybe_register_decomposition(aten.log_sigmoid_forward.default) def log_sigmoid_forward(self: Tensor) -> Tuple[Tensor, Tensor]: min = torch.minimum(self.new_zeros(()), self) z = torch.exp(-torch.abs(self)) if self.is_cuda: buffer = self.new_zeros((0,)) else: buffer = z return min - torch.log1p(z), buffer def recompute_mean_var( input: Tensor, rstd: Tensor, inner_dim_indices: List[int], keepdim: bool ): # for most norm decompositions, it will be the same as the core version except for here. # We recompute the mean and variance so that they track gradients through input mean = torch.mean(input, dim=inner_dim_indices, keepdim=keepdim) var = torch.var(input, dim=inner_dim_indices, unbiased=False, keepdim=keepdim) eps = torch.pow(1 / rstd, 2) - var # this makes me so sad inside eps = eps.detach() rstd = 1 / torch.sqrt(var + eps) return mean, rstd @register_decomposition_for_jvp(aten.native_layer_norm_backward) def native_layer_norm_backward( grad_out: Tensor, input: Tensor, normalized_shape: List[int], mean: Tensor, rstd: Tensor, weight: Optional[Tensor], bias: Optional[Tensor], output_mask: List[bool], ) -> Tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor]]: input_shape = input.shape input_ndim = input.dim() axis = input_ndim - len(normalized_shape) inner_dims = input_shape[axis:] outer_dims = input_shape[:axis] inner_dim_indices = list(range(axis, input_ndim)) outer_dim_indices = list(range(0, axis)) N = 1 for i in inner_dims: N *= i M = 1 for i in outer_dims: M *= i if M <= 0 or N <= 0: return ( input.new_zeros(input_shape), input.new_zeros(input_shape[axis:]), input.new_zeros(input_shape[axis:]), ) mean_, rstd_ = recompute_mean_var(input, rstd, inner_dim_indices, keepdim=True) x_hat = (input - mean_) * rstd_ if weight is not None: grad_x_hat = grad_out * weight else: grad_x_hat = grad_out a = grad_x_hat * N b = torch.sum(grad_x_hat, inner_dim_indices, True) c1 = torch.mul(grad_x_hat, x_hat) c2 = torch.sum(c1, inner_dim_indices, True) c3 = torch.mul(x_hat, c2) inner = a - b - c3 if output_mask[0]: d_input: Optional[Tensor] = (rstd_ / N) * inner else: d_input = torch.zeros_like(input) # should be None but doesn't work with vjp if output_mask[1] and weight is not None: if len(outer_dim_indices) > 0: d_weight: Optional[Tensor] = torch.sum( grad_out * x_hat, outer_dim_indices, False ) else: d_weight = grad_out * x_hat elif weight is not None: d_weight = torch.zeros_like(weight) # should be None but doesn't work with vjp else: d_weight = torch.zeros(()) # should be None but doesn't work with vjp if output_mask[2] and bias is not None: if len(outer_dim_indices) > 0: d_bias: Optional[Tensor] = torch.sum(grad_out, outer_dim_indices, False) else: d_bias = grad_out.clone() elif bias is not None: d_bias = torch.zeros_like(bias) # should be None but doesn't work with vjp else: d_bias = torch.zeros(()) # should be None but doesn't work with vjp return (d_input, d_weight, d_bias) def prod(x: List[int]): r = 1 for i in x: r *= i return r @register_decomposition_for_jvp(aten.native_batch_norm_backward) def native_batch_norm_backward( grad_out: Tensor, input: Tensor, weight: Optional[Tensor], running_mean: Optional[Tensor], running_var: Optional[Tensor], save_mean: Optional[Tensor], save_invstd: Optional[Tensor], train: bool, eps: float, output_mask: List[bool], ) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]: input_shape = input.shape input_rank = input.dim() assert input_rank >= 2, "rank of the input must be at least 2" axis = 1 num_features = prod(input_shape) / input_shape[axis] # type: ignore[arg-type] mean = save_mean invstd = save_invstd if train: assert ( save_mean is not None and save_invstd is not None ), "when train=True, save_mean and save_invstd are required" reduciton_dims = [0] + list(range(2, input.dim())) assert invstd is not None # for typing mean, invstd = recompute_mean_var(input, invstd, reduciton_dims, keepdim=False) else: assert running_mean is not None and running_var is not None mean = running_mean invstd = torch.rsqrt(running_var + eps) assert invstd is not None and mean is not None broadcast_mask = [1] * input_rank broadcast_mask[axis] = input_shape[axis] reduction_axes: List[int] = [] for i in range(input_rank): if i != axis: reduction_axes.append(i) mean = torch.reshape(mean, broadcast_mask) norm = 1.0 / num_features grad_output_sum = torch.sum(grad_out, reduction_axes) dot_p = torch.sum(grad_out * (input - mean), reduction_axes) grad_mean = torch.reshape(grad_output_sum * norm, broadcast_mask) proj_scale = torch.reshape(torch.mul(dot_p * norm, invstd * invstd), broadcast_mask) if weight is None: grad_scale = torch.reshape(invstd, broadcast_mask) * 1.0 else: grad_scale = torch.reshape(invstd * weight, broadcast_mask) if train: proj = (input - mean) * proj_scale grad_input = ((grad_out - proj) - grad_mean) * grad_scale else: grad_input = grad_out * grad_scale if output_mask[1]: grad_weight = dot_p * invstd elif weight is not None: grad_weight = torch.zeros_like( weight ) # should be None but doesn't work with vjp else: grad_weight = torch.zeros(()) # should be None but doesn't work with vjp if output_mask[2]: grad_bias = grad_output_sum else: grad_bias = torch.zeros_like( grad_output_sum ) # should be None but doesn't work with vjp return (grad_input, grad_weight, grad_bias) _register_jit_decomposition_for_jvp(torch.ops.aten.trace.default, use_python=True) _register_jit_decomposition_for_jvp(torch.ops.aten.nll_loss_backward.default) _register_jit_decomposition_for_jvp(torch.ops.aten.nll_loss2d_backward.default) _register_jit_decomposition_for_jvp(torch.ops.aten._log_softmax_backward_data.default) _register_jit_decomposition_for_jvp(torch.ops.aten._softmax_backward_data.default) _register_jit_decomposition_for_jvp(torch.ops.aten.log_sigmoid_forward.default) _register_jit_decomposition_for_jvp(torch.ops.aten.native_layer_norm_backward.default) _register_jit_decomposition_for_jvp(torch.ops.aten.native_batch_norm_backward.default) _register_jit_decomposition_for_jvp(torch.ops.aten.cudnn_batch_norm_backward.default)