fft.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755
  1. import unittest
  2. from functools import partial
  3. from typing import List
  4. import numpy as np
  5. import torch
  6. from torch.testing import make_tensor
  7. from torch.testing._internal.common_cuda import SM53OrLater
  8. from torch.testing._internal.common_device_type import precisionOverride
  9. from torch.testing._internal.common_dtype import (
  10. all_types_and,
  11. all_types_and_complex_and,
  12. )
  13. from torch.testing._internal.common_utils import TEST_SCIPY, TEST_WITH_ROCM
  14. from torch.testing._internal.opinfo.core import (
  15. DecorateInfo,
  16. ErrorInput,
  17. OpInfo,
  18. SampleInput,
  19. SpectralFuncInfo,
  20. SpectralFuncType,
  21. )
  22. from torch.testing._internal.opinfo.refs import (
  23. _find_referenced_opinfo,
  24. _inherit_constructor_args,
  25. PythonRefInfo,
  26. )
  27. has_scipy_fft = False
  28. if TEST_SCIPY:
  29. try:
  30. import scipy.fft
  31. has_scipy_fft = True
  32. except ModuleNotFoundError:
  33. pass
  34. class SpectralFuncPythonRefInfo(SpectralFuncInfo):
  35. """
  36. An OpInfo for a Python reference of an elementwise unary operation.
  37. """
  38. def __init__(
  39. self,
  40. name, # the stringname of the callable Python reference
  41. *,
  42. op=None, # the function variant of the operation, populated as torch.<name> if None
  43. torch_opinfo_name, # the string name of the corresponding torch opinfo
  44. torch_opinfo_variant="",
  45. supports_nvfuser=True,
  46. **kwargs,
  47. ): # additional kwargs override kwargs inherited from the torch opinfo
  48. self.torch_opinfo_name = torch_opinfo_name
  49. self.torch_opinfo = _find_referenced_opinfo(
  50. torch_opinfo_name, torch_opinfo_variant, op_db=op_db
  51. )
  52. self.supports_nvfuser = supports_nvfuser
  53. assert isinstance(self.torch_opinfo, SpectralFuncInfo)
  54. inherited = self.torch_opinfo._original_spectral_func_args
  55. ukwargs = _inherit_constructor_args(name, op, inherited, kwargs)
  56. super().__init__(**ukwargs)
  57. def error_inputs_fft(op_info, device, **kwargs):
  58. make_arg = partial(make_tensor, device=device, dtype=torch.float32)
  59. # Zero-dimensional tensor has no dimension to take FFT of
  60. yield ErrorInput(
  61. SampleInput(make_arg()),
  62. error_type=IndexError,
  63. error_regex="Dimension specified as -1 but tensor has no dimensions",
  64. )
  65. def error_inputs_fftn(op_info, device, **kwargs):
  66. make_arg = partial(make_tensor, device=device, dtype=torch.float32)
  67. # Specifying a dimension on a zero-dimensional tensor
  68. yield ErrorInput(
  69. SampleInput(make_arg(), dim=(0,)),
  70. error_type=IndexError,
  71. error_regex="Dimension specified as 0 but tensor has no dimensions",
  72. )
  73. def sample_inputs_fftshift(op_info, device, dtype, requires_grad, **kwargs):
  74. def mt(shape, **kwargs):
  75. return make_tensor(
  76. shape, device=device, dtype=dtype, requires_grad=requires_grad, **kwargs
  77. )
  78. yield SampleInput(mt((9, 10)))
  79. yield SampleInput(mt((50,)), kwargs=dict(dim=0))
  80. yield SampleInput(mt((5, 11)), kwargs=dict(dim=(1,)))
  81. yield SampleInput(mt((5, 6)), kwargs=dict(dim=(0, 1)))
  82. yield SampleInput(mt((5, 6, 2)), kwargs=dict(dim=(0, 2)))
  83. # Operator database
  84. op_db: List[OpInfo] = [
  85. SpectralFuncInfo(
  86. "fft.fft",
  87. aten_name="fft_fft",
  88. decomp_aten_name="_fft_c2c",
  89. ref=np.fft.fft,
  90. ndimensional=SpectralFuncType.OneD,
  91. dtypes=all_types_and_complex_and(torch.bool),
  92. # rocFFT doesn't support Half/Complex Half Precision FFT
  93. # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
  94. dtypesIfCUDA=all_types_and_complex_and(
  95. torch.bool,
  96. *(
  97. ()
  98. if (TEST_WITH_ROCM or not SM53OrLater)
  99. else (torch.half, torch.complex32)
  100. ),
  101. ),
  102. error_inputs_func=error_inputs_fft,
  103. # https://github.com/pytorch/pytorch/issues/80411
  104. gradcheck_fast_mode=True,
  105. supports_forward_ad=True,
  106. supports_fwgrad_bwgrad=True,
  107. # See https://github.com/pytorch/pytorch/pull/78358
  108. check_batched_forward_grad=False,
  109. ),
  110. SpectralFuncInfo(
  111. "fft.fft2",
  112. aten_name="fft_fft2",
  113. ref=np.fft.fft2,
  114. decomp_aten_name="_fft_c2c",
  115. ndimensional=SpectralFuncType.TwoD,
  116. dtypes=all_types_and_complex_and(torch.bool),
  117. # rocFFT doesn't support Half/Complex Half Precision FFT
  118. # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
  119. dtypesIfCUDA=all_types_and_complex_and(
  120. torch.bool,
  121. *(
  122. ()
  123. if (TEST_WITH_ROCM or not SM53OrLater)
  124. else (torch.half, torch.complex32)
  125. ),
  126. ),
  127. error_inputs_func=error_inputs_fftn,
  128. # https://github.com/pytorch/pytorch/issues/80411
  129. gradcheck_fast_mode=True,
  130. supports_forward_ad=True,
  131. supports_fwgrad_bwgrad=True,
  132. # See https://github.com/pytorch/pytorch/pull/78358
  133. check_batched_forward_grad=False,
  134. decorators=[precisionOverride({torch.float: 1e-4, torch.cfloat: 1e-4})],
  135. ),
  136. SpectralFuncInfo(
  137. "fft.fftn",
  138. aten_name="fft_fftn",
  139. decomp_aten_name="_fft_c2c",
  140. ref=np.fft.fftn,
  141. ndimensional=SpectralFuncType.ND,
  142. dtypes=all_types_and_complex_and(torch.bool),
  143. # rocFFT doesn't support Half/Complex Half Precision FFT
  144. # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
  145. dtypesIfCUDA=all_types_and_complex_and(
  146. torch.bool,
  147. *(
  148. ()
  149. if (TEST_WITH_ROCM or not SM53OrLater)
  150. else (torch.half, torch.complex32)
  151. ),
  152. ),
  153. error_inputs_func=error_inputs_fftn,
  154. # https://github.com/pytorch/pytorch/issues/80411
  155. gradcheck_fast_mode=True,
  156. supports_forward_ad=True,
  157. supports_fwgrad_bwgrad=True,
  158. # See https://github.com/pytorch/pytorch/pull/78358
  159. check_batched_forward_grad=False,
  160. decorators=[precisionOverride({torch.float: 1e-4, torch.cfloat: 1e-4})],
  161. ),
  162. SpectralFuncInfo(
  163. "fft.hfft",
  164. aten_name="fft_hfft",
  165. decomp_aten_name="_fft_c2r",
  166. ref=np.fft.hfft,
  167. ndimensional=SpectralFuncType.OneD,
  168. dtypes=all_types_and_complex_and(torch.bool),
  169. # rocFFT doesn't support Half/Complex Half Precision FFT
  170. # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
  171. dtypesIfCUDA=all_types_and_complex_and(
  172. torch.bool,
  173. *(
  174. ()
  175. if (TEST_WITH_ROCM or not SM53OrLater)
  176. else (torch.half, torch.complex32)
  177. ),
  178. ),
  179. error_inputs_func=error_inputs_fft,
  180. # https://github.com/pytorch/pytorch/issues/80411
  181. gradcheck_fast_mode=True,
  182. supports_forward_ad=True,
  183. supports_fwgrad_bwgrad=True,
  184. # See https://github.com/pytorch/pytorch/pull/78358
  185. check_batched_forward_grad=False,
  186. check_batched_gradgrad=False,
  187. skips=(
  188. # Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479
  189. DecorateInfo(
  190. unittest.skip("Skipped!"),
  191. "TestSchemaCheckModeOpInfo",
  192. "test_schema_correctness",
  193. dtypes=(torch.complex64, torch.complex128),
  194. ),
  195. ),
  196. ),
  197. SpectralFuncInfo(
  198. "fft.hfft2",
  199. aten_name="fft_hfft2",
  200. decomp_aten_name="_fft_c2r",
  201. ref=scipy.fft.hfft2 if has_scipy_fft else None,
  202. ndimensional=SpectralFuncType.TwoD,
  203. dtypes=all_types_and_complex_and(torch.bool),
  204. # rocFFT doesn't support Half/Complex Half Precision FFT
  205. # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
  206. dtypesIfCUDA=all_types_and_complex_and(
  207. torch.bool,
  208. *(
  209. ()
  210. if (TEST_WITH_ROCM or not SM53OrLater)
  211. else (torch.half, torch.complex32)
  212. ),
  213. ),
  214. error_inputs_func=error_inputs_fftn,
  215. # https://github.com/pytorch/pytorch/issues/80411
  216. gradcheck_fast_mode=True,
  217. supports_forward_ad=True,
  218. supports_fwgrad_bwgrad=True,
  219. check_batched_gradgrad=False,
  220. # See https://github.com/pytorch/pytorch/pull/78358
  221. check_batched_forward_grad=False,
  222. decorators=[
  223. DecorateInfo(
  224. precisionOverride({torch.float: 2e-4, torch.cfloat: 2e-4}),
  225. "TestFFT",
  226. "test_reference_nd",
  227. )
  228. ],
  229. skips=(
  230. # Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479
  231. DecorateInfo(
  232. unittest.skip("Skipped!"),
  233. "TestSchemaCheckModeOpInfo",
  234. "test_schema_correctness",
  235. ),
  236. ),
  237. ),
  238. SpectralFuncInfo(
  239. "fft.hfftn",
  240. aten_name="fft_hfftn",
  241. decomp_aten_name="_fft_c2r",
  242. ref=scipy.fft.hfftn if has_scipy_fft else None,
  243. ndimensional=SpectralFuncType.ND,
  244. dtypes=all_types_and_complex_and(torch.bool),
  245. # rocFFT doesn't support Half/Complex Half Precision FFT
  246. # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
  247. dtypesIfCUDA=all_types_and_complex_and(
  248. torch.bool,
  249. *(
  250. ()
  251. if (TEST_WITH_ROCM or not SM53OrLater)
  252. else (torch.half, torch.complex32)
  253. ),
  254. ),
  255. error_inputs_func=error_inputs_fftn,
  256. # https://github.com/pytorch/pytorch/issues/80411
  257. gradcheck_fast_mode=True,
  258. supports_forward_ad=True,
  259. supports_fwgrad_bwgrad=True,
  260. check_batched_gradgrad=False,
  261. # See https://github.com/pytorch/pytorch/pull/78358
  262. check_batched_forward_grad=False,
  263. decorators=[
  264. DecorateInfo(
  265. precisionOverride({torch.float: 2e-4, torch.cfloat: 2e-4}),
  266. "TestFFT",
  267. "test_reference_nd",
  268. ),
  269. ],
  270. skips=(
  271. # Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479
  272. DecorateInfo(
  273. unittest.skip("Skipped!"),
  274. "TestSchemaCheckModeOpInfo",
  275. "test_schema_correctness",
  276. ),
  277. ),
  278. ),
  279. SpectralFuncInfo(
  280. "fft.rfft",
  281. aten_name="fft_rfft",
  282. decomp_aten_name="_fft_r2c",
  283. ref=np.fft.rfft,
  284. ndimensional=SpectralFuncType.OneD,
  285. dtypes=all_types_and(torch.bool),
  286. # rocFFT doesn't support Half/Complex Half Precision FFT
  287. # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
  288. dtypesIfCUDA=all_types_and(
  289. torch.bool, *(() if (TEST_WITH_ROCM or not SM53OrLater) else (torch.half,))
  290. ),
  291. error_inputs_func=error_inputs_fft,
  292. # https://github.com/pytorch/pytorch/issues/80411
  293. gradcheck_fast_mode=True,
  294. supports_forward_ad=True,
  295. supports_fwgrad_bwgrad=True,
  296. check_batched_grad=False,
  297. skips=(),
  298. check_batched_gradgrad=False,
  299. ),
  300. SpectralFuncInfo(
  301. "fft.rfft2",
  302. aten_name="fft_rfft2",
  303. decomp_aten_name="_fft_r2c",
  304. ref=np.fft.rfft2,
  305. ndimensional=SpectralFuncType.TwoD,
  306. dtypes=all_types_and(torch.bool),
  307. # rocFFT doesn't support Half/Complex Half Precision FFT
  308. # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
  309. dtypesIfCUDA=all_types_and(
  310. torch.bool, *(() if (TEST_WITH_ROCM or not SM53OrLater) else (torch.half,))
  311. ),
  312. error_inputs_func=error_inputs_fftn,
  313. # https://github.com/pytorch/pytorch/issues/80411
  314. gradcheck_fast_mode=True,
  315. supports_forward_ad=True,
  316. supports_fwgrad_bwgrad=True,
  317. check_batched_grad=False,
  318. check_batched_gradgrad=False,
  319. decorators=[
  320. precisionOverride({torch.float: 1e-4}),
  321. ],
  322. ),
  323. SpectralFuncInfo(
  324. "fft.rfftn",
  325. aten_name="fft_rfftn",
  326. decomp_aten_name="_fft_r2c",
  327. ref=np.fft.rfftn,
  328. ndimensional=SpectralFuncType.ND,
  329. dtypes=all_types_and(torch.bool),
  330. # rocFFT doesn't support Half/Complex Half Precision FFT
  331. # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
  332. dtypesIfCUDA=all_types_and(
  333. torch.bool, *(() if (TEST_WITH_ROCM or not SM53OrLater) else (torch.half,))
  334. ),
  335. error_inputs_func=error_inputs_fftn,
  336. # https://github.com/pytorch/pytorch/issues/80411
  337. gradcheck_fast_mode=True,
  338. supports_forward_ad=True,
  339. supports_fwgrad_bwgrad=True,
  340. check_batched_grad=False,
  341. check_batched_gradgrad=False,
  342. decorators=[
  343. precisionOverride({torch.float: 1e-4}),
  344. ],
  345. ),
  346. SpectralFuncInfo(
  347. "fft.ifft",
  348. aten_name="fft_ifft",
  349. decomp_aten_name="_fft_c2c",
  350. ref=np.fft.ifft,
  351. ndimensional=SpectralFuncType.OneD,
  352. error_inputs_func=error_inputs_fft,
  353. # https://github.com/pytorch/pytorch/issues/80411
  354. gradcheck_fast_mode=True,
  355. supports_forward_ad=True,
  356. supports_fwgrad_bwgrad=True,
  357. # See https://github.com/pytorch/pytorch/pull/78358
  358. check_batched_forward_grad=False,
  359. dtypes=all_types_and_complex_and(torch.bool),
  360. # rocFFT doesn't support Half/Complex Half Precision FFT
  361. # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
  362. dtypesIfCUDA=all_types_and_complex_and(
  363. torch.bool,
  364. *(
  365. ()
  366. if (TEST_WITH_ROCM or not SM53OrLater)
  367. else (torch.half, torch.complex32)
  368. ),
  369. ),
  370. ),
  371. SpectralFuncInfo(
  372. "fft.ifft2",
  373. aten_name="fft_ifft2",
  374. decomp_aten_name="_fft_c2c",
  375. ref=np.fft.ifft2,
  376. ndimensional=SpectralFuncType.TwoD,
  377. error_inputs_func=error_inputs_fftn,
  378. # https://github.com/pytorch/pytorch/issues/80411
  379. gradcheck_fast_mode=True,
  380. supports_forward_ad=True,
  381. supports_fwgrad_bwgrad=True,
  382. # See https://github.com/pytorch/pytorch/pull/78358
  383. check_batched_forward_grad=False,
  384. dtypes=all_types_and_complex_and(torch.bool),
  385. # rocFFT doesn't support Half/Complex Half Precision FFT
  386. # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
  387. dtypesIfCUDA=all_types_and_complex_and(
  388. torch.bool,
  389. *(
  390. ()
  391. if (TEST_WITH_ROCM or not SM53OrLater)
  392. else (torch.half, torch.complex32)
  393. ),
  394. ),
  395. decorators=[
  396. DecorateInfo(
  397. precisionOverride({torch.float: 1e-4, torch.cfloat: 1e-4}),
  398. "TestFFT",
  399. "test_reference_nd",
  400. )
  401. ],
  402. ),
  403. SpectralFuncInfo(
  404. "fft.ifftn",
  405. aten_name="fft_ifftn",
  406. decomp_aten_name="_fft_c2c",
  407. ref=np.fft.ifftn,
  408. ndimensional=SpectralFuncType.ND,
  409. error_inputs_func=error_inputs_fftn,
  410. # https://github.com/pytorch/pytorch/issues/80411
  411. gradcheck_fast_mode=True,
  412. supports_forward_ad=True,
  413. supports_fwgrad_bwgrad=True,
  414. # See https://github.com/pytorch/pytorch/pull/78358
  415. check_batched_forward_grad=False,
  416. dtypes=all_types_and_complex_and(torch.bool),
  417. # rocFFT doesn't support Half/Complex Half Precision FFT
  418. # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
  419. dtypesIfCUDA=all_types_and_complex_and(
  420. torch.bool,
  421. *(
  422. ()
  423. if (TEST_WITH_ROCM or not SM53OrLater)
  424. else (torch.half, torch.complex32)
  425. ),
  426. ),
  427. decorators=[
  428. DecorateInfo(
  429. precisionOverride({torch.float: 1e-4, torch.cfloat: 1e-4}),
  430. "TestFFT",
  431. "test_reference_nd",
  432. )
  433. ],
  434. ),
  435. SpectralFuncInfo(
  436. "fft.ihfft",
  437. aten_name="fft_ihfft",
  438. decomp_aten_name="_fft_r2c",
  439. ref=np.fft.ihfft,
  440. ndimensional=SpectralFuncType.OneD,
  441. error_inputs_func=error_inputs_fft,
  442. supports_forward_ad=True,
  443. supports_fwgrad_bwgrad=True,
  444. # See https://github.com/pytorch/pytorch/pull/78358
  445. check_batched_forward_grad=False,
  446. dtypes=all_types_and(torch.bool),
  447. # rocFFT doesn't support Half/Complex Half Precision FFT
  448. # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
  449. dtypesIfCUDA=all_types_and(
  450. torch.bool, *(() if (TEST_WITH_ROCM or not SM53OrLater) else (torch.half,))
  451. ),
  452. skips=(),
  453. check_batched_grad=False,
  454. ),
  455. SpectralFuncInfo(
  456. "fft.ihfft2",
  457. aten_name="fft_ihfft2",
  458. decomp_aten_name="_fft_r2c",
  459. ref=scipy.fft.ihfftn if has_scipy_fft else None,
  460. ndimensional=SpectralFuncType.TwoD,
  461. error_inputs_func=error_inputs_fftn,
  462. # https://github.com/pytorch/pytorch/issues/80411
  463. gradcheck_fast_mode=True,
  464. supports_forward_ad=True,
  465. supports_fwgrad_bwgrad=True,
  466. # See https://github.com/pytorch/pytorch/pull/78358
  467. check_batched_forward_grad=False,
  468. dtypes=all_types_and(torch.bool),
  469. # rocFFT doesn't support Half/Complex Half Precision FFT
  470. # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
  471. dtypesIfCUDA=all_types_and(
  472. torch.bool, *(() if (TEST_WITH_ROCM or not SM53OrLater) else (torch.half,))
  473. ),
  474. check_batched_grad=False,
  475. check_batched_gradgrad=False,
  476. decorators=(
  477. # The values for attribute 'shape' do not match: torch.Size([5, 6, 5]) != torch.Size([5, 6, 6]).
  478. DecorateInfo(unittest.expectedFailure, "TestCommon", "test_out_warning"),
  479. DecorateInfo(
  480. precisionOverride({torch.float: 2e-4}), "TestFFT", "test_reference_nd"
  481. ),
  482. # Mismatched elements!
  483. DecorateInfo(unittest.expectedFailure, "TestCommon", "test_out"),
  484. DecorateInfo(unittest.expectedFailure, "TestCommon", "test_out_warnings"),
  485. ),
  486. ),
  487. SpectralFuncInfo(
  488. "fft.ihfftn",
  489. aten_name="fft_ihfftn",
  490. decomp_aten_name="_fft_r2c",
  491. ref=scipy.fft.ihfftn if has_scipy_fft else None,
  492. ndimensional=SpectralFuncType.ND,
  493. error_inputs_func=error_inputs_fftn,
  494. # https://github.com/pytorch/pytorch/issues/80411
  495. gradcheck_fast_mode=True,
  496. supports_forward_ad=True,
  497. supports_fwgrad_bwgrad=True,
  498. # See https://github.com/pytorch/pytorch/pull/78358
  499. check_batched_forward_grad=False,
  500. dtypes=all_types_and(torch.bool),
  501. # rocFFT doesn't support Half/Complex Half Precision FFT
  502. # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archss
  503. dtypesIfCUDA=all_types_and(
  504. torch.bool, *(() if (TEST_WITH_ROCM or not SM53OrLater) else (torch.half,))
  505. ),
  506. check_batched_grad=False,
  507. check_batched_gradgrad=False,
  508. decorators=[
  509. # The values for attribute 'shape' do not match: torch.Size([5, 6, 5]) != torch.Size([5, 6, 6]).
  510. DecorateInfo(unittest.expectedFailure, "TestCommon", "test_out_warning"),
  511. # Mismatched elements!
  512. DecorateInfo(unittest.expectedFailure, "TestCommon", "test_out"),
  513. DecorateInfo(
  514. precisionOverride({torch.float: 2e-4}), "TestFFT", "test_reference_nd"
  515. ),
  516. ],
  517. ),
  518. SpectralFuncInfo(
  519. "fft.irfft",
  520. aten_name="fft_irfft",
  521. decomp_aten_name="_fft_c2r",
  522. ref=np.fft.irfft,
  523. ndimensional=SpectralFuncType.OneD,
  524. error_inputs_func=error_inputs_fft,
  525. # https://github.com/pytorch/pytorch/issues/80411
  526. gradcheck_fast_mode=True,
  527. supports_forward_ad=True,
  528. supports_fwgrad_bwgrad=True,
  529. # See https://github.com/pytorch/pytorch/pull/78358
  530. check_batched_forward_grad=False,
  531. dtypes=all_types_and_complex_and(torch.bool),
  532. # rocFFT doesn't support Half/Complex Half Precision FFT
  533. # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
  534. dtypesIfCUDA=all_types_and_complex_and(
  535. torch.bool,
  536. *(
  537. ()
  538. if (TEST_WITH_ROCM or not SM53OrLater)
  539. else (torch.half, torch.complex32)
  540. ),
  541. ),
  542. check_batched_gradgrad=False,
  543. ),
  544. SpectralFuncInfo(
  545. "fft.irfft2",
  546. aten_name="fft_irfft2",
  547. decomp_aten_name="_fft_c2r",
  548. ref=np.fft.irfft2,
  549. ndimensional=SpectralFuncType.TwoD,
  550. error_inputs_func=error_inputs_fftn,
  551. # https://github.com/pytorch/pytorch/issues/80411
  552. gradcheck_fast_mode=True,
  553. supports_forward_ad=True,
  554. supports_fwgrad_bwgrad=True,
  555. # See https://github.com/pytorch/pytorch/pull/78358
  556. check_batched_forward_grad=False,
  557. dtypes=all_types_and_complex_and(torch.bool),
  558. # rocFFT doesn't support Half/Complex Half Precision FFT
  559. # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
  560. dtypesIfCUDA=all_types_and_complex_and(
  561. torch.bool,
  562. *(
  563. ()
  564. if (TEST_WITH_ROCM or not SM53OrLater)
  565. else (torch.half, torch.complex32)
  566. ),
  567. ),
  568. check_batched_gradgrad=False,
  569. decorators=[
  570. DecorateInfo(
  571. precisionOverride({torch.float: 1e-4, torch.cfloat: 1e-4}),
  572. "TestFFT",
  573. "test_reference_nd",
  574. )
  575. ],
  576. ),
  577. SpectralFuncInfo(
  578. "fft.irfftn",
  579. aten_name="fft_irfftn",
  580. decomp_aten_name="_fft_c2r",
  581. ref=np.fft.irfftn,
  582. ndimensional=SpectralFuncType.ND,
  583. error_inputs_func=error_inputs_fftn,
  584. # https://github.com/pytorch/pytorch/issues/80411
  585. gradcheck_fast_mode=True,
  586. supports_forward_ad=True,
  587. supports_fwgrad_bwgrad=True,
  588. # See https://github.com/pytorch/pytorch/pull/78358
  589. check_batched_forward_grad=False,
  590. dtypes=all_types_and_complex_and(torch.bool),
  591. # rocFFT doesn't support Half/Complex Half Precision FFT
  592. # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
  593. dtypesIfCUDA=all_types_and_complex_and(
  594. torch.bool,
  595. *(
  596. ()
  597. if (TEST_WITH_ROCM or not SM53OrLater)
  598. else (torch.half, torch.complex32)
  599. ),
  600. ),
  601. check_batched_gradgrad=False,
  602. decorators=[
  603. DecorateInfo(
  604. precisionOverride({torch.float: 1e-4, torch.cfloat: 1e-4}),
  605. "TestFFT",
  606. "test_reference_nd",
  607. )
  608. ],
  609. ),
  610. OpInfo(
  611. "fft.fftshift",
  612. dtypes=all_types_and_complex_and(
  613. torch.bool, torch.bfloat16, torch.half, torch.chalf
  614. ),
  615. sample_inputs_func=sample_inputs_fftshift,
  616. supports_out=False,
  617. supports_forward_ad=True,
  618. supports_fwgrad_bwgrad=True,
  619. ),
  620. OpInfo(
  621. "fft.ifftshift",
  622. dtypes=all_types_and_complex_and(
  623. torch.bool, torch.bfloat16, torch.half, torch.chalf
  624. ),
  625. sample_inputs_func=sample_inputs_fftshift,
  626. supports_out=False,
  627. supports_forward_ad=True,
  628. supports_fwgrad_bwgrad=True,
  629. ),
  630. ]
  631. python_ref_db: List[OpInfo] = [
  632. SpectralFuncPythonRefInfo(
  633. "_refs.fft.fft",
  634. torch_opinfo_name="fft.fft",
  635. supports_nvfuser=False,
  636. ),
  637. SpectralFuncPythonRefInfo(
  638. "_refs.fft.ifft",
  639. torch_opinfo_name="fft.ifft",
  640. supports_nvfuser=False,
  641. ),
  642. SpectralFuncPythonRefInfo(
  643. "_refs.fft.rfft",
  644. torch_opinfo_name="fft.rfft",
  645. supports_nvfuser=False,
  646. ),
  647. SpectralFuncPythonRefInfo(
  648. "_refs.fft.irfft",
  649. torch_opinfo_name="fft.irfft",
  650. supports_nvfuser=False,
  651. ),
  652. SpectralFuncPythonRefInfo(
  653. "_refs.fft.hfft",
  654. torch_opinfo_name="fft.hfft",
  655. supports_nvfuser=False,
  656. ),
  657. SpectralFuncPythonRefInfo(
  658. "_refs.fft.ihfft",
  659. torch_opinfo_name="fft.ihfft",
  660. supports_nvfuser=False,
  661. ),
  662. SpectralFuncPythonRefInfo(
  663. "_refs.fft.fftn",
  664. torch_opinfo_name="fft.fftn",
  665. supports_nvfuser=False,
  666. ),
  667. SpectralFuncPythonRefInfo(
  668. "_refs.fft.ifftn",
  669. torch_opinfo_name="fft.ifftn",
  670. supports_nvfuser=False,
  671. ),
  672. SpectralFuncPythonRefInfo(
  673. "_refs.fft.rfftn",
  674. torch_opinfo_name="fft.rfftn",
  675. supports_nvfuser=False,
  676. ),
  677. SpectralFuncPythonRefInfo(
  678. "_refs.fft.irfftn",
  679. torch_opinfo_name="fft.irfftn",
  680. supports_nvfuser=False,
  681. ),
  682. SpectralFuncPythonRefInfo(
  683. "_refs.fft.hfftn",
  684. torch_opinfo_name="fft.hfftn",
  685. supports_nvfuser=False,
  686. ),
  687. SpectralFuncPythonRefInfo(
  688. "_refs.fft.ihfftn",
  689. torch_opinfo_name="fft.ihfftn",
  690. supports_nvfuser=False,
  691. ),
  692. SpectralFuncPythonRefInfo(
  693. "_refs.fft.fft2",
  694. torch_opinfo_name="fft.fft2",
  695. supports_nvfuser=False,
  696. ),
  697. SpectralFuncPythonRefInfo(
  698. "_refs.fft.ifft2",
  699. torch_opinfo_name="fft.ifft2",
  700. supports_nvfuser=False,
  701. ),
  702. SpectralFuncPythonRefInfo(
  703. "_refs.fft.rfft2",
  704. torch_opinfo_name="fft.rfft2",
  705. supports_nvfuser=False,
  706. ),
  707. SpectralFuncPythonRefInfo(
  708. "_refs.fft.irfft2",
  709. torch_opinfo_name="fft.irfft2",
  710. supports_nvfuser=False,
  711. ),
  712. SpectralFuncPythonRefInfo(
  713. "_refs.fft.hfft2",
  714. torch_opinfo_name="fft.hfft2",
  715. supports_nvfuser=False,
  716. ),
  717. SpectralFuncPythonRefInfo(
  718. "_refs.fft.ihfft2",
  719. torch_opinfo_name="fft.ihfft2",
  720. supports_nvfuser=False,
  721. ),
  722. PythonRefInfo(
  723. "_refs.fft.fftshift",
  724. op_db=op_db,
  725. torch_opinfo_name="fft.fftshift",
  726. supports_nvfuser=False,
  727. ),
  728. PythonRefInfo(
  729. "_refs.fft.ifftshift",
  730. op_db=op_db,
  731. torch_opinfo_name="fft.ifftshift",
  732. supports_nvfuser=False,
  733. ),
  734. ]