CUDASparse.h 2.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869
  1. #pragma once
  2. #include <ATen/cuda/CUDAContext.h>
  3. // cuSparse Generic API added in CUDA 10.1
  4. // Windows support added in CUDA 11.0
  5. #if defined(CUDART_VERSION) && defined(CUSPARSE_VERSION) && ((CUSPARSE_VERSION >= 10300) || (CUSPARSE_VERSION >= 11000 && defined(_WIN32)))
  6. #define AT_USE_CUSPARSE_GENERIC_API() 1
  7. #else
  8. #define AT_USE_CUSPARSE_GENERIC_API() 0
  9. #endif
  10. // cuSparse Generic API descriptor pointers were changed to const in CUDA 12.0
  11. #if defined(CUDART_VERSION) && defined(CUSPARSE_VERSION) && \
  12. (CUSPARSE_VERSION < 12000)
  13. #define AT_USE_CUSPARSE_NON_CONST_DESCRIPTORS() 1
  14. #else
  15. #define AT_USE_CUSPARSE_NON_CONST_DESCRIPTORS() 0
  16. #endif
  17. #if defined(CUDART_VERSION) && defined(CUSPARSE_VERSION) && \
  18. (CUSPARSE_VERSION >= 12000)
  19. #define AT_USE_CUSPARSE_CONST_DESCRIPTORS() 1
  20. #else
  21. #define AT_USE_CUSPARSE_CONST_DESCRIPTORS() 0
  22. #endif
  23. // hipSparse Generic API ROCm 5.2
  24. #if defined(USE_ROCM) && ROCM_VERSION >= 50200
  25. #define AT_USE_HIPSPARSE_GENERIC_52_API() 1
  26. #else
  27. #define AT_USE_HIPSPARSE_GENERIC_52_API() 0
  28. #endif
  29. // hipSparse Generic API ROCm 5.1
  30. #if defined(USE_ROCM) && ROCM_VERSION >= 50100
  31. #define AT_USE_HIPSPARSE_GENERIC_API() 1
  32. #else
  33. #define AT_USE_HIPSPARSE_GENERIC_API() 0
  34. #endif
  35. // cuSparse Generic API spsv function was added in CUDA 11.3.0
  36. #if defined(CUDART_VERSION) && defined(CUSPARSE_VERSION) && (CUSPARSE_VERSION >= 11500)
  37. #define AT_USE_CUSPARSE_GENERIC_SPSV() 1
  38. #else
  39. #define AT_USE_CUSPARSE_GENERIC_SPSV() 0
  40. #endif
  41. // cuSparse Generic API spsm function was added in CUDA 11.3.1
  42. #if defined(CUDART_VERSION) && defined(CUSPARSE_VERSION) && (CUSPARSE_VERSION >= 11600)
  43. #define AT_USE_CUSPARSE_GENERIC_SPSM() 1
  44. #else
  45. #define AT_USE_CUSPARSE_GENERIC_SPSM() 0
  46. #endif
  47. // cuSparse Generic API sddmm function was added in CUDA 11.2.1 (cuSparse version 11400)
  48. #if defined(CUDART_VERSION) && defined(CUSPARSE_VERSION) && (CUSPARSE_VERSION >= 11400)
  49. #define AT_USE_CUSPARSE_GENERIC_SDDMM() 1
  50. #else
  51. #define AT_USE_CUSPARSE_GENERIC_SDDMM() 0
  52. #endif
  53. // BSR triangular solve functions were added in hipSPARSE 1.11.2 (ROCm 4.5.0)
  54. #if defined(CUDART_VERSION) || \
  55. (defined(USE_ROCM) && ROCM_VERSION >= 40500 )
  56. #define AT_USE_HIPSPARSE_TRIANGULAR_SOLVE() 1
  57. #else
  58. #define AT_USE_HIPSPARSE_TRIANGULAR_SOLVE() 0
  59. #endif