jit_metaprogramming_utils.py 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722
  1. # Torch
  2. from torch.jit.annotations import BroadcastingList2, BroadcastingList3 # noqa: F401
  3. import torch.nn.functional as F
  4. import torch
  5. import torch.cuda
  6. import torch.jit
  7. import torch.jit._logging
  8. import torch.jit.frontend
  9. from torch.testing._internal.common_nn import module_tests, new_module_tests
  10. from torch.testing._internal.common_utils import is_iterable_of_tensors
  11. import collections
  12. from copy import deepcopy
  13. from typing import Any, Dict, List, Union
  14. import math # noqa: F401
  15. # Testing utils
  16. from torch import inf
  17. # TODO: include files like this should not set the default dtype
  18. torch.set_default_dtype(torch.double)
  19. L = 20
  20. M = 10
  21. S = 5
  22. def unpack_variables(args):
  23. if isinstance(args, tuple):
  24. return tuple(unpack_variables(elem) for elem in args)
  25. else:
  26. return args
  27. class dont_convert(tuple):
  28. pass
  29. non_differentiable = collections.namedtuple('non_differentiable', ['tensor'])
  30. def create_input(call_args, requires_grad=True, non_contiguous=False, call_kwargs=None, dtype=torch.double, device=None):
  31. if not isinstance(call_args, tuple):
  32. call_args = (call_args,)
  33. def map_arg(arg):
  34. def maybe_non_contig(tensor):
  35. if not non_contiguous or tensor.numel() < 2:
  36. return tensor.clone()
  37. return noncontiguous_like(tensor)
  38. def conjugate(tensor):
  39. return tensor.conj()
  40. if isinstance(arg, (torch.Size, dont_convert)):
  41. return arg
  42. elif isinstance(arg, tuple) and len(arg) == 0:
  43. var = conjugate(torch.randn((), dtype=dtype, device=device))
  44. var.requires_grad = requires_grad
  45. return var
  46. elif isinstance(arg, tuple) and not isinstance(arg[0], torch.Tensor):
  47. return conjugate(maybe_non_contig(torch.randn(*arg, dtype=dtype, device=device))).requires_grad_(requires_grad)
  48. # double check casting
  49. elif isinstance(arg, non_differentiable):
  50. if isinstance(arg.tensor, torch.Tensor):
  51. if arg.tensor.dtype == torch.float:
  52. return maybe_non_contig(arg.tensor.to(dtype=torch.double, device=device))
  53. if arg.tensor.dtype == torch.cfloat:
  54. return conjugate(maybe_non_contig(arg.tensor.to(dtype=torch.cdouble, device=device)))
  55. return conjugate(maybe_non_contig(arg.tensor.to(device=device)))
  56. return conjugate(maybe_non_contig(arg.tensor.to(device=device)))
  57. elif isinstance(arg, torch.Tensor):
  58. if arg.dtype == torch.float:
  59. arg = arg.double()
  60. if arg.dtype == torch.cfloat:
  61. arg = arg.to(torch.cdouble)
  62. if arg.is_complex() != dtype.is_complex:
  63. raise RuntimeError("User provided tensor is real for a test that runs with complex dtype, ",
  64. "which is not supported for now")
  65. # NOTE: We do clone() after detach() here because we need to be able to change size/storage of v afterwards
  66. v = conjugate(maybe_non_contig(arg)).detach().to(device=device).clone()
  67. v.requires_grad = requires_grad and (v.is_floating_point() or v.is_complex())
  68. return v
  69. elif callable(arg):
  70. return map_arg(arg(dtype=dtype, device=device))
  71. else:
  72. return arg
  73. args_out = tuple(map_arg(arg) for arg in call_args)
  74. kwargs_out = {k: map_arg(v) for k, v in call_kwargs.items()} if call_kwargs else {}
  75. return args_out, kwargs_out
  76. # NB: JIT script tests for all nn functional interfaces, script mode does
  77. # not support in_place operations yet, so no inplace operation tests added.
  78. # removed all the deprecated functions
  79. #
  80. # (
  81. # method name,
  82. # input size/constructing fn,
  83. # args (tuple represents shape of a tensor arg),
  84. # test variant name(will be used at test name suffix,
  85. # 'inplace' skips grad tests), // optional
  86. # (True, nonfusible_nodes, fusible_nodes) for autodiff // optional
  87. # fn to determine if test should be skipped, // optional
  88. # fn mapping output to part that should be gradcheck'ed, // optional
  89. # kwargs for function, // optional
  90. # )
  91. nn_functional_tests = [
  92. ('conv1d', (S, S, S), ((S, S, S),)),
  93. ('conv2d', (S, S, S, S), ((S, S, S, S),)),
  94. ('conv3d', (S, S, S, S, S), ((S, S, S, S, S),)),
  95. ('conv_transpose1d', (S, S, S), ((S, S, S),)),
  96. ('conv_transpose2d', (S, S, S, S), ((S, S, S, S),)),
  97. ('conv_transpose3d', (S, S, S, S, S), ((S, S, S, S, S),)),
  98. ('conv_tbc', (S, S, S), ((S, S, S), (S,), 2)),
  99. ('avg_pool1d', (S, S, S), (3,)),
  100. ('avg_pool2d', (S, S, S, S), (3,), '', (True,)),
  101. ('avg_pool3d', (S, S, S, S, S), (3,)),
  102. ('fractional_max_pool2d', (S, S, S, S), (3, [2, 3],)),
  103. ('max_pool1d', (S, S, S), (2, 1)),
  104. ('max_pool1d', (S, S, S), (2, 1, 1, 1, False, True), 'with_indices'),
  105. ('max_pool2d', (S, S, S, S), (2, 1), '', (True, 'aten::max_pool2d_with_indices')),
  106. ('max_pool2d', (S, S, S, S), (2, 1, 1, 1, False, True), 'with_indices', (True, 'aten::max_pool2d_with_indices')),
  107. ('max_pool3d', (S, S, S, S, S), (2, 1)),
  108. ('max_unpool1d', torch.tensor([[[2., 4]]]), (torch.tensor([[[1, 3]]]), 2, 2, 0)),
  109. ('max_unpool2d', torch.tensor([[[[2., 4]]]]), (torch.tensor([[[[1, 3]]]]), 2, 2, 0)),
  110. ('max_unpool3d', torch.tensor([[[[[2., 4]]]]]), (torch.tensor([[[[[1, 3]]]]]), 2, 2, 0)),
  111. ('lp_pool1d', (S, S, S), (2., 3, 2,)),
  112. ('lp_pool2d', (S, S, S, S), (2., 3, 2,)),
  113. ('adaptive_max_pool1d', (S, S, S), (5,)),
  114. ('adaptive_max_pool2d', (S, S, S, S), ([5, 7],)),
  115. ('adaptive_max_pool3d', (S, S, S, S, S), ([3, 2, 2],)),
  116. ('adaptive_avg_pool1d', (S, S, S), (5,), '', (True,)),
  117. ('adaptive_avg_pool2d', (S, S, S, S), ([5, 7],), '', (True,)),
  118. ('adaptive_avg_pool3d', (S, S, S, S, S), ([3, 2, 2],), '', (True,)),
  119. ('dropout', (S, S, S), (0.5,), '', (True, 'aten::native_dropout')),
  120. ('alpha_dropout', (S, S, S), (0.5,)),
  121. ('dropout2d', (S, S, S), (0.5,)),
  122. ('dropout2d', (S, S, S, S), (0.5,), 'batched'),
  123. ('dropout3d', (S, S, S, S), (0.5,)),
  124. ('dropout3d', (S, S, S, S, S), (0.5,), 'batched'),
  125. ('feature_alpha_dropout', (S, S, S), (0.5,)),
  126. ('threshold', (S, S, S), (0.1, 2.), '', (True,)),
  127. ('threshold', (S, S, S), (0.1, 2., True), 'inplace'),
  128. ('relu', (S, S, S), (), '', (True,)),
  129. ('relu', (S, S, S), (), 'inplace'),
  130. ('glu', (S - 1, S - 1, S - 1), (),),
  131. ('hardtanh', (S, S, S), (-0.5, 0.5), '', (True,)),
  132. ('hardtanh', (S, S, S), (-0.5, 0.5, True), 'inplace'),
  133. ('relu6', (S, S, S), (), '', (True,)),
  134. ('relu6', (S, S, S), (True), 'inplace'),
  135. ('elu', (S, S, S), (0.9,),),
  136. ('elu', (S, S, S), (0.9, True), 'inplace'),
  137. ('selu', (S, S, S), (),),
  138. ('selu', (S, S, S), (True), 'inplace'),
  139. ('celu', (S, S, S), (0.9,),),
  140. ('celu', (S, S, S), (0.9, True), 'inplace'),
  141. ('leaky_relu', (S, S, S), (0.02,), '', (True,)),
  142. ('leaky_relu', (S, S, S), (0.02,), 'inplace'),
  143. ('rrelu', (S, S), (0.1, 0.3, False),),
  144. ('rrelu', (S, S), (0.1, 0.3, False, True), 'inplace'),
  145. ('hardshrink', (S, S, S), (0.4,), '', (True,)),
  146. ('tanhshrink', (S, S, S), (),),
  147. ('softsign', (S, S, S), (),),
  148. ('softplus', (S, S, S), (), '', (True,)),
  149. ('softmin', (S, S, S), (0,),),
  150. ('softmax', (S, S, S), (0,), '', (True,)),
  151. ('softmax', (S, S, S), (0, 3, torch.double), 'with_all_args', (True,)),
  152. ('tanh', (S, S, S), (), '', (True,)),
  153. ('sigmoid', (S, S, S), (), '', (True,)),
  154. ('silu', (S, S, S), (), '', (True,)),
  155. ('log_softmax', (S, S, S), (0,), '', (True,)),
  156. ('linear', (S, S), ((M, S),), '', (True, ['aten::linear'])),
  157. ('linear', (S, S), ((M, S), (M,)), 'addmm', (True, ['aten::linear'])),
  158. ('bilinear', (S, S, S), ((S, S, M), torch.zeros(M, S, M),),),
  159. ('embedding', torch.tensor([[1, 2, 4, 5], [4, 3, 2, 5]]), (torch.rand(6, 3), ), '', (True,)),
  160. ('embedding_bag', torch.tensor([1, 2, 4, 2]), (torch.rand(5, 3), torch.tensor([0, 4]),),),
  161. ('batch_norm', (S, S),
  162. (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), None, None, True, ),
  163. 'training', (True, 'aten::_batch_norm_impl_index')),
  164. ('batch_norm', (0, S, S, S),
  165. (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)),
  166. non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), True, ),
  167. 'size_zero', (True, 'aten::_batch_norm_impl_index')),
  168. ('batch_norm', (0, S, S, S),
  169. (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)),
  170. non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), True, ),
  171. 'size_zero_inference', (True, 'aten::_batch_norm_impl_index')),
  172. ('batch_norm', (S, S),
  173. (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)),
  174. non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), True, ),
  175. 'with_weight_and_bias_training', (True, 'aten::_batch_norm_impl_index')),
  176. ('batch_norm', (S, S), (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)),
  177. None, non_differentiable(torch.ones(S)), True, ),
  178. 'with_only_bias_training', (True, 'aten::_batch_norm_impl_index')),
  179. ('batch_norm', (S, S), (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)),
  180. non_differentiable(torch.randn(S)), None, True, ),
  181. 'with_only_weight_training', (True, 'aten::_batch_norm_impl_index')),
  182. ('batch_norm', (S, S), (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)),
  183. None, None, False, ),
  184. 'inference', (True, 'aten::_batch_norm_impl_index')),
  185. ('batch_norm', (S, S), (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)),
  186. non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), False, ),
  187. 'with_weight_and_bias_inference', (True, 'aten::_batch_norm_impl_index')),
  188. ('batch_norm', (S, S), (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)),
  189. None, non_differentiable(torch.ones(S)), False, ),
  190. 'with_only_bias_inference', (True, 'aten::_batch_norm_impl_index')),
  191. ('batch_norm', (S, S), (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)),
  192. non_differentiable(torch.randn(S)), None, False, ),
  193. 'with_only_weight_inference', (True, 'aten::_batch_norm_impl_index')),
  194. ('instance_norm', (S, S, S), (non_differentiable(torch.zeros(S)), non_differentiable(torch.ones(S))),),
  195. ('layer_norm', (S, S, S, S), ([5],), '',
  196. (False, ['aten::contiguous', 'aten::_batch_norm_impl_index'])),
  197. ('layer_norm', (S, S, S, S), ([5], non_differentiable(torch.rand(S)),), 'with_only_weight',
  198. (False, ['aten::contiguous', 'aten::_batch_norm_impl_index'])),
  199. ('layer_norm', (S, S, S, S), ([5], None, non_differentiable(torch.rand(S)),), 'with_only_bias',
  200. (False, ['aten::contiguous', 'aten::_batch_norm_impl_index'])),
  201. ('layer_norm', (S, S, S, S), ([5], non_differentiable(torch.rand(S)),
  202. non_differentiable(torch.rand(S))), 'with_weight_and_bias',
  203. (False, ['aten::contiguous', 'aten::_batch_norm_impl_index', 'aten::addcmul'])),
  204. ('group_norm', (S, S, S), (1, torch.rand(5),),),
  205. ('local_response_norm', (S, S, S), (2, ),),
  206. ('nll_loss', F.log_softmax(torch.randn(3, 5), dim=0), (torch.tensor([1, 0, 4]),), '',),
  207. ('poisson_nll_loss', torch.rand(S, 2), (torch.rand(S, 2),),),
  208. ('poisson_nll_loss', torch.rand(S, 2), (torch.rand(S, 2), True, True), 'full'),
  209. ('kl_div', F.log_softmax(torch.randn(S, 10), 1), (F.softmax(torch.randn(S, 10), 1),),),
  210. ('cross_entropy', (3, S), (torch.randint(S, (3,), dtype=torch.int64),),),
  211. ('binary_cross_entropy_with_logits', (3,), (torch.empty(3).random_(2), ),),
  212. ('smooth_l1_loss', (3, S), (non_differentiable(torch.rand(3, S)),),),
  213. ('huber_loss', (3, S), (non_differentiable(torch.rand(3, S)),),),
  214. ('l1_loss', (3, S), (non_differentiable(torch.rand(3, S)),),),
  215. ('mse_loss', (3, S), (non_differentiable(torch.rand(3, S)),),),
  216. ('smooth_l1_loss', (3, S), ((torch.rand(3, S)),), 'with_grad'),
  217. ('huber_loss', (3, S), ((torch.rand(3, S)),), 'with_grad'),
  218. ('l1_loss', (3, S), ((torch.rand(3, S)),), 'with_grad'),
  219. ('mse_loss', (3, S), ((torch.rand(3, S)),), 'with_grad'),
  220. ('margin_ranking_loss', (S,), ((S,), (S,)),),
  221. ('hinge_embedding_loss', (3, S), (non_differentiable(torch.rand(3, S)),),),
  222. ('soft_margin_loss', (3, S), (non_differentiable(torch.rand(3, S)),),),
  223. ('multilabel_soft_margin_loss', (3, S), (non_differentiable(torch.rand(3, S)),),),
  224. ('cosine_embedding_loss', (S, S), ((S, S), non_differentiable(torch.rand(S,))),),
  225. ('pixel_shuffle', (1, 9, 4, 4), (3,),),
  226. ('pixel_unshuffle', (1, 1, 12, 12), (3,),),
  227. ('affine_grid', (S, 2, 3), (torch.Size([S, 1, 7, 7]),),),
  228. ('pad', (3, 3, 4, 2), ([1, 1],),),
  229. ('pairwise_distance', (S, S), ((S, S),),),
  230. ('pdist', (S, S), (),),
  231. ('cosine_similarity', (S, S), ((S, S),),),
  232. ('triplet_margin_loss', (S, S), ((S, S), (S, S)),),
  233. ('normalize', (S, S, S), (),),
  234. ('unfold', (S, S, S, S), ([2, 3]),),
  235. ('fold', (1, 3 * 2 * 2, 12), ([4, 5], [2, 2]),),
  236. ('grid_sample', (S, S, S, S), (non_differentiable(torch.rand(S, S, S, 2)),),),
  237. ('gumbel_softmax', (S, S), (2.,), '', (True, ['aten::softmax', 'aten::add', 'aten::div'], ['aten::neg'])),
  238. ('gumbel_softmax', (S, S), (2., True,), 'hard', (True, ['aten::softmax', 'aten::add', 'aten::div'], ['aten::neg'])),
  239. ('multilabel_margin_loss', torch.tensor([[0.2, -0.2, 0.07]]), (torch.tensor([[0, 0, 1]]),),),
  240. ('multi_margin_loss', (S, S), (non_differentiable(torch.randint(S, (S, ), dtype=torch.int64)),
  241. 1, 1., non_differentiable(torch.randn(S))),),
  242. ('binary_cross_entropy', torch.randn(3, 2).sigmoid(), (non_differentiable(torch.rand(3, 2)),
  243. non_differentiable(torch.randn(3, 2))),),
  244. ('binary_cross_entropy', torch.randn(3, 2).sigmoid(),
  245. (non_differentiable(torch.rand(3, 2)),
  246. non_differentiable(torch.randn(3, 2)), None, None, 'mean'), 'size_average'),
  247. ('ctc_loss', torch.rand(S, S, S).log_softmax(2).detach().requires_grad_(),
  248. (torch.randint(1, S, (S, S), dtype=torch.long), torch.full((S,), S, dtype=torch.long),
  249. torch.randint(1, S, (S,), dtype=torch.long))),
  250. ('upsample', torch.randn(S, S, M, M), (None, 2.), 'with_scale'),
  251. ('upsample', torch.randn(S, S, M, M), (4,), 'with_size'),
  252. ('interpolate', torch.zeros(3, 3).view(1, 1, 3, 3), (2,), 'nearest_4d'),
  253. ('interpolate', torch.randn(S, S, M, M), (None, 2.), 'nearest_4d_with_scale'),
  254. ('interpolate', torch.randn(S, S, M, M), (4,), 'nearest_4d_with_size'),
  255. ('interpolate', torch.zeros(3, 3).view(1, 1, 3, 3), (2,), 'area_4d'),
  256. ('interpolate', torch.randn(S, S, M, M), (None, 2.), 'area_4d_with_scale'),
  257. ('interpolate', torch.randn(S, S, M, M), (4,), 'area_4d_with_size'),
  258. ('interpolate', torch.zeros(3, 3).view(1, 1, 3, 3), (2,), 'bilinear_4d'),
  259. ('interpolate', torch.randn(S, S, M, M), (None, 2.), 'bilinear_4d_with_scale'),
  260. ('interpolate', torch.randn(S, S, M, M), (4,), 'bilinear_4d_with_size'),
  261. ('interpolate', torch.zeros(3, 3).view(1, 1, 3, 3), (2,), 'bicubic_4d'),
  262. ('interpolate', torch.randn(S, S, M, M), (None, 2.), 'bicubic_4d_with_scale'),
  263. ('interpolate', torch.randn(S, S, M, M), (4,), 'bicubic_4d_with_size'),
  264. ('interpolate', torch.zeros(3, 3).view(1, 3, 3), (2,), 'nearest_3d'),
  265. ('interpolate', torch.randn(S, M, M), (None, 2.), 'nearest_3d_with_scale'),
  266. ('interpolate', torch.randn(S, M, M), (4,), 'nearest_3d_with_size'),
  267. ('interpolate', torch.zeros(3, 3).view(1, 3, 3), (2,), 'area_3d'),
  268. ('interpolate', torch.randn(S, M, M), (None, 2.), 'area_3d_with_scale'),
  269. ('interpolate', torch.randn(S, M, M), (4,), 'area_3d_with_size'),
  270. ('interpolate', torch.zeros(3, 3).view(1, 3, 3), (2,), 'linear_3d'),
  271. ('interpolate', torch.randn(S, M, M), (None, 2.), 'linear_3d_with_scale'),
  272. ('interpolate', torch.randn(S, M, M), (4,), 'linear_3d_with_size'),
  273. ('interpolate', torch.randn(S, M, M, M, M), (None, 2.), 'nearest_5d_with_scale'),
  274. ('interpolate', torch.randn(S, M, M, M, M), (4,), 'nearest_5d_with_size'),
  275. ('interpolate', torch.zeros(3, 3, 3).view(1, 1, 3, 3, 3), (2,), 'area_5d'),
  276. ('interpolate', torch.randn(S, M, M, M, M), (None, 2.), 'area_5d_with_scale'),
  277. ('interpolate', torch.randn(S, M, M, M, M), (4,), 'area_5d_with_size'),
  278. ('interpolate', torch.zeros(3, 3, 3).view(1, 1, 3, 3, 3), (2,), 'trilinear_5d'),
  279. ('interpolate', torch.randn(S, M, M, M, M), (None, 2.), 'trilinear_5d_with_scale'),
  280. ('interpolate', torch.randn(S, M, M, M, M), (4,), 'trilinear_5d_with_size'),
  281. ('interpolate', torch.zeros(3, 3).view(1, 1, 3, 3), (2, None, 'nearest', None, False),
  282. 'nearest_4d_not_recompute_scale_factor'),
  283. ('interpolate', torch.randn(S, S, M, M), (4, None, 'nearest', None, False),
  284. 'nearest_4d_with_size_not_recompute_scale_factor'),
  285. ('interpolate', torch.randn(S, S, M, M), (None, 2., 'bilinear', None, False),
  286. 'bilinear_4d_with_scale_not_recompute_scale_factor'),
  287. ('interpolate', torch.randn(S, S, M, M), (4, None, 'bilinear', None, False),
  288. 'bilinear_4d_with_size_not_recompute_scale_factor'),
  289. ('interpolate', torch.randn(S, S, M, M), (None, 2., 'bicubic', None, False),
  290. 'bicubic_4d_with_scale_not_recompute_scale_factor'),
  291. ('interpolate', torch.randn(S, S, M, M), (4, None, 'bicubic', None, False),
  292. 'bicubic_4d_with_size_not_recompute_scale_factor'),
  293. ('interpolate', torch.randn(S, M, M), (None, 2., 'nearest', None, False),
  294. 'nearest_3d_with_scale_not_recompute_scale_factor'),
  295. ('interpolate', torch.randn(S, M, M), (4, None, 'nearest', None, False),
  296. 'nearest_3d_with_size_not_recompute_scale_factor'),
  297. ('interpolate', torch.randn(S, M, M), (None, 2., 'linear', None, False),
  298. 'linear_3d_with_scale_not_recompute_scale_factor'),
  299. ('interpolate', torch.randn(S, M, M), (4, None, 'linear', None, False),
  300. 'linear_3d_with_size_not_recompute_scale_factor'),
  301. ('interpolate', torch.randn(S, M, M, M, M), (None, 2., 'nearest', None, False),
  302. 'nearest_5d_with_scale_not_recompute_scale_factor'),
  303. ('interpolate', torch.randn(S, M, M, M, M), (4, None, 'nearest', None, False),
  304. 'nearest_5d_with_size_not_recompute_scale_factor'),
  305. ('interpolate', torch.randn(S, M, M, M, M), (None, 2., 'trilinear', None, False),
  306. 'trilinear_5d_with_scale_not_recompute_scale_factor'),
  307. ('interpolate', torch.randn(S, M, M, M, M), (4, None, 'trilinear', None, False),
  308. 'trilinear_5d_with_size_not_recompute_scale_factor'),
  309. ]
  310. script_template = '''
  311. def the_method({}):
  312. return {}
  313. '''
  314. def value_to_literal(value):
  315. if isinstance(value, str):
  316. # Quotes string and escapes special characters
  317. return ascii(value)
  318. if isinstance(value, torch.Tensor):
  319. return 'torch.' + str(value)
  320. else:
  321. return str(value)
  322. def get_call(method_name, func_type, args, kwargs):
  323. kwargs_str = ', '.join([k + '=' + value_to_literal(v) for k, v in kwargs.items()])
  324. self_arg = args[0]
  325. if(func_type == 'method'):
  326. args = args[1:]
  327. argument_str = ', '.join(args)
  328. argument_str += ', ' if len(args) and len(kwargs) else ''
  329. argument_str += kwargs_str
  330. if func_type == 'functional' or func_type == 'function':
  331. call = 'torch.{}({})'.format(method_name, argument_str)
  332. elif func_type == 'method':
  333. call = '{}.{}({})'.format(self_arg, method_name, argument_str)
  334. elif func_type == 'nn_functional':
  335. call = 'torch.nn.functional.{}({})'.format(method_name, argument_str)
  336. else:
  337. raise TypeError('Unsupported function type')
  338. return call
  339. def get_constant(x):
  340. if x == inf:
  341. return 'math.inf'
  342. if x == -inf:
  343. return '-math.inf'
  344. return x
  345. def get_script_args(args):
  346. formals: List[str] = []
  347. tensors: List[Union[torch.Tensor, List[torch.Tensor]]] = []
  348. actuals: List[str] = []
  349. for arg in args:
  350. if isinstance(arg, torch.Tensor):
  351. name = 'i{}'.format(len(formals))
  352. formals.append(name)
  353. actuals.append(name)
  354. tensors.append(arg)
  355. elif is_iterable_of_tensors(arg):
  356. name = 'i{}'.format(len(formals))
  357. formals.append(name + ': List[torch.Tensor]')
  358. actuals.append(name)
  359. tensors.append(list(arg))
  360. elif isinstance(arg, str):
  361. actuals.append("'{}'".format(arg))
  362. else:
  363. actuals.append(str(get_constant(arg)))
  364. return (formals, tensors, actuals)
  365. # create a script function from (name, func_type, output_process_fn),
  366. # and returns the compiled function and example inputs
  367. def gen_script_fn_and_args(method_name, func_type, *args, **kwargs):
  368. formals, tensors, actuals = get_script_args(args)
  369. call = get_call(method_name, func_type, actuals, kwargs)
  370. script = script_template.format(', '.join(formals), call)
  371. CU = torch.jit.CompilationUnit(script)
  372. return CU.the_method, tensors
  373. # create a script function from (name, func_type),
  374. # returns a function takes in (args, kwargs) and runs the compiled function
  375. def create_script_fn(self, method_name, func_type):
  376. # function returns tuple containing original output and
  377. # filtered output to be used in checking gradients
  378. def script_fn(*args, **kwargs):
  379. fn, tensors = gen_script_fn_and_args(method_name, func_type, *args, **kwargs)
  380. self.assertExportImport(fn.graph, tensors)
  381. output = fn(*tensors)
  382. # skip type annotate function attributes for now, see: https://github.com/python/mypy/issues/2087
  383. script_fn.last_graph = fn.graph_for(*tensors) # type: ignore[attr-defined]
  384. return output
  385. return script_fn
  386. class SplitInputs():
  387. all_tensors: List[Any]
  388. tensor_args: List[Any]
  389. nontensor_args: List[Any]
  390. arg_types: List[str]
  391. tensor_kwargs: Dict[str, Any]
  392. kwarg_order: List[str]
  393. nontensor_kwargs: Dict[str, Any]
  394. kwarg_types: Dict[str, Any]
  395. @staticmethod
  396. def _is_tensor_input(arg):
  397. return isinstance(arg, torch.Tensor) or is_iterable_of_tensors(arg)
  398. def __init__(self, args, kwargs):
  399. self.arg_types = ['t' if self._is_tensor_input(arg) else 's' for arg in args]
  400. self.kwarg_types = {k: 't' if self._is_tensor_input(v) else 's' for k, v in kwargs.items()}
  401. self.tensor_args = [arg for arg in args if self._is_tensor_input(arg)]
  402. self.nontensor_args = [arg for arg in args if not self._is_tensor_input(arg)]
  403. self.tensor_kwargs = {k: v for k, v in kwargs.items() if self._is_tensor_input(v)}
  404. self.nontensor_kwargs = {k: v for k, v in kwargs.items() if not self._is_tensor_input(v)}
  405. self.all_tensors = [*self.tensor_args, *[v for k, v in self.tensor_kwargs.items()]]
  406. self.kwarg_order = [k for k, v in kwargs.items()]
  407. def nontensors_match(self, other: 'SplitInputs'):
  408. if self.arg_types != other.arg_types:
  409. return False
  410. if self.kwarg_types != other.kwarg_types:
  411. return False
  412. if self.kwarg_order != other.kwarg_order:
  413. return False
  414. if self.nontensor_args != other.nontensor_args:
  415. return False
  416. if self.nontensor_kwargs != other.nontensor_kwargs:
  417. return False
  418. return True
  419. # make a new function where all non-tensor arguments in 'args' have been partially
  420. # applied, and all tensor arguments remain.
  421. # used to trace functions when some arguments are not tensors
  422. def partial_apply_nontensors(fn, args, kwargs):
  423. inputs = SplitInputs(args, kwargs)
  424. def new_fn(*tensors_):
  425. tensors = iter(tensors_)
  426. full_args = [args[i] if s == 's' else next(tensors) for i, s in enumerate(inputs.arg_types)]
  427. full_kwargs = {k: kwargs[k] if s == 's' else next(tensors) for k, s in inputs.kwarg_types.items()}
  428. return fn(*full_args, **full_kwargs)
  429. return new_fn, inputs
  430. # create a trace function from input fn
  431. def create_traced_fn(self, fn, cache_traced_fn=False):
  432. def traced_fn(*inputs, **kwargs):
  433. # `check_trace` is set to False because check_trace is run with @no_grad
  434. # Also, `check_against_reference` already does all the checks
  435. # against python function
  436. fn_tensors, split_inputs = partial_apply_nontensors(fn, inputs, kwargs)
  437. if not cache_traced_fn or not hasattr(traced_fn, 'traced'):
  438. traced = torch.jit.trace(fn_tensors, split_inputs.all_tensors, check_trace=False)
  439. self.assertExportImport(traced.graph, split_inputs.all_tensors)
  440. output = traced(*split_inputs.all_tensors)
  441. if cache_traced_fn:
  442. traced_fn.traced = traced
  443. traced_fn.split_inputs = split_inputs
  444. else:
  445. # Guard to check that nontensor inputs are the same as during tracing
  446. self.assertTrue(traced_fn.split_inputs.nontensors_match(split_inputs))
  447. output = traced_fn.traced(*split_inputs.all_tensors)
  448. traced = traced_fn.traced
  449. # skip type annotate function attributes for now, see: https://github.com/python/mypy/issues/2087
  450. traced_fn.last_graph = traced.graph_for(*split_inputs.all_tensors) # type: ignore[attr-defined]
  451. traced_fn.graph = traced.graph # type: ignore[attr-defined]
  452. return output
  453. return traced_fn
  454. # known to be failing in script
  455. EXCLUDE_SCRIPT = {
  456. 'test_norm_fro_default',
  457. 'test_norm_fro_cpu',
  458. 'test_norm_nuc',
  459. 'test_norm_fro',
  460. 'test_norm_nuc_batched',
  461. # aten op has additional cudnn argument
  462. 'test_nn_unfold',
  463. # flaky test - TODO fix
  464. 'test_nn_ctc_loss',
  465. # unknown builtin op
  466. 'test_nn_fold',
  467. # jit doesn't support sparse tensors.
  468. 'test_to_sparse',
  469. 'test_to_sparse_dim',
  470. }
  471. # generates a script function and set of example inputs
  472. # from a specified test in the format of nn_functional_tests
  473. def get_nn_functional_compiled_fn_and_inputs(name, self_size, args, variant_name='', *extra_args):
  474. test_name = 'test_nn_' + name
  475. if variant_name != '':
  476. test_name = test_name + '_' + variant_name
  477. no_grad = variant_name == 'inplace'
  478. self_variable = create_input((self_size,))[0][0]
  479. kwargs = None
  480. # need to record this because methods can change the size (e.g. unsqueeze)
  481. args_variable, kwargs_variable = create_input(args)
  482. self_tensor = deepcopy(self_variable.data)
  483. args_tensor = deepcopy(unpack_variables(args_variable))
  484. f_args_variable = (self_variable,) + args_variable
  485. f_args_tensor = (self_tensor,) + args_tensor
  486. with torch._jit_internal._disable_emit_hooks():
  487. script_fn, inputs = gen_script_fn_and_args(name, "nn_functional", *f_args_variable)
  488. return script_fn, inputs
  489. # additional modules test
  490. # TODO: delete this list once we make all nn_tests work
  491. additional_module_tests = [
  492. {
  493. 'module_name': 'Bilinear',
  494. 'constructor_args': (S, S, M),
  495. 'input_size': (S, S),
  496. 'extra_args': ((S, S),)
  497. },
  498. {
  499. 'module_name': 'RNNCell',
  500. 'constructor_args': (S, S),
  501. 'input_size': (S, S),
  502. },
  503. {
  504. 'module_name': 'LSTMCell',
  505. 'constructor_args': (S, S),
  506. 'input_size': (S, S),
  507. },
  508. {
  509. 'module_name': 'GRUCell',
  510. 'constructor_args': (S, S),
  511. 'input_size': (S, S),
  512. },
  513. {
  514. 'module_name': 'MultiheadAttention',
  515. 'constructor_args': (128, 8),
  516. 'input_size': (10, 8, 128),
  517. 'extra_args': (torch.randn(10, 8, 128), torch.randn(10, 8, 128)),
  518. 'slowTest': True
  519. },
  520. {
  521. 'module_name': 'Transformer',
  522. 'constructor_args': (1, 1, 1, 1, 2),
  523. 'input_size': (3, 1, 1),
  524. 'extra_args': (torch.randn(1, 1, 1),),
  525. 'slowTest': True
  526. }
  527. ]
  528. EXCLUDE_SCRIPT_MODULES = {
  529. 'test_nn_AdaptiveAvgPool2d_tuple_none',
  530. 'test_nn_AdaptiveAvgPool3d_tuple_none',
  531. 'test_nn_AdaptiveMaxPool2d_tuple_none',
  532. 'test_nn_AdaptiveMaxPool3d_tuple_none',
  533. # Doesn't use future division, so this is not supported
  534. 'test_nn_CrossMapLRN2d',
  535. }
  536. script_method_template = '''
  537. def forward({}):
  538. return {}
  539. '''
  540. def create_script_module(self, nn_module, constructor_args, *args, **kwargs):
  541. def script_module(*args, **kwargs):
  542. formals, tensors, actuals = get_script_args(args)
  543. method_args = ', '.join(['self'] + actuals)
  544. call_args_str = ', '.join(actuals)
  545. call = "self.submodule({})".format(call_args_str)
  546. script = script_method_template.format(method_args, call)
  547. submodule_constants = []
  548. if kwargs.get('is_constant'):
  549. submodule_constants = ['submodule']
  550. # Create module to use the script method
  551. class TheModule(torch.jit.ScriptModule):
  552. __constants__ = submodule_constants
  553. def __init__(self):
  554. super().__init__()
  555. self.submodule = nn_module(*constructor_args)
  556. def make_module(script):
  557. module = TheModule()
  558. # check __repr__
  559. str(module)
  560. module.define(script)
  561. return module
  562. module = make_module(script)
  563. if self:
  564. self.assertExportImportModule(module, tensors)
  565. module(*args)
  566. # skip type annotate function attributes for now, see: https://github.com/python/mypy/issues/2087
  567. create_script_module.last_graph = module.graph # type: ignore[attr-defined]
  568. return module
  569. return script_module
  570. def check_alias_annotation(method_name, args, kwargs, *, aten_name, func_type='method'):
  571. formals, tensors, actuals = get_script_args(args)
  572. call = get_call(method_name, func_type, actuals, kwargs)
  573. script = script_template.format(', '.join(formals), call)
  574. CU = torch.jit.CompilationUnit(script)
  575. # to clean up IR
  576. torch._C._jit_pass_inline(CU.the_method.graph)
  577. torch._C._jit_pass_constant_propagation(CU.the_method.graph)
  578. torch._C._jit_check_alias_annotation(CU.the_method.graph, tuple(tensors), aten_name)
  579. def get_nn_module_name_from_kwargs(**kwargs):
  580. if 'module_name' in kwargs:
  581. return kwargs['module_name']
  582. elif 'fullname' in kwargs:
  583. return kwargs['fullname']
  584. elif 'constructor' in kwargs:
  585. return kwargs['constructor'].__name__
  586. def get_nn_mod_test_name(**kwargs):
  587. if 'fullname' in kwargs:
  588. test_name = kwargs['fullname']
  589. else:
  590. test_name = get_nn_module_name_from_kwargs(**kwargs)
  591. if 'desc' in kwargs:
  592. test_name = "{}_{}".format(test_name, kwargs['desc'])
  593. return 'test_nn_{}'.format(test_name)
  594. def get_nn_module_class_from_kwargs(**kwargs):
  595. name = get_nn_module_name_from_kwargs(**kwargs)
  596. index = name.find("_")
  597. if index == -1:
  598. return name
  599. else:
  600. return name[0:name.find("_")]
  601. def try_get_nn_module_compiled_mod_and_inputs(*args, **kwargs):
  602. name = get_nn_module_name_from_kwargs(**kwargs)
  603. if 'desc' in kwargs and 'eval' in kwargs['desc']:
  604. # eval() is not supported, so skip these tests
  605. return
  606. test_name = name
  607. if 'desc' in kwargs:
  608. test_name = "{}_{}".format(test_name, kwargs['desc'])
  609. test_name = get_nn_mod_test_name(**kwargs)
  610. if test_name in EXCLUDE_SCRIPT_MODULES:
  611. return
  612. if 'constructor' in kwargs:
  613. nn_module = kwargs['constructor']
  614. else:
  615. nn_module = getattr(torch.nn, name)
  616. if "FunctionalModule" in str(nn_module):
  617. return
  618. if 'constructor_args_fn' in kwargs:
  619. constructor_args = kwargs['constructor_args_fn']()
  620. else:
  621. constructor_args = kwargs.get('constructor_args', ())
  622. # Set up inputs from tuple of sizes or constructor fn
  623. input_dtype = torch.double
  624. if 'input_fn' in kwargs:
  625. input = kwargs['input_fn']()
  626. if isinstance(input, torch.Tensor):
  627. input = (input,)
  628. if all(tensor.is_complex() for tensor in input):
  629. input_dtype = torch.cdouble
  630. else:
  631. input = (kwargs['input_size'],)
  632. # Extra parameters to forward()
  633. if 'extra_args' in kwargs:
  634. input = input + kwargs['extra_args']
  635. if 'target_size' in kwargs:
  636. input = input + (kwargs['target_size'],)
  637. elif 'target_fn' in kwargs:
  638. if torch.is_tensor(input):
  639. input = (input,)
  640. input = input + (kwargs['target_fn'](),)
  641. args_variable, kwargs_variable = create_input(input, dtype=input_dtype)
  642. f_args_variable = deepcopy(unpack_variables(args_variable))
  643. out_var = deepcopy(f_args_variable)
  644. args, mod = f_args_variable, create_script_module(None, nn_module, constructor_args, *f_args_variable)(*f_args_variable)
  645. return mod, out_var
  646. def get_all_nn_module_tests():
  647. return module_tests + new_module_tests + additional_module_tests