linalg.py 78 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208220922102211221222132214221522162217221822192220222122222223222422252226
  1. import itertools
  2. import unittest
  3. from functools import partial
  4. from itertools import product
  5. from typing import Iterable, List
  6. import numpy as np
  7. from numpy import inf
  8. import torch
  9. from torch.testing import make_tensor
  10. from torch.testing._internal.common_cuda import (
  11. _get_magma_version,
  12. _get_torch_cuda_version,
  13. with_tf32_off,
  14. )
  15. from torch.testing._internal.common_device_type import (
  16. has_cusolver,
  17. skipCPUIfNoLapack,
  18. skipCUDAIf,
  19. skipCUDAIfNoCusolver,
  20. skipCUDAIfNoMagma,
  21. skipCUDAIfNoMagmaAndNoCusolver,
  22. skipCUDAIfRocm,
  23. tol,
  24. toleranceOverride,
  25. )
  26. from torch.testing._internal.common_dtype import (
  27. all_types_and_complex,
  28. all_types_and_complex_and,
  29. floating_and_complex_types,
  30. floating_and_complex_types_and,
  31. )
  32. from torch.testing._internal.common_utils import (
  33. GRADCHECK_NONDET_TOL,
  34. IS_MACOS,
  35. make_fullrank_matrices_with_distinct_singular_values,
  36. skipIfSlowGradcheckEnv,
  37. slowTest,
  38. )
  39. from torch.testing._internal.opinfo.core import (
  40. clone_sample,
  41. DecorateInfo,
  42. ErrorInput,
  43. gradcheck_wrapper_hermitian_input,
  44. OpInfo,
  45. ReductionOpInfo,
  46. S,
  47. SampleInput,
  48. )
  49. from torch.testing._internal.opinfo.refs import PythonRefInfo, ReductionPythonRefInfo
  50. def sample_kwargs_vector_norm(t, **kwargs):
  51. # orders with / without identity
  52. def ords():
  53. has_id = (6, 4, 2, 1, 0, 0.9)
  54. no_id = (inf, -2.1, -inf)
  55. if t.numel() == 0:
  56. dim = kwargs.get("dim")
  57. if dim is None:
  58. return has_id
  59. if not isinstance(dim, Iterable):
  60. dim = (dim,)
  61. for d in dim:
  62. if t.size(d) == 0:
  63. return has_id
  64. return has_id + no_id
  65. return (((), dict(ord=o)) for o in ords())
  66. def sample_inputs_svd(op_info, device, dtype, requires_grad=False, **kwargs):
  67. make_fullrank = make_fullrank_matrices_with_distinct_singular_values
  68. make_arg = partial(
  69. make_fullrank, dtype=dtype, device=device, requires_grad=requires_grad
  70. )
  71. is_linalg_svd = "linalg.svd" in op_info.name
  72. batches = [(), (0,), (3,)]
  73. ns = [0, 3, 5]
  74. def uniformize(usv):
  75. S = usv[1]
  76. k = S.shape[-1]
  77. U = usv[0][..., :k]
  78. Vh = usv[2] if is_linalg_svd else usv[2].mH
  79. Vh = Vh[..., :k, :]
  80. return U, S, Vh
  81. def fn_U(usv):
  82. U, _, _ = uniformize(usv)
  83. return U.abs()
  84. def fn_S(usv):
  85. return uniformize(usv)[1]
  86. def fn_Vh(usv):
  87. # We also return S to test
  88. _, S, Vh = uniformize(usv)
  89. return S, Vh.abs()
  90. def fn_UVh(usv):
  91. U, S, Vh = uniformize(usv)
  92. return U @ Vh, S
  93. fns = (fn_U, fn_S, fn_Vh, fn_UVh)
  94. fullmat = "full_matrices" if is_linalg_svd else "some"
  95. for batch, n, k, fullmat_val, fn in product(batches, ns, ns, (True, False), fns):
  96. shape = batch + (n, k)
  97. yield SampleInput(
  98. make_arg(*shape), kwargs={fullmat: fullmat_val}, output_process_fn_grad=fn
  99. )
  100. def sample_inputs_cross(op_info, device, dtype, requires_grad, **kwargs):
  101. make_arg = partial(
  102. make_tensor, dtype=dtype, device=device, requires_grad=requires_grad
  103. )
  104. yield SampleInput(make_arg((S, 3)), args=(make_arg((S, 3)),))
  105. yield SampleInput(
  106. make_arg((S, 3, S)), args=(make_arg((S, 3, S)),), kwargs=dict(dim=1)
  107. )
  108. yield SampleInput(make_arg((1, 3)), args=(make_arg((S, 3)),), kwargs=dict(dim=-1))
  109. def error_inputs_cross(op_info, device, **kwargs):
  110. make_arg = partial(make_tensor, device=device, dtype=torch.float32)
  111. sample = SampleInput(input=make_arg((S, 3)), args=(make_arg((S, 1)),))
  112. err = "inputs dimension -1 must have length 3"
  113. yield ErrorInput(sample, error_regex=err, error_type=RuntimeError)
  114. sample = SampleInput(input=make_arg((5, S, 3)), args=(make_arg((S, 3)),))
  115. err = "inputs must have the same number of dimensions"
  116. yield ErrorInput(sample, error_regex=err, error_type=RuntimeError)
  117. sample = SampleInput(input=make_arg((S, 2)), args=(make_arg((S, 2)),))
  118. err = "must have length 3"
  119. yield ErrorInput(sample, error_regex=err, error_type=RuntimeError)
  120. sample = SampleInput(
  121. input=make_arg((S, 2)), args=(make_arg((S, 2)),), kwargs=dict(dim=2)
  122. )
  123. err = "Dimension out of range"
  124. yield ErrorInput(sample, error_regex=err, error_type=IndexError)
  125. def sample_inputs_householder_product(op_info, device, dtype, requires_grad, **kwargs):
  126. """
  127. This function generates input for torch.linalg.householder_product (torch.orgqr).
  128. The first argument should be a square matrix or batch of square matrices, the second argument is a vector or batch of vectors.
  129. Empty, square, rectangular, batched square and batched rectangular input is generated.
  130. """
  131. make_arg = partial(
  132. make_tensor,
  133. device=device,
  134. dtype=dtype,
  135. requires_grad=requires_grad,
  136. low=-2,
  137. high=2,
  138. )
  139. # Each column of the matrix is getting multiplied many times leading to very large values for
  140. # the Jacobian matrix entries and making the finite-difference result of grad check less accurate.
  141. # That's why gradcheck with the default range [-9, 9] fails and [-2, 2] is used here.
  142. yield SampleInput(make_arg((S, S)), make_arg((S,)))
  143. yield SampleInput(make_arg((S + 1, S)), make_arg((S,)))
  144. yield SampleInput(make_arg((2, 1, S, S)), make_arg((2, 1, S)))
  145. yield SampleInput(make_arg((2, 1, S + 1, S)), make_arg((2, 1, S)))
  146. yield SampleInput(
  147. make_arg((0, 0), low=None, high=None),
  148. make_arg((0,), low=None, high=None),
  149. )
  150. yield SampleInput(make_arg((S, S)), make_arg((0,), low=None, high=None))
  151. # m = n = S, k = S - 2
  152. yield SampleInput(make_arg((S, S)), make_arg((S - 2,), low=None, high=None))
  153. # m = S, n = S -1, k = S - 2
  154. yield SampleInput(make_arg((S, S - 1)), make_arg((S - 2,), low=None, high=None))
  155. def sample_inputs_linalg_det_singular(op_info, device, dtype, requires_grad, **kwargs):
  156. make_arg = partial(make_tensor, device=device, dtype=dtype)
  157. def make_singular_matrix_batch_base(size, rank):
  158. assert size[-1] == size[-2]
  159. assert rank > 0 and rank < size[-1]
  160. n = size[-1]
  161. a = make_arg(size[:-2] + (n, rank)) / 10
  162. b = make_arg(size[:-2] + (rank, n)) / 10
  163. x = a @ b
  164. lu, pivs, _ = torch.linalg.lu_factor_ex(x)
  165. p, l, u = torch.lu_unpack(lu, pivs)
  166. u_diag_abs = u.diagonal(0, -2, -1).abs()
  167. u_diag_abs_largest = u_diag_abs.max(dim=-1, keepdim=True).values
  168. u_diag_abs_smallest_idxs = torch.topk(
  169. u_diag_abs, k=(n - rank), largest=False
  170. ).indices
  171. u.diagonal(0, -2, -1).div_(u_diag_abs_largest)
  172. u.diagonal(0, -2, -1)[..., u_diag_abs_smallest_idxs] = torch.finfo(dtype).eps
  173. matrix = p @ l @ u
  174. matrix.requires_grad_(requires_grad)
  175. return matrix
  176. for batch, size in product(((), (2,), (2, 2)), range(6)):
  177. shape = batch + (size, size)
  178. for rank in range(1, size):
  179. yield SampleInput(make_singular_matrix_batch_base(shape, rank))
  180. def sample_inputs_linalg_matrix_power(op_info, device, dtype, requires_grad, **kwargs):
  181. make_fullrank = make_fullrank_matrices_with_distinct_singular_values
  182. make_arg = partial(
  183. make_tensor, dtype=dtype, device=device, requires_grad=requires_grad
  184. )
  185. make_arg_fullrank = partial(
  186. make_fullrank, dtype=dtype, device=device, requires_grad=requires_grad
  187. )
  188. # (<matrix_size>, (<batch_sizes, ...>))
  189. test_sizes = [
  190. (1, ()),
  191. (2, (0,)),
  192. (2, (2,)),
  193. ]
  194. for matrix_size, batch_sizes in test_sizes:
  195. size = batch_sizes + (matrix_size, matrix_size)
  196. for n in (0, 3, 5):
  197. yield SampleInput(make_arg(size), args=(n,))
  198. for n in [-4, -2, -1]:
  199. yield SampleInput(make_arg_fullrank(*size), args=(n,))
  200. def sample_inputs_linalg_det_logdet_slogdet(
  201. op_info, device, dtype, requires_grad, **kwargs
  202. ):
  203. make_fullrank = make_fullrank_matrices_with_distinct_singular_values
  204. make_arg = partial(
  205. make_fullrank, dtype=dtype, device=device, requires_grad=requires_grad
  206. )
  207. batches = [(), (0,), (3,)]
  208. ns = [0, 1, 5]
  209. is_logdet = op_info.name == "logdet"
  210. for (
  211. batch,
  212. n,
  213. ) in product(batches, ns):
  214. shape = batch + (n, n)
  215. A = make_arg(*shape)
  216. # Need to make the matrices in A have positive determinant for autograd
  217. # To do so, we multiply A by its determinant to flip the sign of its determinant
  218. if is_logdet and not A.is_complex() and A.numel() > 0:
  219. s = torch.linalg.slogdet(A).sign
  220. A = A * s.unsqueeze(-1).unsqueeze(-1)
  221. A.requires_grad_(requires_grad)
  222. yield SampleInput(A)
  223. def sample_inputs_lu_solve(op_info, device, dtype, requires_grad=False, **kwargs):
  224. """Samples the inputs for both linalg.lu_solve and lu_solve"""
  225. make_fn = make_fullrank_matrices_with_distinct_singular_values
  226. make_a = partial(make_fn, dtype=dtype, device=device)
  227. make_b = partial(make_tensor, dtype=dtype, device=device)
  228. def clone(X, requires_grad):
  229. Y = X.clone()
  230. Y.requires_grad_(requires_grad)
  231. return Y
  232. is_linalg_lu_solve = op_info.name == "linalg.lu_solve"
  233. batches = ((), (0,), (2,))
  234. ns = (3, 1, 0)
  235. nrhs = (4, 1, 0)
  236. for n, batch, rhs in product(ns, batches, nrhs):
  237. A = make_a(*(batch + (n, n)))
  238. LU, pivots = torch.linalg.lu_factor(A)
  239. B = make_b(batch + (n, rhs))
  240. grads = (False,) if not requires_grad else (True, False)
  241. # we try all possible combinations of requires_grad for each input
  242. for LU_grad, B_grad in product(grads, grads):
  243. # when requires_grad == True, at least one input has to have requires_grad enabled
  244. if requires_grad and not LU_grad and not B_grad:
  245. continue
  246. if is_linalg_lu_solve:
  247. for adjoint, left in product((True, False), repeat=2):
  248. yield SampleInput(
  249. clone(LU, LU_grad),
  250. args=(pivots, clone(B if left else B.mT, B_grad)),
  251. kwargs=dict(adjoint=adjoint, left=left),
  252. )
  253. else:
  254. yield SampleInput(clone(B, B_grad), args=(clone(LU, LU_grad), pivots))
  255. def sample_inputs_linalg_multi_dot(op_info, device, dtype, requires_grad, **kwargs):
  256. # Each test case consists of the sizes in the chain of multiplications
  257. # e.g. [2, 3, 4, 5] generates matrices (2, 3) @ (3, 4) @ (4, 5)
  258. test_cases = [
  259. [1, 2, 1],
  260. [2, 0, 2],
  261. [0, 2, 2],
  262. [2, 2, 2, 2],
  263. [2, 3, 4, 5],
  264. [5, 4, 0, 2],
  265. [2, 4, 3, 5, 3, 2],
  266. ]
  267. for sizes in test_cases:
  268. tensors = []
  269. for size in zip(sizes[:-1], sizes[1:]):
  270. t = make_tensor(
  271. size, dtype=dtype, device=device, requires_grad=requires_grad
  272. )
  273. tensors.append(t)
  274. yield SampleInput(tensors)
  275. def sample_inputs_linalg_matrix_norm(op_info, device, dtype, requires_grad, **kwargs):
  276. low_precision_dtypes = (torch.float16, torch.bfloat16, torch.complex32)
  277. make_arg = partial(
  278. make_tensor, device=device, dtype=dtype, requires_grad=requires_grad
  279. )
  280. sizes = ((2, 2), (2, 3, 2))
  281. if dtype in low_precision_dtypes:
  282. # svdvals not supported for low precision dtypes
  283. ords = ("fro", inf, -inf, 1, -1)
  284. else:
  285. ords = ("fro", "nuc", inf, -inf, 1, -1, 2, -2)
  286. dims = ((-2, -1), (-1, 0))
  287. for size, ord, dim, keepdim in product(sizes, ords, dims, [True, False]):
  288. yield SampleInput(make_arg(size), args=(ord, dim, keepdim))
  289. def sample_inputs_linalg_norm(
  290. op_info, device, dtype, requires_grad, *, variant=None, **kwargs
  291. ):
  292. if variant is not None and variant not in ("subgradient_at_zero",):
  293. raise ValueError(
  294. f"Unsupported variant, expected variant to be 'subgradient_at_zero' but got: {variant}"
  295. )
  296. test_sizes = [
  297. (S,),
  298. (0,),
  299. (S, S),
  300. (0, 0),
  301. (S, 0),
  302. (0, S),
  303. (S, S, S),
  304. (0, S, S),
  305. (S, 0, S),
  306. (0, 0, 0),
  307. ]
  308. vector_ords = (None, 0, 0.5, 1, 2, 3.5, inf, -0.5, -1, -2, -3.5, -inf)
  309. if dtype in {torch.float16, torch.bfloat16, torch.complex32}:
  310. # svdvals not supported for low precision dtypes
  311. matrix_ords = ("fro", inf, -inf, 1, -1)
  312. else:
  313. matrix_ords = (None, "fro", "nuc", inf, -inf, 1, -1, 2, -2)
  314. make_arg = partial(
  315. make_tensor,
  316. dtype=dtype,
  317. device=device,
  318. requires_grad=requires_grad,
  319. low=None,
  320. high=None,
  321. )
  322. for test_size in test_sizes:
  323. is_vector_norm = len(test_size) == 1
  324. is_matrix_norm = len(test_size) == 2
  325. # IndexError: amax(): Expected reduction dim 0 to have non-zero size.
  326. is_valid_for_p2 = is_vector_norm or (test_size[-1] != 0 and test_size[-2] != 0)
  327. for keepdim in [False, True]:
  328. if variant != "subgradient_at_zero" and is_valid_for_p2:
  329. yield SampleInput(make_arg(test_size), keepdim=keepdim)
  330. if not (is_vector_norm or is_matrix_norm):
  331. continue
  332. ords = vector_ords if is_vector_norm else matrix_ords
  333. for ord in ords:
  334. if is_vector_norm and test_size[-1] == 0:
  335. if ord == np.inf or (ord is not None and ord < 0):
  336. # RuntimeError: linalg.vector_norm cannot compute the
  337. # {ord} norm on an empty tensor because the operation
  338. # does not have an identity
  339. continue
  340. elif is_matrix_norm:
  341. dims_to_check = {
  342. None: (0,),
  343. np.inf: (0,),
  344. 2: (0, 1),
  345. 1: (1,),
  346. -1: (1,),
  347. -2: (0, 1),
  348. -np.inf: (0,),
  349. }.get(ord, ())
  350. if any(test_size[d] == 0 for d in dims_to_check):
  351. # IndexError: amax(): Expected reduction dim {dim} to
  352. # have non-zero size.
  353. continue
  354. if variant == "subgradient_at_zero":
  355. yield SampleInput(
  356. torch.zeros(
  357. test_size,
  358. dtype=dtype,
  359. device=device,
  360. requires_grad=requires_grad,
  361. ),
  362. ord,
  363. keepdim=keepdim,
  364. )
  365. else:
  366. yield SampleInput(make_arg(test_size), ord, keepdim=keepdim)
  367. if ord in ["nuc", "fro"]:
  368. yield SampleInput(
  369. make_arg(test_size), ord=ord, keepdim=keepdim, dim=(0, 1)
  370. )
  371. def sample_inputs_linalg_vecdot(op_info, device, dtype, requires_grad, **kwargs):
  372. make_arg = partial(
  373. make_tensor, device=device, dtype=dtype, requires_grad=requires_grad
  374. )
  375. batches = ((), (0,), (1,), (5,))
  376. ns = (0, 1, 3, 5)
  377. for b, n in product(batches, ns):
  378. shape = b + (n,)
  379. yield SampleInput(make_arg(shape), args=(make_arg(shape),))
  380. for i in range(len(shape)):
  381. yield SampleInput(
  382. make_arg(shape), args=(make_arg(shape),), kwargs=dict(dim=i)
  383. )
  384. def sample_inputs_linalg_invertible(
  385. op_info, device, dtype, requires_grad=False, **kwargs
  386. ):
  387. """
  388. This function generates invertible inputs for linear algebra ops
  389. The input is generated as the itertools.product of 'batches' and 'ns'.
  390. In total this function generates 8 SampleInputs
  391. 'batches' cases include:
  392. () - single input,
  393. (0,) - zero batched dimension,
  394. (2,) - batch of two matrices,
  395. (1, 1) - 1x1 batch of matrices
  396. 'ns' gives 0x0 and 5x5 matrices.
  397. Zeros in dimensions are edge cases in the implementation and important to test for in order to avoid unexpected crashes.
  398. """
  399. make_fn = make_fullrank_matrices_with_distinct_singular_values
  400. make_arg = partial(make_fn, dtype=dtype, device=device, requires_grad=requires_grad)
  401. batches = [(), (0,), (2,), (1, 1)]
  402. ns = [5, 0]
  403. for batch, n in product(batches, ns):
  404. yield SampleInput(make_arg(*batch, n, n))
  405. def sample_inputs_matrix_rank(op_info, device, dtype, requires_grad=False, **kwargs):
  406. """
  407. This function produces inputs for matrix rank that test
  408. all possible combinations for atol and rtol
  409. """
  410. def make_tol_arg(kwarg_type, inp):
  411. if kwarg_type == "none":
  412. return None
  413. if kwarg_type == "float":
  414. return 1.0
  415. assert kwarg_type == "tensor"
  416. return torch.ones(inp.shape[:-2], device=device)
  417. for tol_type in ["float", "tensor"]:
  418. for atol_type, rtol_type in product(["none", tol_type], repeat=2):
  419. if (
  420. not atol_type and not rtol_type
  421. ): # default behavior, so skipped here so it's not tested 2 extra times
  422. continue
  423. for sample in sample_inputs_linalg_invertible(
  424. op_info, device, dtype, requires_grad
  425. ):
  426. assert sample.kwargs == {}
  427. sample.kwargs = {
  428. "atol": make_tol_arg(atol_type, sample.input),
  429. "rtol": make_tol_arg(rtol_type, sample.input),
  430. }
  431. yield sample
  432. for sample in sample_inputs_linalg_invertible(
  433. op_info, device, dtype, requires_grad
  434. ):
  435. yield sample # default kwargs
  436. def sample_inputs_linalg_pinv_singular(
  437. op_info, device, dtype, requires_grad=False, **kwargs
  438. ):
  439. """
  440. This function produces factors `a` and `b` to generate inputs of the form `a @ b.t()` to
  441. test the backward method of `linalg_pinv`. That way we always preserve the rank of the
  442. input no matter the perturbations applied to it by the gradcheck.
  443. Note that `pinv` is Frechet-differentiable in a rank-preserving neighborhood.
  444. """
  445. batches = [(), (0,), (2,), (1, 1)]
  446. # the size of at least 30 is required to cause failures for the previous implicit implementation
  447. # of the pinv's backward method, albeit it is slow.
  448. size = [0, 3, 50]
  449. for batch, m, n in product(batches, size, size):
  450. for k in range(min(3, min(m, n))):
  451. # Note that by making the columns of `a` and `b` orthonormal we make sure that
  452. # the product matrix `a @ b.t()` has condition number 1 when restricted to its image
  453. a = (
  454. torch.rand(*batch, m, k, device=device, dtype=dtype)
  455. .qr()
  456. .Q.requires_grad_(requires_grad)
  457. )
  458. b = (
  459. torch.rand(*batch, n, k, device=device, dtype=dtype)
  460. .qr()
  461. .Q.requires_grad_(requires_grad)
  462. )
  463. yield SampleInput(a, args=(b,))
  464. def sample_inputs_linalg_cond(op_info, device, dtype, requires_grad=False, **kwargs):
  465. make_arg = partial(
  466. make_tensor, dtype=dtype, device=device, requires_grad=requires_grad
  467. )
  468. # autograd is not supported for inputs with zero number of elements
  469. shapes = (
  470. (S, S),
  471. (2, S, S),
  472. (2, 1, S, S),
  473. )
  474. for shape in shapes:
  475. yield SampleInput(make_arg(shape))
  476. def sample_inputs_linalg_vander(op_info, device, dtype, requires_grad=False, **kwargs):
  477. make_arg = partial(
  478. make_tensor, dtype=dtype, device=device, requires_grad=requires_grad
  479. )
  480. shapes = (
  481. (),
  482. (1,),
  483. (S,),
  484. (2, S),
  485. )
  486. for shape in shapes:
  487. if len(shape) > 0 and shape[-1] > 1:
  488. yield SampleInput(make_arg(shape))
  489. n = shape[-1] if len(shape) > 0 else 1
  490. for i in range(3):
  491. # n-1, n, n+1
  492. N = n + i - 1
  493. if N < 2:
  494. continue
  495. yield SampleInput(make_arg(shape), kwargs=dict(N=N))
  496. def np_vander_batched(x, N=None):
  497. # Wrapper around np.vander that supports batches of 1 dimension (enough for the tests)
  498. if x.ndim == 0:
  499. x = x[np.newaxis]
  500. if x.ndim == 1:
  501. y = np.vander(x, N=N, increasing=True)
  502. return y
  503. else:
  504. if N is None:
  505. N = x.shape[-1]
  506. y = np.vander(x.ravel(), N=N, increasing=True).reshape((*x.shape, N))
  507. return y
  508. def sample_inputs_linalg_cholesky_inverse(
  509. op_info, device, dtype, requires_grad=False, **kwargs
  510. ):
  511. from torch.testing._internal.common_utils import random_well_conditioned_matrix
  512. # Cholesky factorization is for positive-definite matrices
  513. single_well_conditioned_matrix = random_well_conditioned_matrix(
  514. S, S, dtype=dtype, device=device
  515. )
  516. batch_well_conditioned_matrices = random_well_conditioned_matrix(
  517. 2, S, S, dtype=dtype, device=device
  518. )
  519. single_pd = single_well_conditioned_matrix @ single_well_conditioned_matrix.mH
  520. batch_pd = batch_well_conditioned_matrices @ batch_well_conditioned_matrices.mH
  521. inputs = (
  522. torch.zeros(0, 0, dtype=dtype, device=device), # 0x0 matrix
  523. torch.zeros(0, 2, 2, dtype=dtype, device=device), # zero batch of matrices
  524. single_pd,
  525. batch_pd,
  526. )
  527. test_cases = (torch.linalg.cholesky(a, upper=False) for a in inputs)
  528. for l in test_cases:
  529. # generated lower-triangular samples
  530. l.requires_grad = requires_grad
  531. yield SampleInput(l) # upper=False by default
  532. yield SampleInput(
  533. l.detach().clone().requires_grad_(requires_grad), kwargs=dict(upper=False)
  534. )
  535. # generate upper-triangular inputs
  536. u = l.detach().clone().mT.contiguous().requires_grad_(requires_grad)
  537. yield SampleInput(u, kwargs=dict(upper=True))
  538. def sample_inputs_linalg_ldl_factor(
  539. op_info, device, dtype, requires_grad=False, **kwargs
  540. ):
  541. from torch.testing._internal.common_utils import (
  542. random_hermitian_pd_matrix,
  543. random_symmetric_pd_matrix,
  544. )
  545. device = torch.device(device)
  546. # Symmetric inputs
  547. yield SampleInput(
  548. random_symmetric_pd_matrix(S, dtype=dtype, device=device),
  549. kwargs=dict(hermitian=False),
  550. ) # single matrix
  551. yield SampleInput(
  552. random_symmetric_pd_matrix(S, 2, dtype=dtype, device=device),
  553. kwargs=dict(hermitian=False),
  554. ) # batch of matrices
  555. yield SampleInput(
  556. torch.zeros(0, 0, dtype=dtype, device=device), kwargs=dict(hermitian=False)
  557. ) # 0x0 matrix
  558. yield SampleInput(
  559. torch.zeros(0, 2, 2, dtype=dtype, device=device), kwargs=dict(hermitian=False)
  560. ) # zero batch of matrices
  561. # Hermitian inputs
  562. # hermitian=True for complex inputs on CUDA is supported only with MAGMA 2.5.4+
  563. magma_254_available = device.type == "cuda" and _get_magma_version() >= (2, 5, 4)
  564. if dtype.is_complex and (device.type == "cpu" or magma_254_available):
  565. yield SampleInput(
  566. random_hermitian_pd_matrix(S, dtype=dtype, device=device),
  567. kwargs=dict(hermitian=True),
  568. ) # single matrix
  569. yield SampleInput(
  570. random_hermitian_pd_matrix(S, 2, dtype=dtype, device=device),
  571. kwargs=dict(hermitian=True),
  572. ) # batch of matrices
  573. def sample_inputs_linalg_ldl_solve(
  574. op_info, device, dtype, requires_grad=False, **kwargs
  575. ):
  576. # Generate LDL factors of symmetric (and Hermitian on CPU) matrices
  577. from torch.testing._internal.common_utils import (
  578. random_hermitian_pd_matrix,
  579. random_symmetric_pd_matrix,
  580. )
  581. device = torch.device(device)
  582. symmetric_inputs = (
  583. random_symmetric_pd_matrix(S, dtype=dtype, device=device), # single matrix
  584. random_symmetric_pd_matrix(
  585. S, 2, dtype=dtype, device=device
  586. ), # batch of matrices
  587. torch.zeros(0, 0, dtype=dtype, device=device), # 0x0 matrix
  588. torch.zeros(0, 2, 2, dtype=dtype, device=device), # zero batch of matrices
  589. )
  590. hermitian_inputs = (
  591. (
  592. random_hermitian_pd_matrix(S, dtype=dtype, device=device),
  593. random_hermitian_pd_matrix(S, 2, dtype=dtype, device=device),
  594. )
  595. if device.type == "cpu" and dtype.is_complex
  596. else ()
  597. )
  598. test_cases1 = (
  599. torch.linalg.ldl_factor_ex(a, hermitian=False) for a in symmetric_inputs
  600. )
  601. test_cases2 = (
  602. torch.linalg.ldl_factor_ex(a, hermitian=True) for a in hermitian_inputs
  603. )
  604. # Symmetric case
  605. make_arg = partial(
  606. make_tensor, device=device, dtype=dtype, requires_grad=requires_grad
  607. )
  608. for test_case in test_cases1:
  609. factors, pivots, _ = test_case
  610. factors.requires_grad = requires_grad
  611. for B_batch_shape in ((), factors.shape[:-2]):
  612. B = make_arg((*B_batch_shape, factors.shape[-1], S))
  613. yield SampleInput(factors, args=(pivots, B), kwargs=dict(hermitian=False))
  614. clone_factors = factors.detach().clone().requires_grad_(requires_grad)
  615. yield SampleInput(
  616. clone_factors, args=(pivots, B), kwargs=dict(hermitian=False)
  617. )
  618. # Hermitian case
  619. for test_case in test_cases2:
  620. factors, pivots, _ = test_case
  621. factors.requires_grad = requires_grad
  622. for B_batch_shape in ((), factors.shape[:-2]):
  623. B = make_arg((*B_batch_shape, factors.shape[-1], S))
  624. yield SampleInput(factors, args=(pivots, B), kwargs=dict(hermitian=True))
  625. clone_factors = factors.detach().clone().requires_grad_(requires_grad)
  626. yield SampleInput(
  627. clone_factors, args=(pivots, B), kwargs=dict(hermitian=True)
  628. )
  629. def sample_inputs_linalg_lstsq(op_info, device, dtype, requires_grad=False, **kwargs):
  630. from torch.testing._internal.common_utils import random_well_conditioned_matrix
  631. device = torch.device(device)
  632. drivers: Tuple[str, ...]
  633. if device.type == "cuda":
  634. drivers = ("gels",)
  635. else:
  636. drivers = ("gels", "gelsy", "gelss", "gelsd")
  637. # we generate matrices of shape (..., n + delta, n)
  638. deltas: Tuple[int, ...]
  639. if device.type == "cpu" or has_cusolver():
  640. deltas = (-1, 0, +1)
  641. # only square systems if Cusolver is not available
  642. # becase we solve a lstsq problem with a transposed matrix in the backward
  643. else:
  644. deltas = (0,)
  645. for batch, driver, delta in product(((), (3,), (3, 3)), drivers, deltas):
  646. shape = batch + (3 + delta, 3)
  647. a = random_well_conditioned_matrix(*shape, dtype=dtype, device=device)
  648. a.requires_grad_(requires_grad)
  649. b = make_tensor(
  650. shape,
  651. dtype=dtype,
  652. device=device,
  653. low=None,
  654. high=None,
  655. requires_grad=requires_grad,
  656. )
  657. yield SampleInput(a, b, driver=driver)
  658. def error_inputs_lstsq(op_info, device, **kwargs):
  659. zero_d = torch.randn((), device=device)
  660. yield ErrorInput(
  661. SampleInput(zero_d, args=(zero_d,)),
  662. error_type=RuntimeError,
  663. error_regex="at least 2 dimensions",
  664. )
  665. def error_inputs_lstsq_grad_oriented(op_info, device, **kwargs):
  666. zero_d = torch.randn((), device=device)
  667. yield ErrorInput(
  668. SampleInput(zero_d, args=(zero_d, None)),
  669. error_type=RuntimeError,
  670. error_regex="at least 2 dimensions",
  671. )
  672. def sample_inputs_linalg_cholesky(
  673. op_info, device, dtype, requires_grad=False, **kwargs
  674. ):
  675. """
  676. This function generates always positive-definite input for torch.linalg.cholesky using
  677. random_hermitian_pd_matrix.
  678. The input is generated as the itertools.product of 'batches' and 'ns'.
  679. In total this function generates 8 SampleInputs
  680. 'batches' cases include:
  681. () - single input,
  682. (0,) - zero batched dimension,
  683. (2,) - batch of two matrices,
  684. (1, 1) - 1x1 batch of matrices
  685. 'ns' gives 0x0 and 5x5 matrices.
  686. Zeros in dimensions are edge cases in the implementation and important to test for in order to avoid unexpected crashes.
  687. """
  688. from torch.testing._internal.common_utils import random_hermitian_pd_matrix
  689. batches = [(), (0,), (2,), (1, 1)]
  690. ns = [5, 0]
  691. for batch, n, upper in product(batches, ns, [True, False]):
  692. a = random_hermitian_pd_matrix(n, *batch, dtype=dtype, device=device)
  693. a.requires_grad = requires_grad
  694. yield SampleInput(a, upper=upper)
  695. def sample_inputs_linalg_eig(op_info, device, dtype, requires_grad=False, **kwargs):
  696. """
  697. This function generates input for torch.linalg.eig
  698. """
  699. def out_fn(output):
  700. return output[0], abs(output[1])
  701. samples = sample_inputs_linalg_invertible(op_info, device, dtype, requires_grad)
  702. for sample in samples:
  703. sample.output_process_fn_grad = out_fn
  704. yield sample
  705. def sample_inputs_linalg_eigh(op_info, device, dtype, requires_grad=False, **kwargs):
  706. """
  707. This function generates input for torch.linalg.eigh/eigvalsh with UPLO="U" or "L" keyword argument.
  708. """
  709. def out_fn(output):
  710. if isinstance(output, tuple):
  711. # eigh function
  712. return output[0], abs(output[1])
  713. else:
  714. # eigvalsh function
  715. return output
  716. # Samples do not need to be Hermitian, as we're using gradcheck_wrapper_hermitian_input
  717. samples = sample_inputs_linalg_invertible(op_info, device, dtype, requires_grad)
  718. for sample in samples:
  719. sample.kwargs = {"UPLO": np.random.choice(["L", "U"])}
  720. sample.output_process_fn_grad = out_fn
  721. yield sample
  722. def sample_inputs_linalg_pinv(op_info, device, dtype, requires_grad=False, **kwargs):
  723. """
  724. This function generates input for torch.linalg.pinv with hermitian=False keyword argument.
  725. """
  726. for o in sample_inputs_linalg_invertible(
  727. op_info, device, dtype, requires_grad, **kwargs
  728. ):
  729. real_dtype = o.input.real.dtype if dtype.is_complex else dtype
  730. # requires_grad path for rtol tensor is not implemented
  731. for rtol in (None, 1.0, torch.tensor(1.0, dtype=real_dtype, device=device)):
  732. o = clone_sample(o)
  733. o.kwargs = {"rtol": rtol}
  734. yield o
  735. def sample_inputs_linalg_pinv_hermitian(
  736. op_info, device, dtype, requires_grad=False, **kwargs
  737. ):
  738. """
  739. This function generates input for torch.linalg.pinv with hermitian=True keyword argument.
  740. """
  741. for o in sample_inputs_linalg_invertible(
  742. op_info, device, dtype, requires_grad, **kwargs
  743. ):
  744. o.kwargs = {"hermitian": True}
  745. yield o
  746. def sample_inputs_linalg_solve(
  747. op_info, device, dtype, requires_grad=False, vector_rhs_allowed=True, **kwargs
  748. ):
  749. """
  750. This function generates always solvable input for torch.linalg.solve
  751. We sample a fullrank square matrix (i.e. invertible) A
  752. The first input to torch.linalg.solve is generated as the itertools.product of 'batches' and 'ns'.
  753. The second input is generated as the product of 'batches', 'ns' and 'nrhs'.
  754. In total this function generates 18 SampleInputs
  755. 'batches' cases include:
  756. () - single input,
  757. (0,) - zero batched dimension,
  758. (2,) - batch of two matrices.
  759. 'ns' gives 0x0 and 5x5 matrices.
  760. and 'nrhs' controls the number of vectors to solve for:
  761. () - using 1 as the number of vectors implicitly
  762. (1,) - same as () but explicit
  763. (3,) - solve for 3 vectors.
  764. Zeros in dimensions are edge cases in the implementation and important to test for in order to avoid unexpected crashes.
  765. 'vector_rhs_allowed' controls whether to include nrhs = () to the list of SampleInputs.
  766. torch.solve / triangular_solve / cholesky_solve (opposed to torch.linalg.solve) do not allow
  767. 1D tensors (vectors) as the right-hand-side.
  768. Once torch.solve / triangular_solve / cholesky_solve and its testing are removed,
  769. 'vector_rhs_allowed' may be removed here as well.
  770. """
  771. make_fullrank = make_fullrank_matrices_with_distinct_singular_values
  772. make_a = partial(
  773. make_fullrank, dtype=dtype, device=device, requires_grad=requires_grad
  774. )
  775. make_b = partial(
  776. make_tensor, dtype=dtype, device=device, requires_grad=requires_grad
  777. )
  778. batches = [(), (0,), (2,)]
  779. ns = [5, 0]
  780. if vector_rhs_allowed:
  781. nrhs = [(), (1,), (3,)]
  782. else:
  783. nrhs = [(1,), (3,)]
  784. for n, batch, rhs in product(ns, batches, nrhs):
  785. yield SampleInput(make_a(*batch, n, n), args=(make_b((batch + (n,) + rhs)),))
  786. def sample_inputs_linalg_solve_triangular(
  787. op_info, device, dtype, requires_grad=False, **kwargs
  788. ):
  789. make_arg = partial(make_tensor, dtype=dtype, device=device)
  790. bs = (1, 2, 0)
  791. ns = (3, 0)
  792. ks = (1, 3, 0)
  793. for b, n, k, (left, upper, uni) in product(
  794. bs, ns, ks, product((True, False), repeat=3)
  795. ):
  796. if b == 1:
  797. A = make_arg((n, n)) if left else make_arg((k, k))
  798. B = make_arg((n, k))
  799. else:
  800. A = make_arg((b, n, n)) if left else make_arg((b, k, k))
  801. B = make_arg((b, n, k))
  802. if uni:
  803. # Not really necessary, but writing it for consistency
  804. A.diagonal(0, -2, -1).fill_(1.0)
  805. else:
  806. d = A.diagonal(0, -2, -1)
  807. d[d.abs() < 1e-6] = 1.0
  808. if upper:
  809. A.triu_()
  810. else:
  811. A.tril_()
  812. kwargs = {"upper": upper, "left": left, "unitriangular": uni}
  813. if requires_grad:
  814. for grad_A, grad_B in product((True, False), repeat=2):
  815. # Either A or B needs to have a gradient
  816. if not grad_A and not grad_B:
  817. continue
  818. yield SampleInput(
  819. A.clone().requires_grad_(grad_A),
  820. args=(B.clone().requires_grad_(grad_B),),
  821. kwargs=kwargs,
  822. )
  823. else:
  824. yield SampleInput(A, args=(B,), kwargs=kwargs)
  825. def sample_inputs_legacy_solve(op_info, device, dtype, requires_grad=False, **kwargs):
  826. """
  827. This function generates always solvable input for legacy solve functions
  828. (the ones that are not in torch.linalg module).
  829. The difference from sample_inputs_linalg_solve is that here the right-hand-side of A x = b equation
  830. should have b.ndim >= 2, vectors are not allowed.
  831. Also the arguments order is swapped.
  832. """
  833. out = sample_inputs_linalg_solve(
  834. op_info, device, dtype, requires_grad=requires_grad, vector_rhs_allowed=False
  835. )
  836. def out_fn(output):
  837. return output[0]
  838. # Reverses tensor order
  839. for sample in out:
  840. sample.input, sample.args = sample.args[0], (sample.input,)
  841. if op_info.name == "solve":
  842. sample.output_process_fn_grad = out_fn
  843. yield sample
  844. def sample_inputs_linalg_lu(op_info, device, dtype, requires_grad=False, **kwargs):
  845. full_rank = op_info.name == "linalg.lu_factor"
  846. make_fn = (
  847. make_tensor
  848. if not full_rank
  849. else make_fullrank_matrices_with_distinct_singular_values
  850. )
  851. make_arg = partial(make_fn, dtype=dtype, device=device, requires_grad=requires_grad)
  852. def out_fn(output):
  853. if op_info.name == "linalg.lu":
  854. return output[1], output[2]
  855. else:
  856. return output
  857. batch_shapes = ((), (3,), (3, 3))
  858. # pivot=False only supported in CUDA
  859. pivots = (True, False) if torch.device(device).type == "cuda" else (True,)
  860. deltas = (-2, -1, 0, +1, +2)
  861. for batch_shape, pivot, delta in product(batch_shapes, pivots, deltas):
  862. shape = batch_shape + (S + delta, S)
  863. # Insanely annoying that make_fullrank_blablabla accepts a *shape and not a tuple!
  864. A = make_arg(shape) if not full_rank else make_arg(*shape)
  865. yield SampleInput(A, kwargs={"pivot": pivot}, output_process_fn_grad=out_fn)
  866. def sample_inputs_linalg_svdvals(op_info, device, dtype, requires_grad=False, **kwargs):
  867. make_arg = partial(
  868. make_tensor, dtype=dtype, device=device, requires_grad=requires_grad
  869. )
  870. batches = [(), (0,), (2,), (1, 1)]
  871. ns = [5, 2, 0]
  872. for batch, m, n in product(batches, ns, ns):
  873. yield SampleInput(make_arg(batch + (m, n)))
  874. def sample_inputs_linalg_qr_geqrf(
  875. op_info, device, dtype, requires_grad=False, **kwargs
  876. ):
  877. # QR is just well defined when the matrix is full rank
  878. make_fullrank = make_fullrank_matrices_with_distinct_singular_values
  879. make_arg = partial(
  880. make_fullrank, dtype=dtype, device=device, requires_grad=requires_grad
  881. )
  882. batches = [(), (0,), (2,), (1, 1)]
  883. ns = [5, 2, 0]
  884. for batch, (m, n) in product(batches, product(ns, ns)):
  885. shape = batch + (m, n)
  886. yield SampleInput(make_arg(*shape))
  887. def sample_inputs_tensorsolve(op_info, device, dtype, requires_grad, **kwargs):
  888. a_shapes = [(2, 3, 6), (3, 4, 4, 3)]
  889. # Zero-dim tensors are not supported in NumPy, so we skip them for now.
  890. # NumPy is used in reference check tests.
  891. # See https://github.com/numpy/numpy/pull/20482 for tracking NumPy bugfix.
  892. # a_shapes += [(0, 0, 1, 2, 3, 0)]
  893. dimss = [None, (0, 2)]
  894. make_arg = partial(
  895. make_tensor, dtype=dtype, device=device, requires_grad=requires_grad
  896. )
  897. for a_shape, dims in itertools.product(a_shapes, dimss):
  898. a = make_arg(a_shape)
  899. b = make_arg(a_shape[:2])
  900. yield SampleInput(a, b, dims=dims)
  901. def sample_inputs_tensorinv(op_info, device, dtype, requires_grad, **kwargs):
  902. make_arg = make_fullrank_matrices_with_distinct_singular_values
  903. def make_input():
  904. return make_arg(12, 12, device=device, dtype=dtype, requires_grad=requires_grad)
  905. # lhs / rhs shape can have any number of dimensions as long as their product equals 12
  906. shapes = [
  907. ((2, 2, 3), (12, 1)),
  908. ((4, 3), (6, 1, 2)),
  909. ]
  910. for shape_lhs, shape_rhs in shapes:
  911. inp = make_input().reshape(*shape_lhs, *shape_rhs).detach()
  912. inp.requires_grad_(requires_grad)
  913. yield SampleInput(inp, ind=len(shape_lhs))
  914. op_db: List[OpInfo] = [
  915. OpInfo(
  916. "linalg.cross",
  917. ref=lambda x, y, dim=-1: np.cross(x, y, axis=dim),
  918. op=torch.linalg.cross,
  919. dtypes=all_types_and_complex_and(torch.bfloat16),
  920. dtypesIfCUDA=all_types_and_complex_and(torch.half),
  921. aten_name="linalg_cross",
  922. sample_inputs_func=sample_inputs_cross,
  923. error_inputs_func=error_inputs_cross,
  924. supports_out=True,
  925. supports_fwgrad_bwgrad=True,
  926. supports_forward_ad=True,
  927. skips=(
  928. DecorateInfo(
  929. unittest.skip("Unsupported on MPS for now"),
  930. "TestCommon",
  931. "test_numpy_ref_mps",
  932. ),
  933. ),
  934. ),
  935. OpInfo(
  936. "linalg.det",
  937. aten_name="linalg_det",
  938. op=torch.linalg.det,
  939. aliases=("det",),
  940. dtypes=floating_and_complex_types(),
  941. supports_forward_ad=True,
  942. supports_fwgrad_bwgrad=True,
  943. sample_inputs_func=sample_inputs_linalg_det_logdet_slogdet,
  944. decorators=[skipCPUIfNoLapack, skipCUDAIfNoMagmaAndNoCusolver],
  945. check_batched_gradgrad=False,
  946. ),
  947. OpInfo(
  948. "linalg.det",
  949. aten_name="linalg_det",
  950. op=torch.linalg.det,
  951. variant_test_name="singular",
  952. aliases=("det",),
  953. dtypes=floating_and_complex_types(),
  954. supports_forward_ad=True,
  955. supports_fwgrad_bwgrad=True,
  956. check_batched_gradgrad=False,
  957. sample_inputs_func=sample_inputs_linalg_det_singular,
  958. decorators=[skipCPUIfNoLapack, skipCUDAIfNoMagmaAndNoCusolver],
  959. skips=(
  960. DecorateInfo(
  961. unittest.skip("The backward may give different results"),
  962. "TestCommon",
  963. "test_noncontiguous_samples",
  964. ),
  965. DecorateInfo(
  966. unittest.skip("Gradients are incorrect on macos"),
  967. "TestBwdGradients",
  968. "test_fn_grad",
  969. device_type="cpu",
  970. dtypes=(torch.float64,),
  971. active_if=IS_MACOS,
  972. ),
  973. DecorateInfo(
  974. unittest.skip("Gradients are incorrect on macos"),
  975. "TestFwdGradients",
  976. "test_forward_mode_AD",
  977. device_type="cpu",
  978. dtypes=(torch.float64,),
  979. active_if=IS_MACOS,
  980. ),
  981. # Both Hessians are incorrect on complex inputs??
  982. DecorateInfo(
  983. unittest.expectedFailure,
  984. "TestBwdGradients",
  985. "test_fn_gradgrad",
  986. dtypes=(torch.complex128,),
  987. ),
  988. DecorateInfo(
  989. unittest.expectedFailure,
  990. "TestFwdGradients",
  991. "test_fn_fwgrad_bwgrad",
  992. dtypes=(torch.complex128,),
  993. ),
  994. DecorateInfo(
  995. unittest.skip("Skipped, see https://github.com//issues/84192"),
  996. "TestBwdGradients",
  997. "test_fn_gradgrad",
  998. device_type="cuda",
  999. ),
  1000. DecorateInfo(
  1001. unittest.skip("Skipped, see https://github.com//issues/84192"),
  1002. "TestFwdGradients",
  1003. "test_fn_fwgrad_bwgrad",
  1004. device_type="cuda",
  1005. ),
  1006. ),
  1007. ),
  1008. OpInfo(
  1009. "linalg.cholesky",
  1010. aten_name="linalg_cholesky",
  1011. dtypes=floating_and_complex_types(),
  1012. supports_forward_ad=True,
  1013. supports_fwgrad_bwgrad=True,
  1014. # See https://github.com/pytorch/pytorch/pull/78358
  1015. check_batched_forward_grad=False,
  1016. sample_inputs_func=sample_inputs_linalg_cholesky,
  1017. gradcheck_wrapper=gradcheck_wrapper_hermitian_input,
  1018. decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack],
  1019. ),
  1020. OpInfo(
  1021. "linalg.cholesky_ex",
  1022. aten_name="linalg_cholesky_ex",
  1023. dtypes=floating_and_complex_types(),
  1024. supports_forward_ad=True,
  1025. supports_fwgrad_bwgrad=True,
  1026. # See https://github.com/pytorch/pytorch/pull/78358
  1027. check_batched_forward_grad=False,
  1028. sample_inputs_func=sample_inputs_linalg_cholesky,
  1029. gradcheck_wrapper=gradcheck_wrapper_hermitian_input,
  1030. decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack],
  1031. ),
  1032. OpInfo(
  1033. "linalg.vecdot",
  1034. aten_name="linalg_vecdot",
  1035. ref=lambda x, y, *, dim=-1: (x.conj() * y).sum(dim),
  1036. dtypes=floating_and_complex_types_and(torch.bfloat16),
  1037. dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16),
  1038. sample_inputs_func=sample_inputs_linalg_vecdot,
  1039. check_batched_forward_grad=False,
  1040. supports_forward_ad=True,
  1041. supports_fwgrad_bwgrad=True,
  1042. skips=(
  1043. # Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479
  1044. DecorateInfo(
  1045. unittest.skip("Skipped!"),
  1046. "TestSchemaCheckModeOpInfo",
  1047. "test_schema_correctness",
  1048. dtypes=(torch.complex64, torch.complex128),
  1049. ),
  1050. DecorateInfo(
  1051. unittest.skip("Unsupported on MPS for now"),
  1052. "TestCommon",
  1053. "test_numpy_ref_mps",
  1054. ),
  1055. ),
  1056. ),
  1057. OpInfo(
  1058. "linalg.cond",
  1059. aten_name="linalg_cond",
  1060. dtypes=floating_and_complex_types(),
  1061. sample_inputs_func=sample_inputs_linalg_cond,
  1062. check_batched_gradgrad=False,
  1063. check_batched_forward_grad=False,
  1064. supports_forward_ad=True,
  1065. supports_fwgrad_bwgrad=True,
  1066. gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
  1067. decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack, with_tf32_off],
  1068. ),
  1069. OpInfo(
  1070. "linalg.eig",
  1071. aten_name="linalg_eig",
  1072. op=torch.linalg.eig,
  1073. dtypes=floating_and_complex_types(),
  1074. sample_inputs_func=sample_inputs_linalg_eig,
  1075. check_batched_forward_grad=False,
  1076. check_batched_grad=False,
  1077. check_batched_gradgrad=False,
  1078. supports_forward_ad=True,
  1079. supports_fwgrad_bwgrad=True,
  1080. skips=(
  1081. # AssertionError: Scalars are not equal!
  1082. DecorateInfo(
  1083. unittest.expectedFailure, "TestCommon", "test_out", device_type="cpu"
  1084. ),
  1085. DecorateInfo(
  1086. unittest.skip("Skipped!"),
  1087. "TestCommon",
  1088. "test_out",
  1089. device_type="mps",
  1090. dtypes=[torch.float32],
  1091. ),
  1092. DecorateInfo(
  1093. unittest.skip("Skipped!"),
  1094. "TestCommon",
  1095. "test_variant_consistency_eager",
  1096. device_type="mps",
  1097. dtypes=[torch.float32],
  1098. ),
  1099. DecorateInfo(
  1100. unittest.skip("Skipped!"),
  1101. "TestJit",
  1102. "test_variant_consistency_jit",
  1103. device_type="mps",
  1104. dtypes=[torch.float32],
  1105. ),
  1106. ),
  1107. decorators=[skipCUDAIfNoMagma, skipCPUIfNoLapack, with_tf32_off],
  1108. ),
  1109. OpInfo(
  1110. "linalg.eigvals",
  1111. aten_name="linalg_eigvals",
  1112. op=torch.linalg.eigvals,
  1113. dtypes=floating_and_complex_types(),
  1114. sample_inputs_func=sample_inputs_linalg_invertible,
  1115. check_batched_forward_grad=False,
  1116. check_batched_grad=False,
  1117. check_batched_gradgrad=False,
  1118. supports_forward_ad=True,
  1119. supports_fwgrad_bwgrad=True,
  1120. decorators=[skipCUDAIfNoMagma, skipCPUIfNoLapack],
  1121. skips=(
  1122. # exits early on eager extremal value test
  1123. DecorateInfo(
  1124. unittest.skip("Skipped!"),
  1125. "TestCudaFuserOpInfo",
  1126. "test_nvfuser_extremal_values",
  1127. ),
  1128. DecorateInfo(
  1129. unittest.skip("Skipped!"),
  1130. "TestCommon",
  1131. "test_out",
  1132. device_type="mps",
  1133. dtypes=[torch.float32],
  1134. ),
  1135. DecorateInfo(
  1136. unittest.skip("Skipped!"),
  1137. "TestCommon",
  1138. "test_variant_consistency_eager",
  1139. device_type="mps",
  1140. dtypes=[torch.float32],
  1141. ),
  1142. DecorateInfo(
  1143. unittest.skip("Skipped!"),
  1144. "TestJit",
  1145. "test_variant_consistency_jit",
  1146. device_type="mps",
  1147. dtypes=[torch.float32],
  1148. ),
  1149. ),
  1150. ),
  1151. OpInfo(
  1152. "linalg.eigh",
  1153. aten_name="linalg_eigh",
  1154. dtypes=floating_and_complex_types(),
  1155. sample_inputs_func=sample_inputs_linalg_eigh,
  1156. gradcheck_wrapper=gradcheck_wrapper_hermitian_input,
  1157. check_batched_forward_grad=False,
  1158. check_batched_grad=False,
  1159. check_batched_gradgrad=False,
  1160. supports_forward_ad=True,
  1161. supports_fwgrad_bwgrad=True,
  1162. decorators=[skipCUDAIfNoMagma, skipCPUIfNoLapack, with_tf32_off],
  1163. skips=(
  1164. DecorateInfo(
  1165. unittest.skip("Skipped!"),
  1166. "TestCommon",
  1167. "test_out",
  1168. device_type="mps",
  1169. dtypes=[torch.float32],
  1170. ),
  1171. DecorateInfo(
  1172. unittest.skip("Skipped!"),
  1173. "TestCommon",
  1174. "test_variant_consistency_eager",
  1175. device_type="mps",
  1176. dtypes=[torch.float32],
  1177. ),
  1178. DecorateInfo(
  1179. unittest.skip("Skipped!"),
  1180. "TestJit",
  1181. "test_variant_consistency_jit",
  1182. device_type="mps",
  1183. dtypes=[torch.float32],
  1184. ),
  1185. ),
  1186. ),
  1187. OpInfo(
  1188. "linalg.eigvalsh",
  1189. aten_name="linalg_eigvalsh",
  1190. dtypes=floating_and_complex_types(),
  1191. sample_inputs_func=sample_inputs_linalg_eigh,
  1192. gradcheck_wrapper=gradcheck_wrapper_hermitian_input,
  1193. check_batched_forward_grad=False,
  1194. check_batched_grad=False,
  1195. check_batched_gradgrad=False,
  1196. supports_forward_ad=True,
  1197. supports_fwgrad_bwgrad=True,
  1198. decorators=[skipCUDAIfNoMagma, skipCPUIfNoLapack],
  1199. skips=(
  1200. # Pre-existing condition; Needs to be fixed
  1201. DecorateInfo(
  1202. unittest.skip("Skipped!"),
  1203. "TestCommon",
  1204. "test_out",
  1205. device_type="mps",
  1206. dtypes=[torch.float32],
  1207. ),
  1208. DecorateInfo(
  1209. unittest.skip("Skipped!"),
  1210. "TestCommon",
  1211. "test_variant_consistency_eager",
  1212. device_type="mps",
  1213. dtypes=[torch.float32],
  1214. ),
  1215. DecorateInfo(
  1216. unittest.skip("Skipped!"),
  1217. "TestJit",
  1218. "test_variant_consistency_jit",
  1219. device_type="mps",
  1220. dtypes=[torch.float32],
  1221. ),
  1222. ),
  1223. ),
  1224. OpInfo(
  1225. "linalg.householder_product",
  1226. aten_name="linalg_householder_product",
  1227. op=torch.linalg.householder_product,
  1228. aliases=("orgqr",),
  1229. dtypes=floating_and_complex_types(),
  1230. # https://github.com/pytorch/pytorch/issues/80411
  1231. gradcheck_fast_mode=True,
  1232. # TODO: backward uses in-place operations that vmap doesn't like
  1233. check_batched_grad=False,
  1234. check_batched_gradgrad=False,
  1235. supports_forward_ad=True,
  1236. supports_fwgrad_bwgrad=True,
  1237. check_batched_forward_grad=False,
  1238. sample_inputs_func=sample_inputs_householder_product,
  1239. decorators=[
  1240. skipCUDAIfNoCusolver,
  1241. skipCPUIfNoLapack,
  1242. DecorateInfo(
  1243. toleranceOverride({torch.complex64: tol(atol=1e-3, rtol=1e-3)})
  1244. ),
  1245. DecorateInfo(
  1246. unittest.skip("Skipped! Flaky"),
  1247. "TestFwdGradients",
  1248. "test_fn_fwgrad_bwgrad",
  1249. device_type="cpu",
  1250. dtypes=(torch.complex128,),
  1251. ),
  1252. ],
  1253. ),
  1254. OpInfo(
  1255. "linalg.ldl_factor",
  1256. aten_name="linalg_ldl_factor",
  1257. dtypes=floating_and_complex_types(),
  1258. supports_autograd=False,
  1259. sample_inputs_func=sample_inputs_linalg_ldl_factor,
  1260. decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack, skipCUDAIfRocm],
  1261. ),
  1262. OpInfo(
  1263. "linalg.ldl_factor_ex",
  1264. aten_name="linalg_ldl_factor_ex",
  1265. dtypes=floating_and_complex_types(),
  1266. supports_autograd=False,
  1267. sample_inputs_func=sample_inputs_linalg_ldl_factor,
  1268. decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack, skipCUDAIfRocm],
  1269. ),
  1270. OpInfo(
  1271. "linalg.ldl_solve",
  1272. aten_name="linalg_ldl_solve",
  1273. dtypes=floating_and_complex_types(),
  1274. supports_autograd=False,
  1275. sample_inputs_func=sample_inputs_linalg_ldl_solve,
  1276. decorators=[
  1277. skipCUDAIf(
  1278. _get_torch_cuda_version() < (11, 4), "not available before CUDA 11.3.1"
  1279. ),
  1280. skipCUDAIfNoCusolver,
  1281. skipCUDAIfRocm,
  1282. skipCPUIfNoLapack,
  1283. ],
  1284. ),
  1285. OpInfo(
  1286. "linalg.lstsq",
  1287. aten_name="linalg_lstsq",
  1288. dtypes=floating_and_complex_types(),
  1289. supports_out=True,
  1290. sample_inputs_func=sample_inputs_linalg_lstsq,
  1291. error_inputs_func=error_inputs_lstsq,
  1292. decorators=[skipCUDAIfNoMagma, skipCPUIfNoLapack],
  1293. skips=(
  1294. # we skip gradient checks for this suite as they are tested in
  1295. # variant_test_name='grad_oriented'
  1296. DecorateInfo(unittest.skip("Skipped!"), "TestFwdGradients"),
  1297. DecorateInfo(unittest.skip("Skipped!"), "TestBwdGradients"),
  1298. # The values for attribute 'shape' do not match
  1299. DecorateInfo(unittest.skip("Skipped!"), "TestCommon", "test_out"),
  1300. DecorateInfo(
  1301. unittest.skip("Skipped!"),
  1302. "TestCommon",
  1303. "test_out",
  1304. device_type="mps",
  1305. dtypes=[torch.float32],
  1306. ),
  1307. DecorateInfo(
  1308. unittest.skip("Skipped!"),
  1309. "TestCommon",
  1310. "test_variant_consistency_eager",
  1311. device_type="mps",
  1312. dtypes=[torch.float32],
  1313. ),
  1314. DecorateInfo(
  1315. unittest.skip("Skipped!"),
  1316. "TestJit",
  1317. "test_variant_consistency_jit",
  1318. device_type="mps",
  1319. dtypes=[torch.float32],
  1320. ),
  1321. ),
  1322. ),
  1323. OpInfo(
  1324. "linalg.lstsq",
  1325. aten_name="linalg_lstsq",
  1326. variant_test_name="grad_oriented",
  1327. # gradchecks for forward AD fails with multi-Tensor outputs
  1328. op=lambda a, b, driver: torch.linalg.lstsq(a, b, driver=driver)[0],
  1329. supports_out=False,
  1330. dtypes=floating_and_complex_types(),
  1331. sample_inputs_func=sample_inputs_linalg_lstsq,
  1332. error_inputs_func=error_inputs_lstsq_grad_oriented,
  1333. # Runs very slowly on slow gradcheck - alternatively reduce input sizes
  1334. gradcheck_fast_mode=True,
  1335. supports_autograd=True,
  1336. supports_forward_ad=True,
  1337. supports_fwgrad_bwgrad=True,
  1338. decorators=[skipCUDAIfNoMagma, skipCPUIfNoLapack],
  1339. skips=(
  1340. # tests do not work with passing lambda for op
  1341. DecorateInfo(
  1342. unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"
  1343. ),
  1344. DecorateInfo(
  1345. unittest.expectedFailure,
  1346. "TestOperatorSignatures",
  1347. "test_get_torch_func_signature_exhaustive",
  1348. ),
  1349. ),
  1350. ),
  1351. OpInfo(
  1352. "linalg.matrix_power",
  1353. aliases=("matrix_power",),
  1354. aten_name="linalg_matrix_power",
  1355. dtypes=floating_and_complex_types(),
  1356. # https://github.com/pytorch/pytorch/issues/80411
  1357. gradcheck_fast_mode=True,
  1358. supports_inplace_autograd=False,
  1359. supports_forward_ad=True,
  1360. supports_fwgrad_bwgrad=True,
  1361. check_batched_grad=False,
  1362. decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack, with_tf32_off],
  1363. sample_inputs_func=sample_inputs_linalg_matrix_power,
  1364. ),
  1365. OpInfo(
  1366. "linalg.multi_dot",
  1367. # Need this lambda because gradcheck does not work with TensorList inputs
  1368. aten_name="linalg_multi_dot",
  1369. dtypes=all_types_and_complex_and(torch.bfloat16),
  1370. dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16),
  1371. supports_inplace_autograd=False,
  1372. # Batched grad checks fail for empty input tensors (see https://github.com/pytorch/pytorch/issues/53407)
  1373. check_batched_grad=False,
  1374. check_batched_gradgrad=False,
  1375. supports_forward_ad=True,
  1376. supports_fwgrad_bwgrad=True,
  1377. # https://github.com/pytorch/pytorch/issues/66357
  1378. check_batched_forward_grad=False,
  1379. sample_inputs_func=sample_inputs_linalg_multi_dot,
  1380. gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
  1381. skips=(
  1382. # https://github.com/pytorch/pytorch/issues/67470
  1383. DecorateInfo(
  1384. unittest.skip("67470!"), "TestCommon", "test_noncontiguous_samples"
  1385. ),
  1386. # Fails on XLA.
  1387. # AssertionError: False is not true : Tensors failed to compare as equal!
  1388. DecorateInfo(
  1389. unittest.skip("Skipped!"),
  1390. "TestOpInfo",
  1391. device_type="xla",
  1392. dtypes=(torch.long,),
  1393. ),
  1394. # https://github.com/pytorch/pytorch/issues/71774
  1395. DecorateInfo(
  1396. unittest.skip("Skipped!"),
  1397. "TestNNCOpInfo",
  1398. "test_nnc_correctness",
  1399. device_type="cpu",
  1400. dtypes=(torch.long,),
  1401. ),
  1402. ),
  1403. ),
  1404. # NB: linalg.norm has two variants so that different skips can be used for different sample inputs
  1405. OpInfo(
  1406. "linalg.norm",
  1407. aten_name="linalg_norm",
  1408. op=torch.linalg.norm,
  1409. dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16),
  1410. decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack, with_tf32_off],
  1411. sample_inputs_func=sample_inputs_linalg_norm,
  1412. supports_forward_ad=True,
  1413. check_batched_forward_grad=False,
  1414. supports_fwgrad_bwgrad=True,
  1415. skips=(
  1416. DecorateInfo(
  1417. unittest.expectedFailure, "TestBwdGradients", "test_fn_gradgrad"
  1418. ),
  1419. ),
  1420. ),
  1421. OpInfo(
  1422. "linalg.norm",
  1423. op=torch.linalg.norm,
  1424. variant_test_name="subgradients_at_zero",
  1425. dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16),
  1426. decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack, with_tf32_off],
  1427. sample_inputs_func=partial(
  1428. sample_inputs_linalg_norm, variant="subgradient_at_zero"
  1429. ),
  1430. aten_name="linalg_norm",
  1431. supports_forward_ad=True,
  1432. # torch.autograd.gradcheck.GradcheckError: While computing batched gradients, got:
  1433. # Could not allocate memory to change Tensor SizesAndStrides!
  1434. check_batched_forward_grad=False,
  1435. supports_fwgrad_bwgrad=True,
  1436. skips=(
  1437. # [NEW] Skips specifically for sample inputs at zero
  1438. # norm's vjp/jvp are not well-conditioned near zero
  1439. DecorateInfo(
  1440. unittest.expectedFailure, "TestBwdGradients", "test_fn_gradgrad"
  1441. ),
  1442. DecorateInfo(
  1443. unittest.expectedFailure, "TestFwdGradients", "test_fn_fwgrad_bwgrad"
  1444. ),
  1445. DecorateInfo(
  1446. unittest.expectedFailure, "TestFwdGradients", "test_forward_mode_AD"
  1447. ),
  1448. DecorateInfo(unittest.expectedFailure, "TestBwdGradients", "test_fn_grad"),
  1449. ),
  1450. ),
  1451. OpInfo(
  1452. "linalg.matrix_norm",
  1453. aten_name="linalg_matrix_norm",
  1454. dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16),
  1455. supports_forward_ad=True,
  1456. check_batched_forward_grad=False,
  1457. check_batched_gradgrad=False,
  1458. supports_fwgrad_bwgrad=True,
  1459. decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack, with_tf32_off],
  1460. sample_inputs_func=sample_inputs_linalg_matrix_norm,
  1461. ),
  1462. OpInfo(
  1463. "linalg.qr",
  1464. aten_name="linalg_qr",
  1465. op=torch.linalg.qr,
  1466. dtypes=floating_and_complex_types(),
  1467. supports_forward_ad=True,
  1468. supports_fwgrad_bwgrad=True,
  1469. # In-place ops
  1470. check_batched_gradgrad=False,
  1471. sample_inputs_func=sample_inputs_linalg_qr_geqrf,
  1472. decorators=[skipCUDAIfNoCusolver, skipCPUIfNoLapack],
  1473. ),
  1474. OpInfo(
  1475. "linalg.slogdet",
  1476. aten_name="linalg_slogdet",
  1477. op=torch.linalg.slogdet,
  1478. dtypes=floating_and_complex_types(),
  1479. supports_forward_ad=True,
  1480. supports_fwgrad_bwgrad=True,
  1481. sample_inputs_func=sample_inputs_linalg_det_logdet_slogdet,
  1482. decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack],
  1483. ),
  1484. OpInfo(
  1485. "linalg.vander",
  1486. aten_name="linalg_vander",
  1487. ref=np_vander_batched,
  1488. op=torch.linalg.vander,
  1489. dtypes=all_types_and_complex(),
  1490. supports_forward_ad=True,
  1491. supports_fwgrad_bwgrad=True,
  1492. supports_out=False,
  1493. sample_inputs_func=sample_inputs_linalg_vander,
  1494. skips=(
  1495. DecorateInfo(
  1496. unittest.skip("Unsupported on MPS for now"),
  1497. "TestCommon",
  1498. "test_numpy_ref_mps",
  1499. ),
  1500. ),
  1501. ),
  1502. ReductionOpInfo(
  1503. "linalg.vector_norm",
  1504. op=torch.linalg.vector_norm,
  1505. identity=0,
  1506. nan_policy="propagate",
  1507. supports_multiple_dims=True,
  1508. complex_to_real=True,
  1509. supports_forward_ad=True,
  1510. # torch.autograd.gradcheck.GradcheckError: While computing batched gradients
  1511. # got: Could not allocate memory to change Tensor SizesAndStrides!
  1512. check_batched_forward_grad=False,
  1513. supports_fwgrad_bwgrad=True,
  1514. dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16),
  1515. generate_args_kwargs=sample_kwargs_vector_norm,
  1516. aten_name="linalg_vector_norm",
  1517. skips=(
  1518. # FIXME: sum reduces all dimensions when dim=[]
  1519. DecorateInfo(unittest.expectedFailure, "TestReductions", "test_dim_empty"),
  1520. DecorateInfo(
  1521. unittest.expectedFailure, "TestReductions", "test_dim_empty_keepdim"
  1522. ),
  1523. ),
  1524. ),
  1525. OpInfo(
  1526. "linalg.lu_factor",
  1527. aten_name="linalg_lu_factor",
  1528. op=torch.linalg.lu_factor,
  1529. dtypes=floating_and_complex_types(),
  1530. # Runs very slowly on slow gradcheck - alternatively reduce input sizes
  1531. # https://github.com/pytorch/pytorch/issues/80411
  1532. gradcheck_fast_mode=True,
  1533. supports_forward_ad=True,
  1534. supports_fwgrad_bwgrad=True,
  1535. sample_inputs_func=sample_inputs_linalg_lu,
  1536. decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack],
  1537. skips=(
  1538. # linalg.lu_factor: LU without pivoting is not implemented on the CPU
  1539. DecorateInfo(unittest.expectedFailure, "TestCommon", "test_compare_cpu"),
  1540. ),
  1541. ),
  1542. OpInfo(
  1543. "linalg.lu_factor_ex",
  1544. aten_name="linalg_lu_factor_ex",
  1545. op=torch.linalg.lu_factor_ex,
  1546. dtypes=floating_and_complex_types(),
  1547. # https://github.com/pytorch/pytorch/issues/80411
  1548. gradcheck_fast_mode=True,
  1549. supports_forward_ad=True,
  1550. supports_fwgrad_bwgrad=True,
  1551. sample_inputs_func=sample_inputs_linalg_lu,
  1552. decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack],
  1553. skips=(
  1554. # linalg.lu_factor: LU without pivoting is not implemented on the CPU
  1555. DecorateInfo(unittest.expectedFailure, "TestCommon", "test_compare_cpu"),
  1556. ),
  1557. ),
  1558. OpInfo(
  1559. "linalg.lu",
  1560. aten_name="linalg_lu",
  1561. op=torch.linalg.lu,
  1562. dtypes=floating_and_complex_types(),
  1563. # https://github.com/pytorch/pytorch/issues/80411
  1564. # Runs very slowly on slow-gradcheck - alternatively reduce input sizes
  1565. gradcheck_fast_mode=True,
  1566. supports_forward_ad=True,
  1567. supports_fwgrad_bwgrad=True,
  1568. sample_inputs_func=sample_inputs_linalg_lu,
  1569. decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack],
  1570. skips=(
  1571. # linalg.lu_factor: LU without pivoting is not implemented on the CPU
  1572. DecorateInfo(unittest.expectedFailure, "TestCommon", "test_compare_cpu"),
  1573. ),
  1574. ),
  1575. OpInfo(
  1576. "linalg.lu_solve",
  1577. op=torch.linalg.lu_solve,
  1578. aten_name="linalg_lu_solve",
  1579. dtypes=floating_and_complex_types(),
  1580. # Runs very slowly on slow gradcheck - alternatively reduce input sizes
  1581. gradcheck_fast_mode=True,
  1582. supports_forward_ad=True,
  1583. check_batched_forward_grad=False,
  1584. supports_fwgrad_bwgrad=True,
  1585. sample_inputs_func=sample_inputs_lu_solve,
  1586. skips=(
  1587. DecorateInfo(
  1588. unittest.skip("Tests different backward paths"),
  1589. "TestCommon",
  1590. "test_floating_inputs_are_differentiable",
  1591. ),
  1592. ),
  1593. decorators=[skipCPUIfNoLapack, skipCUDAIfNoMagmaAndNoCusolver],
  1594. ),
  1595. OpInfo(
  1596. "linalg.inv",
  1597. aten_name="linalg_inv",
  1598. op=torch.linalg.inv,
  1599. aliases=("inverse",),
  1600. dtypes=floating_and_complex_types(),
  1601. sample_inputs_func=sample_inputs_linalg_invertible,
  1602. check_batched_gradgrad=False,
  1603. supports_forward_ad=True,
  1604. supports_fwgrad_bwgrad=True,
  1605. decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack],
  1606. skips=(
  1607. DecorateInfo(
  1608. unittest.skip("Skipped!"),
  1609. "TestCommon",
  1610. "test_out",
  1611. device_type="mps",
  1612. dtypes=[torch.float32],
  1613. ),
  1614. DecorateInfo(
  1615. unittest.skip("Skipped!"),
  1616. "TestCommon",
  1617. "test_variant_consistency_eager",
  1618. device_type="mps",
  1619. dtypes=[torch.float32],
  1620. ),
  1621. DecorateInfo(
  1622. unittest.skip("Skipped!"),
  1623. "TestJit",
  1624. "test_variant_consistency_jit",
  1625. device_type="mps",
  1626. dtypes=[torch.float32],
  1627. ),
  1628. ),
  1629. ),
  1630. OpInfo(
  1631. "linalg.inv_ex",
  1632. aten_name="linalg_inv_ex",
  1633. op=torch.linalg.inv_ex,
  1634. dtypes=floating_and_complex_types(),
  1635. sample_inputs_func=sample_inputs_linalg_invertible,
  1636. check_batched_gradgrad=False,
  1637. supports_forward_ad=True,
  1638. supports_fwgrad_bwgrad=True,
  1639. decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack],
  1640. skips=(
  1641. DecorateInfo(
  1642. unittest.skip("Skipped!"),
  1643. "TestCommon",
  1644. "test_out",
  1645. device_type="mps",
  1646. dtypes=[torch.float32],
  1647. ),
  1648. DecorateInfo(
  1649. unittest.skip("Skipped!"),
  1650. "TestCommon",
  1651. "test_variant_consistency_eager",
  1652. device_type="mps",
  1653. dtypes=[torch.float32],
  1654. ),
  1655. DecorateInfo(
  1656. unittest.skip("Skipped!"),
  1657. "TestJit",
  1658. "test_variant_consistency_jit",
  1659. device_type="mps",
  1660. dtypes=[torch.float32],
  1661. ),
  1662. ),
  1663. ),
  1664. OpInfo(
  1665. "linalg.solve",
  1666. aten_name="linalg_solve",
  1667. op=torch.linalg.solve,
  1668. dtypes=floating_and_complex_types(),
  1669. sample_inputs_func=sample_inputs_linalg_solve,
  1670. # Runs very slowly on slow gradcheck - alternatively reduce input sizes
  1671. gradcheck_fast_mode=True,
  1672. supports_forward_ad=True,
  1673. supports_fwgrad_bwgrad=True,
  1674. decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack],
  1675. skips=(
  1676. DecorateInfo(
  1677. unittest.skip("Skipped!"),
  1678. "TestCommon",
  1679. "test_out",
  1680. device_type="mps",
  1681. dtypes=[torch.float32],
  1682. ),
  1683. DecorateInfo(
  1684. unittest.skip("Skipped!"),
  1685. "TestCommon",
  1686. "test_variant_consistency_eager",
  1687. device_type="mps",
  1688. dtypes=[torch.float32],
  1689. ),
  1690. DecorateInfo(
  1691. unittest.skip("Skipped!"),
  1692. "TestJit",
  1693. "test_variant_consistency_jit",
  1694. device_type="mps",
  1695. dtypes=[torch.float32],
  1696. ),
  1697. ),
  1698. ),
  1699. OpInfo(
  1700. "linalg.solve_ex",
  1701. aten_name="linalg_solve_ex",
  1702. op=torch.linalg.solve_ex,
  1703. dtypes=floating_and_complex_types(),
  1704. sample_inputs_func=sample_inputs_linalg_solve,
  1705. supports_forward_ad=True,
  1706. supports_fwgrad_bwgrad=True,
  1707. decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack],
  1708. skips=(
  1709. DecorateInfo(
  1710. unittest.skip("Skipped!"),
  1711. "TestCommon",
  1712. "test_out",
  1713. device_type="mps",
  1714. dtypes=[torch.float32],
  1715. ),
  1716. DecorateInfo(
  1717. unittest.skip("Skipped!"),
  1718. "TestCommon",
  1719. "test_variant_consistency_eager",
  1720. device_type="mps",
  1721. dtypes=[torch.float32],
  1722. ),
  1723. DecorateInfo(
  1724. unittest.skip("Skipped!"),
  1725. "TestJit",
  1726. "test_variant_consistency_jit",
  1727. device_type="mps",
  1728. dtypes=[torch.float32],
  1729. ),
  1730. ),
  1731. ),
  1732. OpInfo(
  1733. "linalg.solve_triangular",
  1734. aten_name="linalg_solve_triangular",
  1735. op=torch.linalg.solve_triangular,
  1736. dtypes=floating_and_complex_types(),
  1737. sample_inputs_func=sample_inputs_linalg_solve_triangular,
  1738. supports_fwgrad_bwgrad=True,
  1739. skips=(skipCPUIfNoLapack,),
  1740. # linalg.solve_triangular cannot be batched over because of a call to out.copy_(result);
  1741. supports_forward_ad=True,
  1742. ),
  1743. OpInfo(
  1744. "linalg.matrix_rank",
  1745. aten_name="linalg_matrix_rank",
  1746. dtypes=floating_and_complex_types(),
  1747. supports_autograd=False,
  1748. sample_inputs_func=sample_inputs_matrix_rank,
  1749. decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack],
  1750. skips=(
  1751. DecorateInfo(
  1752. unittest.skip("Skipped!"),
  1753. "TestCommon",
  1754. "test_out",
  1755. device_type="mps",
  1756. dtypes=[torch.float32],
  1757. ),
  1758. DecorateInfo(
  1759. unittest.skip("Skipped!"),
  1760. "TestCommon",
  1761. "test_variant_consistency_eager",
  1762. device_type="mps",
  1763. dtypes=[torch.float32],
  1764. ),
  1765. # jit doesn't accept tensor inputs for matrix rank
  1766. DecorateInfo(
  1767. unittest.skip("Skipped!"),
  1768. "TestJit",
  1769. "test_variant_consistency_jit",
  1770. dtypes=[torch.complex64, torch.float32],
  1771. ),
  1772. ),
  1773. ),
  1774. OpInfo(
  1775. "linalg.matrix_rank",
  1776. aten_name="linalg_matrix_rank",
  1777. variant_test_name="hermitian",
  1778. dtypes=floating_and_complex_types(),
  1779. supports_autograd=False,
  1780. sample_inputs_func=sample_inputs_linalg_pinv_hermitian,
  1781. decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack],
  1782. skips=(
  1783. DecorateInfo(
  1784. unittest.skip("Skipped!"),
  1785. "TestCommon",
  1786. "test_out",
  1787. device_type="mps",
  1788. dtypes=[torch.float32],
  1789. ),
  1790. DecorateInfo(
  1791. unittest.skip("Skipped!"),
  1792. "TestJit",
  1793. "test_variant_consistency_jit",
  1794. device_type="mps",
  1795. dtypes=[torch.float32],
  1796. ),
  1797. ),
  1798. ),
  1799. OpInfo(
  1800. "linalg.pinv",
  1801. aten_name="linalg_pinv",
  1802. op=torch.linalg.pinv,
  1803. dtypes=floating_and_complex_types(),
  1804. # Runs very slowly on slow gradcheck - alternatively reduce input sizes
  1805. gradcheck_fast_mode=True,
  1806. check_batched_grad=False,
  1807. check_batched_gradgrad=False,
  1808. supports_forward_ad=True,
  1809. supports_fwgrad_bwgrad=True,
  1810. sample_inputs_func=sample_inputs_linalg_pinv,
  1811. decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack],
  1812. skips=(
  1813. # errors with "leaked XXXX bytes CUDA memory on device 0"
  1814. DecorateInfo(
  1815. unittest.skip("Skipped!"),
  1816. "TestJit",
  1817. "test_variant_consistency_jit",
  1818. device_type="cuda",
  1819. ),
  1820. ),
  1821. ),
  1822. OpInfo(
  1823. "linalg.pinv",
  1824. aten_name="linalg_pinv",
  1825. variant_test_name="singular",
  1826. # pinv is Frechet-differentiable in a rank-preserving neighborhood,
  1827. # so we feed inputs that are the products of two full-rank factors,
  1828. # to avoid any rank changes caused by the perturbations in the gradcheck
  1829. op=lambda a, b: torch.linalg.pinv(a @ b.mT),
  1830. dtypes=floating_and_complex_types(),
  1831. supports_out=False,
  1832. check_batched_grad=False,
  1833. check_batched_gradgrad=False,
  1834. supports_forward_ad=True,
  1835. supports_fwgrad_bwgrad=True,
  1836. sample_inputs_func=sample_inputs_linalg_pinv_singular,
  1837. # Only large tensors show issues with implicit backward used prior to
  1838. # explicit backward implementation.
  1839. decorators=[slowTest, skipCUDAIfNoCusolver, skipCPUIfNoLapack],
  1840. skips=(
  1841. DecorateInfo(
  1842. unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"
  1843. ),
  1844. # CUDA runs out of memory
  1845. DecorateInfo(
  1846. unittest.skip("Skipped!"),
  1847. "TestFwdGradients",
  1848. "test_fn_fwgrad_bwgrad",
  1849. device_type="cuda",
  1850. dtypes=[torch.cdouble],
  1851. ),
  1852. # This test takes almost 2 hours to run!
  1853. DecorateInfo(
  1854. unittest.skip("Skipped!"),
  1855. "TestBwdGradients",
  1856. "test_fn_gradgrad",
  1857. device_type="cuda",
  1858. dtypes=[torch.cdouble],
  1859. ),
  1860. ),
  1861. ),
  1862. OpInfo(
  1863. "linalg.pinv",
  1864. aten_name="linalg_pinv",
  1865. variant_test_name="hermitian",
  1866. dtypes=floating_and_complex_types(),
  1867. check_batched_grad=False,
  1868. check_batched_gradgrad=False,
  1869. supports_forward_ad=True,
  1870. supports_fwgrad_bwgrad=True,
  1871. # See https://github.com/pytorch/pytorch/pull/78358
  1872. check_batched_forward_grad=False,
  1873. sample_inputs_func=sample_inputs_linalg_pinv_hermitian,
  1874. gradcheck_wrapper=gradcheck_wrapper_hermitian_input,
  1875. decorators=[skipCUDAIfNoMagma, skipCPUIfNoLapack],
  1876. skips=(
  1877. DecorateInfo(
  1878. unittest.skip("Skipped!"),
  1879. "TestCommon",
  1880. "test_out",
  1881. device_type="mps",
  1882. dtypes=[torch.float32],
  1883. ),
  1884. DecorateInfo(
  1885. unittest.skip("Skipped!"),
  1886. "TestCommon",
  1887. "test_variant_consistency_eager",
  1888. device_type="mps",
  1889. dtypes=[torch.float32],
  1890. ),
  1891. DecorateInfo(
  1892. unittest.skip("Skipped!"),
  1893. "TestJit",
  1894. "test_variant_consistency_jit",
  1895. device_type="mps",
  1896. dtypes=[torch.float32],
  1897. ),
  1898. DecorateInfo(
  1899. toleranceOverride({torch.float32: tol(atol=1e-5, rtol=1e-5)}),
  1900. "TestCommon",
  1901. "test_noncontiguous_samples",
  1902. device_type="cuda",
  1903. ),
  1904. # This test is flaky under slow gradcheck, likely due to rounding issues
  1905. DecorateInfo(
  1906. skipIfSlowGradcheckEnv,
  1907. "TestFwdGradients",
  1908. "test_fn_fwgrad_bwgrad",
  1909. device_type="cuda",
  1910. ),
  1911. ),
  1912. ),
  1913. OpInfo(
  1914. "linalg.svd",
  1915. op=torch.linalg.svd,
  1916. aten_name="linalg_svd",
  1917. decomp_aten_name="_linalg_svd",
  1918. dtypes=floating_and_complex_types(),
  1919. # Runs very slowly on slow-gradcheck - alternatively reduce input sizes
  1920. gradcheck_fast_mode=True,
  1921. supports_fwgrad_bwgrad=True,
  1922. supports_forward_ad=True,
  1923. check_batched_forward_grad=False,
  1924. # We're using at::allclose, which does not have a batching rule
  1925. check_batched_grad=False,
  1926. check_batched_gradgrad=False,
  1927. sample_inputs_func=sample_inputs_svd,
  1928. decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack, with_tf32_off],
  1929. skips=(
  1930. DecorateInfo(
  1931. unittest.skip("Skipped!"),
  1932. "TestCommon",
  1933. "test_out",
  1934. device_type="mps",
  1935. dtypes=[torch.float32],
  1936. ),
  1937. DecorateInfo(
  1938. unittest.skip("Skipped!"),
  1939. "TestCommon",
  1940. "test_variant_consistency_eager",
  1941. device_type="mps",
  1942. dtypes=[torch.float32],
  1943. ),
  1944. DecorateInfo(
  1945. unittest.skip("Skipped!"),
  1946. "TestJit",
  1947. "test_variant_consistency_jit",
  1948. device_type="mps",
  1949. dtypes=[torch.float32],
  1950. ),
  1951. ),
  1952. ),
  1953. OpInfo(
  1954. "linalg.svdvals",
  1955. op=torch.linalg.svdvals,
  1956. aten_name="linalg_svdvals",
  1957. decomp_aten_name="_linalg_svd",
  1958. dtypes=floating_and_complex_types(),
  1959. check_batched_forward_grad=False,
  1960. supports_fwgrad_bwgrad=True,
  1961. supports_forward_ad=True,
  1962. # We're using at::allclose, which does not have a batching rule
  1963. check_batched_gradgrad=False,
  1964. sample_inputs_func=sample_inputs_linalg_svdvals,
  1965. decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack, with_tf32_off],
  1966. ),
  1967. OpInfo(
  1968. "linalg.tensorinv",
  1969. ref=np.linalg.tensorinv,
  1970. dtypes=floating_and_complex_types(),
  1971. sample_inputs_func=sample_inputs_tensorinv,
  1972. supports_forward_ad=True,
  1973. supports_fwgrad_bwgrad=True,
  1974. # See https://github.com/pytorch/pytorch/pull/78358
  1975. check_batched_forward_grad=False,
  1976. decorators=[skipCPUIfNoLapack, skipCUDAIfNoMagmaAndNoCusolver],
  1977. skips=(
  1978. DecorateInfo(
  1979. unittest.skip("Unsupported on MPS for now"),
  1980. "TestCommon",
  1981. "test_numpy_ref_mps",
  1982. ),
  1983. ),
  1984. ),
  1985. OpInfo(
  1986. "linalg.tensorsolve",
  1987. ref=lambda a, b, dims=None: np.linalg.tensorsolve(a, b, axes=dims),
  1988. dtypes=floating_and_complex_types(),
  1989. sample_inputs_func=sample_inputs_tensorsolve,
  1990. supports_forward_ad=True,
  1991. supports_fwgrad_bwgrad=True,
  1992. decorators=[
  1993. skipCUDAIfNoMagmaAndNoCusolver,
  1994. skipCPUIfNoLapack,
  1995. DecorateInfo(
  1996. toleranceOverride({torch.float32: tol(atol=1e-03, rtol=1e-03)}),
  1997. "TestCommon",
  1998. "test_noncontiguous_samples",
  1999. device_type="cuda",
  2000. ),
  2001. ],
  2002. skips=(
  2003. DecorateInfo(
  2004. unittest.skip("Unsupported on MPS for now"),
  2005. "TestCommon",
  2006. "test_numpy_ref_mps",
  2007. ),
  2008. ),
  2009. ),
  2010. ]
  2011. python_ref_db: List[OpInfo] = [
  2012. #
  2013. # torch.linalg
  2014. #
  2015. ReductionPythonRefInfo(
  2016. "_refs.linalg.vector_norm",
  2017. torch_opinfo_name="linalg.vector_norm",
  2018. supports_out=True,
  2019. supports_nvfuser=False, # clone_default
  2020. op_db=op_db,
  2021. ),
  2022. PythonRefInfo(
  2023. "_refs.linalg.matrix_norm",
  2024. torch_opinfo_name="linalg.matrix_norm",
  2025. supports_out=True,
  2026. # Uses svdvals which does not support nvfuser
  2027. supports_nvfuser=False,
  2028. # Uses vector_norm inside and vector_norm is affected by
  2029. # https://github.com/pytorch/pytorch/issues/77216
  2030. validate_view_consistency=False,
  2031. op_db=op_db,
  2032. ),
  2033. PythonRefInfo(
  2034. "_refs.linalg.norm",
  2035. torch_opinfo_name="linalg.norm",
  2036. supports_out=True,
  2037. # Uses svdvals which does not support nvfuser
  2038. supports_nvfuser=False,
  2039. # Uses vector_norm inside and vector_norm is affected by
  2040. # https://github.com/pytorch/pytorch/issues/77216
  2041. validate_view_consistency=False,
  2042. op_db=op_db,
  2043. ),
  2044. PythonRefInfo(
  2045. "_refs.linalg.svd",
  2046. torch_opinfo_name="linalg.svd",
  2047. supports_out=True,
  2048. supports_nvfuser=False,
  2049. op_db=op_db,
  2050. ),
  2051. PythonRefInfo(
  2052. "_refs.linalg.svdvals",
  2053. torch_opinfo_name="linalg.svdvals",
  2054. supports_out=True,
  2055. supports_nvfuser=False,
  2056. op_db=op_db,
  2057. ),
  2058. ]