123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568 |
- import math
- from typing import Iterable, List, Literal, NamedTuple, Optional, Sequence, Tuple, Union
- import torch
- import torch._prims as prims
- import torch._prims_common as utils
- from torch._decomp import register_decomposition
- from torch._prims_common import check, DimsType, ShapeType, TensorLikeType
- from torch._prims_common.wrappers import _maybe_convert_to_dtype, out_wrapper
- __all__ = [
- # Transforms
- "fft",
- "fft2",
- "fftn",
- "hfft",
- "hfft2",
- "hfftn",
- "rfft",
- "rfft2",
- "rfftn",
- "ifft",
- "ifft2",
- "ifftn",
- "ihfft",
- "ihfft2",
- "ihfftn",
- "irfft",
- "irfft2",
- "irfftn",
- # Helpers
- "fftshift",
- "ifftshift",
- ]
- NormType = Union[None, Literal["forward"], Literal["backward"], Literal["ortho"]]
- _NORM_VALUES = {None, "forward", "backward", "ortho"}
- aten = torch._ops.ops.aten
- def _apply_norm(
- x: TensorLikeType, norm: NormType, signal_numel: int, forward: bool
- ) -> TensorLikeType:
- """Apply normalization to the un-normalized FFT result"""
- check(norm in _NORM_VALUES, lambda: f"Invalid normalization mode: {norm}")
- if norm == "ortho":
- return x * (1 / math.sqrt(signal_numel))
- normalize = (not forward and (norm is None or norm == "backward")) or (
- forward and norm == "forward"
- )
- return x * (1 / signal_numel) if normalize else x
- def _promote_type_fft(dtype: torch.dtype, require_complex: bool) -> torch.dtype:
- """Helper to promote a dtype to one supported by the FFT primitives"""
- if dtype.is_complex:
- return dtype
- # Promote integral to default float type
- if not dtype.is_floating_point:
- dtype = torch.get_default_dtype()
- if require_complex:
- dtype = utils.corresponding_complex_dtype(dtype)
- return dtype
- def _maybe_promote_tensor_fft(
- t: TensorLikeType, require_complex: bool = False
- ) -> TensorLikeType:
- """Helper to promote a tensor to a dtype supported by the FFT primitives"""
- cur_type = t.dtype
- new_type = _promote_type_fft(cur_type, require_complex)
- return _maybe_convert_to_dtype(t, new_type) # type: ignore[return-value]
- def _resize_fft_input(
- x: TensorLikeType, dims: Tuple[int, ...], sizes: Tuple[int, ...]
- ) -> TensorLikeType:
- """
- Fixes the shape of x such that x.size(dims[i]) == sizes[i],
- either by zero-padding, or by slicing x starting from 0.
- """
- assert len(dims) == len(sizes)
- must_copy = False
- x_sizes = x.shape
- pad_amount = [0] * len(x_sizes) * 2
- for i in range(len(dims)):
- if sizes[i] == -1:
- continue
- if x_sizes[dims[i]] < sizes[i]:
- must_copy = True
- pad_idx = len(pad_amount) - 2 * dims[i] - 1
- pad_amount[pad_idx] = sizes[i] - x_sizes[dims[i]]
- if x_sizes[dims[i]] > sizes[i]:
- x = x.narrow(dims[i], 0, sizes[i])
- return torch.constant_pad_nd(x, pad_amount) if must_copy else x
- def _fft_c2r(
- func_name: str,
- input: TensorLikeType,
- n: Optional[int],
- dim: int,
- norm: NormType,
- forward: bool,
- ) -> TensorLikeType:
- """Common code for performing any complex to real FFT (irfft or hfft)"""
- input = _maybe_promote_tensor_fft(input, require_complex=True)
- dims = (utils.canonicalize_dim(input.ndim, dim, wrap_scalar=False),)
- last_dim_size = n if n is not None else 2 * (input.shape[dim] - 1)
- check(last_dim_size >= 1, lambda: f"Invalid number of data points ({n}) specified")
- if n is not None:
- input = _resize_fft_input(input, dims=dims, sizes=(last_dim_size // 2 + 1,))
- if forward:
- input = torch.conj(input)
- output = prims.fft_c2r(input, dim=dims, last_dim_size=last_dim_size)
- return _apply_norm(output, norm=norm, signal_numel=last_dim_size, forward=forward)
- def _fft_r2c(
- func_name: str,
- input: TensorLikeType,
- n: Optional[int],
- dim: int,
- norm: NormType,
- forward: bool,
- onesided: bool,
- ) -> TensorLikeType:
- """Common code for performing any real to complex FFT (rfft or ihfft)"""
- check(
- not input.dtype.is_complex,
- lambda: f"{func_name} expects a floating point input tensor, but got {input.dtype}",
- )
- input = _maybe_promote_tensor_fft(input)
- dims = (utils.canonicalize_dim(input.ndim, dim, wrap_scalar=False),)
- if n is not None:
- input = _resize_fft_input(input, dims, (n,))
- ret = prims.fft_r2c(input, dim=dims, onesided=onesided)
- ret = _apply_norm(ret, norm, input.shape[dim], forward)
- return ret if forward else torch.conj(ret)
- def _fft_c2c(
- func_name: str,
- input: TensorLikeType,
- n: Optional[int],
- dim: int,
- norm: NormType,
- forward: bool,
- ) -> TensorLikeType:
- """Common code for performing any complex to complex FFT (fft or ifft)"""
- check(
- input.dtype.is_complex,
- lambda: f"{func_name} expects a complex input tensor, but got {input.dtype}",
- )
- dims = (utils.canonicalize_dim(input.ndim, dim, wrap_scalar=False),)
- if n is not None:
- input = _resize_fft_input(input, dims, (n,))
- ret = prims.fft_c2c(input, dim=dims, forward=forward)
- return _apply_norm(ret, norm, input.shape[dim], forward)
- @register_decomposition(aten.fft_fft)
- @out_wrapper()
- def fft(
- input: TensorLikeType,
- n: Optional[int] = None,
- dim: int = -1,
- norm: NormType = None,
- ) -> TensorLikeType:
- if input.dtype.is_complex:
- return _fft_c2c("fft", input, n, dim, norm, forward=True)
- else:
- return _fft_r2c("fft", input, n, dim, norm, forward=True, onesided=False)
- @register_decomposition(aten.fft_ifft)
- @out_wrapper()
- def ifft(
- input: TensorLikeType,
- n: Optional[int] = None,
- dim: int = -1,
- norm: NormType = None,
- ) -> TensorLikeType:
- if input.dtype.is_complex:
- return _fft_c2c("ifft", input, n, dim, norm, forward=False)
- else:
- return _fft_r2c("ifft", input, n, dim, norm, forward=False, onesided=False)
- @register_decomposition(aten.fft_rfft)
- @out_wrapper()
- def rfft(
- input: TensorLikeType,
- n: Optional[int] = None,
- dim: int = -1,
- norm: NormType = None,
- ) -> TensorLikeType:
- return _fft_r2c("rfft", input, n, dim, norm, forward=True, onesided=True)
- @register_decomposition(aten.fft_irfft)
- @out_wrapper()
- def irfft(
- input: TensorLikeType,
- n: Optional[int] = None,
- dim: int = -1,
- norm: NormType = None,
- ) -> TensorLikeType:
- return _fft_c2r("irfft", input, n, dim, norm, forward=False)
- @register_decomposition(aten.fft_hfft)
- @out_wrapper()
- def hfft(
- input: TensorLikeType,
- n: Optional[int] = None,
- dim: int = -1,
- norm: NormType = None,
- ) -> TensorLikeType:
- return _fft_c2r("hfft", input, n, dim, norm, forward=True)
- @register_decomposition(aten.fft_ihfft)
- @out_wrapper()
- def ihfft(
- input: TensorLikeType,
- n: Optional[int] = None,
- dim: int = -1,
- norm: NormType = None,
- ) -> TensorLikeType:
- return _fft_r2c("ihfft", input, n, dim, norm, forward=False, onesided=True)
- class _ShapeAndDims(NamedTuple):
- shape: Tuple[int, ...]
- dims: Tuple[int, ...]
- def _canonicalize_fft_shape_and_dim_args(
- input: TensorLikeType, shape: Optional[ShapeType], dim: Optional[DimsType]
- ) -> _ShapeAndDims:
- """Convert the shape and dim arguments into a canonical form where neither are optional"""
- input_dim = input.ndim
- input_sizes = input.shape
- if dim is not None:
- if not isinstance(dim, Sequence):
- dim = (dim,)
- ret_dims = utils.canonicalize_dims(input_dim, dim, wrap_scalar=False)
- # Check dims are unique
- check(len(set(dim)) == len(dim), lambda: "FFT dims must be unique")
- if shape is not None:
- if not isinstance(shape, Sequence):
- shape = (shape,)
- # Has shape, might have dim
- check(
- dim is None or len(dim) == len(shape),
- lambda: "When given, dim and shape arguments must have the same length",
- )
- transform_ndim = len(shape)
- check(
- transform_ndim <= input_dim,
- lambda: f"Got shape with {transform_ndim} values but input tensor "
- f"only has {input_dim} dimensions.",
- )
- # If shape is given, dims defaults to the last len(shape) dimensions
- if dim is None:
- ret_dims = tuple(range(input_dim - transform_ndim, input_dim))
- # Translate any -1 values in shape to the default length
- ret_shape = tuple(
- s if s != -1 else input_sizes[d] for (s, d) in zip(shape, ret_dims)
- )
- elif dim is None:
- # No shape, no dim
- ret_dims = tuple(range(input_dim))
- ret_shape = tuple(input_sizes)
- else:
- # No shape, has dim
- ret_shape = tuple(input_sizes[d] for d in ret_dims)
- for n in ret_shape:
- check(n > 0, lambda: f"Invalid number of data points ({n}) specified")
- return _ShapeAndDims(shape=ret_shape, dims=ret_dims)
- def _prod(xs: Iterable[int]) -> int:
- """Compute product of a list"""
- prod = 1
- for x in xs:
- prod *= x
- return prod
- def _fftn_c2c(
- function_name: str,
- input: TensorLikeType,
- shape: Tuple[int, ...],
- dim: Tuple[int, ...],
- norm: NormType,
- forward: bool,
- ) -> TensorLikeType:
- """Common code for n-dimensional complex to complex FFTs (fftn or ifftn)"""
- check(
- input.dtype.is_complex,
- lambda: f"{function_name} expects a complex input tensor, "
- f"but got {input.dtype}",
- )
- x = _resize_fft_input(input, dim, shape)
- output = prims.fft_c2c(x, dim=dim, forward=forward)
- return _apply_norm(output, norm=norm, signal_numel=_prod(shape), forward=forward)
- @register_decomposition(aten.fft_fftn)
- @out_wrapper()
- def fftn(
- input: TensorLikeType,
- s: Optional[ShapeType] = None,
- dim: Optional[DimsType] = None,
- norm: NormType = None,
- ) -> TensorLikeType:
- (shape, dim) = _canonicalize_fft_shape_and_dim_args(input, s, dim)
- x = _maybe_promote_tensor_fft(input, require_complex=True)
- return _fftn_c2c("fftn", x, shape, dim, norm, forward=True)
- @register_decomposition(aten.fft_ifftn)
- @out_wrapper()
- def ifftn(
- input: TensorLikeType,
- s: Optional[ShapeType] = None,
- dim: Optional[DimsType] = None,
- norm: NormType = None,
- ) -> TensorLikeType:
- (shape, dim) = _canonicalize_fft_shape_and_dim_args(input, s, dim)
- x = _maybe_promote_tensor_fft(input, require_complex=True)
- return _fftn_c2c("ifftn", x, shape, dim, norm, forward=False)
- @register_decomposition(aten.fft_rfftn)
- @out_wrapper()
- def rfftn(
- input: TensorLikeType,
- s: Optional[ShapeType] = None,
- dim: Optional[DimsType] = None,
- norm: NormType = None,
- ) -> TensorLikeType:
- check(
- not input.dtype.is_complex,
- lambda: f"rfftn expects a real-valued input tensor, but got {input.dtype}",
- )
- shape, dim = _canonicalize_fft_shape_and_dim_args(input, s, dim)
- input = _maybe_promote_tensor_fft(input, require_complex=False)
- input = _resize_fft_input(input, dim, shape)
- out = prims.fft_r2c(input, dim=dim, onesided=True)
- return _apply_norm(out, norm=norm, signal_numel=_prod(shape), forward=True)
- @register_decomposition(aten.fft_ihfftn)
- @out_wrapper()
- def ihfftn(
- input: TensorLikeType,
- s: Optional[ShapeType] = None,
- dim: Optional[DimsType] = None,
- norm: NormType = None,
- ) -> TensorLikeType:
- check(
- not input.dtype.is_complex,
- lambda: f"ihfftn expects a real-valued input tensor, but got {input.dtype}",
- )
- shape, dim = _canonicalize_fft_shape_and_dim_args(input, s, dim)
- check(len(shape) > 0, lambda: "ihfftn must transform at least one axis")
- input = _maybe_promote_tensor_fft(input, require_complex=False)
- input = _resize_fft_input(input, dim, shape)
- tmp = prims.fft_r2c(input, dim=dim[-1:], onesided=True)
- if len(dim) == 1:
- tmp = _apply_norm(tmp, norm=norm, signal_numel=shape[0], forward=False)
- return prims.conj(tmp)
- tmp = prims.conj_physical(tmp)
- tmp = prims.fft_c2c(tmp, dim=dim[:-1], forward=False)
- return _apply_norm(tmp, norm=norm, signal_numel=_prod(shape), forward=False)
- class _CanonicalizeC2rReturn(NamedTuple):
- shape: Tuple[int, ...]
- dim: Tuple[int, ...]
- last_dim_size: int
- def _canonicalize_fft_c2r_shape_and_dim_args(
- fname: str,
- input: TensorLikeType,
- s: Optional[ShapeType],
- dim: Optional[DimsType],
- ) -> _CanonicalizeC2rReturn:
- """Canonicalize shape and dim arguments for n-dimensional c2r transforms,
- as well as calculating the last_dim_size which is shape[dim[-1]] for the output"""
- (shape, dim) = _canonicalize_fft_shape_and_dim_args(input, s, dim)
- check(len(shape) > 0, lambda: f"{fname} must transform at least one axis")
- if s is None or s[-1] == -1:
- last_dim_size = 2 * (input.shape[dim[-1]] - 1)
- else:
- last_dim_size = shape[-1]
- check(
- last_dim_size >= 1,
- lambda: f"Invalid number of data points ({last_dim_size}) specified",
- )
- shape_list = list(shape)
- shape_list[-1] = last_dim_size // 2 + 1
- return _CanonicalizeC2rReturn(
- shape=tuple(shape_list), dim=dim, last_dim_size=last_dim_size
- )
- @register_decomposition(aten.fft_irfftn)
- @out_wrapper()
- def irfftn(
- input: TensorLikeType,
- s: Optional[ShapeType] = None,
- dim: Optional[DimsType] = None,
- norm: NormType = None,
- ) -> TensorLikeType:
- shape, dim, last_dim_size = _canonicalize_fft_c2r_shape_and_dim_args(
- "irfftn", input, s, dim
- )
- input = _maybe_promote_tensor_fft(input, require_complex=True)
- input = _resize_fft_input(input, dim, shape)
- out = prims.fft_c2r(input, dim=dim, last_dim_size=last_dim_size)
- return _apply_norm(out, norm, _prod(out.shape[d] for d in dim), forward=False)
- @register_decomposition(aten.fft_hfftn)
- @out_wrapper()
- def hfftn(
- input: TensorLikeType,
- s: Optional[ShapeType] = None,
- dim: Optional[DimsType] = None,
- norm: NormType = None,
- ) -> TensorLikeType:
- shape, dim, last_dim_size = _canonicalize_fft_c2r_shape_and_dim_args(
- "hfftn", input, s, dim
- )
- input = _maybe_promote_tensor_fft(input, require_complex=True)
- input = _resize_fft_input(input, dim, shape)
- tmp = prims.fft_c2c(input, dim=dim[:-1], forward=True) if len(dim) > 1 else input
- tmp = _apply_norm(tmp, norm, _prod(shape[:-1]), forward=True)
- tmp = prims.conj_physical(tmp)
- out = prims.fft_c2r(tmp, dim=dim[-1:], last_dim_size=last_dim_size)
- return _apply_norm(out, norm, last_dim_size, forward=True)
- @register_decomposition(aten.fft_fft2)
- @out_wrapper()
- def fft2(
- input: TensorLikeType,
- s: Optional[ShapeType] = None,
- dim: Optional[DimsType] = (-2, -1),
- norm: NormType = None,
- ) -> TensorLikeType:
- return torch.fft.fftn(input, s=s, dim=dim, norm=norm)
- @register_decomposition(aten.fft_ifft2)
- @out_wrapper()
- def ifft2(
- input: TensorLikeType,
- s: Optional[ShapeType] = None,
- dim: Optional[DimsType] = (-2, -1),
- norm: NormType = None,
- ) -> TensorLikeType:
- return torch.fft.ifftn(input, s=s, dim=dim, norm=norm)
- @register_decomposition(aten.fft_rfft2)
- @out_wrapper()
- def rfft2(
- input: TensorLikeType,
- s: Optional[ShapeType] = None,
- dim: Optional[DimsType] = (-2, -1),
- norm: NormType = None,
- ) -> TensorLikeType:
- return torch.fft.rfftn(input, s=s, dim=dim, norm=norm)
- @register_decomposition(aten.fft_irfft2)
- @out_wrapper()
- def irfft2(
- input: TensorLikeType,
- s: Optional[ShapeType] = None,
- dim: Optional[DimsType] = (-2, -1),
- norm: NormType = None,
- ) -> TensorLikeType:
- return torch.fft.irfftn(input, s=s, dim=dim, norm=norm)
- @register_decomposition(aten.fft_hfft2)
- @out_wrapper()
- def hfft2(
- input: TensorLikeType,
- s: Optional[ShapeType] = None,
- dim: Optional[DimsType] = (-2, -1),
- norm: NormType = None,
- ) -> TensorLikeType:
- return torch.fft.hfftn(input, s=s, dim=dim, norm=norm)
- @register_decomposition(aten.fft_ihfft2)
- @out_wrapper()
- def ihfft2(
- input: TensorLikeType,
- s: Optional[ShapeType] = None,
- dim: Optional[DimsType] = (-2, -1),
- norm: NormType = None,
- ) -> TensorLikeType:
- return torch.fft.ihfftn(input, s=s, dim=dim, norm=norm)
- def _default_alldims(dim: Optional[DimsType], x: TensorLikeType) -> List[int]:
- """Convert Optional[DimsType] to a simple list, defaulting to all dimensions"""
- if dim is None:
- return list(range(x.ndim))
- elif not isinstance(dim, Sequence):
- return [dim]
- else:
- return list(dim)
- @register_decomposition(aten.fft_fftshift)
- def fftshift(input: TensorLikeType, dim: Optional[DimsType] = None) -> TensorLikeType:
- dims = _default_alldims(dim, input)
- shift = [input.shape[d] // 2 for d in dims]
- return torch.roll(input, shift, dims)
- @register_decomposition(aten.fft_ifftshift)
- def ifftshift(input: TensorLikeType, dim: Optional[DimsType] = None) -> TensorLikeType:
- dims = _default_alldims(dim, input)
- shift = [(input.shape[d] + 1) // 2 for d in dims]
- return torch.roll(input, shift, dims)
|