123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755 |
- import unittest
- from functools import partial
- from typing import List
- import numpy as np
- import torch
- from torch.testing import make_tensor
- from torch.testing._internal.common_cuda import SM53OrLater
- from torch.testing._internal.common_device_type import precisionOverride
- from torch.testing._internal.common_dtype import (
- all_types_and,
- all_types_and_complex_and,
- )
- from torch.testing._internal.common_utils import TEST_SCIPY, TEST_WITH_ROCM
- from torch.testing._internal.opinfo.core import (
- DecorateInfo,
- ErrorInput,
- OpInfo,
- SampleInput,
- SpectralFuncInfo,
- SpectralFuncType,
- )
- from torch.testing._internal.opinfo.refs import (
- _find_referenced_opinfo,
- _inherit_constructor_args,
- PythonRefInfo,
- )
- has_scipy_fft = False
- if TEST_SCIPY:
- try:
- import scipy.fft
- has_scipy_fft = True
- except ModuleNotFoundError:
- pass
- class SpectralFuncPythonRefInfo(SpectralFuncInfo):
- """
- An OpInfo for a Python reference of an elementwise unary operation.
- """
- def __init__(
- self,
- name, # the stringname of the callable Python reference
- *,
- op=None, # the function variant of the operation, populated as torch.<name> if None
- torch_opinfo_name, # the string name of the corresponding torch opinfo
- torch_opinfo_variant="",
- supports_nvfuser=True,
- **kwargs,
- ): # additional kwargs override kwargs inherited from the torch opinfo
- self.torch_opinfo_name = torch_opinfo_name
- self.torch_opinfo = _find_referenced_opinfo(
- torch_opinfo_name, torch_opinfo_variant, op_db=op_db
- )
- self.supports_nvfuser = supports_nvfuser
- assert isinstance(self.torch_opinfo, SpectralFuncInfo)
- inherited = self.torch_opinfo._original_spectral_func_args
- ukwargs = _inherit_constructor_args(name, op, inherited, kwargs)
- super().__init__(**ukwargs)
- def error_inputs_fft(op_info, device, **kwargs):
- make_arg = partial(make_tensor, device=device, dtype=torch.float32)
- # Zero-dimensional tensor has no dimension to take FFT of
- yield ErrorInput(
- SampleInput(make_arg()),
- error_type=IndexError,
- error_regex="Dimension specified as -1 but tensor has no dimensions",
- )
- def error_inputs_fftn(op_info, device, **kwargs):
- make_arg = partial(make_tensor, device=device, dtype=torch.float32)
- # Specifying a dimension on a zero-dimensional tensor
- yield ErrorInput(
- SampleInput(make_arg(), dim=(0,)),
- error_type=IndexError,
- error_regex="Dimension specified as 0 but tensor has no dimensions",
- )
- def sample_inputs_fftshift(op_info, device, dtype, requires_grad, **kwargs):
- def mt(shape, **kwargs):
- return make_tensor(
- shape, device=device, dtype=dtype, requires_grad=requires_grad, **kwargs
- )
- yield SampleInput(mt((9, 10)))
- yield SampleInput(mt((50,)), kwargs=dict(dim=0))
- yield SampleInput(mt((5, 11)), kwargs=dict(dim=(1,)))
- yield SampleInput(mt((5, 6)), kwargs=dict(dim=(0, 1)))
- yield SampleInput(mt((5, 6, 2)), kwargs=dict(dim=(0, 2)))
- # Operator database
- op_db: List[OpInfo] = [
- SpectralFuncInfo(
- "fft.fft",
- aten_name="fft_fft",
- decomp_aten_name="_fft_c2c",
- ref=np.fft.fft,
- ndimensional=SpectralFuncType.OneD,
- dtypes=all_types_and_complex_and(torch.bool),
- # rocFFT doesn't support Half/Complex Half Precision FFT
- # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
- dtypesIfCUDA=all_types_and_complex_and(
- torch.bool,
- *(
- ()
- if (TEST_WITH_ROCM or not SM53OrLater)
- else (torch.half, torch.complex32)
- ),
- ),
- error_inputs_func=error_inputs_fft,
- # https://github.com/pytorch/pytorch/issues/80411
- gradcheck_fast_mode=True,
- supports_forward_ad=True,
- supports_fwgrad_bwgrad=True,
- # See https://github.com/pytorch/pytorch/pull/78358
- check_batched_forward_grad=False,
- ),
- SpectralFuncInfo(
- "fft.fft2",
- aten_name="fft_fft2",
- ref=np.fft.fft2,
- decomp_aten_name="_fft_c2c",
- ndimensional=SpectralFuncType.TwoD,
- dtypes=all_types_and_complex_and(torch.bool),
- # rocFFT doesn't support Half/Complex Half Precision FFT
- # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
- dtypesIfCUDA=all_types_and_complex_and(
- torch.bool,
- *(
- ()
- if (TEST_WITH_ROCM or not SM53OrLater)
- else (torch.half, torch.complex32)
- ),
- ),
- error_inputs_func=error_inputs_fftn,
- # https://github.com/pytorch/pytorch/issues/80411
- gradcheck_fast_mode=True,
- supports_forward_ad=True,
- supports_fwgrad_bwgrad=True,
- # See https://github.com/pytorch/pytorch/pull/78358
- check_batched_forward_grad=False,
- decorators=[precisionOverride({torch.float: 1e-4, torch.cfloat: 1e-4})],
- ),
- SpectralFuncInfo(
- "fft.fftn",
- aten_name="fft_fftn",
- decomp_aten_name="_fft_c2c",
- ref=np.fft.fftn,
- ndimensional=SpectralFuncType.ND,
- dtypes=all_types_and_complex_and(torch.bool),
- # rocFFT doesn't support Half/Complex Half Precision FFT
- # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
- dtypesIfCUDA=all_types_and_complex_and(
- torch.bool,
- *(
- ()
- if (TEST_WITH_ROCM or not SM53OrLater)
- else (torch.half, torch.complex32)
- ),
- ),
- error_inputs_func=error_inputs_fftn,
- # https://github.com/pytorch/pytorch/issues/80411
- gradcheck_fast_mode=True,
- supports_forward_ad=True,
- supports_fwgrad_bwgrad=True,
- # See https://github.com/pytorch/pytorch/pull/78358
- check_batched_forward_grad=False,
- decorators=[precisionOverride({torch.float: 1e-4, torch.cfloat: 1e-4})],
- ),
- SpectralFuncInfo(
- "fft.hfft",
- aten_name="fft_hfft",
- decomp_aten_name="_fft_c2r",
- ref=np.fft.hfft,
- ndimensional=SpectralFuncType.OneD,
- dtypes=all_types_and_complex_and(torch.bool),
- # rocFFT doesn't support Half/Complex Half Precision FFT
- # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
- dtypesIfCUDA=all_types_and_complex_and(
- torch.bool,
- *(
- ()
- if (TEST_WITH_ROCM or not SM53OrLater)
- else (torch.half, torch.complex32)
- ),
- ),
- error_inputs_func=error_inputs_fft,
- # https://github.com/pytorch/pytorch/issues/80411
- gradcheck_fast_mode=True,
- supports_forward_ad=True,
- supports_fwgrad_bwgrad=True,
- # See https://github.com/pytorch/pytorch/pull/78358
- check_batched_forward_grad=False,
- check_batched_gradgrad=False,
- skips=(
- # Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479
- DecorateInfo(
- unittest.skip("Skipped!"),
- "TestSchemaCheckModeOpInfo",
- "test_schema_correctness",
- dtypes=(torch.complex64, torch.complex128),
- ),
- ),
- ),
- SpectralFuncInfo(
- "fft.hfft2",
- aten_name="fft_hfft2",
- decomp_aten_name="_fft_c2r",
- ref=scipy.fft.hfft2 if has_scipy_fft else None,
- ndimensional=SpectralFuncType.TwoD,
- dtypes=all_types_and_complex_and(torch.bool),
- # rocFFT doesn't support Half/Complex Half Precision FFT
- # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
- dtypesIfCUDA=all_types_and_complex_and(
- torch.bool,
- *(
- ()
- if (TEST_WITH_ROCM or not SM53OrLater)
- else (torch.half, torch.complex32)
- ),
- ),
- error_inputs_func=error_inputs_fftn,
- # https://github.com/pytorch/pytorch/issues/80411
- gradcheck_fast_mode=True,
- supports_forward_ad=True,
- supports_fwgrad_bwgrad=True,
- check_batched_gradgrad=False,
- # See https://github.com/pytorch/pytorch/pull/78358
- check_batched_forward_grad=False,
- decorators=[
- DecorateInfo(
- precisionOverride({torch.float: 2e-4, torch.cfloat: 2e-4}),
- "TestFFT",
- "test_reference_nd",
- )
- ],
- skips=(
- # Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479
- DecorateInfo(
- unittest.skip("Skipped!"),
- "TestSchemaCheckModeOpInfo",
- "test_schema_correctness",
- ),
- ),
- ),
- SpectralFuncInfo(
- "fft.hfftn",
- aten_name="fft_hfftn",
- decomp_aten_name="_fft_c2r",
- ref=scipy.fft.hfftn if has_scipy_fft else None,
- ndimensional=SpectralFuncType.ND,
- dtypes=all_types_and_complex_and(torch.bool),
- # rocFFT doesn't support Half/Complex Half Precision FFT
- # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
- dtypesIfCUDA=all_types_and_complex_and(
- torch.bool,
- *(
- ()
- if (TEST_WITH_ROCM or not SM53OrLater)
- else (torch.half, torch.complex32)
- ),
- ),
- error_inputs_func=error_inputs_fftn,
- # https://github.com/pytorch/pytorch/issues/80411
- gradcheck_fast_mode=True,
- supports_forward_ad=True,
- supports_fwgrad_bwgrad=True,
- check_batched_gradgrad=False,
- # See https://github.com/pytorch/pytorch/pull/78358
- check_batched_forward_grad=False,
- decorators=[
- DecorateInfo(
- precisionOverride({torch.float: 2e-4, torch.cfloat: 2e-4}),
- "TestFFT",
- "test_reference_nd",
- ),
- ],
- skips=(
- # Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479
- DecorateInfo(
- unittest.skip("Skipped!"),
- "TestSchemaCheckModeOpInfo",
- "test_schema_correctness",
- ),
- ),
- ),
- SpectralFuncInfo(
- "fft.rfft",
- aten_name="fft_rfft",
- decomp_aten_name="_fft_r2c",
- ref=np.fft.rfft,
- ndimensional=SpectralFuncType.OneD,
- dtypes=all_types_and(torch.bool),
- # rocFFT doesn't support Half/Complex Half Precision FFT
- # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
- dtypesIfCUDA=all_types_and(
- torch.bool, *(() if (TEST_WITH_ROCM or not SM53OrLater) else (torch.half,))
- ),
- error_inputs_func=error_inputs_fft,
- # https://github.com/pytorch/pytorch/issues/80411
- gradcheck_fast_mode=True,
- supports_forward_ad=True,
- supports_fwgrad_bwgrad=True,
- check_batched_grad=False,
- skips=(),
- check_batched_gradgrad=False,
- ),
- SpectralFuncInfo(
- "fft.rfft2",
- aten_name="fft_rfft2",
- decomp_aten_name="_fft_r2c",
- ref=np.fft.rfft2,
- ndimensional=SpectralFuncType.TwoD,
- dtypes=all_types_and(torch.bool),
- # rocFFT doesn't support Half/Complex Half Precision FFT
- # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
- dtypesIfCUDA=all_types_and(
- torch.bool, *(() if (TEST_WITH_ROCM or not SM53OrLater) else (torch.half,))
- ),
- error_inputs_func=error_inputs_fftn,
- # https://github.com/pytorch/pytorch/issues/80411
- gradcheck_fast_mode=True,
- supports_forward_ad=True,
- supports_fwgrad_bwgrad=True,
- check_batched_grad=False,
- check_batched_gradgrad=False,
- decorators=[
- precisionOverride({torch.float: 1e-4}),
- ],
- ),
- SpectralFuncInfo(
- "fft.rfftn",
- aten_name="fft_rfftn",
- decomp_aten_name="_fft_r2c",
- ref=np.fft.rfftn,
- ndimensional=SpectralFuncType.ND,
- dtypes=all_types_and(torch.bool),
- # rocFFT doesn't support Half/Complex Half Precision FFT
- # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
- dtypesIfCUDA=all_types_and(
- torch.bool, *(() if (TEST_WITH_ROCM or not SM53OrLater) else (torch.half,))
- ),
- error_inputs_func=error_inputs_fftn,
- # https://github.com/pytorch/pytorch/issues/80411
- gradcheck_fast_mode=True,
- supports_forward_ad=True,
- supports_fwgrad_bwgrad=True,
- check_batched_grad=False,
- check_batched_gradgrad=False,
- decorators=[
- precisionOverride({torch.float: 1e-4}),
- ],
- ),
- SpectralFuncInfo(
- "fft.ifft",
- aten_name="fft_ifft",
- decomp_aten_name="_fft_c2c",
- ref=np.fft.ifft,
- ndimensional=SpectralFuncType.OneD,
- error_inputs_func=error_inputs_fft,
- # https://github.com/pytorch/pytorch/issues/80411
- gradcheck_fast_mode=True,
- supports_forward_ad=True,
- supports_fwgrad_bwgrad=True,
- # See https://github.com/pytorch/pytorch/pull/78358
- check_batched_forward_grad=False,
- dtypes=all_types_and_complex_and(torch.bool),
- # rocFFT doesn't support Half/Complex Half Precision FFT
- # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
- dtypesIfCUDA=all_types_and_complex_and(
- torch.bool,
- *(
- ()
- if (TEST_WITH_ROCM or not SM53OrLater)
- else (torch.half, torch.complex32)
- ),
- ),
- ),
- SpectralFuncInfo(
- "fft.ifft2",
- aten_name="fft_ifft2",
- decomp_aten_name="_fft_c2c",
- ref=np.fft.ifft2,
- ndimensional=SpectralFuncType.TwoD,
- error_inputs_func=error_inputs_fftn,
- # https://github.com/pytorch/pytorch/issues/80411
- gradcheck_fast_mode=True,
- supports_forward_ad=True,
- supports_fwgrad_bwgrad=True,
- # See https://github.com/pytorch/pytorch/pull/78358
- check_batched_forward_grad=False,
- dtypes=all_types_and_complex_and(torch.bool),
- # rocFFT doesn't support Half/Complex Half Precision FFT
- # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
- dtypesIfCUDA=all_types_and_complex_and(
- torch.bool,
- *(
- ()
- if (TEST_WITH_ROCM or not SM53OrLater)
- else (torch.half, torch.complex32)
- ),
- ),
- decorators=[
- DecorateInfo(
- precisionOverride({torch.float: 1e-4, torch.cfloat: 1e-4}),
- "TestFFT",
- "test_reference_nd",
- )
- ],
- ),
- SpectralFuncInfo(
- "fft.ifftn",
- aten_name="fft_ifftn",
- decomp_aten_name="_fft_c2c",
- ref=np.fft.ifftn,
- ndimensional=SpectralFuncType.ND,
- error_inputs_func=error_inputs_fftn,
- # https://github.com/pytorch/pytorch/issues/80411
- gradcheck_fast_mode=True,
- supports_forward_ad=True,
- supports_fwgrad_bwgrad=True,
- # See https://github.com/pytorch/pytorch/pull/78358
- check_batched_forward_grad=False,
- dtypes=all_types_and_complex_and(torch.bool),
- # rocFFT doesn't support Half/Complex Half Precision FFT
- # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
- dtypesIfCUDA=all_types_and_complex_and(
- torch.bool,
- *(
- ()
- if (TEST_WITH_ROCM or not SM53OrLater)
- else (torch.half, torch.complex32)
- ),
- ),
- decorators=[
- DecorateInfo(
- precisionOverride({torch.float: 1e-4, torch.cfloat: 1e-4}),
- "TestFFT",
- "test_reference_nd",
- )
- ],
- ),
- SpectralFuncInfo(
- "fft.ihfft",
- aten_name="fft_ihfft",
- decomp_aten_name="_fft_r2c",
- ref=np.fft.ihfft,
- ndimensional=SpectralFuncType.OneD,
- error_inputs_func=error_inputs_fft,
- supports_forward_ad=True,
- supports_fwgrad_bwgrad=True,
- # See https://github.com/pytorch/pytorch/pull/78358
- check_batched_forward_grad=False,
- dtypes=all_types_and(torch.bool),
- # rocFFT doesn't support Half/Complex Half Precision FFT
- # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
- dtypesIfCUDA=all_types_and(
- torch.bool, *(() if (TEST_WITH_ROCM or not SM53OrLater) else (torch.half,))
- ),
- skips=(),
- check_batched_grad=False,
- ),
- SpectralFuncInfo(
- "fft.ihfft2",
- aten_name="fft_ihfft2",
- decomp_aten_name="_fft_r2c",
- ref=scipy.fft.ihfftn if has_scipy_fft else None,
- ndimensional=SpectralFuncType.TwoD,
- error_inputs_func=error_inputs_fftn,
- # https://github.com/pytorch/pytorch/issues/80411
- gradcheck_fast_mode=True,
- supports_forward_ad=True,
- supports_fwgrad_bwgrad=True,
- # See https://github.com/pytorch/pytorch/pull/78358
- check_batched_forward_grad=False,
- dtypes=all_types_and(torch.bool),
- # rocFFT doesn't support Half/Complex Half Precision FFT
- # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
- dtypesIfCUDA=all_types_and(
- torch.bool, *(() if (TEST_WITH_ROCM or not SM53OrLater) else (torch.half,))
- ),
- check_batched_grad=False,
- check_batched_gradgrad=False,
- decorators=(
- # The values for attribute 'shape' do not match: torch.Size([5, 6, 5]) != torch.Size([5, 6, 6]).
- DecorateInfo(unittest.expectedFailure, "TestCommon", "test_out_warning"),
- DecorateInfo(
- precisionOverride({torch.float: 2e-4}), "TestFFT", "test_reference_nd"
- ),
- # Mismatched elements!
- DecorateInfo(unittest.expectedFailure, "TestCommon", "test_out"),
- DecorateInfo(unittest.expectedFailure, "TestCommon", "test_out_warnings"),
- ),
- ),
- SpectralFuncInfo(
- "fft.ihfftn",
- aten_name="fft_ihfftn",
- decomp_aten_name="_fft_r2c",
- ref=scipy.fft.ihfftn if has_scipy_fft else None,
- ndimensional=SpectralFuncType.ND,
- error_inputs_func=error_inputs_fftn,
- # https://github.com/pytorch/pytorch/issues/80411
- gradcheck_fast_mode=True,
- supports_forward_ad=True,
- supports_fwgrad_bwgrad=True,
- # See https://github.com/pytorch/pytorch/pull/78358
- check_batched_forward_grad=False,
- dtypes=all_types_and(torch.bool),
- # rocFFT doesn't support Half/Complex Half Precision FFT
- # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archss
- dtypesIfCUDA=all_types_and(
- torch.bool, *(() if (TEST_WITH_ROCM or not SM53OrLater) else (torch.half,))
- ),
- check_batched_grad=False,
- check_batched_gradgrad=False,
- decorators=[
- # The values for attribute 'shape' do not match: torch.Size([5, 6, 5]) != torch.Size([5, 6, 6]).
- DecorateInfo(unittest.expectedFailure, "TestCommon", "test_out_warning"),
- # Mismatched elements!
- DecorateInfo(unittest.expectedFailure, "TestCommon", "test_out"),
- DecorateInfo(
- precisionOverride({torch.float: 2e-4}), "TestFFT", "test_reference_nd"
- ),
- ],
- ),
- SpectralFuncInfo(
- "fft.irfft",
- aten_name="fft_irfft",
- decomp_aten_name="_fft_c2r",
- ref=np.fft.irfft,
- ndimensional=SpectralFuncType.OneD,
- error_inputs_func=error_inputs_fft,
- # https://github.com/pytorch/pytorch/issues/80411
- gradcheck_fast_mode=True,
- supports_forward_ad=True,
- supports_fwgrad_bwgrad=True,
- # See https://github.com/pytorch/pytorch/pull/78358
- check_batched_forward_grad=False,
- dtypes=all_types_and_complex_and(torch.bool),
- # rocFFT doesn't support Half/Complex Half Precision FFT
- # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
- dtypesIfCUDA=all_types_and_complex_and(
- torch.bool,
- *(
- ()
- if (TEST_WITH_ROCM or not SM53OrLater)
- else (torch.half, torch.complex32)
- ),
- ),
- check_batched_gradgrad=False,
- ),
- SpectralFuncInfo(
- "fft.irfft2",
- aten_name="fft_irfft2",
- decomp_aten_name="_fft_c2r",
- ref=np.fft.irfft2,
- ndimensional=SpectralFuncType.TwoD,
- error_inputs_func=error_inputs_fftn,
- # https://github.com/pytorch/pytorch/issues/80411
- gradcheck_fast_mode=True,
- supports_forward_ad=True,
- supports_fwgrad_bwgrad=True,
- # See https://github.com/pytorch/pytorch/pull/78358
- check_batched_forward_grad=False,
- dtypes=all_types_and_complex_and(torch.bool),
- # rocFFT doesn't support Half/Complex Half Precision FFT
- # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
- dtypesIfCUDA=all_types_and_complex_and(
- torch.bool,
- *(
- ()
- if (TEST_WITH_ROCM or not SM53OrLater)
- else (torch.half, torch.complex32)
- ),
- ),
- check_batched_gradgrad=False,
- decorators=[
- DecorateInfo(
- precisionOverride({torch.float: 1e-4, torch.cfloat: 1e-4}),
- "TestFFT",
- "test_reference_nd",
- )
- ],
- ),
- SpectralFuncInfo(
- "fft.irfftn",
- aten_name="fft_irfftn",
- decomp_aten_name="_fft_c2r",
- ref=np.fft.irfftn,
- ndimensional=SpectralFuncType.ND,
- error_inputs_func=error_inputs_fftn,
- # https://github.com/pytorch/pytorch/issues/80411
- gradcheck_fast_mode=True,
- supports_forward_ad=True,
- supports_fwgrad_bwgrad=True,
- # See https://github.com/pytorch/pytorch/pull/78358
- check_batched_forward_grad=False,
- dtypes=all_types_and_complex_and(torch.bool),
- # rocFFT doesn't support Half/Complex Half Precision FFT
- # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
- dtypesIfCUDA=all_types_and_complex_and(
- torch.bool,
- *(
- ()
- if (TEST_WITH_ROCM or not SM53OrLater)
- else (torch.half, torch.complex32)
- ),
- ),
- check_batched_gradgrad=False,
- decorators=[
- DecorateInfo(
- precisionOverride({torch.float: 1e-4, torch.cfloat: 1e-4}),
- "TestFFT",
- "test_reference_nd",
- )
- ],
- ),
- OpInfo(
- "fft.fftshift",
- dtypes=all_types_and_complex_and(
- torch.bool, torch.bfloat16, torch.half, torch.chalf
- ),
- sample_inputs_func=sample_inputs_fftshift,
- supports_out=False,
- supports_forward_ad=True,
- supports_fwgrad_bwgrad=True,
- ),
- OpInfo(
- "fft.ifftshift",
- dtypes=all_types_and_complex_and(
- torch.bool, torch.bfloat16, torch.half, torch.chalf
- ),
- sample_inputs_func=sample_inputs_fftshift,
- supports_out=False,
- supports_forward_ad=True,
- supports_fwgrad_bwgrad=True,
- ),
- ]
- python_ref_db: List[OpInfo] = [
- SpectralFuncPythonRefInfo(
- "_refs.fft.fft",
- torch_opinfo_name="fft.fft",
- supports_nvfuser=False,
- ),
- SpectralFuncPythonRefInfo(
- "_refs.fft.ifft",
- torch_opinfo_name="fft.ifft",
- supports_nvfuser=False,
- ),
- SpectralFuncPythonRefInfo(
- "_refs.fft.rfft",
- torch_opinfo_name="fft.rfft",
- supports_nvfuser=False,
- ),
- SpectralFuncPythonRefInfo(
- "_refs.fft.irfft",
- torch_opinfo_name="fft.irfft",
- supports_nvfuser=False,
- ),
- SpectralFuncPythonRefInfo(
- "_refs.fft.hfft",
- torch_opinfo_name="fft.hfft",
- supports_nvfuser=False,
- ),
- SpectralFuncPythonRefInfo(
- "_refs.fft.ihfft",
- torch_opinfo_name="fft.ihfft",
- supports_nvfuser=False,
- ),
- SpectralFuncPythonRefInfo(
- "_refs.fft.fftn",
- torch_opinfo_name="fft.fftn",
- supports_nvfuser=False,
- ),
- SpectralFuncPythonRefInfo(
- "_refs.fft.ifftn",
- torch_opinfo_name="fft.ifftn",
- supports_nvfuser=False,
- ),
- SpectralFuncPythonRefInfo(
- "_refs.fft.rfftn",
- torch_opinfo_name="fft.rfftn",
- supports_nvfuser=False,
- ),
- SpectralFuncPythonRefInfo(
- "_refs.fft.irfftn",
- torch_opinfo_name="fft.irfftn",
- supports_nvfuser=False,
- ),
- SpectralFuncPythonRefInfo(
- "_refs.fft.hfftn",
- torch_opinfo_name="fft.hfftn",
- supports_nvfuser=False,
- ),
- SpectralFuncPythonRefInfo(
- "_refs.fft.ihfftn",
- torch_opinfo_name="fft.ihfftn",
- supports_nvfuser=False,
- ),
- SpectralFuncPythonRefInfo(
- "_refs.fft.fft2",
- torch_opinfo_name="fft.fft2",
- supports_nvfuser=False,
- ),
- SpectralFuncPythonRefInfo(
- "_refs.fft.ifft2",
- torch_opinfo_name="fft.ifft2",
- supports_nvfuser=False,
- ),
- SpectralFuncPythonRefInfo(
- "_refs.fft.rfft2",
- torch_opinfo_name="fft.rfft2",
- supports_nvfuser=False,
- ),
- SpectralFuncPythonRefInfo(
- "_refs.fft.irfft2",
- torch_opinfo_name="fft.irfft2",
- supports_nvfuser=False,
- ),
- SpectralFuncPythonRefInfo(
- "_refs.fft.hfft2",
- torch_opinfo_name="fft.hfft2",
- supports_nvfuser=False,
- ),
- SpectralFuncPythonRefInfo(
- "_refs.fft.ihfft2",
- torch_opinfo_name="fft.ihfft2",
- supports_nvfuser=False,
- ),
- PythonRefInfo(
- "_refs.fft.fftshift",
- op_db=op_db,
- torch_opinfo_name="fft.fftshift",
- supports_nvfuser=False,
- ),
- PythonRefInfo(
- "_refs.fft.ifftshift",
- op_db=op_db,
- torch_opinfo_name="fft.ifftshift",
- supports_nvfuser=False,
- ),
- ]
|