fft.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568
  1. import math
  2. from typing import Iterable, List, Literal, NamedTuple, Optional, Sequence, Tuple, Union
  3. import torch
  4. import torch._prims as prims
  5. import torch._prims_common as utils
  6. from torch._decomp import register_decomposition
  7. from torch._prims_common import check, DimsType, ShapeType, TensorLikeType
  8. from torch._prims_common.wrappers import _maybe_convert_to_dtype, out_wrapper
  9. __all__ = [
  10. # Transforms
  11. "fft",
  12. "fft2",
  13. "fftn",
  14. "hfft",
  15. "hfft2",
  16. "hfftn",
  17. "rfft",
  18. "rfft2",
  19. "rfftn",
  20. "ifft",
  21. "ifft2",
  22. "ifftn",
  23. "ihfft",
  24. "ihfft2",
  25. "ihfftn",
  26. "irfft",
  27. "irfft2",
  28. "irfftn",
  29. # Helpers
  30. "fftshift",
  31. "ifftshift",
  32. ]
  33. NormType = Union[None, Literal["forward"], Literal["backward"], Literal["ortho"]]
  34. _NORM_VALUES = {None, "forward", "backward", "ortho"}
  35. aten = torch._ops.ops.aten
  36. def _apply_norm(
  37. x: TensorLikeType, norm: NormType, signal_numel: int, forward: bool
  38. ) -> TensorLikeType:
  39. """Apply normalization to the un-normalized FFT result"""
  40. check(norm in _NORM_VALUES, lambda: f"Invalid normalization mode: {norm}")
  41. if norm == "ortho":
  42. return x * (1 / math.sqrt(signal_numel))
  43. normalize = (not forward and (norm is None or norm == "backward")) or (
  44. forward and norm == "forward"
  45. )
  46. return x * (1 / signal_numel) if normalize else x
  47. def _promote_type_fft(dtype: torch.dtype, require_complex: bool) -> torch.dtype:
  48. """Helper to promote a dtype to one supported by the FFT primitives"""
  49. if dtype.is_complex:
  50. return dtype
  51. # Promote integral to default float type
  52. if not dtype.is_floating_point:
  53. dtype = torch.get_default_dtype()
  54. if require_complex:
  55. dtype = utils.corresponding_complex_dtype(dtype)
  56. return dtype
  57. def _maybe_promote_tensor_fft(
  58. t: TensorLikeType, require_complex: bool = False
  59. ) -> TensorLikeType:
  60. """Helper to promote a tensor to a dtype supported by the FFT primitives"""
  61. cur_type = t.dtype
  62. new_type = _promote_type_fft(cur_type, require_complex)
  63. return _maybe_convert_to_dtype(t, new_type) # type: ignore[return-value]
  64. def _resize_fft_input(
  65. x: TensorLikeType, dims: Tuple[int, ...], sizes: Tuple[int, ...]
  66. ) -> TensorLikeType:
  67. """
  68. Fixes the shape of x such that x.size(dims[i]) == sizes[i],
  69. either by zero-padding, or by slicing x starting from 0.
  70. """
  71. assert len(dims) == len(sizes)
  72. must_copy = False
  73. x_sizes = x.shape
  74. pad_amount = [0] * len(x_sizes) * 2
  75. for i in range(len(dims)):
  76. if sizes[i] == -1:
  77. continue
  78. if x_sizes[dims[i]] < sizes[i]:
  79. must_copy = True
  80. pad_idx = len(pad_amount) - 2 * dims[i] - 1
  81. pad_amount[pad_idx] = sizes[i] - x_sizes[dims[i]]
  82. if x_sizes[dims[i]] > sizes[i]:
  83. x = x.narrow(dims[i], 0, sizes[i])
  84. return torch.constant_pad_nd(x, pad_amount) if must_copy else x
  85. def _fft_c2r(
  86. func_name: str,
  87. input: TensorLikeType,
  88. n: Optional[int],
  89. dim: int,
  90. norm: NormType,
  91. forward: bool,
  92. ) -> TensorLikeType:
  93. """Common code for performing any complex to real FFT (irfft or hfft)"""
  94. input = _maybe_promote_tensor_fft(input, require_complex=True)
  95. dims = (utils.canonicalize_dim(input.ndim, dim, wrap_scalar=False),)
  96. last_dim_size = n if n is not None else 2 * (input.shape[dim] - 1)
  97. check(last_dim_size >= 1, lambda: f"Invalid number of data points ({n}) specified")
  98. if n is not None:
  99. input = _resize_fft_input(input, dims=dims, sizes=(last_dim_size // 2 + 1,))
  100. if forward:
  101. input = torch.conj(input)
  102. output = prims.fft_c2r(input, dim=dims, last_dim_size=last_dim_size)
  103. return _apply_norm(output, norm=norm, signal_numel=last_dim_size, forward=forward)
  104. def _fft_r2c(
  105. func_name: str,
  106. input: TensorLikeType,
  107. n: Optional[int],
  108. dim: int,
  109. norm: NormType,
  110. forward: bool,
  111. onesided: bool,
  112. ) -> TensorLikeType:
  113. """Common code for performing any real to complex FFT (rfft or ihfft)"""
  114. check(
  115. not input.dtype.is_complex,
  116. lambda: f"{func_name} expects a floating point input tensor, but got {input.dtype}",
  117. )
  118. input = _maybe_promote_tensor_fft(input)
  119. dims = (utils.canonicalize_dim(input.ndim, dim, wrap_scalar=False),)
  120. if n is not None:
  121. input = _resize_fft_input(input, dims, (n,))
  122. ret = prims.fft_r2c(input, dim=dims, onesided=onesided)
  123. ret = _apply_norm(ret, norm, input.shape[dim], forward)
  124. return ret if forward else torch.conj(ret)
  125. def _fft_c2c(
  126. func_name: str,
  127. input: TensorLikeType,
  128. n: Optional[int],
  129. dim: int,
  130. norm: NormType,
  131. forward: bool,
  132. ) -> TensorLikeType:
  133. """Common code for performing any complex to complex FFT (fft or ifft)"""
  134. check(
  135. input.dtype.is_complex,
  136. lambda: f"{func_name} expects a complex input tensor, but got {input.dtype}",
  137. )
  138. dims = (utils.canonicalize_dim(input.ndim, dim, wrap_scalar=False),)
  139. if n is not None:
  140. input = _resize_fft_input(input, dims, (n,))
  141. ret = prims.fft_c2c(input, dim=dims, forward=forward)
  142. return _apply_norm(ret, norm, input.shape[dim], forward)
  143. @register_decomposition(aten.fft_fft)
  144. @out_wrapper()
  145. def fft(
  146. input: TensorLikeType,
  147. n: Optional[int] = None,
  148. dim: int = -1,
  149. norm: NormType = None,
  150. ) -> TensorLikeType:
  151. if input.dtype.is_complex:
  152. return _fft_c2c("fft", input, n, dim, norm, forward=True)
  153. else:
  154. return _fft_r2c("fft", input, n, dim, norm, forward=True, onesided=False)
  155. @register_decomposition(aten.fft_ifft)
  156. @out_wrapper()
  157. def ifft(
  158. input: TensorLikeType,
  159. n: Optional[int] = None,
  160. dim: int = -1,
  161. norm: NormType = None,
  162. ) -> TensorLikeType:
  163. if input.dtype.is_complex:
  164. return _fft_c2c("ifft", input, n, dim, norm, forward=False)
  165. else:
  166. return _fft_r2c("ifft", input, n, dim, norm, forward=False, onesided=False)
  167. @register_decomposition(aten.fft_rfft)
  168. @out_wrapper()
  169. def rfft(
  170. input: TensorLikeType,
  171. n: Optional[int] = None,
  172. dim: int = -1,
  173. norm: NormType = None,
  174. ) -> TensorLikeType:
  175. return _fft_r2c("rfft", input, n, dim, norm, forward=True, onesided=True)
  176. @register_decomposition(aten.fft_irfft)
  177. @out_wrapper()
  178. def irfft(
  179. input: TensorLikeType,
  180. n: Optional[int] = None,
  181. dim: int = -1,
  182. norm: NormType = None,
  183. ) -> TensorLikeType:
  184. return _fft_c2r("irfft", input, n, dim, norm, forward=False)
  185. @register_decomposition(aten.fft_hfft)
  186. @out_wrapper()
  187. def hfft(
  188. input: TensorLikeType,
  189. n: Optional[int] = None,
  190. dim: int = -1,
  191. norm: NormType = None,
  192. ) -> TensorLikeType:
  193. return _fft_c2r("hfft", input, n, dim, norm, forward=True)
  194. @register_decomposition(aten.fft_ihfft)
  195. @out_wrapper()
  196. def ihfft(
  197. input: TensorLikeType,
  198. n: Optional[int] = None,
  199. dim: int = -1,
  200. norm: NormType = None,
  201. ) -> TensorLikeType:
  202. return _fft_r2c("ihfft", input, n, dim, norm, forward=False, onesided=True)
  203. class _ShapeAndDims(NamedTuple):
  204. shape: Tuple[int, ...]
  205. dims: Tuple[int, ...]
  206. def _canonicalize_fft_shape_and_dim_args(
  207. input: TensorLikeType, shape: Optional[ShapeType], dim: Optional[DimsType]
  208. ) -> _ShapeAndDims:
  209. """Convert the shape and dim arguments into a canonical form where neither are optional"""
  210. input_dim = input.ndim
  211. input_sizes = input.shape
  212. if dim is not None:
  213. if not isinstance(dim, Sequence):
  214. dim = (dim,)
  215. ret_dims = utils.canonicalize_dims(input_dim, dim, wrap_scalar=False)
  216. # Check dims are unique
  217. check(len(set(dim)) == len(dim), lambda: "FFT dims must be unique")
  218. if shape is not None:
  219. if not isinstance(shape, Sequence):
  220. shape = (shape,)
  221. # Has shape, might have dim
  222. check(
  223. dim is None or len(dim) == len(shape),
  224. lambda: "When given, dim and shape arguments must have the same length",
  225. )
  226. transform_ndim = len(shape)
  227. check(
  228. transform_ndim <= input_dim,
  229. lambda: f"Got shape with {transform_ndim} values but input tensor "
  230. f"only has {input_dim} dimensions.",
  231. )
  232. # If shape is given, dims defaults to the last len(shape) dimensions
  233. if dim is None:
  234. ret_dims = tuple(range(input_dim - transform_ndim, input_dim))
  235. # Translate any -1 values in shape to the default length
  236. ret_shape = tuple(
  237. s if s != -1 else input_sizes[d] for (s, d) in zip(shape, ret_dims)
  238. )
  239. elif dim is None:
  240. # No shape, no dim
  241. ret_dims = tuple(range(input_dim))
  242. ret_shape = tuple(input_sizes)
  243. else:
  244. # No shape, has dim
  245. ret_shape = tuple(input_sizes[d] for d in ret_dims)
  246. for n in ret_shape:
  247. check(n > 0, lambda: f"Invalid number of data points ({n}) specified")
  248. return _ShapeAndDims(shape=ret_shape, dims=ret_dims)
  249. def _prod(xs: Iterable[int]) -> int:
  250. """Compute product of a list"""
  251. prod = 1
  252. for x in xs:
  253. prod *= x
  254. return prod
  255. def _fftn_c2c(
  256. function_name: str,
  257. input: TensorLikeType,
  258. shape: Tuple[int, ...],
  259. dim: Tuple[int, ...],
  260. norm: NormType,
  261. forward: bool,
  262. ) -> TensorLikeType:
  263. """Common code for n-dimensional complex to complex FFTs (fftn or ifftn)"""
  264. check(
  265. input.dtype.is_complex,
  266. lambda: f"{function_name} expects a complex input tensor, "
  267. f"but got {input.dtype}",
  268. )
  269. x = _resize_fft_input(input, dim, shape)
  270. output = prims.fft_c2c(x, dim=dim, forward=forward)
  271. return _apply_norm(output, norm=norm, signal_numel=_prod(shape), forward=forward)
  272. @register_decomposition(aten.fft_fftn)
  273. @out_wrapper()
  274. def fftn(
  275. input: TensorLikeType,
  276. s: Optional[ShapeType] = None,
  277. dim: Optional[DimsType] = None,
  278. norm: NormType = None,
  279. ) -> TensorLikeType:
  280. (shape, dim) = _canonicalize_fft_shape_and_dim_args(input, s, dim)
  281. x = _maybe_promote_tensor_fft(input, require_complex=True)
  282. return _fftn_c2c("fftn", x, shape, dim, norm, forward=True)
  283. @register_decomposition(aten.fft_ifftn)
  284. @out_wrapper()
  285. def ifftn(
  286. input: TensorLikeType,
  287. s: Optional[ShapeType] = None,
  288. dim: Optional[DimsType] = None,
  289. norm: NormType = None,
  290. ) -> TensorLikeType:
  291. (shape, dim) = _canonicalize_fft_shape_and_dim_args(input, s, dim)
  292. x = _maybe_promote_tensor_fft(input, require_complex=True)
  293. return _fftn_c2c("ifftn", x, shape, dim, norm, forward=False)
  294. @register_decomposition(aten.fft_rfftn)
  295. @out_wrapper()
  296. def rfftn(
  297. input: TensorLikeType,
  298. s: Optional[ShapeType] = None,
  299. dim: Optional[DimsType] = None,
  300. norm: NormType = None,
  301. ) -> TensorLikeType:
  302. check(
  303. not input.dtype.is_complex,
  304. lambda: f"rfftn expects a real-valued input tensor, but got {input.dtype}",
  305. )
  306. shape, dim = _canonicalize_fft_shape_and_dim_args(input, s, dim)
  307. input = _maybe_promote_tensor_fft(input, require_complex=False)
  308. input = _resize_fft_input(input, dim, shape)
  309. out = prims.fft_r2c(input, dim=dim, onesided=True)
  310. return _apply_norm(out, norm=norm, signal_numel=_prod(shape), forward=True)
  311. @register_decomposition(aten.fft_ihfftn)
  312. @out_wrapper()
  313. def ihfftn(
  314. input: TensorLikeType,
  315. s: Optional[ShapeType] = None,
  316. dim: Optional[DimsType] = None,
  317. norm: NormType = None,
  318. ) -> TensorLikeType:
  319. check(
  320. not input.dtype.is_complex,
  321. lambda: f"ihfftn expects a real-valued input tensor, but got {input.dtype}",
  322. )
  323. shape, dim = _canonicalize_fft_shape_and_dim_args(input, s, dim)
  324. check(len(shape) > 0, lambda: "ihfftn must transform at least one axis")
  325. input = _maybe_promote_tensor_fft(input, require_complex=False)
  326. input = _resize_fft_input(input, dim, shape)
  327. tmp = prims.fft_r2c(input, dim=dim[-1:], onesided=True)
  328. if len(dim) == 1:
  329. tmp = _apply_norm(tmp, norm=norm, signal_numel=shape[0], forward=False)
  330. return prims.conj(tmp)
  331. tmp = prims.conj_physical(tmp)
  332. tmp = prims.fft_c2c(tmp, dim=dim[:-1], forward=False)
  333. return _apply_norm(tmp, norm=norm, signal_numel=_prod(shape), forward=False)
  334. class _CanonicalizeC2rReturn(NamedTuple):
  335. shape: Tuple[int, ...]
  336. dim: Tuple[int, ...]
  337. last_dim_size: int
  338. def _canonicalize_fft_c2r_shape_and_dim_args(
  339. fname: str,
  340. input: TensorLikeType,
  341. s: Optional[ShapeType],
  342. dim: Optional[DimsType],
  343. ) -> _CanonicalizeC2rReturn:
  344. """Canonicalize shape and dim arguments for n-dimensional c2r transforms,
  345. as well as calculating the last_dim_size which is shape[dim[-1]] for the output"""
  346. (shape, dim) = _canonicalize_fft_shape_and_dim_args(input, s, dim)
  347. check(len(shape) > 0, lambda: f"{fname} must transform at least one axis")
  348. if s is None or s[-1] == -1:
  349. last_dim_size = 2 * (input.shape[dim[-1]] - 1)
  350. else:
  351. last_dim_size = shape[-1]
  352. check(
  353. last_dim_size >= 1,
  354. lambda: f"Invalid number of data points ({last_dim_size}) specified",
  355. )
  356. shape_list = list(shape)
  357. shape_list[-1] = last_dim_size // 2 + 1
  358. return _CanonicalizeC2rReturn(
  359. shape=tuple(shape_list), dim=dim, last_dim_size=last_dim_size
  360. )
  361. @register_decomposition(aten.fft_irfftn)
  362. @out_wrapper()
  363. def irfftn(
  364. input: TensorLikeType,
  365. s: Optional[ShapeType] = None,
  366. dim: Optional[DimsType] = None,
  367. norm: NormType = None,
  368. ) -> TensorLikeType:
  369. shape, dim, last_dim_size = _canonicalize_fft_c2r_shape_and_dim_args(
  370. "irfftn", input, s, dim
  371. )
  372. input = _maybe_promote_tensor_fft(input, require_complex=True)
  373. input = _resize_fft_input(input, dim, shape)
  374. out = prims.fft_c2r(input, dim=dim, last_dim_size=last_dim_size)
  375. return _apply_norm(out, norm, _prod(out.shape[d] for d in dim), forward=False)
  376. @register_decomposition(aten.fft_hfftn)
  377. @out_wrapper()
  378. def hfftn(
  379. input: TensorLikeType,
  380. s: Optional[ShapeType] = None,
  381. dim: Optional[DimsType] = None,
  382. norm: NormType = None,
  383. ) -> TensorLikeType:
  384. shape, dim, last_dim_size = _canonicalize_fft_c2r_shape_and_dim_args(
  385. "hfftn", input, s, dim
  386. )
  387. input = _maybe_promote_tensor_fft(input, require_complex=True)
  388. input = _resize_fft_input(input, dim, shape)
  389. tmp = prims.fft_c2c(input, dim=dim[:-1], forward=True) if len(dim) > 1 else input
  390. tmp = _apply_norm(tmp, norm, _prod(shape[:-1]), forward=True)
  391. tmp = prims.conj_physical(tmp)
  392. out = prims.fft_c2r(tmp, dim=dim[-1:], last_dim_size=last_dim_size)
  393. return _apply_norm(out, norm, last_dim_size, forward=True)
  394. @register_decomposition(aten.fft_fft2)
  395. @out_wrapper()
  396. def fft2(
  397. input: TensorLikeType,
  398. s: Optional[ShapeType] = None,
  399. dim: Optional[DimsType] = (-2, -1),
  400. norm: NormType = None,
  401. ) -> TensorLikeType:
  402. return torch.fft.fftn(input, s=s, dim=dim, norm=norm)
  403. @register_decomposition(aten.fft_ifft2)
  404. @out_wrapper()
  405. def ifft2(
  406. input: TensorLikeType,
  407. s: Optional[ShapeType] = None,
  408. dim: Optional[DimsType] = (-2, -1),
  409. norm: NormType = None,
  410. ) -> TensorLikeType:
  411. return torch.fft.ifftn(input, s=s, dim=dim, norm=norm)
  412. @register_decomposition(aten.fft_rfft2)
  413. @out_wrapper()
  414. def rfft2(
  415. input: TensorLikeType,
  416. s: Optional[ShapeType] = None,
  417. dim: Optional[DimsType] = (-2, -1),
  418. norm: NormType = None,
  419. ) -> TensorLikeType:
  420. return torch.fft.rfftn(input, s=s, dim=dim, norm=norm)
  421. @register_decomposition(aten.fft_irfft2)
  422. @out_wrapper()
  423. def irfft2(
  424. input: TensorLikeType,
  425. s: Optional[ShapeType] = None,
  426. dim: Optional[DimsType] = (-2, -1),
  427. norm: NormType = None,
  428. ) -> TensorLikeType:
  429. return torch.fft.irfftn(input, s=s, dim=dim, norm=norm)
  430. @register_decomposition(aten.fft_hfft2)
  431. @out_wrapper()
  432. def hfft2(
  433. input: TensorLikeType,
  434. s: Optional[ShapeType] = None,
  435. dim: Optional[DimsType] = (-2, -1),
  436. norm: NormType = None,
  437. ) -> TensorLikeType:
  438. return torch.fft.hfftn(input, s=s, dim=dim, norm=norm)
  439. @register_decomposition(aten.fft_ihfft2)
  440. @out_wrapper()
  441. def ihfft2(
  442. input: TensorLikeType,
  443. s: Optional[ShapeType] = None,
  444. dim: Optional[DimsType] = (-2, -1),
  445. norm: NormType = None,
  446. ) -> TensorLikeType:
  447. return torch.fft.ihfftn(input, s=s, dim=dim, norm=norm)
  448. def _default_alldims(dim: Optional[DimsType], x: TensorLikeType) -> List[int]:
  449. """Convert Optional[DimsType] to a simple list, defaulting to all dimensions"""
  450. if dim is None:
  451. return list(range(x.ndim))
  452. elif not isinstance(dim, Sequence):
  453. return [dim]
  454. else:
  455. return list(dim)
  456. @register_decomposition(aten.fft_fftshift)
  457. def fftshift(input: TensorLikeType, dim: Optional[DimsType] = None) -> TensorLikeType:
  458. dims = _default_alldims(dim, input)
  459. shift = [input.shape[d] // 2 for d in dims]
  460. return torch.roll(input, shift, dims)
  461. @register_decomposition(aten.fft_ifftshift)
  462. def ifftshift(input: TensorLikeType, dim: Optional[DimsType] = None) -> TensorLikeType:
  463. dims = _default_alldims(dim, input)
  464. shift = [(input.shape[d] + 1) // 2 for d in dims]
  465. return torch.roll(input, shift, dims)