autograd_function_db.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570
  1. import torch
  2. from functools import partial
  3. from torch.testing import make_tensor
  4. from torch.testing._internal.opinfo.core import (
  5. OpInfo,
  6. SampleInput,
  7. )
  8. from torch.testing._internal.common_dtype import all_types_and
  9. import numpy as np
  10. # Note: [autograd.Function db]
  11. #
  12. # This is a collection of autograd.Function test cases written as OpInfos
  13. # so they can easily be consumed by OpInfo-based tests to check if a subsystem
  14. # supports autograd.Function.
  15. #
  16. # Axes:
  17. # - saves {output, input, intermediate, non-tensor}
  18. # - {inputs, output} x {single tensor, tensors, arbitrary objects}
  19. # - Uses {mark_dirty, mark_non_differentiable, once_differentiable}
  20. def to_numpy(tensor):
  21. return tensor.cpu().numpy()
  22. class NumpyCube(torch.autograd.Function):
  23. @staticmethod
  24. def forward(input):
  25. input_np = to_numpy(input)
  26. dinput = torch.tensor(3 * input_np ** 2, device=input.device)
  27. return torch.tensor(input_np ** 3, device=input.device), dinput
  28. @staticmethod
  29. def setup_context(ctx, inputs, output):
  30. ctx.save_for_backward(inputs[0], output[1])
  31. ctx.save_for_forward(inputs[0], output[1])
  32. @staticmethod
  33. def backward(ctx, grad_output, grad_saved):
  34. input, dinput = ctx.saved_tensors
  35. return NumpyMul.apply(grad_output, dinput) + 6 * NumpyMul.apply(grad_saved, input)
  36. @staticmethod
  37. def vmap(info, in_dims, input):
  38. result = NumpyCube.apply(input)
  39. return result, (in_dims[0], in_dims[0])
  40. @staticmethod
  41. def jvp(ctx, input_tangent):
  42. input, dinput = ctx.saved_tensors
  43. return NumpyMul.apply(input_tangent, dinput), 6 * NumpyMul.apply(input_tangent, input)
  44. class CubeGenVmap(torch.autograd.Function):
  45. generate_vmap_rule = True
  46. @staticmethod
  47. def forward(x):
  48. return x ** 3, 3 * x ** 2
  49. @staticmethod
  50. def setup_context(ctx, inputs, outputs):
  51. ctx.save_for_backward(inputs[0], outputs[1])
  52. ctx.save_for_forward(inputs[0], outputs[1])
  53. @staticmethod
  54. def backward(ctx, grad_output, grad_saved):
  55. input, dinput = ctx.saved_tensors
  56. result = grad_output * dinput + 6 * dinput
  57. return result
  58. @staticmethod
  59. def jvp(ctx, input_tangent):
  60. input, dinput = ctx.saved_tensors
  61. return MulGenVmap.apply(input_tangent, dinput), 6 * NumpyMul.apply(input_tangent, input)
  62. def sample_inputs_numpy_cube(opinfo, device, dtype, requires_grad, **kwargs):
  63. make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
  64. yield SampleInput(make_arg(1, low=0.8, high=2), args=())
  65. class NumpyCubeNotComposable(torch.autograd.Function):
  66. @staticmethod
  67. def forward(input):
  68. input_np = to_numpy(input)
  69. return torch.tensor(input_np ** 3, device=input.device), input_np
  70. @staticmethod
  71. def setup_context(ctx, inputs, output):
  72. _, input_np = output
  73. ctx.input_np = input_np
  74. ctx.device = inputs[0].device
  75. @staticmethod
  76. @torch.autograd.function.once_differentiable
  77. def backward(ctx, grad_output, grad_saved):
  78. result_np = 3 * (ctx.input_np ** 2)
  79. return torch.tensor(result_np, device=ctx.device)
  80. class NumpyMul(torch.autograd.Function):
  81. @staticmethod
  82. def forward(x, y):
  83. return torch.tensor(to_numpy(x) * to_numpy(y), device=x.device)
  84. @staticmethod
  85. def setup_context(ctx, inputs, output):
  86. ctx.save_for_backward(*inputs)
  87. ctx.save_for_forward(*inputs)
  88. @staticmethod
  89. def backward(ctx, grad_output):
  90. x, y = ctx.saved_tensors
  91. gx = None
  92. if ctx.needs_input_grad[0]:
  93. gx = NumpyMul.apply(grad_output, y)
  94. gy = None
  95. if ctx.needs_input_grad[1]:
  96. gy = NumpyMul.apply(grad_output, x)
  97. return gx, gy
  98. @staticmethod
  99. def vmap(info, in_dims, x, y):
  100. x_bdim, y_bdim = in_dims
  101. x = x.movedim(x_bdim, -1) if x_bdim is not None else x.unsqueeze(-1)
  102. y = y.movedim(y_bdim, -1) if y_bdim is not None else y.unsqueeze(-1)
  103. result = NumpyMul.apply(x, y)
  104. result = result.movedim(-1, 0)
  105. return result, 0
  106. @staticmethod
  107. def jvp(ctx, x_tangent, y_tangent):
  108. x, y = ctx.saved_tensors
  109. return x_tangent * y + y_tangent * x
  110. def sample_inputs_numpy_mul(opinfo, device, dtype, requires_grad, **kwargs):
  111. make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
  112. # Broadcasting
  113. yield SampleInput(make_arg(4, low=0.9, high=2), args=(make_arg(3, 4, low=0.9, high=2),))
  114. class MulGenVmap(torch.autograd.Function):
  115. generate_vmap_rule = True
  116. @staticmethod
  117. def forward(x, y):
  118. return x * y
  119. @staticmethod
  120. def setup_context(ctx, inputs, outputs):
  121. ctx.save_for_backward(*inputs)
  122. ctx.save_for_forward(*inputs)
  123. @staticmethod
  124. def backward(ctx, grad_output):
  125. x, y = ctx.saved_tensors
  126. gx = None
  127. if ctx.needs_input_grad[0]:
  128. gx = MulGenVmap.apply(grad_output, y)
  129. gy = None
  130. if ctx.needs_input_grad[1]:
  131. gy = MulGenVmap.apply(grad_output, x)
  132. return gx, gy
  133. @staticmethod
  134. def jvp(ctx, x_tangent, y_tangent):
  135. x, y = ctx.saved_tensors
  136. return x_tangent * y + y_tangent * x
  137. class NumpyExp_(torch.autograd.Function):
  138. @staticmethod
  139. def forward(x):
  140. x_np = to_numpy(x)
  141. np.exp(x_np, x_np)
  142. return x
  143. @staticmethod
  144. def setup_context(ctx, inputs, output):
  145. x, = inputs
  146. ctx.mark_dirty(x)
  147. ctx.save_for_backward(output)
  148. ctx.save_for_forward(output)
  149. @staticmethod
  150. def backward(ctx, grad_output):
  151. output, = ctx.saved_tensors
  152. return NumpyMul.apply(grad_output, output)
  153. @staticmethod
  154. def vmap(info, in_dims, x):
  155. NumpyExp_.apply(x)
  156. return x, in_dims[0]
  157. @staticmethod
  158. def jvp(ctx, x_tangent):
  159. # Doesn't call numpy operations because I didn't want to write NumpyMul_
  160. output, = ctx.saved_tensors
  161. x_tangent.mul_(output)
  162. return x_tangent
  163. class NumpySort(torch.autograd.Function):
  164. @staticmethod
  165. def forward(x, dim):
  166. device = x.device
  167. x = to_numpy(x)
  168. ind = np.argsort(x, axis=dim)
  169. ind_inv = np.argsort(ind, axis=dim)
  170. result = np.take_along_axis(x, ind, axis=dim)
  171. return (
  172. torch.tensor(x, device=device),
  173. torch.tensor(ind, device=device),
  174. torch.tensor(ind_inv, device=device),
  175. )
  176. @staticmethod
  177. def setup_context(ctx, inputs, output):
  178. x, dim = inputs
  179. _, ind, ind_inv = output
  180. ctx.mark_non_differentiable(ind, ind_inv)
  181. ctx.save_for_backward(ind, ind_inv)
  182. ctx.save_for_forward(ind, ind_inv)
  183. ctx.dim = dim
  184. @staticmethod
  185. def backward(ctx, grad_output, _0, _1):
  186. ind, ind_inv = ctx.saved_tensors
  187. return NumpyTake.apply(grad_output, ind_inv, ind, ctx.dim), None
  188. @staticmethod
  189. def vmap(info, in_dims, x, dim):
  190. x_bdim, _ = in_dims
  191. x = x.movedim(x_bdim, 0)
  192. # wrap dim
  193. dim = dim if dim >= 0 else dim + x.dim() - 1
  194. return NumpySort.apply(x, dim + 1), (0, 0, 0)
  195. @staticmethod
  196. def jvp(ctx, x_tangent, _):
  197. ind, ind_inv = ctx.saved_tensors
  198. return NumpyTake.apply(x_tangent, ind, ind_inv, ctx.dim), None, None
  199. class SortGenVmap(torch.autograd.Function):
  200. generate_vmap_rule = True
  201. @staticmethod
  202. def forward(x, dim):
  203. device = x.device
  204. ind = torch.argsort(x, dim=dim)
  205. ind_inv = torch.argsort(ind, axis=dim)
  206. result = torch.take_along_dim(x, ind, dim=dim)
  207. return result, ind, ind_inv
  208. @staticmethod
  209. def setup_context(ctx, inputs, outputs):
  210. x, dim = inputs
  211. _, ind, ind_inv = outputs
  212. ctx.mark_non_differentiable(ind, ind_inv)
  213. ctx.save_for_backward(ind, ind_inv)
  214. ctx.save_for_forward(ind, ind_inv)
  215. ctx.dim = dim
  216. @staticmethod
  217. def backward(ctx, grad_output, _0, _1):
  218. ind, ind_inv = ctx.saved_tensors
  219. return TakeGenVmap.apply(grad_output, ind_inv, ind, ctx.dim), None
  220. @staticmethod
  221. def jvp(ctx, x_tangent, _):
  222. ind, ind_inv = ctx.saved_tensors
  223. return TakeGenVmap.apply(x_tangent, ind, ind_inv, ctx.dim), None, None
  224. def sample_inputs_numpy_sort(opinfo, device, dtype, requires_grad, **kwargs):
  225. make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
  226. yield SampleInput(make_arg(3, 5), args=(1,))
  227. class NumpyTake(torch.autograd.Function):
  228. @staticmethod
  229. def forward(x, ind, ind_inv, dim):
  230. device = x.device
  231. x = to_numpy(x)
  232. ind = to_numpy(ind)
  233. return torch.tensor(np.take_along_axis(x, ind, dim), device=device)
  234. @staticmethod
  235. def setup_context(ctx, inputs, output):
  236. x, ind, ind_inv, dim = inputs
  237. ctx.save_for_backward(ind, ind_inv)
  238. ctx.save_for_forward(ind, ind_inv)
  239. ctx.dim = dim
  240. @staticmethod
  241. def backward(ctx, grad_output):
  242. ind, ind_inv = ctx.saved_tensors
  243. result = NumpyTake.apply(grad_output, ind_inv, ind, ctx.dim)
  244. return result, None, None, None
  245. @staticmethod
  246. def vmap(info, in_dims, x, ind, ind_inv, dim):
  247. x_bdim, ind_bdim, ind_inv_bdim, _ = in_dims
  248. # wrap dim
  249. logical_dim = x.dim() if x_bdim is None else x_bdim - 1
  250. dim = dim if dim >= 0 else dim + logical_dim
  251. def expand_bdim(x, x_bdim):
  252. if x_bdim is None:
  253. return x.expand(info.batch_size, *x.shape)
  254. return x.movedim(x_bdim, 0)
  255. x = expand_bdim(x, x_bdim)
  256. ind = expand_bdim(ind, ind_bdim)
  257. ind_inv = expand_bdim(ind_inv, ind_inv_bdim)
  258. return NumpyTake.apply(x, ind, ind_inv, dim + 1), 0
  259. @staticmethod
  260. def jvp(ctx, x_tangent, ind_tangent, ind_inv_tangent, _):
  261. assert ind_tangent is None
  262. assert ind_inv_tangent is None
  263. ind, ind_inv = ctx.saved_tensors
  264. return NumpyTake.apply(x_tangent, ind, ind_inv, ctx.dim)
  265. class TakeGenVmap(torch.autograd.Function):
  266. generate_vmap_rule = True
  267. @staticmethod
  268. def forward(x, ind, ind_inv, dim):
  269. return torch.take_along_dim(x, ind, dim)
  270. @staticmethod
  271. def setup_context(ctx, inputs, outputs):
  272. x, ind, ind_inv, dim = inputs
  273. ctx.save_for_backward(ind, ind_inv)
  274. ctx.save_for_forward(ind, ind_inv)
  275. ctx.dim = dim
  276. @staticmethod
  277. def backward(ctx, grad_output):
  278. ind, ind_inv = ctx.saved_tensors
  279. result = TakeGenVmap.apply(grad_output, ind_inv, ind, ctx.dim)
  280. return result, None, None, None
  281. @staticmethod
  282. def jvp(ctx, x_tangent, ind_tangent, ind_inv_tangent, _):
  283. ind, ind_inv = ctx.saved_tensors
  284. return TakeGenVmap.apply(x_tangent, ind, ind_inv, ctx.dim)
  285. class Select(torch.autograd.Function):
  286. @staticmethod
  287. def forward(x, idx):
  288. return x[idx]
  289. @staticmethod
  290. def setup_context(ctx, inputs, output):
  291. x, idx = inputs
  292. ctx.x_shape = x.shape
  293. ctx.idx = idx
  294. @staticmethod
  295. def backward(ctx, grad_output):
  296. result = grad_output.new_zeros(ctx.x_shape)
  297. result[ctx.idx] = grad_output
  298. return result, None
  299. @staticmethod
  300. def vmap(info, in_dims, x, idx):
  301. x_bdim, _ = in_dims
  302. x = x.movedim(x_bdim, 1)
  303. return Select.apply(x, idx), 0
  304. @staticmethod
  305. def jvp(ctx, x_tangent, _):
  306. return Select.apply(x_tangent, ctx.idx)
  307. class SelectGenVmap(torch.autograd.Function):
  308. generate_vmap_rule = True
  309. @staticmethod
  310. def forward(x, idx):
  311. return x[idx]
  312. @staticmethod
  313. def setup_context(ctx, inputs, outputs):
  314. x, idx = inputs
  315. ctx.x_shape = x.shape
  316. ctx.idx = idx
  317. @staticmethod
  318. def backward(ctx, grad_output):
  319. result = grad_output.new_zeros(ctx.x_shape)
  320. result[ctx.idx] = grad_output
  321. return result, None
  322. @staticmethod
  323. def jvp(ctx, x_tangent, _):
  324. return SelectGenVmap.apply(x_tangent, ctx.idx)
  325. def sample_inputs_select(opinfo, device, dtype, requires_grad, **kwargs):
  326. make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
  327. yield SampleInput(make_arg(3, 5), args=(2,))
  328. class ScaleGradGenVmap(torch.autograd.Function):
  329. generate_vmap_rule = True
  330. scale = 3.14
  331. @staticmethod
  332. def forward(x):
  333. return x.clone()
  334. @staticmethod
  335. def setup_context(ctx, inputs, outputs):
  336. pass
  337. @staticmethod
  338. def backward(ctx, grad_output):
  339. return grad_output * ScaleGradGenVmap.scale
  340. @staticmethod
  341. def jvp(ctx, x_tangent):
  342. return x_tangent * ScaleGradGenVmap.scale
  343. class ZeroGradientsGenVmap(torch.autograd.Function):
  344. generate_vmap_rule = True
  345. @staticmethod
  346. def forward(x, y):
  347. return x.clone(), y.clone()
  348. @staticmethod
  349. def setup_context(ctx, inputs, outputs):
  350. pass
  351. @staticmethod
  352. def backward(ctx, gx, gy):
  353. # Intentionally returning torch.zeros instead of zeros_like or new_zeros.
  354. # Also intentionally not None.
  355. return (
  356. # Intentionally too-large gradient
  357. torch.zeros(3, 4, *gx.shape, dtype=gx.dtype, device=gx.device),
  358. torch.zeros(gy.shape, dtype=gy.dtype, device=gy.device),
  359. )
  360. @staticmethod
  361. def jvp(ctx, gx, gy):
  362. # Intentionally returning torch.zeros instead of zeros_like or new_zeros.
  363. # Also intentionally not None.
  364. return (
  365. torch.zeros(gx.shape, dtype=gx.dtype, device=gx.device),
  366. torch.zeros(gy.shape, dtype=gy.dtype, device=gy.device),
  367. )
  368. autograd_function_db = [
  369. OpInfo(
  370. 'NumpyCubeAutogradFunction',
  371. op=NumpyCube.apply,
  372. supports_forward_ad=True,
  373. supports_fwgrad_bwgrad=True,
  374. sample_inputs_func=sample_inputs_numpy_cube,
  375. dtypes=all_types_and(torch.bool, torch.half),
  376. supports_out=False,
  377. ),
  378. OpInfo(
  379. 'NumpyExpMarkDirtyAutogradFunction',
  380. op=lambda x: NumpyExp_.apply(x.clone()),
  381. inplace_variant=NumpyExp_.apply,
  382. supports_forward_ad=True,
  383. supports_fwgrad_bwgrad=True,
  384. sample_inputs_func=sample_inputs_numpy_cube,
  385. dtypes=all_types_and(torch.bool, torch.half),
  386. supports_out=False,
  387. ),
  388. OpInfo(
  389. 'NumpyMulAutogradFunction',
  390. op=NumpyMul.apply,
  391. supports_forward_ad=True,
  392. supports_fwgrad_bwgrad=True,
  393. sample_inputs_func=sample_inputs_numpy_mul,
  394. dtypes=all_types_and(torch.bool, torch.half),
  395. supports_out=False,
  396. ),
  397. OpInfo(
  398. 'NumpyCubeNotComposableAutogradFunction',
  399. op=lambda x: NumpyCubeNotComposable.apply(x)[0],
  400. supports_forward_ad=False,
  401. supports_fwgrad_bwgrad=False,
  402. sample_inputs_func=sample_inputs_numpy_cube,
  403. dtypes=all_types_and(torch.bool, torch.half),
  404. supports_out=False,
  405. ),
  406. OpInfo(
  407. 'NumpySortAutogradFunction',
  408. op=NumpySort.apply,
  409. supports_forward_ad=False,
  410. supports_fwgrad_bwgrad=False,
  411. sample_inputs_func=sample_inputs_numpy_sort,
  412. dtypes=all_types_and(torch.bool, torch.half),
  413. supports_out=False,
  414. gradcheck_wrapper=lambda y, ind: y,
  415. ),
  416. OpInfo(
  417. 'SelectAutogradFunction',
  418. op=Select.apply,
  419. supports_forward_ad=True,
  420. supports_fwgrad_bwgrad=True,
  421. sample_inputs_func=sample_inputs_select,
  422. dtypes=all_types_and(torch.bool, torch.half),
  423. supports_out=False,
  424. ),
  425. OpInfo(
  426. 'CubeGenVmapAutogradFunction',
  427. op=CubeGenVmap.apply,
  428. supports_forward_ad=True,
  429. supports_fwgrad_bwgrad=True,
  430. sample_inputs_func=sample_inputs_numpy_cube,
  431. dtypes=all_types_and(torch.bool, torch.half),
  432. supports_out=False,
  433. ),
  434. OpInfo(
  435. 'MulGenVmapAutogradFunction',
  436. op=MulGenVmap.apply,
  437. supports_forward_ad=True,
  438. supports_fwgrad_bwgrad=True,
  439. sample_inputs_func=sample_inputs_numpy_mul,
  440. dtypes=all_types_and(torch.bool, torch.half),
  441. supports_out=False,
  442. ),
  443. OpInfo(
  444. 'SortGenVmapAutogradFunction',
  445. op=SortGenVmap.apply,
  446. supports_forward_ad=True,
  447. supports_fwgrad_bwgrad=True,
  448. sample_inputs_func=sample_inputs_numpy_sort,
  449. dtypes=all_types_and(torch.bool, torch.half),
  450. supports_out=False,
  451. gradcheck_wrapper=lambda y, ind: y,
  452. ),
  453. OpInfo(
  454. 'SelectGenVmapAutogradFunction',
  455. op=SelectGenVmap.apply,
  456. supports_forward_ad=True,
  457. supports_fwgrad_bwgrad=True,
  458. sample_inputs_func=sample_inputs_select,
  459. dtypes=all_types_and(torch.bool, torch.half),
  460. supports_out=False,
  461. ),
  462. OpInfo(
  463. 'ScaleGradGenVmapAutogradFunction',
  464. op=ScaleGradGenVmap.apply,
  465. supports_forward_ad=True,
  466. supports_fwgrad_bwgrad=True,
  467. sample_inputs_func=sample_inputs_numpy_cube,
  468. dtypes=all_types_and(torch.bool, torch.half),
  469. supports_out=False,
  470. ),
  471. OpInfo(
  472. 'ZeroGradientsGenVmapAutogradFunction',
  473. op=ZeroGradientsGenVmap.apply,
  474. supports_forward_ad=True,
  475. supports_fwgrad_bwgrad=True,
  476. sample_inputs_func=sample_inputs_numpy_mul,
  477. dtypes=all_types_and(torch.bool, torch.half),
  478. supports_out=False,
  479. ),
  480. ]