decompositions_for_jvp.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294
  1. import inspect
  2. from typing import Callable, Dict, List, Optional, Tuple
  3. import torch
  4. import torch._decomp
  5. from torch import Tensor
  6. decomposition_table = torch._decomp.decomposition_table
  7. decomposition_table_for_jvp: Dict[torch._ops.OpOverload, Callable] = {}
  8. register_decomposition = torch._decomp.register_decomposition
  9. aten = torch.ops.aten
  10. # NOTE: [forward-mode AD decompositions mechanism]
  11. #
  12. # The mechanism is in VariableType,
  13. # IF any inputs have forward grad
  14. # AND there is no forward AD formula implemented
  15. # AND the functions is actually differentiable
  16. # run the decomposition
  17. # See run_jit_decomposition_with_args_for_jvp
  18. # We currently use python decompositions that we torchscript.
  19. #
  20. # Note that we would be building the backward graph at the decomposed level
  21. # too, but that is OK, because we would've errored out otherwise anyway.
  22. #
  23. # TODO: The mechanism we are using to register decompositions doesn't
  24. # seem to be exclusively used for jvp. So open question here is whether
  25. # torch/csrc/jit/runtime/decomposition_registry.cpp is being used for other things.
  26. # If that is the case, we may go down the decomposition path unexpectedly
  27. # (and possibly produce an unintelligible error) vs erroring out earlier and
  28. # printing that the forward AD formula is not implemented.
  29. #
  30. # The solution to this may be to have a explicitly white list control when
  31. # to enable the decomposition.
  32. def maybe_register_decomposition(op):
  33. def decorator(f):
  34. try:
  35. return register_decomposition(op)(f)
  36. except Exception:
  37. return f
  38. return decorator
  39. # Functions where we need a special decomposition for jvp but there's another version that
  40. # should be used more generally (ex. for jvp we need to recompute the mean and variance for
  41. # the backwards of a normalization function. Without jvp, it should used the saved value)
  42. decomposition_table_for_jvp = {}
  43. def register_decomposition_for_jvp(fn):
  44. return register_decomposition(fn, registry=decomposition_table_for_jvp)
  45. def _register_jit_decomposition_for_jvp(decomp, use_python=False):
  46. if decomp in decomposition_table_for_jvp:
  47. decomposition_table_used = decomposition_table_for_jvp
  48. elif decomp in decomposition_table:
  49. decomposition_table_used = decomposition_table
  50. else:
  51. raise RuntimeError(f"could not find decomposition for {decomp}")
  52. decomp_fn = decomposition_table_used[decomp]
  53. if use_python:
  54. decomp_fn = torch.jit.ignore(decomp_fn)
  55. sig = inspect.signature(decomp_fn)
  56. # Create a string wrapping the function from the signature
  57. # example output:
  58. # def wrapped_decomp(x: torch.Tensor, y: int, z: int):
  59. # return decomp_fn(x, y, z)
  60. # Thanks copilot!
  61. def get_function_def(sig):
  62. param_def = [f"{param_str}" for param_str in sig.parameters.values()]
  63. param_use = [f"{param_str}" for param_str in sig.parameters.keys()]
  64. return f"def wrapped_decomp({', '.join(param_def)}):\n return decomp_fn({', '.join(param_use)})\n"
  65. f_str = get_function_def(sig)
  66. graph = torch.jit.CompilationUnit(f_str).wrapped_decomp.graph
  67. else:
  68. graph = torch.jit.script(decomp_fn).graph
  69. torch.jit._register_decomposition(decomp, graph)
  70. # The only decompositions here are temporary or hacks for the purposes of jvp
  71. # TODO: do these also belong here?
  72. @maybe_register_decomposition(aten.trace.default)
  73. def trace(self: Tensor) -> Tensor:
  74. return torch.sum(torch.diag(self))
  75. @maybe_register_decomposition(aten.log_sigmoid_forward.default)
  76. def log_sigmoid_forward(self: Tensor) -> Tuple[Tensor, Tensor]:
  77. min = torch.minimum(self.new_zeros(()), self)
  78. z = torch.exp(-torch.abs(self))
  79. if self.is_cuda:
  80. buffer = self.new_zeros((0,))
  81. else:
  82. buffer = z
  83. return min - torch.log1p(z), buffer
  84. def recompute_mean_var(
  85. input: Tensor, rstd: Tensor, inner_dim_indices: List[int], keepdim: bool
  86. ):
  87. # for most norm decompositions, it will be the same as the core version except for here.
  88. # We recompute the mean and variance so that they track gradients through input
  89. mean = torch.mean(input, dim=inner_dim_indices, keepdim=keepdim)
  90. var = torch.var(input, dim=inner_dim_indices, unbiased=False, keepdim=keepdim)
  91. eps = torch.pow(1 / rstd, 2) - var # this makes me so sad inside
  92. eps = eps.detach()
  93. rstd = 1 / torch.sqrt(var + eps)
  94. return mean, rstd
  95. @register_decomposition_for_jvp(aten.native_layer_norm_backward)
  96. def native_layer_norm_backward(
  97. grad_out: Tensor,
  98. input: Tensor,
  99. normalized_shape: List[int],
  100. mean: Tensor,
  101. rstd: Tensor,
  102. weight: Optional[Tensor],
  103. bias: Optional[Tensor],
  104. output_mask: List[bool],
  105. ) -> Tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor]]:
  106. input_shape = input.shape
  107. input_ndim = input.dim()
  108. axis = input_ndim - len(normalized_shape)
  109. inner_dims = input_shape[axis:]
  110. outer_dims = input_shape[:axis]
  111. inner_dim_indices = list(range(axis, input_ndim))
  112. outer_dim_indices = list(range(0, axis))
  113. N = 1
  114. for i in inner_dims:
  115. N *= i
  116. M = 1
  117. for i in outer_dims:
  118. M *= i
  119. if M <= 0 or N <= 0:
  120. return (
  121. input.new_zeros(input_shape),
  122. input.new_zeros(input_shape[axis:]),
  123. input.new_zeros(input_shape[axis:]),
  124. )
  125. mean_, rstd_ = recompute_mean_var(input, rstd, inner_dim_indices, keepdim=True)
  126. x_hat = (input - mean_) * rstd_
  127. if weight is not None:
  128. grad_x_hat = grad_out * weight
  129. else:
  130. grad_x_hat = grad_out
  131. a = grad_x_hat * N
  132. b = torch.sum(grad_x_hat, inner_dim_indices, True)
  133. c1 = torch.mul(grad_x_hat, x_hat)
  134. c2 = torch.sum(c1, inner_dim_indices, True)
  135. c3 = torch.mul(x_hat, c2)
  136. inner = a - b - c3
  137. if output_mask[0]:
  138. d_input: Optional[Tensor] = (rstd_ / N) * inner
  139. else:
  140. d_input = torch.zeros_like(input) # should be None but doesn't work with vjp
  141. if output_mask[1] and weight is not None:
  142. if len(outer_dim_indices) > 0:
  143. d_weight: Optional[Tensor] = torch.sum(
  144. grad_out * x_hat, outer_dim_indices, False
  145. )
  146. else:
  147. d_weight = grad_out * x_hat
  148. elif weight is not None:
  149. d_weight = torch.zeros_like(weight) # should be None but doesn't work with vjp
  150. else:
  151. d_weight = torch.zeros(()) # should be None but doesn't work with vjp
  152. if output_mask[2] and bias is not None:
  153. if len(outer_dim_indices) > 0:
  154. d_bias: Optional[Tensor] = torch.sum(grad_out, outer_dim_indices, False)
  155. else:
  156. d_bias = grad_out.clone()
  157. elif bias is not None:
  158. d_bias = torch.zeros_like(bias) # should be None but doesn't work with vjp
  159. else:
  160. d_bias = torch.zeros(()) # should be None but doesn't work with vjp
  161. return (d_input, d_weight, d_bias)
  162. def prod(x: List[int]):
  163. r = 1
  164. for i in x:
  165. r *= i
  166. return r
  167. @register_decomposition_for_jvp(aten.native_batch_norm_backward)
  168. def native_batch_norm_backward(
  169. grad_out: Tensor,
  170. input: Tensor,
  171. weight: Optional[Tensor],
  172. running_mean: Optional[Tensor],
  173. running_var: Optional[Tensor],
  174. save_mean: Optional[Tensor],
  175. save_invstd: Optional[Tensor],
  176. train: bool,
  177. eps: float,
  178. output_mask: List[bool],
  179. ) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
  180. input_shape = input.shape
  181. input_rank = input.dim()
  182. assert input_rank >= 2, "rank of the input must be at least 2"
  183. axis = 1
  184. num_features = prod(input_shape) / input_shape[axis] # type: ignore[arg-type]
  185. mean = save_mean
  186. invstd = save_invstd
  187. if train:
  188. assert (
  189. save_mean is not None and save_invstd is not None
  190. ), "when train=True, save_mean and save_invstd are required"
  191. reduciton_dims = [0] + list(range(2, input.dim()))
  192. assert invstd is not None # for typing
  193. mean, invstd = recompute_mean_var(input, invstd, reduciton_dims, keepdim=False)
  194. else:
  195. assert running_mean is not None and running_var is not None
  196. mean = running_mean
  197. invstd = torch.rsqrt(running_var + eps)
  198. assert invstd is not None and mean is not None
  199. broadcast_mask = [1] * input_rank
  200. broadcast_mask[axis] = input_shape[axis]
  201. reduction_axes: List[int] = []
  202. for i in range(input_rank):
  203. if i != axis:
  204. reduction_axes.append(i)
  205. mean = torch.reshape(mean, broadcast_mask)
  206. norm = 1.0 / num_features
  207. grad_output_sum = torch.sum(grad_out, reduction_axes)
  208. dot_p = torch.sum(grad_out * (input - mean), reduction_axes)
  209. grad_mean = torch.reshape(grad_output_sum * norm, broadcast_mask)
  210. proj_scale = torch.reshape(torch.mul(dot_p * norm, invstd * invstd), broadcast_mask)
  211. if weight is None:
  212. grad_scale = torch.reshape(invstd, broadcast_mask) * 1.0
  213. else:
  214. grad_scale = torch.reshape(invstd * weight, broadcast_mask)
  215. if train:
  216. proj = (input - mean) * proj_scale
  217. grad_input = ((grad_out - proj) - grad_mean) * grad_scale
  218. else:
  219. grad_input = grad_out * grad_scale
  220. if output_mask[1]:
  221. grad_weight = dot_p * invstd
  222. elif weight is not None:
  223. grad_weight = torch.zeros_like(
  224. weight
  225. ) # should be None but doesn't work with vjp
  226. else:
  227. grad_weight = torch.zeros(()) # should be None but doesn't work with vjp
  228. if output_mask[2]:
  229. grad_bias = grad_output_sum
  230. else:
  231. grad_bias = torch.zeros_like(
  232. grad_output_sum
  233. ) # should be None but doesn't work with vjp
  234. return (grad_input, grad_weight, grad_bias)
  235. _register_jit_decomposition_for_jvp(torch.ops.aten.trace.default, use_python=True)
  236. _register_jit_decomposition_for_jvp(torch.ops.aten.nll_loss_backward.default)
  237. _register_jit_decomposition_for_jvp(torch.ops.aten.nll_loss2d_backward.default)
  238. _register_jit_decomposition_for_jvp(torch.ops.aten._log_softmax_backward_data.default)
  239. _register_jit_decomposition_for_jvp(torch.ops.aten._softmax_backward_data.default)
  240. _register_jit_decomposition_for_jvp(torch.ops.aten.log_sigmoid_forward.default)
  241. _register_jit_decomposition_for_jvp(torch.ops.aten.native_layer_norm_backward.default)
  242. _register_jit_decomposition_for_jvp(torch.ops.aten.native_batch_norm_backward.default)
  243. _register_jit_decomposition_for_jvp(torch.ops.aten.cudnn_batch_norm_backward.default)