Exceptions.h 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156
  1. #pragma once
  2. #include <cublas_v2.h>
  3. #include <cusparse.h>
  4. #include <c10/macros/Export.h>
  5. #ifdef CUDART_VERSION
  6. #include <cusolver_common.h>
  7. #endif
  8. #include <ATen/Context.h>
  9. #include <c10/util/Exception.h>
  10. #include <c10/cuda/CUDAException.h>
  11. namespace c10 {
  12. class CuDNNError : public c10::Error {
  13. using Error::Error;
  14. };
  15. } // namespace c10
  16. #define AT_CUDNN_CHECK_WITH_SHAPES(EXPR, ...) AT_CUDNN_CHECK(EXPR, "\n", ##__VA_ARGS__)
  17. // See Note [CHECK macro]
  18. #define AT_CUDNN_CHECK(EXPR, ...) \
  19. do { \
  20. cudnnStatus_t status = EXPR; \
  21. if (status != CUDNN_STATUS_SUCCESS) { \
  22. if (status == CUDNN_STATUS_NOT_SUPPORTED) { \
  23. TORCH_CHECK_WITH(CuDNNError, false, \
  24. "cuDNN error: ", \
  25. cudnnGetErrorString(status), \
  26. ". This error may appear if you passed in a non-contiguous input.", ##__VA_ARGS__); \
  27. } else { \
  28. TORCH_CHECK_WITH(CuDNNError, false, \
  29. "cuDNN error: ", cudnnGetErrorString(status), ##__VA_ARGS__); \
  30. } \
  31. } \
  32. } while (0)
  33. namespace at { namespace cuda { namespace blas {
  34. C10_EXPORT const char* _cublasGetErrorEnum(cublasStatus_t error);
  35. }}} // namespace at::cuda::blas
  36. #define TORCH_CUDABLAS_CHECK(EXPR) \
  37. do { \
  38. cublasStatus_t __err = EXPR; \
  39. TORCH_CHECK(__err == CUBLAS_STATUS_SUCCESS, \
  40. "CUDA error: ", \
  41. at::cuda::blas::_cublasGetErrorEnum(__err), \
  42. " when calling `" #EXPR "`"); \
  43. } while (0)
  44. const char *cusparseGetErrorString(cusparseStatus_t status);
  45. #define TORCH_CUDASPARSE_CHECK(EXPR) \
  46. do { \
  47. cusparseStatus_t __err = EXPR; \
  48. TORCH_CHECK(__err == CUSPARSE_STATUS_SUCCESS, \
  49. "CUDA error: ", \
  50. cusparseGetErrorString(__err), \
  51. " when calling `" #EXPR "`"); \
  52. } while (0)
  53. // cusolver related headers are only supported on cuda now
  54. #ifdef CUDART_VERSION
  55. namespace at { namespace cuda { namespace solver {
  56. C10_EXPORT const char* cusolverGetErrorMessage(cusolverStatus_t status);
  57. }}} // namespace at::cuda::solver
  58. // When cuda < 11.5, cusolver raises CUSOLVER_STATUS_EXECUTION_FAILED when input contains nan.
  59. // When cuda >= 11.5, cusolver normally finishes execution and sets info array indicating convergence issue.
  60. #define TORCH_CUSOLVER_CHECK(EXPR) \
  61. do { \
  62. cusolverStatus_t __err = EXPR; \
  63. if ((CUDA_VERSION < 11500 && \
  64. __err == CUSOLVER_STATUS_EXECUTION_FAILED) || \
  65. (CUDA_VERSION >= 11500 && \
  66. __err == CUSOLVER_STATUS_INVALID_VALUE)) { \
  67. TORCH_CHECK_LINALG( \
  68. false, \
  69. "cusolver error: ", \
  70. at::cuda::solver::cusolverGetErrorMessage(__err), \
  71. ", when calling `" #EXPR "`", \
  72. ". This error may appear if the input matrix contains NaN."); \
  73. } else { \
  74. TORCH_CHECK( \
  75. __err == CUSOLVER_STATUS_SUCCESS, \
  76. "cusolver error: ", \
  77. at::cuda::solver::cusolverGetErrorMessage(__err), \
  78. ", when calling `" #EXPR "`"); \
  79. } \
  80. } while (0)
  81. #else
  82. #define TORCH_CUSOLVER_CHECK(EXPR) EXPR
  83. #endif
  84. #define AT_CUDA_CHECK(EXPR) C10_CUDA_CHECK(EXPR)
  85. // For CUDA Driver API
  86. //
  87. // This is here instead of in c10 because NVRTC is loaded dynamically via a stub
  88. // in ATen, and we need to use its nvrtcGetErrorString.
  89. // See NOTE [ USE OF NVRTC AND DRIVER API ].
  90. #if !defined(USE_ROCM)
  91. #define AT_CUDA_DRIVER_CHECK(EXPR) \
  92. do { \
  93. CUresult __err = EXPR; \
  94. if (__err != CUDA_SUCCESS) { \
  95. const char* err_str; \
  96. CUresult get_error_str_err C10_UNUSED = at::globalContext().getNVRTC().cuGetErrorString(__err, &err_str); \
  97. if (get_error_str_err != CUDA_SUCCESS) { \
  98. AT_ERROR("CUDA driver error: unknown error"); \
  99. } else { \
  100. AT_ERROR("CUDA driver error: ", err_str); \
  101. } \
  102. } \
  103. } while (0)
  104. #else
  105. #define AT_CUDA_DRIVER_CHECK(EXPR) \
  106. do { \
  107. CUresult __err = EXPR; \
  108. if (__err != CUDA_SUCCESS) { \
  109. AT_ERROR("CUDA driver error: ", static_cast<int>(__err)); \
  110. } \
  111. } while (0)
  112. #endif
  113. // For CUDA NVRTC
  114. //
  115. // Note: As of CUDA 10, nvrtc error code 7, NVRTC_ERROR_BUILTIN_OPERATION_FAILURE,
  116. // incorrectly produces the error string "NVRTC unknown error."
  117. // The following maps it correctly.
  118. //
  119. // This is here instead of in c10 because NVRTC is loaded dynamically via a stub
  120. // in ATen, and we need to use its nvrtcGetErrorString.
  121. // See NOTE [ USE OF NVRTC AND DRIVER API ].
  122. #define AT_CUDA_NVRTC_CHECK(EXPR) \
  123. do { \
  124. nvrtcResult __err = EXPR; \
  125. if (__err != NVRTC_SUCCESS) { \
  126. if (static_cast<int>(__err) != 7) { \
  127. AT_ERROR("CUDA NVRTC error: ", at::globalContext().getNVRTC().nvrtcGetErrorString(__err)); \
  128. } else { \
  129. AT_ERROR("CUDA NVRTC error: NVRTC_ERROR_BUILTIN_OPERATION_FAILURE"); \
  130. } \
  131. } \
  132. } while (0)