common_modules.py 78 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550
  1. import torch
  2. import unittest
  3. from copy import deepcopy
  4. from enum import Enum
  5. from functools import wraps, partial
  6. from itertools import chain, product
  7. import itertools
  8. import torch.nn.functional as F
  9. from torch.nn.utils.rnn import pack_padded_sequence
  10. from torch.testing import make_tensor
  11. from torch.testing._internal.common_cuda import TEST_CUDNN
  12. from torch.testing._internal.common_dtype import floating_types, floating_and_complex_types_and
  13. from torch.testing._internal.common_device_type import (
  14. _TestParametrizer, _update_param_kwargs, toleranceOverride, tol,
  15. skipCUDAIfCudnnVersionLessThan, skipCUDAIfRocm, precisionOverride, skipMeta, skipCUDAVersionIn)
  16. from torch.testing._internal.common_methods_invocations import DecorateInfo
  17. from torch.testing._internal.common_nn import nllloss_reference, get_reduction
  18. from torch.testing._internal.common_utils import (
  19. freeze_rng_state, set_single_threaded_if_parallel_tbb, skipIfMps, GRADCHECK_NONDET_TOL, TEST_WITH_ROCM)
  20. from types import ModuleType
  21. from typing import List, Tuple, Type, Set, Dict
  22. # List of all namespaces containing modules to test.
  23. MODULE_NAMESPACES: List[ModuleType] = [
  24. torch.nn.modules,
  25. torch.ao.nn.qat.modules,
  26. torch.ao.nn.quantizable.modules,
  27. torch.ao.nn.quantized.modules,
  28. torch.ao.nn.quantized.modules,
  29. ]
  30. # Modules that shouldn't be tested for one reason or another.
  31. MODULES_TO_SKIP: Set[Type] = {
  32. torch.nn.Module, # abstract base class
  33. torch.nn.Container, # deprecated
  34. torch.nn.NLLLoss2d, # deprecated
  35. torch.ao.nn.quantized.MaxPool2d, # aliases to nn.MaxPool2d
  36. torch.ao.nn.quantized.MaxPool2d, # aliases to nn.MaxPool2d
  37. }
  38. # List of all module classes to test.
  39. MODULE_CLASSES: List[Type] = list(chain(*[
  40. [getattr(namespace, module_name) for module_name in namespace.__all__] # type: ignore[attr-defined]
  41. for namespace in MODULE_NAMESPACES]))
  42. MODULE_CLASSES = [cls for cls in MODULE_CLASSES if cls not in MODULES_TO_SKIP]
  43. # Dict of module class -> common name. Useful for making test names more intuitive.
  44. # Example: torch.nn.modules.linear.Linear -> "nn.Linear"
  45. MODULE_CLASS_NAMES: Dict[Type, str] = {}
  46. for namespace in MODULE_NAMESPACES:
  47. for module_name in namespace.__all__: # type: ignore[attr-defined]
  48. module_cls = getattr(namespace, module_name)
  49. namespace_name = namespace.__name__.replace('torch.', '').replace('.modules', '')
  50. # Deal with any aliases by preferring earlier names.
  51. if module_cls not in MODULE_CLASS_NAMES:
  52. MODULE_CLASS_NAMES[module_cls] = f'{namespace_name}.{module_name}'
  53. # Specifies the modes (i.e. train, eval) to test over.
  54. TrainEvalMode = Enum('TrainEvalMode', ('train_only', 'eval_only', 'train_and_eval'))
  55. class modules(_TestParametrizer):
  56. """ PROTOTYPE: Decorator for specifying a list of modules over which to run a test. """
  57. def __init__(self, module_info_iterable, allowed_dtypes=None, train_eval_mode=TrainEvalMode.train_and_eval):
  58. self.module_info_list = list(module_info_iterable)
  59. self.allowed_dtypes = set(allowed_dtypes) if allowed_dtypes is not None else None
  60. self.train_eval_mode = train_eval_mode
  61. def _get_training_flags(self, module_info):
  62. training_flags = []
  63. if (self.train_eval_mode == TrainEvalMode.train_only or
  64. self.train_eval_mode == TrainEvalMode.train_and_eval):
  65. training_flags.append(True)
  66. if (self.train_eval_mode == TrainEvalMode.eval_only or
  67. self.train_eval_mode == TrainEvalMode.train_and_eval):
  68. training_flags.append(False)
  69. # If train and eval modes don't differ for the module, don't bother using more than one.
  70. if not module_info.train_and_eval_differ:
  71. training_flags = training_flags[:1]
  72. return training_flags
  73. def _parametrize_test(self, test, generic_cls, device_cls):
  74. if device_cls is None:
  75. raise RuntimeError('The @modules decorator is only intended to be used in a device-specific '
  76. 'context; use it with instantiate_device_type_tests() instead of '
  77. 'instantiate_parametrized_tests()')
  78. for module_info in self.module_info_list:
  79. dtypes = set(module_info.dtypes)
  80. if self.allowed_dtypes is not None:
  81. dtypes = dtypes.intersection(self.allowed_dtypes)
  82. training_flags = self._get_training_flags(module_info)
  83. for (training, dtype) in product(training_flags, dtypes):
  84. # Construct the test name; device / dtype parts are handled outside.
  85. # See [Note: device and dtype suffix placement]
  86. test_name = module_info.formatted_name
  87. if len(training_flags) > 1:
  88. test_name += f"_{'train_mode' if training else 'eval_mode'}"
  89. # Construct parameter kwargs to pass to the test.
  90. param_kwargs = {'module_info': module_info}
  91. _update_param_kwargs(param_kwargs, 'dtype', dtype)
  92. _update_param_kwargs(param_kwargs, 'training', training)
  93. try:
  94. @wraps(test)
  95. def test_wrapper(*args, **kwargs):
  96. return test(*args, **kwargs)
  97. decorator_fn = partial(module_info.get_decorators, generic_cls.__name__,
  98. test.__name__, device_cls.device_type, dtype)
  99. yield (test_wrapper, test_name, param_kwargs, decorator_fn)
  100. except Exception as ex:
  101. # Provides an error message for debugging before rethrowing the exception
  102. print("Failed to instantiate {0} for module {1}!".format(test_name, module_info.name))
  103. raise ex
  104. def get_module_common_name(module_cls):
  105. if module_cls in MODULE_CLASS_NAMES:
  106. # Example: "nn.Linear"
  107. return MODULE_CLASS_NAMES[module_cls]
  108. else:
  109. return module_cls.__name__
  110. class FunctionInput:
  111. """ Contains args and kwargs to pass as input to a function. """
  112. __slots__ = ['args', 'kwargs']
  113. def __init__(self, *args, **kwargs):
  114. self.args = args
  115. self.kwargs = kwargs
  116. class ModuleInput:
  117. """ Contains args / kwargs for module instantiation + forward pass. """
  118. __slots__ = ['constructor_input', 'forward_input', 'desc', 'reference_fn']
  119. def __init__(self, constructor_input, forward_input=None, desc='', reference_fn=None):
  120. self.constructor_input = constructor_input # Inputs to pass during construction
  121. self.forward_input = forward_input # Inputs to pass to forward()
  122. self.desc = desc # Description for this set of inputs
  123. self.reference_fn = reference_fn # Reference with signature: reference_fn(module, parameters, *args, **kwargs)
  124. if reference_fn is not None:
  125. @wraps(reference_fn)
  126. def copy_reference_fn(m, *args, **kwargs):
  127. # Copy inputs to avoid undesired side effects from calling the reference.
  128. args, kwargs = deepcopy(args), deepcopy(kwargs)
  129. # Note that module parameters are passed in for convenience.
  130. return reference_fn(m, list(m.parameters()), *args, **kwargs)
  131. self.reference_fn = copy_reference_fn
  132. class ModuleInfo:
  133. """ Module information to be used in testing. """
  134. def __init__(self,
  135. module_cls, # Class object for the module under test
  136. *,
  137. module_inputs_func, # Function to generate module inputs
  138. skips=(), # Indicates which tests to skip
  139. decorators=None, # Additional decorators to apply to generated tests
  140. dtypes=floating_types(), # dtypes this function is expected to work with
  141. supports_gradgrad=True, # whether the op supports second order gradients
  142. gradcheck_nondet_tol=0.0, # tolerance for nondeterminism while performing gradcheck
  143. module_memformat_affects_out=False, # whether converting module to channels last will generate
  144. # channels last output
  145. train_and_eval_differ=False, # whether the module has differing behavior between train and eval
  146. ):
  147. self.module_cls = module_cls
  148. self.module_inputs_func = module_inputs_func
  149. self.decorators = (*(decorators if decorators else []), *(skips if skips else []))
  150. self.dtypes = dtypes
  151. self.supports_gradgrad = supports_gradgrad
  152. self.gradcheck_nondet_tol = gradcheck_nondet_tol
  153. self.module_memformat_affects_out = module_memformat_affects_out
  154. self.train_and_eval_differ = train_and_eval_differ
  155. def get_decorators(self, test_class, test_name, device, dtype, param_kwargs):
  156. result = [set_single_threaded_if_parallel_tbb]
  157. for decorator in self.decorators:
  158. if isinstance(decorator, DecorateInfo):
  159. if decorator.is_active(test_class, test_name, device, dtype, param_kwargs):
  160. result.extend(decorator.decorators)
  161. else:
  162. result.append(decorator)
  163. return result
  164. @property
  165. def name(self):
  166. return get_module_common_name(self.module_cls)
  167. @property
  168. def formatted_name(self):
  169. return self.name.replace('.', '_')
  170. def module_inputs_torch_nn_Linear(module_info, device, dtype, requires_grad, training, **kwargs):
  171. make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
  172. module_inputs = [
  173. ModuleInput(constructor_input=FunctionInput(10, 8),
  174. forward_input=FunctionInput(input=make_input((4, 10))),
  175. reference_fn=lambda m, p, input: torch.mm(input, p[0].t()) + p[1].view(1, -1).expand(4, 8)),
  176. ModuleInput(constructor_input=FunctionInput(10, 8, bias=False),
  177. forward_input=FunctionInput(make_input((4, 10))),
  178. desc='no_bias',
  179. reference_fn=lambda m, p, i: torch.mm(i, p[0].t())),
  180. ModuleInput(constructor_input=FunctionInput(3, 5),
  181. forward_input=FunctionInput(make_input(3)),
  182. desc='no_batch_dim',
  183. reference_fn=lambda m, p, i: torch.mm(i.view(1, -1), p[0].t()).view(-1) + p[1])
  184. ]
  185. return module_inputs
  186. def module_inputs_torch_nn_Bilinear(module_info, device, dtype, requires_grad, training, **kwargs):
  187. make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
  188. def bilinear_reference_fn(m, p, x1, x2, bias=True):
  189. result = torch.einsum('bn,anm,bm->ba', x1, p[0], x2)
  190. if bias:
  191. if x1.shape[0] == 1:
  192. result = result.view(-1) + p[1]
  193. else:
  194. result = result + p[1].view(1, -1).expand(x1.shape[0], p[0].shape[0])
  195. return result
  196. module_inputs = [
  197. ModuleInput(constructor_input=FunctionInput(2, 3, 4),
  198. forward_input=FunctionInput(make_input((8, 2)), make_input((8, 3))),
  199. reference_fn=lambda m, p, x1, x2: bilinear_reference_fn(m, p, x1, x2)),
  200. ModuleInput(constructor_input=FunctionInput(2, 3, 4, bias=False),
  201. forward_input=FunctionInput(make_input((8, 2)), make_input((8, 3))),
  202. desc='no_bias',
  203. reference_fn=lambda m, p, x1, x2: bilinear_reference_fn(m, p, x1, x2, bias=False)),
  204. ModuleInput(constructor_input=FunctionInput(2, 3, 4),
  205. forward_input=FunctionInput(make_input((2)), make_input((3))),
  206. desc='no_batch_dim',
  207. reference_fn=lambda m, p, x1, x2: bilinear_reference_fn(m, p, x1.view(1, -1), x2.view(1, -1))),
  208. ]
  209. return module_inputs
  210. def module_inputs_torch_nn_NLLLoss(module_info, device, dtype, requires_grad, training, **kwargs):
  211. def make_input(shape, device=device, dtype=dtype, requires_grad=requires_grad):
  212. return make_tensor(shape, device=device, dtype=dtype,
  213. requires_grad=False).log_softmax(dim=1).requires_grad_(requires_grad)
  214. make_weight = partial(make_tensor, device=device, dtype=dtype, requires_grad=False)
  215. cases: List[Tuple[str, dict]] = [
  216. ('', {}),
  217. ('reduction_sum', {'reduction': 'sum'}),
  218. ('reduction_none', {'reduction': 'none'}),
  219. ('ignore_index', {'ignore_index': 2}),
  220. ('weights', {'weight': make_weight(10).abs()}),
  221. ('weights_ignore_index', {'weight': make_weight(10).abs(), 'ignore_index': 2}),
  222. ('weights_ignore_index_neg', {'weight': make_weight(10).abs(), 'ignore_index': -1})
  223. ]
  224. # TODO: Uncomment when negative weights is supported.
  225. # negative_weight = make_weight(10)
  226. # negative_weight[0] = -1
  227. # cases.append(('weights_negative', {'weight': negative_weight}))
  228. module_inputs = []
  229. for desc, constructor_kwargs in cases:
  230. def reference_fn(m, p, i, t, constructor_kwargs=constructor_kwargs):
  231. return nllloss_reference(i, t, **constructor_kwargs)
  232. module_inputs.append(
  233. ModuleInput(constructor_input=FunctionInput(**constructor_kwargs),
  234. forward_input=FunctionInput(make_input((15, 10)),
  235. torch.empty(15, device=device).uniform_().mul(10).floor().long()),
  236. desc=desc,
  237. reference_fn=reference_fn)
  238. )
  239. return module_inputs
  240. def module_inputs_torch_nn_GaussianNLLLoss(module_info, device, dtype, requires_grad, training, **kwargs):
  241. make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
  242. make_target = partial(make_tensor, device=device, dtype=dtype, requires_grad=False)
  243. cases: List[Tuple[str, dict]] = [
  244. ('', {}),
  245. ('reduction_sum', {'reduction': 'sum'}),
  246. ('reduction_mean', {'reduction': 'mean'}),
  247. ('reduction_none', {'reduction': 'none'}),
  248. ]
  249. module_inputs = []
  250. for desc, constructor_kwargs in cases:
  251. module_inputs.append(
  252. ModuleInput(constructor_input=FunctionInput(**constructor_kwargs),
  253. forward_input=FunctionInput(make_input((3)),
  254. make_target((3)),
  255. make_input((1)).abs()),
  256. desc=desc,
  257. reference_fn=no_batch_dim_reference_fn)
  258. )
  259. return module_inputs
  260. def no_batch_dim_reference_fn(m, p, *args, **kwargs):
  261. """Reference function for modules supporting no batch dimensions.
  262. Unbatched inputs are unsqueezed to form a
  263. single batch input before passing them to the module.
  264. The output is squeezed to compare with the
  265. output of unbatched input to the module.
  266. Currently it only supports modules which return a single Tensor as output.
  267. You can bind the following kwargs.
  268. Kwargs:
  269. batch_first[bool] : If True, all the Tensors in `args` while be unsqueezed at dim `0` .
  270. and output will be squeezed at dim `0` else dim `1` for both.
  271. kwargs_to_batchify[dict] : Dictionary specifying the name of the argument and dimension to unsqueeze.
  272. Useful if there are few arguments whose batch dimension are different
  273. from the ones selected by `batch_first`.
  274. is_criterion[bool] : Specify if the module is a criterion and handle the reduction for output accordingly.
  275. """
  276. def get_and_pop(key, default):
  277. v = kwargs.get(key, default)
  278. if key in kwargs:
  279. kwargs.pop(key)
  280. return v
  281. batch_dim = 0 if get_and_pop('batch_first', True) else 1
  282. kwargs_to_batchify = get_and_pop('kwargs_to_batchify', None)
  283. is_criterion = get_and_pop('is_criterion', False)
  284. if kwargs_to_batchify is not None:
  285. assert isinstance(kwargs_to_batchify, dict)
  286. for k, v in kwargs.items():
  287. if k in kwargs_to_batchify and v is not None:
  288. bdim = kwargs_to_batchify[k]
  289. kwargs[k] = v.unsqueeze(bdim)
  290. single_batch_input_args = [input.unsqueeze(batch_dim) for input in args]
  291. with freeze_rng_state():
  292. output = m(*single_batch_input_args, **kwargs).squeeze(batch_dim)
  293. if is_criterion:
  294. reduction = get_reduction(m)
  295. if reduction == 'none':
  296. return output.squeeze(0)
  297. return output
  298. def no_batch_dim_reference_mha(m, p, *args, **kwargs):
  299. """Reference function for MultiheadAttention supporting no batch dimensions.
  300. Unbatched inputs are unsqueezed to form a
  301. single batch input before passing them to the module.
  302. The output is squeezed to compare with the
  303. output of unbatched input to the module.
  304. """
  305. batch_dim = 0 if kwargs.get('batch_first', True) else 1
  306. if 'batch_first' in kwargs:
  307. kwargs.pop('batch_first')
  308. if 'key_padding_mask' in kwargs and kwargs['key_padding_mask'] is not None:
  309. kwargs['key_padding_mask'] = kwargs['key_padding_mask'].unsqueeze(0)
  310. single_batch_input_args = [input.unsqueeze(batch_dim) for input in args]
  311. with freeze_rng_state():
  312. output = m(*single_batch_input_args, **kwargs)
  313. return (output[0].squeeze(batch_dim), output[1].squeeze(0))
  314. def no_batch_dim_reference_rnn_gru(m, p, *args, **kwargs):
  315. """Reference function for RNN and GRU supporting no batch dimensions.
  316. Unbatched inputs are unsqueezed to form a
  317. single batch input before passing them to the module.
  318. The output is squeezed to compare with the
  319. output of unbatched input to the module.
  320. """
  321. if len(args) == 1:
  322. inp, = args
  323. h = None
  324. elif len(args) == 2:
  325. inp, h = args
  326. h = h.unsqueeze(1)
  327. batch_dim = 0 if kwargs['batch_first'] else 1
  328. kwargs.pop('batch_first')
  329. inp = inp.unsqueeze(batch_dim)
  330. single_batch_input_args = (inp, h)
  331. with freeze_rng_state():
  332. output = m(*single_batch_input_args, **kwargs)
  333. return (output[0].squeeze(batch_dim), output[1].squeeze(1))
  334. def no_batch_dim_reference_lstm(m, p, *args, **kwargs):
  335. """Reference function for LSTM supporting no batch dimensions.
  336. Unbatched inputs are unsqueezed to form a
  337. single batch input before passing them to the module.
  338. The output is squeezed to compare with the
  339. output of unbatched input to the module.
  340. """
  341. if len(args) == 1:
  342. inp, = args
  343. h = None
  344. elif len(args) == 2:
  345. inp, h = args
  346. h = (h[0].unsqueeze(1), h[1].unsqueeze(1))
  347. batch_dim = 0 if kwargs['batch_first'] else 1
  348. kwargs.pop('batch_first')
  349. inp = inp.unsqueeze(batch_dim)
  350. single_batch_input_args = (inp, h)
  351. with freeze_rng_state():
  352. output = m(*single_batch_input_args, **kwargs)
  353. return (output[0].squeeze(batch_dim), (output[1][0].squeeze(1), output[1][1].squeeze(1)))
  354. def no_batch_dim_reference_lstmcell(m, p, *args, **kwargs):
  355. """Reference function for LSTMCell supporting no batch dimensions.
  356. The module is passed the input and target in batched form with a single item.
  357. The output is squeezed to compare with the no-batch input.
  358. """
  359. inp, (h, c) = args
  360. single_batch_input_args = (inp.unsqueeze(0), (h.unsqueeze(0), c.unsqueeze(0)))
  361. with freeze_rng_state():
  362. output = m(*single_batch_input_args, **kwargs)
  363. return (output[0].squeeze(0), output[1].squeeze(0))
  364. def generate_regression_criterion_inputs(make_input):
  365. return [
  366. ModuleInput(
  367. constructor_input=FunctionInput(reduction=reduction),
  368. forward_input=FunctionInput(make_input((4, )), make_input(4,)),
  369. reference_fn=partial(no_batch_dim_reference_fn, is_criterion=True),
  370. desc='no_batch_dim_{}'.format(reduction)
  371. ) for reduction in ['none', 'mean', 'sum']]
  372. def module_inputs_torch_nn_AvgPool1d(module_info, device, dtype, requires_grad, training, **kwargs):
  373. make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
  374. return [
  375. ModuleInput(constructor_input=FunctionInput(kernel_size=2),
  376. forward_input=FunctionInput(make_input((3, 6))),
  377. desc='no_batch_dim',
  378. reference_fn=no_batch_dim_reference_fn)]
  379. def module_inputs_torch_nn_AdaptiveAvgPool2d(module_info, device, dtype, requires_grad, training, **kwargs):
  380. make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
  381. return [
  382. ModuleInput(constructor_input=FunctionInput(3,),
  383. forward_input=FunctionInput(make_input((1, 3, 5, 6))),
  384. desc='single')]
  385. def module_inputs_torch_nn_BatchNorm2d(module_info, device, dtype, requires_grad, training, **kwargs):
  386. make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
  387. return [
  388. ModuleInput(constructor_input=FunctionInput(3,),
  389. forward_input=FunctionInput(make_input((2, 3, 6, 6))))]
  390. def module_inputs_torch_nn_BatchNorm3d(module_info, device, dtype, requires_grad, training, **kwargs):
  391. make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
  392. return [
  393. ModuleInput(constructor_input=FunctionInput(3,),
  394. forward_input=FunctionInput(make_input((2, 3, 4, 4, 4))))]
  395. def module_inputs_torch_nn_ConvNd(module_info, device, dtype, requires_grad, training, **kwargs):
  396. N = kwargs['N']
  397. lazy = kwargs.get('lazy', False)
  398. transposed = kwargs.get('transposed', False)
  399. make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
  400. conv_kwargs_list = [{}] if transposed else [{}, {'padding': 'same'}]
  401. kernel_size, C_in, C_out = 3, 4, 5
  402. input_no_batch_shape = (C_in,) + tuple((i + 3 for i in range(N)))
  403. input_batch_shape = (2,) + input_no_batch_shape
  404. return [
  405. ModuleInput(constructor_input=(FunctionInput(C_out, kernel_size, **conv_kwargs) if lazy else
  406. FunctionInput(C_in, C_out, kernel_size, **conv_kwargs)),
  407. forward_input=FunctionInput(make_input(
  408. input_batch_shape if with_batch else input_no_batch_shape)),
  409. desc=('' if with_batch else 'no_batch_dim'),
  410. reference_fn=(None if with_batch else no_batch_dim_reference_fn))
  411. for with_batch, conv_kwargs in itertools.product([True, False], conv_kwargs_list)
  412. ]
  413. def module_inputs_torch_nn_ELU(module_info, device, dtype, requires_grad, training, **kwargs):
  414. make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
  415. return [
  416. ModuleInput(constructor_input=FunctionInput(alpha=2.),
  417. forward_input=FunctionInput(make_input((3, 2, 5))),
  418. reference_fn=lambda m, p, i: torch.where(i >= 0, i, 2 * (i.exp() - 1))),
  419. ModuleInput(constructor_input=FunctionInput(alpha=2.),
  420. forward_input=FunctionInput(make_input(())),
  421. desc='scalar'),
  422. ModuleInput(constructor_input=FunctionInput(),
  423. forward_input=FunctionInput(make_input((3,))),
  424. desc='no_batch_dim',
  425. reference_fn=no_batch_dim_reference_fn),
  426. ModuleInput(constructor_input=FunctionInput(alpha=2.),
  427. forward_input=FunctionInput(make_input((2, 3, 2, 5))),
  428. desc='4d_input')]
  429. def module_inputs_torch_nn_CELU(module_info, device, dtype, requires_grad, training, **kwargs):
  430. make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
  431. return [
  432. ModuleInput(constructor_input=FunctionInput(alpha=2.),
  433. forward_input=FunctionInput(make_input((3, 2, 5))),
  434. reference_fn=lambda m, p, i: torch.where(i >= 0, i, 2. * ((.5 * i).exp() - 1))),
  435. ModuleInput(constructor_input=FunctionInput(alpha=2.),
  436. forward_input=FunctionInput(make_input(())),
  437. reference_fn=lambda m, p, i: torch.where(i >= 0, i, 2 * (i.exp() - 1)),
  438. desc='scalar'),
  439. ModuleInput(constructor_input=FunctionInput(alpha=2.),
  440. forward_input=FunctionInput(make_input((3,))),
  441. desc='no_batch_dim',
  442. reference_fn=no_batch_dim_reference_fn)]
  443. def module_inputs_torch_nn_ReLU(module_info, device, dtype, requires_grad, training, **kwargs):
  444. make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
  445. return [
  446. ModuleInput(constructor_input=FunctionInput(),
  447. forward_input=FunctionInput(make_input(4)),
  448. desc='no_batch_dim'),
  449. ModuleInput(constructor_input=FunctionInput(),
  450. forward_input=FunctionInput(make_input((2, 3, 4, 5))),
  451. desc='channels_last_mem_format'),
  452. ModuleInput(constructor_input=FunctionInput(),
  453. forward_input=FunctionInput(make_input((2, 3, 3, 4, 5))),
  454. desc='channels_last_3d_mem_format')]
  455. def module_inputs_torch_nn_L1Loss(module_info, device, dtype, requires_grad, training, **kwargs):
  456. make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
  457. return [
  458. ModuleInput(constructor_input=FunctionInput(),
  459. forward_input=FunctionInput(make_input((2, 3, 4)),
  460. make_input((2, 3, 4))),
  461. reference_fn=lambda m, p, i, t: 1. / i.numel() * sum((a - b).abs().sum()
  462. for a, b in zip(i, t))),
  463. ModuleInput(constructor_input=FunctionInput(),
  464. forward_input=FunctionInput(make_input(()), make_input(())),
  465. reference_fn=lambda m, p, i, t: 1. / i.numel() * (i - t).abs().sum(),
  466. desc='scalar')] + generate_regression_criterion_inputs(make_input)
  467. def module_inputs_torch_nn_CrossEntropyLoss(module_info, device, dtype, requires_grad, training, **kwargs):
  468. make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
  469. make_target = partial(make_tensor, device=device, dtype=torch.long, requires_grad=False)
  470. make_weight = partial(make_tensor, device=device, dtype=dtype, requires_grad=False)
  471. reductions = ['sum', 'mean', 'none']
  472. samples = []
  473. # Samples below are for validating the no-batch-dim support.
  474. for reduction in reductions:
  475. samples.append(
  476. ModuleInput(constructor_input=FunctionInput(reduction=reduction),
  477. forward_input=FunctionInput(make_input((9,)), make_target((), low=0, high=9)),
  478. reference_fn=partial(no_batch_dim_reference_fn, is_criterion=True))
  479. )
  480. samples.append(
  481. ModuleInput(constructor_input=FunctionInput(reduction=reduction, weight=make_weight((9,))),
  482. forward_input=FunctionInput(make_input((9,)), make_target((), low=0, high=9)),
  483. reference_fn=partial(no_batch_dim_reference_fn, is_criterion=True))
  484. )
  485. samples.append(
  486. ModuleInput(constructor_input=FunctionInput(reduction=reduction, label_smoothing=0.5),
  487. forward_input=FunctionInput(make_input((9,)), make_target((), low=0, high=9)),
  488. reference_fn=partial(no_batch_dim_reference_fn, is_criterion=True))
  489. )
  490. samples.append(
  491. ModuleInput(constructor_input=FunctionInput(reduction=reduction, label_smoothing=0.5,
  492. weight=make_weight((9,))),
  493. forward_input=FunctionInput(make_input((9,)), make_target((), low=0, high=9)),
  494. reference_fn=partial(no_batch_dim_reference_fn, is_criterion=True))
  495. )
  496. return samples
  497. def module_inputs_torch_nn_Hardswish(module_info, device, dtype, requires_grad, training, **kwargs):
  498. make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
  499. return [
  500. ModuleInput(
  501. constructor_input=FunctionInput(),
  502. forward_input=FunctionInput(make_input(4)),
  503. reference_fn=no_batch_dim_reference_fn,
  504. desc='no_batch_dim',
  505. ),
  506. ModuleInput(
  507. constructor_input=FunctionInput(),
  508. forward_input=FunctionInput(make_input((2, 3, 2, 5))),
  509. desc='4d_input')
  510. ]
  511. def module_inputs_torch_nn_MaxPool2d(module_info, device, dtype, requires_grad, training, **kwargs):
  512. make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
  513. return [
  514. ModuleInput(
  515. constructor_input=FunctionInput((3, 3), (2, 2), (1, 1)),
  516. forward_input=FunctionInput(make_input(((3, 7, 7)))),
  517. desc='3d_input'),
  518. ModuleInput(
  519. constructor_input=FunctionInput((3, 3), (2, 2), (1, 1)),
  520. forward_input=FunctionInput(make_input((1, 3, 7, 7))),
  521. desc='4d_input'),
  522. ModuleInput(
  523. constructor_input=FunctionInput((3, 3), (2, 2), (1, 1), return_indices=True),
  524. forward_input=FunctionInput(make_input((1, 3, 7, 7))),
  525. desc='return_indices'),
  526. ]
  527. def module_inputs_torch_nn_Sigmoid(module_info, device, dtype, requires_grad, training, **kwargs):
  528. make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
  529. return [
  530. ModuleInput(
  531. constructor_input=FunctionInput(),
  532. forward_input=FunctionInput(make_input((2, 3, 4, 5))),
  533. desc='channels_last_mem_format'
  534. ),
  535. ModuleInput(
  536. constructor_input=FunctionInput(),
  537. forward_input=FunctionInput(make_input((2, 3, 3, 4, 5))),
  538. desc='channels_last_3d_mem_format'
  539. )
  540. ]
  541. def module_inputs_torch_nn_TransformerEncoder(module_info, device, dtype, requires_grad, training, **kwargs):
  542. # Reuse the TransformerEncoderLayer samples since the forward args are nearly the same.
  543. for layer_module_input in module_inputs_torch_nn_TransformerEncoderLayer(
  544. None, device, dtype, requires_grad, training):
  545. # Construct a TransformerEncoderLayer object to pass to TransformerEncoder.
  546. l_args, l_kwargs = (layer_module_input.constructor_input.args,
  547. layer_module_input.constructor_input.kwargs)
  548. encoder_layer = torch.nn.TransformerEncoderLayer(*l_args, **l_kwargs)
  549. num_layers = 2
  550. # Note: TransformerEncoderLayer takes a "src_mask" while
  551. # TransformerEncoder takes a "mask"; rename kwarg appropriately.
  552. forward_input = layer_module_input.forward_input
  553. if 'src_mask' in forward_input.kwargs:
  554. forward_input.kwargs['mask'] = forward_input.kwargs['src_mask']
  555. del forward_input.kwargs['src_mask']
  556. yield ModuleInput(
  557. constructor_input=FunctionInput(encoder_layer, num_layers),
  558. forward_input=forward_input,
  559. desc=layer_module_input.desc
  560. )
  561. def module_inputs_torch_nn_TransformerEncoderLayer(module_info, device, dtype, requires_grad, training, **kwargs):
  562. make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
  563. samples = [
  564. ModuleInput(
  565. constructor_input=FunctionInput(4, 2, 16, 0.0),
  566. forward_input=FunctionInput(
  567. make_input((2, 3, 4))
  568. ),
  569. desc='relu_activation'
  570. ),
  571. ModuleInput(
  572. constructor_input=FunctionInput(4, 2, 8, 0.0, F.gelu),
  573. forward_input=FunctionInput(
  574. make_input((2, 3, 4))
  575. ),
  576. desc='gelu_activation'
  577. ), ]
  578. # Samples below are for validating the no-batch-dim support.
  579. key_padding_masks = (None, torch.tensor([False, False, True], device=device, dtype=torch.bool))
  580. attn_masks = (None, torch.tensor([False, False, True], device=device, dtype=torch.bool).expand((3, 3)))
  581. for src_mask, src_key_padding_mask, norm_first in itertools.product(attn_masks, key_padding_masks, (True, False)):
  582. samples.append(
  583. ModuleInput(
  584. constructor_input=FunctionInput(d_model=4, nhead=2, dim_feedforward=8,
  585. dropout=0.0, batch_first=True, norm_first=norm_first),
  586. forward_input=FunctionInput(
  587. make_input((3, 4)), src_mask=src_mask, src_key_padding_mask=src_key_padding_mask
  588. ),
  589. reference_fn=partial(no_batch_dim_reference_fn,
  590. batch_first=True, kwargs_to_batchify={'src_key_padding_mask': 0}),
  591. desc='no_batch_dim_batch_first'
  592. ))
  593. samples.append(
  594. ModuleInput(
  595. constructor_input=FunctionInput(4, 2, 8, dropout=0.0, batch_first=False, norm_first=norm_first),
  596. forward_input=FunctionInput(
  597. make_input((3, 4)), src_mask=src_mask, src_key_padding_mask=src_key_padding_mask
  598. ),
  599. reference_fn=partial(no_batch_dim_reference_fn,
  600. batch_first=False, kwargs_to_batchify={'src_key_padding_mask': 0}),
  601. desc='no_batch_dim'
  602. ))
  603. def fast_path_reference_fn(module, parameters, *args, **kwargs):
  604. assert not module.training
  605. module = module.train(True)
  606. output = module(*args, **kwargs)
  607. module = module.train(False)
  608. return output
  609. if not training:
  610. for norm_first in (True, False):
  611. samples.append(
  612. ModuleInput(
  613. constructor_input=FunctionInput(4, 2, 8, dropout=0.0, batch_first=True, norm_first=norm_first),
  614. forward_input=FunctionInput(
  615. make_input((2, 3, 4)),
  616. ),
  617. reference_fn=fast_path_reference_fn,
  618. desc="fast_path_norm_first" if norm_first else "fast_path"
  619. )
  620. )
  621. return samples
  622. def module_inputs_torch_nn_TransformerDecoderLayer(module_info, device, dtype, requires_grad, training, **kwargs):
  623. make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
  624. samples = [
  625. ModuleInput(
  626. constructor_input=FunctionInput(4, 2, 16, 0.0),
  627. forward_input=FunctionInput(
  628. make_input((2, 3, 4)), make_input((2, 3, 4))
  629. ),
  630. desc='relu_activation'
  631. ),
  632. ModuleInput(
  633. constructor_input=FunctionInput(4, 2, 8, 0.0, F.gelu),
  634. forward_input=FunctionInput(
  635. make_input((2, 3, 4)), make_input((2, 3, 4))
  636. ),
  637. desc='gelu_activation'
  638. ), ]
  639. # Samples below are for validating the no-batch-dim support.
  640. key_padding_masks = (None, torch.tensor([False, False, True], device=device, dtype=torch.bool))
  641. attn_masks = (None, torch.tensor([False, False, True], device=device, dtype=torch.bool).expand((3, 3)))
  642. for tgt_mask, tgt_key_padding_mask, norm_first in itertools.product(attn_masks, key_padding_masks, (True, False)):
  643. # Using same mask for tgt and memory
  644. memory_mask = tgt_mask
  645. memory_key_padding_mask = tgt_key_padding_mask
  646. samples.append(
  647. ModuleInput(
  648. constructor_input=FunctionInput(d_model=4, nhead=2, dim_feedforward=8,
  649. dropout=0.0, batch_first=True, norm_first=norm_first),
  650. forward_input=FunctionInput(
  651. make_input((3, 4)), make_input((3, 4)), tgt_mask=tgt_mask, memory_mask=memory_mask,
  652. tgt_key_padding_mask=tgt_key_padding_mask, memory_key_padding_mask=memory_key_padding_mask
  653. ),
  654. reference_fn=partial(no_batch_dim_reference_fn,
  655. batch_first=True,
  656. kwargs_to_batchify={'tgt_key_padding_mask': 0, 'memory_key_padding_mask': 0}),
  657. desc='no_batch_dim_batch_first'
  658. ))
  659. samples.append(
  660. ModuleInput(
  661. constructor_input=FunctionInput(4, 2, 8, dropout=0.0, batch_first=False, norm_first=norm_first),
  662. forward_input=FunctionInput(
  663. make_input((3, 4)), make_input((3, 4)), tgt_mask=tgt_mask, memory_mask=memory_mask,
  664. tgt_key_padding_mask=tgt_key_padding_mask, memory_key_padding_mask=memory_key_padding_mask
  665. ),
  666. reference_fn=partial(no_batch_dim_reference_fn,
  667. batch_first=False,
  668. kwargs_to_batchify={'tgt_key_padding_mask': 0, 'memory_key_padding_mask': 0}),
  669. desc='no_batch_dim'
  670. ))
  671. return samples
  672. def module_inputs_torch_nn_Transformer(module_info, device, dtype, requires_grad, training, **kwargs):
  673. make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
  674. samples = []
  675. # Samples below are for validating the no-batch-dim support.
  676. key_padding_masks = (None, torch.tensor([False, False, True], device=device, dtype=torch.bool))
  677. attn_masks = (None, torch.tensor([False, False, True], device=device, dtype=torch.bool).expand((3, 3)))
  678. for mask, key_padding_mask, norm_first in itertools.product(attn_masks, key_padding_masks, (True, False)):
  679. # Using same mask for tgt and memory
  680. src_mask , tgt_mask = (mask,) * 2
  681. src_key_padding_mask, tgt_key_padding_mask = (key_padding_mask,) * 2
  682. samples.append(
  683. ModuleInput(
  684. constructor_input=FunctionInput(d_model=4, nhead=2, dim_feedforward=8,
  685. num_encoder_layers=1, num_decoder_layers=1,
  686. dropout=0.0, batch_first=True, norm_first=norm_first),
  687. forward_input=FunctionInput(
  688. make_input((3, 4)), make_input((3, 4)), tgt_mask=tgt_mask, src_mask=src_mask,
  689. tgt_key_padding_mask=tgt_key_padding_mask, src_key_padding_mask=src_key_padding_mask
  690. ),
  691. reference_fn=partial(no_batch_dim_reference_fn,
  692. batch_first=True,
  693. kwargs_to_batchify={'tgt_key_padding_mask': 0, 'src_key_padding_mask': 0}),
  694. desc='no_batch_dim_batch_first'
  695. ))
  696. samples.append(
  697. ModuleInput(
  698. constructor_input=FunctionInput(d_model=4, nhead=2, dim_feedforward=8,
  699. num_encoder_layers=1, num_decoder_layers=1,
  700. dropout=0.0, batch_first=False, norm_first=norm_first),
  701. forward_input=FunctionInput(
  702. make_input((3, 4)), make_input((3, 4)), tgt_mask=tgt_mask, src_mask=src_mask,
  703. tgt_key_padding_mask=tgt_key_padding_mask, src_key_padding_mask=src_key_padding_mask
  704. ),
  705. reference_fn=partial(no_batch_dim_reference_fn,
  706. batch_first=False,
  707. kwargs_to_batchify={'tgt_key_padding_mask': 0, 'src_key_padding_mask': 0}),
  708. desc='no_batch_dim'
  709. ))
  710. return samples
  711. def module_inputs_torch_nn_Embedding(module_info, device, dtype, requires_grad, training, **kwargs):
  712. make_empty = partial(torch.empty, device=device, dtype=torch.long, requires_grad=False)
  713. return [
  714. ModuleInput(
  715. constructor_input=FunctionInput(num_embeddings=4, embedding_dim=3),
  716. forward_input=FunctionInput(make_empty(2, 3).random_(4))
  717. ),
  718. ModuleInput(
  719. constructor_input=FunctionInput(num_embeddings=4, embedding_dim=3),
  720. forward_input=FunctionInput(make_empty(1, 512).random_(4).expand(7, 512)),
  721. desc='discontiguous'
  722. ),
  723. ]
  724. def module_inputs_torch_nn_MultiheadAttention(module_info, device, dtype, requires_grad, training, **kwargs):
  725. # Currently all samples below are for validating the no-batch-dim support.
  726. make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
  727. samples = []
  728. bool_vals = (True, False)
  729. key_padding_masks = (None, torch.tensor([False, False, True], device=device, dtype=torch.bool))
  730. attn_masks = (None, torch.tensor([False, False, True], device=device, dtype=torch.bool).expand((3, 3, 3)))
  731. products = itertools.product(bool_vals, bool_vals, bool_vals, key_padding_masks, attn_masks)
  732. for bias, add_bias_kv, add_zero_attn, key_padding_mask, attn_mask in products:
  733. samples.append(
  734. ModuleInput(
  735. constructor_input=FunctionInput(embed_dim=3, num_heads=3, batch_first=True,
  736. bias=bias, add_bias_kv=add_bias_kv, add_zero_attn=add_zero_attn),
  737. forward_input=FunctionInput(make_input((3, 3)), make_input((3, 3)), make_input((3, 3)),
  738. key_padding_mask=key_padding_mask, attn_mask=attn_mask),
  739. reference_fn=no_batch_dim_reference_mha,
  740. )
  741. )
  742. samples.append(
  743. ModuleInput(
  744. constructor_input=FunctionInput(embed_dim=3, num_heads=3, batch_first=False,
  745. bias=bias, add_bias_kv=add_bias_kv, add_zero_attn=add_zero_attn),
  746. forward_input=FunctionInput(make_input((3, 3)), make_input((3, 3)), make_input((3, 3)),
  747. key_padding_mask=key_padding_mask, attn_mask=attn_mask),
  748. reference_fn=partial(no_batch_dim_reference_mha, batch_first=False),
  749. )
  750. )
  751. return samples
  752. def module_inputs_torch_nn_RNN_GRU_Cell(module_info, device, dtype, requires_grad, training, **kwargs):
  753. # Currently all samples below are for validating the no-batch-dim support.
  754. make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
  755. samples = [
  756. ModuleInput(
  757. constructor_input=FunctionInput(5, 10),
  758. forward_input=FunctionInput(make_input(5), make_input(10)),
  759. reference_fn=no_batch_dim_reference_fn,
  760. ),
  761. ModuleInput(
  762. constructor_input=FunctionInput(5, 10, bias=True),
  763. forward_input=FunctionInput(make_input(5), make_input(10)),
  764. reference_fn=no_batch_dim_reference_fn,
  765. )
  766. ]
  767. is_rnn = kwargs.get('is_rnn', False)
  768. if is_rnn:
  769. # RNN also supports `nonlinearity` argument.
  770. # `tanh` is the default, so we check with `relu`
  771. samples.append(
  772. ModuleInput(
  773. constructor_input=FunctionInput(5, 10, bias=True, nonlinearity='relu'),
  774. forward_input=FunctionInput(make_input(5), make_input(10)),
  775. reference_fn=no_batch_dim_reference_fn,
  776. )
  777. )
  778. return samples
  779. def module_inputs_torch_nn_LSTMCell(module_info, device, dtype, requires_grad, training, **kwargs):
  780. # Currently all samples below are for validating the no-batch-dim support.
  781. make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
  782. samples = (
  783. ModuleInput(
  784. constructor_input=FunctionInput(5, 10),
  785. forward_input=FunctionInput(make_input(5), (make_input(10), make_input(10))),
  786. reference_fn=no_batch_dim_reference_lstmcell,
  787. ),
  788. ModuleInput(
  789. constructor_input=FunctionInput(5, 10, bias=True),
  790. forward_input=FunctionInput(make_input(5), (make_input(10), make_input(10))),
  791. reference_fn=no_batch_dim_reference_lstmcell,
  792. ),
  793. )
  794. return samples
  795. def make_packed_sequence(inp, batch_sizes):
  796. required_grad = inp.requires_grad
  797. inp.requires_grad_(False) # user won't have access to inp so won't be able to get its grads
  798. seq = pack_padded_sequence(inp, batch_sizes)
  799. seq.data.requires_grad_(required_grad)
  800. return seq
  801. def module_inputs_torch_nn_RNN_GRU(module_info, device, dtype, requires_grad, training, with_packed_sequence=False, **kwargs):
  802. # Currently all samples below are for validating the no-batch-dim support.
  803. make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
  804. is_rnn = kwargs['is_rnn']
  805. nonlinearity = ('relu', 'tanh')
  806. bias = (False, True)
  807. batch_first = (False, True)
  808. bidirectional = (False, True)
  809. samples = []
  810. if is_rnn:
  811. prod_gen = product(nonlinearity, bias, batch_first, bidirectional)
  812. else:
  813. prod_gen = product(bias, batch_first, bidirectional)
  814. for args in prod_gen:
  815. if is_rnn:
  816. nl, b, b_f, bidir = args
  817. else:
  818. b, b_f, bidir = args
  819. cons_args = {'input_size': 2, 'hidden_size': 2, 'num_layers': 2,
  820. 'batch_first': b_f, 'bias': b, 'bidirectional': bidir}
  821. cons_args_hidden = {'input_size': 2, 'hidden_size': 3, 'num_layers': 2,
  822. 'batch_first': b_f, 'bias': b, 'bidirectional': bidir}
  823. if is_rnn:
  824. cons_args['nonlinearity'] = nl
  825. cons_args_hidden['nonlinearity'] = nl
  826. samples.append(
  827. ModuleInput(
  828. constructor_input=FunctionInput(**cons_args),
  829. forward_input=FunctionInput(make_input((3, 2))),
  830. reference_fn=partial(no_batch_dim_reference_rnn_gru, batch_first=b_f),
  831. )
  832. )
  833. samples.append(
  834. ModuleInput(
  835. constructor_input=FunctionInput(**cons_args_hidden),
  836. forward_input=FunctionInput(make_input((3, 2)), make_input((4 if bidir else 2, 3))),
  837. reference_fn=partial(no_batch_dim_reference_rnn_gru, batch_first=b_f),
  838. )
  839. )
  840. if with_packed_sequence:
  841. samples.append(
  842. ModuleInput(
  843. constructor_input=FunctionInput(**cons_args),
  844. forward_input=FunctionInput(make_packed_sequence(make_input((5, 2, 2)), torch.tensor([5, 3]))),
  845. reference_fn=partial(no_batch_dim_reference_rnn_gru, batch_first=b_f),
  846. )
  847. )
  848. samples.append(
  849. ModuleInput(
  850. constructor_input=FunctionInput(**cons_args),
  851. forward_input=FunctionInput(make_packed_sequence(make_input((5, 5, 2)), torch.tensor([5, 3, 3, 2, 2]))),
  852. reference_fn=partial(no_batch_dim_reference_rnn_gru, batch_first=b_f),
  853. )
  854. )
  855. return samples
  856. def module_inputs_torch_nn_LSTM(module_info, device, dtype, requires_grad, training, **kwargs):
  857. # Currently all samples below are for validating the no-batch-dim support.
  858. make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
  859. bias = (False, True)
  860. batch_first = (False, True)
  861. bidirectional = (False, True)
  862. proj_sizes = (0, 2)
  863. samples = []
  864. prod_gen = product(bias, batch_first, bidirectional, proj_sizes)
  865. for args in prod_gen:
  866. b, b_f, bidir, proj_size = args
  867. hidden_size = 3
  868. cons_args = {'input_size': 2, 'hidden_size': hidden_size, 'num_layers': 2, 'proj_size': proj_size,
  869. 'batch_first': b_f, 'bias': b, 'bidirectional': bidir}
  870. cons_args_hidden = {'input_size': 2, 'hidden_size': hidden_size, 'num_layers': 2, 'proj_size': proj_size,
  871. 'batch_first': b_f, 'bias': b, 'bidirectional': bidir}
  872. samples.append(
  873. ModuleInput(
  874. constructor_input=FunctionInput(**cons_args),
  875. forward_input=FunctionInput(make_input((2, 2))),
  876. reference_fn=partial(no_batch_dim_reference_lstm, batch_first=b_f),
  877. )
  878. )
  879. h_out = proj_size if proj_size > 0 else hidden_size
  880. hx = (make_input((4 if bidir else 2, h_out)), make_input((4 if bidir else 2, hidden_size)))
  881. samples.append(
  882. ModuleInput(
  883. constructor_input=FunctionInput(**cons_args_hidden),
  884. forward_input=FunctionInput(make_input((3, 2)), hx),
  885. reference_fn=partial(no_batch_dim_reference_lstm, batch_first=b_f),
  886. )
  887. )
  888. return samples
  889. # All these operators share similar issues on cuDNN and MIOpen
  890. rnn_gru_lstm_module_info_decorators = (
  891. # RuntimeError: Batching rule not implemented for aten::_cudnn_rnn_backward.
  892. # We could not generate a fallback
  893. DecorateInfo(
  894. unittest.expectedFailure, "TestModule", "test_grad",
  895. active_if=(TEST_CUDNN and not TEST_WITH_ROCM), device_type='cuda'
  896. ),
  897. # NotImplementedError: the derivative for '_cudnn_rnn_backward' is not implemented.
  898. # Double backwards is not supported for CuDNN RNNs due to limitations in the CuDNN API
  899. DecorateInfo(
  900. unittest.expectedFailure, "TestModule", "test_gradgrad",
  901. active_if=(TEST_CUDNN and not TEST_WITH_ROCM), device_type='cuda'
  902. ),
  903. # CUDNN GRU doesn't accept non-contiguous hx
  904. DecorateInfo(
  905. unittest.expectedFailure, "TestModule", "test_non_contiguous_tensors",
  906. active_if=(TEST_CUDNN and not TEST_WITH_ROCM), device_type='cuda'
  907. ),
  908. # MIOPEN GRU doesn't accept non-contiguous hx (this is dispatched to miopen only for float).
  909. DecorateInfo(
  910. unittest.expectedFailure, "TestModule", "test_non_contiguous_tensors",
  911. active_if=(TEST_CUDNN and TEST_WITH_ROCM), dtypes=(torch.float,), device_type='cuda'
  912. ),
  913. DecorateInfo(
  914. skipCUDAVersionIn([(11, 7)]), "TestExpandedWeightModule", "test_module",
  915. device_type='cuda'
  916. ),
  917. DecorateInfo(
  918. skipCUDAVersionIn([(11, 7)]), "TestDecomp", "test_rnn_decomp_module",
  919. device_type='cuda'
  920. )
  921. )
  922. # Database of ModuleInfo entries in alphabetical order.
  923. module_db: List[ModuleInfo] = [
  924. ModuleInfo(torch.nn.AdaptiveAvgPool2d,
  925. gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
  926. module_inputs_func=module_inputs_torch_nn_AdaptiveAvgPool2d,
  927. skips=(
  928. DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),)
  929. ),
  930. ModuleInfo(torch.nn.AvgPool1d,
  931. module_inputs_func=module_inputs_torch_nn_AvgPool1d,
  932. skips=(
  933. # No channels_last support for AvgPool1d as it does not take 4D inputs
  934. DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),
  935. DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),)
  936. ),
  937. ModuleInfo(torch.nn.BatchNorm2d,
  938. train_and_eval_differ=True,
  939. module_inputs_func=module_inputs_torch_nn_BatchNorm2d,
  940. skips=(
  941. DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),)
  942. ),
  943. ModuleInfo(torch.nn.BatchNorm3d,
  944. train_and_eval_differ=True,
  945. module_inputs_func=module_inputs_torch_nn_BatchNorm3d,
  946. skips=(
  947. DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),)
  948. ),
  949. ModuleInfo(torch.nn.Conv1d,
  950. module_inputs_func=partial(module_inputs_torch_nn_ConvNd, N=1, lazy=False),
  951. gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
  952. module_memformat_affects_out=True,
  953. skips=(
  954. # channels_last support on cuda requires cudnn >= 7603
  955. DecorateInfo(skipCUDAIfCudnnVersionLessThan(version=7603), 'TestModule', 'test_memory_format'),
  956. # Failure on ROCM for float32 issue #70125
  957. DecorateInfo(skipCUDAIfRocm, 'TestModule', 'test_memory_format', dtypes=[torch.float32]),
  958. DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64])
  959. ),
  960. decorators=(
  961. DecorateInfo(precisionOverride({torch.float32: 1e-04}), 'TestModule', 'test_memory_format'),
  962. )),
  963. ModuleInfo(torch.nn.Conv2d,
  964. module_inputs_func=partial(module_inputs_torch_nn_ConvNd, N=2, lazy=False),
  965. gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
  966. module_memformat_affects_out=True,
  967. skips=(
  968. # channels_last support on cuda requires cudnn >= 7603
  969. DecorateInfo(skipCUDAIfCudnnVersionLessThan(version=7603), 'TestModule', 'test_memory_format'),
  970. # Failure on ROCM for float32 issue #70125
  971. DecorateInfo(skipCUDAIfRocm, 'TestModule', 'test_memory_format', dtypes=[torch.float32]),
  972. DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),
  973. # This was wrongly being skipped before and needs investigation.
  974. # See https://github.com/pytorch/pytorch/issues/80247
  975. DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format",
  976. device_type='cuda', dtypes=[torch.float64]),
  977. ),
  978. decorators=(
  979. DecorateInfo(precisionOverride({torch.float32: 1e-04}), 'TestModule', 'test_memory_format'),
  980. )),
  981. ModuleInfo(torch.nn.Conv3d,
  982. module_inputs_func=partial(module_inputs_torch_nn_ConvNd, N=3, lazy=False),
  983. gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
  984. module_memformat_affects_out=True,
  985. skips=(
  986. # channels_last support on cuda requires cudnn >= 8005
  987. DecorateInfo(skipCUDAIfCudnnVersionLessThan(version=8005), 'TestModule', 'test_memory_format'),
  988. # Failure on ROCM for float32 issue #70125
  989. DecorateInfo(skipCUDAIfRocm, 'TestModule', 'test_memory_format', dtypes=[torch.float32]),
  990. DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),
  991. # This was wrongly being skipped before and needs investigation.
  992. # See https://github.com/pytorch/pytorch/issues/80247
  993. DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format"),
  994. ),
  995. decorators=(
  996. DecorateInfo(precisionOverride({torch.float32: 1e-04}), 'TestModule', 'test_memory_format'),
  997. )),
  998. ModuleInfo(torch.nn.ConvTranspose1d,
  999. module_inputs_func=partial(module_inputs_torch_nn_ConvNd, N=1, lazy=False, transposed=True),
  1000. gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
  1001. module_memformat_affects_out=True,
  1002. dtypes=floating_and_complex_types_and(torch.chalf),
  1003. skips=(
  1004. # channels_last support on cuda requires cudnn >= 7603
  1005. DecorateInfo(skipCUDAIfCudnnVersionLessThan(version=7603), 'TestModule', 'test_memory_format'),
  1006. # Failure on ROCM for float32 issue #70125
  1007. DecorateInfo(skipCUDAIfRocm, 'TestModule', 'test_memory_format', dtypes=[torch.float32]),
  1008. DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),
  1009. # Not implmented for chalf on CPU
  1010. DecorateInfo(unittest.expectedFailure, 'TestModule', 'test_forward',
  1011. dtypes=(torch.chalf,), device_type='cpu'),
  1012. DecorateInfo(unittest.expectedFailure, 'TestModule', 'test_memory_format',
  1013. dtypes=(torch.chalf,), device_type='cpu'),
  1014. DecorateInfo(unittest.expectedFailure, 'TestModule',
  1015. 'test_if_train_and_eval_modes_differ', dtypes=(torch.chalf,), device_type='cpu'),
  1016. DecorateInfo(unittest.expectedFailure, 'TestModule', 'test_non_contiguous_tensors',
  1017. dtypes=(torch.chalf,), device_type='cpu'),
  1018. DecorateInfo(unittest.expectedFailure, 'TestModule', 'test_cpu_gpu_parity',
  1019. dtypes=(torch.chalf,), device_type='cuda'),
  1020. DecorateInfo(unittest.expectedFailure, 'TestModule', 'test_multiple_device_transfer',
  1021. dtypes=(torch.chalf,), device_type='cuda'),
  1022. # Ref: https://github.com/pytorch/pytorch/issues/73502
  1023. DecorateInfo(unittest.expectedFailure, 'TestModule', 'test_pickle', dtypes=(torch.chalf,)),
  1024. ),
  1025. decorators=(
  1026. DecorateInfo(precisionOverride({torch.float32: 1e-04}), 'TestModule', 'test_memory_format'),
  1027. )),
  1028. ModuleInfo(torch.nn.ConvTranspose2d,
  1029. module_inputs_func=partial(module_inputs_torch_nn_ConvNd, N=2, lazy=False, transposed=True),
  1030. gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
  1031. module_memformat_affects_out=True,
  1032. dtypes=floating_and_complex_types_and(torch.chalf),
  1033. skips=(
  1034. # channels_last support on cuda requires cudnn >= 7603
  1035. DecorateInfo(skipCUDAIfCudnnVersionLessThan(version=7603), 'TestModule', 'test_memory_format'),
  1036. # Failure on ROCM for float32 issue #70125
  1037. DecorateInfo(skipCUDAIfRocm, 'TestModule', 'test_memory_format', dtypes=[torch.float32]),
  1038. DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),
  1039. # This was wrongly being skipped before and needs investigation.
  1040. # See https://github.com/pytorch/pytorch/issues/80247
  1041. DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format", device_type='cuda',
  1042. dtypes=[torch.float64, torch.complex128]),
  1043. # These fail only on ROCm
  1044. DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format", device_type='cuda',
  1045. dtypes=[torch.complex32], active_if=TEST_WITH_ROCM),
  1046. # Not implmented for chalf on CPU
  1047. DecorateInfo(unittest.expectedFailure, 'TestModule', 'test_forward',
  1048. dtypes=(torch.chalf,), device_type='cpu'),
  1049. DecorateInfo(unittest.expectedFailure, 'TestModule', 'test_memory_format',
  1050. dtypes=(torch.chalf,), device_type='cpu'),
  1051. DecorateInfo(unittest.expectedFailure, 'TestModule',
  1052. 'test_if_train_and_eval_modes_differ', dtypes=(torch.chalf,), device_type='cpu'),
  1053. DecorateInfo(unittest.expectedFailure, 'TestModule', 'test_non_contiguous_tensors',
  1054. dtypes=(torch.chalf,), device_type='cpu'),
  1055. DecorateInfo(unittest.expectedFailure, 'TestModule', 'test_cpu_gpu_parity',
  1056. dtypes=(torch.chalf,), device_type='cuda'),
  1057. DecorateInfo(unittest.expectedFailure, 'TestModule', 'test_multiple_device_transfer',
  1058. dtypes=(torch.chalf,), device_type='cuda'),
  1059. # Ref: https://github.com/pytorch/pytorch/issues/73502
  1060. DecorateInfo(unittest.expectedFailure, 'TestModule', 'test_pickle', dtypes=(torch.chalf,)),
  1061. ),
  1062. decorators=(
  1063. DecorateInfo(precisionOverride({torch.float32: 1e-04}), 'TestModule', 'test_memory_format'),
  1064. )),
  1065. ModuleInfo(torch.nn.ConvTranspose3d,
  1066. module_inputs_func=partial(module_inputs_torch_nn_ConvNd, N=3, lazy=False, transposed=True),
  1067. dtypes=floating_and_complex_types_and(torch.chalf),
  1068. gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
  1069. module_memformat_affects_out=True,
  1070. skips=(
  1071. # channels_last support on cuda requires cudnn >= 8005
  1072. DecorateInfo(skipCUDAIfCudnnVersionLessThan(version=8005), 'TestModule', 'test_memory_format'),
  1073. # Failure on ROCM for float32 issue #70125
  1074. DecorateInfo(skipCUDAIfRocm, 'TestModule', 'test_memory_format', dtypes=[torch.float32]),
  1075. DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),
  1076. # This was wrongly being skipped before and needs investigation.
  1077. # See https://github.com/pytorch/pytorch/issues/80247
  1078. DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format"),
  1079. # These fail only on ROCm
  1080. DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format", device_type='cuda',
  1081. dtypes=[torch.complex32, torch.complex64], active_if=TEST_WITH_ROCM),
  1082. # Not implmented for chalf on CPU
  1083. DecorateInfo(unittest.expectedFailure, 'TestModule', 'test_forward',
  1084. dtypes=(torch.chalf,), device_type='cpu'),
  1085. DecorateInfo(unittest.expectedFailure, 'TestModule', 'test_memory_format',
  1086. dtypes=(torch.chalf,), device_type='cpu'),
  1087. DecorateInfo(unittest.expectedFailure, 'TestModule',
  1088. 'test_if_train_and_eval_modes_differ', dtypes=(torch.chalf,), device_type='cpu'),
  1089. DecorateInfo(unittest.expectedFailure, 'TestModule', 'test_non_contiguous_tensors',
  1090. dtypes=(torch.chalf,), device_type='cpu'),
  1091. DecorateInfo(unittest.expectedFailure, 'TestModule', 'test_cpu_gpu_parity',
  1092. dtypes=(torch.chalf,), device_type='cuda'),
  1093. DecorateInfo(unittest.expectedFailure, 'TestModule', 'test_multiple_device_transfer',
  1094. dtypes=(torch.chalf,), device_type='cuda'),
  1095. # Ref: https://github.com/pytorch/pytorch/issues/73502
  1096. DecorateInfo(unittest.expectedFailure, 'TestModule', 'test_pickle', dtypes=(torch.chalf,)),
  1097. ),
  1098. decorators=(
  1099. DecorateInfo(precisionOverride({torch.float32: 1e-04}), 'TestModule', 'test_memory_format'),
  1100. DecorateInfo(precisionOverride({torch.complex64: 1e-04}), 'TestModule', 'test_cpu_gpu_parity'),
  1101. )),
  1102. ModuleInfo(torch.nn.ELU,
  1103. module_inputs_func=module_inputs_torch_nn_ELU,
  1104. skips=(
  1105. DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),)
  1106. ),
  1107. ModuleInfo(torch.nn.L1Loss,
  1108. module_inputs_func=module_inputs_torch_nn_L1Loss,
  1109. skips=(
  1110. # No channels_last support for loss functions.
  1111. DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),
  1112. DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),)
  1113. ),
  1114. ModuleInfo(torch.nn.LazyConv1d,
  1115. module_inputs_func=partial(module_inputs_torch_nn_ConvNd, N=1, lazy=True),
  1116. gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
  1117. module_memformat_affects_out=True,
  1118. skips=(
  1119. # channels_last support on cuda requires cudnn >= 7603
  1120. DecorateInfo(skipCUDAIfCudnnVersionLessThan(version=7603), 'TestModule', 'test_memory_format'),
  1121. # Failure on ROCM for float32 issue #70125
  1122. DecorateInfo(skipCUDAIfRocm, 'TestModule', 'test_memory_format', dtypes=[torch.float32]),
  1123. # Lazy modules don't currently play well with ModuleInfo tests on the meta device.
  1124. # See https://github.com/pytorch/pytorch/issues/70505 for more info.
  1125. DecorateInfo(skipMeta),
  1126. DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),
  1127. ),
  1128. decorators=(
  1129. DecorateInfo(precisionOverride({torch.float32: 1e-04}), 'TestModule', 'test_memory_format'),
  1130. )),
  1131. ModuleInfo(torch.nn.LazyConv2d,
  1132. module_inputs_func=partial(module_inputs_torch_nn_ConvNd, N=2, lazy=True),
  1133. gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
  1134. module_memformat_affects_out=True,
  1135. skips=(
  1136. DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),
  1137. # channels_last support on cuda requires cudnn >= 7603
  1138. DecorateInfo(skipCUDAIfCudnnVersionLessThan(version=7603), 'TestModule', 'test_memory_format'),
  1139. # Failure on ROCM for float32 issue #70125
  1140. DecorateInfo(skipCUDAIfRocm, 'TestModule', 'test_memory_format', dtypes=[torch.float32]),
  1141. # Lazy modules don't currently play well with ModuleInfo tests on the meta device.
  1142. # See https://github.com/pytorch/pytorch/issues/70505 for more info.
  1143. DecorateInfo(skipMeta),
  1144. # This was wrongly being skipped before and needs investigation.
  1145. # See https://github.com/pytorch/pytorch/issues/80247
  1146. DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format",
  1147. device_type='cuda', dtypes=[torch.float64]),
  1148. ),
  1149. decorators=(
  1150. DecorateInfo(precisionOverride({torch.float32: 1e-04}), 'TestModule', 'test_memory_format'),
  1151. )),
  1152. ModuleInfo(torch.nn.LazyConv3d,
  1153. module_inputs_func=partial(module_inputs_torch_nn_ConvNd, N=3, lazy=True),
  1154. gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
  1155. module_memformat_affects_out=True,
  1156. skips=(
  1157. # channels_last support on cuda requires cudnn >= 8005
  1158. DecorateInfo(skipCUDAIfCudnnVersionLessThan(version=8005), 'TestModule', 'test_memory_format'),
  1159. # Failure on ROCM for float32 issue #70125
  1160. DecorateInfo(skipCUDAIfRocm, 'TestModule', 'test_memory_format', dtypes=[torch.float32]),
  1161. # Lazy modules don't currently play well with ModuleInfo tests on the meta device.
  1162. # See https://github.com/pytorch/pytorch/issues/70505 for more info.
  1163. DecorateInfo(skipMeta),
  1164. DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),
  1165. # This was wrongly being skipped before and needs investigation.
  1166. # See https://github.com/pytorch/pytorch/issues/80247
  1167. DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format"),
  1168. ),
  1169. decorators=(
  1170. DecorateInfo(precisionOverride({torch.float32: 1e-04}), 'TestModule', 'test_memory_format'),
  1171. )),
  1172. ModuleInfo(torch.nn.LazyConvTranspose1d,
  1173. module_inputs_func=partial(module_inputs_torch_nn_ConvNd, N=1, lazy=True, transposed=True),
  1174. gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
  1175. module_memformat_affects_out=True,
  1176. skips=(
  1177. # channels_last support on cuda requires cudnn >= 7603
  1178. DecorateInfo(skipCUDAIfCudnnVersionLessThan(version=7603), 'TestModule', 'test_memory_format'),
  1179. # Failure on ROCM for float32 issue #70125
  1180. DecorateInfo(skipCUDAIfRocm, 'TestModule', 'test_memory_format', dtypes=[torch.float32]),
  1181. # Lazy modules don't currently play well with ModuleInfo tests on the meta device.
  1182. # See https://github.com/pytorch/pytorch/issues/70505 for more info.
  1183. DecorateInfo(skipMeta),
  1184. DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),
  1185. ),
  1186. decorators=(
  1187. DecorateInfo(precisionOverride({torch.float32: 1e-04}), 'TestModule', 'test_memory_format'),
  1188. )),
  1189. ModuleInfo(torch.nn.LazyConvTranspose2d,
  1190. module_inputs_func=partial(module_inputs_torch_nn_ConvNd, N=2, lazy=True, transposed=True),
  1191. gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
  1192. module_memformat_affects_out=True,
  1193. skips=(
  1194. # channels_last support on cuda requires cudnn >= 7603
  1195. DecorateInfo(skipCUDAIfCudnnVersionLessThan(version=7603), 'TestModule', 'test_memory_format'),
  1196. # Failure on ROCM for float32 issue #70125
  1197. DecorateInfo(skipCUDAIfRocm, 'TestModule', 'test_memory_format', dtypes=[torch.float32]),
  1198. # Lazy modules don't currently play well with ModuleInfo tests on the meta device.
  1199. # See https://github.com/pytorch/pytorch/issues/70505 for more info.
  1200. DecorateInfo(skipMeta),
  1201. DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),
  1202. # This was wrongly being skipped before and needs investigation.
  1203. # See https://github.com/pytorch/pytorch/issues/80247
  1204. DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format", device_type='cuda',
  1205. dtypes=[torch.float64]),
  1206. ),
  1207. decorators=(
  1208. DecorateInfo(precisionOverride({torch.float32: 1e-04}), 'TestModule', 'test_memory_format'),
  1209. )),
  1210. ModuleInfo(torch.nn.LazyConvTranspose3d,
  1211. module_inputs_func=partial(module_inputs_torch_nn_ConvNd, N=3, lazy=True, transposed=True),
  1212. gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
  1213. module_memformat_affects_out=True,
  1214. skips=(
  1215. # channels_last support on cuda requires cudnn >= 8005
  1216. DecorateInfo(skipCUDAIfCudnnVersionLessThan(version=8005), 'TestModule', 'test_memory_format'),
  1217. # Failure on ROCM for float32 issue #70125
  1218. DecorateInfo(skipCUDAIfRocm, 'TestModule', 'test_memory_format', dtypes=[torch.float32]),
  1219. # Lazy modules don't currently play well with ModuleInfo tests on the meta device.
  1220. # See https://github.com/pytorch/pytorch/issues/70505 for more info.
  1221. DecorateInfo(skipMeta),
  1222. DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),
  1223. # This was wrongly being skipped before and needs investigation.
  1224. # See https://github.com/pytorch/pytorch/issues/80247
  1225. DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format"),
  1226. ),
  1227. decorators=(
  1228. DecorateInfo(precisionOverride({torch.float32: 1e-04}), 'TestModule', 'test_memory_format'),
  1229. )),
  1230. ModuleInfo(torch.nn.Linear,
  1231. module_inputs_func=module_inputs_torch_nn_Linear,
  1232. skips=(
  1233. DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),
  1234. # No channels_last support for Linear currently.
  1235. DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),)
  1236. ),
  1237. ModuleInfo(torch.nn.Bilinear,
  1238. module_inputs_func=module_inputs_torch_nn_Bilinear,
  1239. decorators=[
  1240. DecorateInfo(
  1241. toleranceOverride({
  1242. torch.float32: tol(atol=1e-4, rtol=1e-4),
  1243. torch.float64: tol(atol=1e-4, rtol=1e-4)}),
  1244. 'TestModule', 'test_forward', device_type='cpu')
  1245. ],
  1246. skips=(
  1247. DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),
  1248. # No channels_last support for Bilinear currently.
  1249. DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),)
  1250. ),
  1251. ModuleInfo(torch.nn.MaxPool2d,
  1252. module_inputs_func=module_inputs_torch_nn_MaxPool2d,
  1253. skips=(
  1254. # TODO: test_non_contiguous_tensors doesn't handle case where output is not a singleton (such as
  1255. # return_indices=True for MaxPool2D), submit fix
  1256. DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_non_contiguous_tensors'),
  1257. # TODO: test_cpu_gpu_parity doesn't handle case where output is not a singleton, submit fix
  1258. DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_cpu_gpu_parity'),
  1259. DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),)
  1260. ),
  1261. ModuleInfo(torch.nn.NLLLoss,
  1262. module_inputs_func=module_inputs_torch_nn_NLLLoss,
  1263. skips=(
  1264. # No channels_last support for loss functions.
  1265. DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),
  1266. DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),)
  1267. ),
  1268. ModuleInfo(torch.nn.GaussianNLLLoss,
  1269. module_inputs_func=module_inputs_torch_nn_GaussianNLLLoss,
  1270. skips=(
  1271. # No channels_last support for loss functions.
  1272. DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),
  1273. DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),)),
  1274. ModuleInfo(torch.nn.CrossEntropyLoss,
  1275. module_inputs_func=module_inputs_torch_nn_CrossEntropyLoss,
  1276. skips=(
  1277. DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),)
  1278. ),
  1279. ModuleInfo(torch.nn.Hardswish,
  1280. module_inputs_func=module_inputs_torch_nn_Hardswish,
  1281. skips=(
  1282. DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),),
  1283. supports_gradgrad=False),
  1284. # TransformerEncoder takes the same inputs as TransformerEncoderLayer
  1285. ModuleInfo(torch.nn.TransformerEncoder,
  1286. train_and_eval_differ=True,
  1287. module_inputs_func=module_inputs_torch_nn_TransformerEncoder,
  1288. skips=(
  1289. # No channels_last support for TransformerEncoderLayer currently.
  1290. DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),
  1291. # Doesn't support device / dtype kwargs directly because it is just a
  1292. # container of TransformerEncoderLayers.
  1293. DecorateInfo(unittest.expectedFailure, 'TestModule', 'test_factory_kwargs'),
  1294. DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),)
  1295. ),
  1296. ModuleInfo(torch.nn.TransformerEncoderLayer,
  1297. train_and_eval_differ=True,
  1298. module_inputs_func=module_inputs_torch_nn_TransformerEncoderLayer,
  1299. skips=(
  1300. # No channels_last support for TransformerEncoderLayer currently.
  1301. DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),
  1302. DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),)
  1303. ),
  1304. ModuleInfo(torch.nn.TransformerDecoderLayer,
  1305. module_inputs_func=module_inputs_torch_nn_TransformerDecoderLayer,
  1306. skips=(
  1307. # No channels_last support for TransformerDecoderLayer currently.
  1308. DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),
  1309. DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),)
  1310. ),
  1311. ModuleInfo(torch.nn.Transformer,
  1312. module_inputs_func=module_inputs_torch_nn_Transformer,
  1313. skips=(
  1314. # No channels_last support for Transformer currently.
  1315. DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),
  1316. DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),)
  1317. ),
  1318. ModuleInfo(torch.nn.MultiheadAttention,
  1319. train_and_eval_differ=True,
  1320. module_inputs_func=module_inputs_torch_nn_MultiheadAttention,
  1321. skips=(
  1322. # No channels_last support for MultiheadAttention currently.
  1323. DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),
  1324. DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),)
  1325. ),
  1326. ModuleInfo(torch.nn.Embedding,
  1327. module_inputs_func=module_inputs_torch_nn_Embedding,
  1328. skips=(
  1329. DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),
  1330. DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),)
  1331. ),
  1332. ModuleInfo(torch.nn.ReLU,
  1333. module_inputs_func=module_inputs_torch_nn_ReLU,
  1334. skips=(
  1335. DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),)
  1336. ),
  1337. ModuleInfo(torch.nn.RNNCell,
  1338. module_inputs_func=partial(module_inputs_torch_nn_RNN_GRU_Cell, is_rnn=True),
  1339. skips=(
  1340. DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),)
  1341. ),
  1342. ModuleInfo(torch.nn.GRUCell,
  1343. module_inputs_func=module_inputs_torch_nn_RNN_GRU_Cell,
  1344. skips=(
  1345. DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),)
  1346. ),
  1347. ModuleInfo(torch.nn.LSTMCell,
  1348. module_inputs_func=module_inputs_torch_nn_LSTMCell,
  1349. skips=(
  1350. DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),)
  1351. ),
  1352. ModuleInfo(torch.nn.Sigmoid,
  1353. module_inputs_func=module_inputs_torch_nn_Sigmoid,
  1354. skips=(
  1355. DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),)
  1356. ),
  1357. ModuleInfo(torch.nn.RNN,
  1358. train_and_eval_differ=True,
  1359. module_inputs_func=partial(module_inputs_torch_nn_RNN_GRU, is_rnn=True),
  1360. skips=(
  1361. DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),),
  1362. decorators=rnn_gru_lstm_module_info_decorators
  1363. ),
  1364. ModuleInfo(torch.nn.GRU,
  1365. train_and_eval_differ=True,
  1366. module_inputs_func=partial(module_inputs_torch_nn_RNN_GRU, is_rnn=False),
  1367. skips=(
  1368. DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),),
  1369. decorators=rnn_gru_lstm_module_info_decorators),
  1370. ModuleInfo(torch.nn.LSTM,
  1371. train_and_eval_differ=True,
  1372. module_inputs_func=module_inputs_torch_nn_LSTM,
  1373. skips=(
  1374. DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),),
  1375. decorators=rnn_gru_lstm_module_info_decorators)
  1376. ]