CuFFTUtils.h 1.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273
  1. #pragma once
  2. #include <ATen/Config.h>
  3. #include <string>
  4. #include <stdexcept>
  5. #include <sstream>
  6. #include <cufft.h>
  7. #include <cufftXt.h>
  8. namespace at { namespace native {
  9. // This means that max dim is 3 + 2 = 5 with batch dimension and possible
  10. // complex dimension
  11. constexpr int max_rank = 3;
  12. static inline std::string _cudaGetErrorEnum(cufftResult error)
  13. {
  14. switch (error)
  15. {
  16. case CUFFT_SUCCESS:
  17. return "CUFFT_SUCCESS";
  18. case CUFFT_INVALID_PLAN:
  19. return "CUFFT_INVALID_PLAN";
  20. case CUFFT_ALLOC_FAILED:
  21. return "CUFFT_ALLOC_FAILED";
  22. case CUFFT_INVALID_TYPE:
  23. return "CUFFT_INVALID_TYPE";
  24. case CUFFT_INVALID_VALUE:
  25. return "CUFFT_INVALID_VALUE";
  26. case CUFFT_INTERNAL_ERROR:
  27. return "CUFFT_INTERNAL_ERROR";
  28. case CUFFT_EXEC_FAILED:
  29. return "CUFFT_EXEC_FAILED";
  30. case CUFFT_SETUP_FAILED:
  31. return "CUFFT_SETUP_FAILED";
  32. case CUFFT_INVALID_SIZE:
  33. return "CUFFT_INVALID_SIZE";
  34. case CUFFT_UNALIGNED_DATA:
  35. return "CUFFT_UNALIGNED_DATA";
  36. case CUFFT_INCOMPLETE_PARAMETER_LIST:
  37. return "CUFFT_INCOMPLETE_PARAMETER_LIST";
  38. case CUFFT_INVALID_DEVICE:
  39. return "CUFFT_INVALID_DEVICE";
  40. case CUFFT_PARSE_ERROR:
  41. return "CUFFT_PARSE_ERROR";
  42. case CUFFT_NO_WORKSPACE:
  43. return "CUFFT_NO_WORKSPACE";
  44. case CUFFT_NOT_IMPLEMENTED:
  45. return "CUFFT_NOT_IMPLEMENTED";
  46. #if !defined(USE_ROCM)
  47. case CUFFT_LICENSE_ERROR:
  48. return "CUFFT_LICENSE_ERROR";
  49. #endif
  50. case CUFFT_NOT_SUPPORTED:
  51. return "CUFFT_NOT_SUPPORTED";
  52. default:
  53. std::ostringstream ss;
  54. ss << "unknown error " << error;
  55. return ss.str();
  56. }
  57. }
  58. static inline void CUFFT_CHECK(cufftResult error)
  59. {
  60. if (error != CUFFT_SUCCESS) {
  61. std::ostringstream ss;
  62. ss << "cuFFT error: " << _cudaGetErrorEnum(error);
  63. AT_ERROR(ss.str());
  64. }
  65. }
  66. }} // at::native