CUDASparseDescriptors.h 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266
  1. #pragma once
  2. #include <ATen/Tensor.h>
  3. #include <ATen/cuda/CUDAContext.h>
  4. #include <ATen/cuda/CUDASparse.h>
  5. #include <c10/core/ScalarType.h>
  6. #if defined(USE_ROCM)
  7. #include <type_traits>
  8. #endif
  9. namespace at {
  10. namespace cuda {
  11. namespace sparse {
  12. template <typename T, cusparseStatus_t (*destructor)(T*)>
  13. struct CuSparseDescriptorDeleter {
  14. void operator()(T* x) {
  15. if (x != nullptr) {
  16. TORCH_CUDASPARSE_CHECK(destructor(x));
  17. }
  18. }
  19. };
  20. template <typename T, cusparseStatus_t (*destructor)(T*)>
  21. class CuSparseDescriptor {
  22. public:
  23. T* descriptor() const {
  24. return descriptor_.get();
  25. }
  26. T* descriptor() {
  27. return descriptor_.get();
  28. }
  29. protected:
  30. std::unique_ptr<T, CuSparseDescriptorDeleter<T, destructor>> descriptor_;
  31. };
  32. #if AT_USE_CUSPARSE_CONST_DESCRIPTORS()
  33. template <typename T, cusparseStatus_t (*destructor)(const T*)>
  34. struct ConstCuSparseDescriptorDeleter {
  35. void operator()(T* x) {
  36. if (x != nullptr) {
  37. TORCH_CUDASPARSE_CHECK(destructor(x));
  38. }
  39. }
  40. };
  41. template <typename T, cusparseStatus_t (*destructor)(const T*)>
  42. class ConstCuSparseDescriptor {
  43. public:
  44. T* descriptor() const {
  45. return descriptor_.get();
  46. }
  47. T* descriptor() {
  48. return descriptor_.get();
  49. }
  50. protected:
  51. std::unique_ptr<T, ConstCuSparseDescriptorDeleter<T, destructor>> descriptor_;
  52. };
  53. #endif // AT_USE_CUSPARSE_CONST_DESCRIPTORS
  54. #if defined(USE_ROCM)
  55. // hipSPARSE doesn't define this
  56. using cusparseMatDescr = std::remove_pointer<cusparseMatDescr_t>::type;
  57. using cusparseDnMatDescr = std::remove_pointer<cusparseDnMatDescr_t>::type;
  58. using cusparseDnVecDescr = std::remove_pointer<cusparseDnVecDescr_t>::type;
  59. using cusparseSpMatDescr = std::remove_pointer<cusparseSpMatDescr_t>::type;
  60. using cusparseSpMatDescr = std::remove_pointer<cusparseSpMatDescr_t>::type;
  61. using cusparseSpGEMMDescr = std::remove_pointer<cusparseSpGEMMDescr_t>::type;
  62. #if AT_USE_HIPSPARSE_TRIANGULAR_SOLVE()
  63. using bsrsv2Info = std::remove_pointer<bsrsv2Info_t>::type;
  64. using bsrsm2Info = std::remove_pointer<bsrsm2Info_t>::type;
  65. #endif
  66. #endif
  67. class TORCH_CUDA_CPP_API CuSparseMatDescriptor
  68. : public CuSparseDescriptor<cusparseMatDescr, &cusparseDestroyMatDescr> {
  69. public:
  70. CuSparseMatDescriptor() {
  71. cusparseMatDescr_t raw_descriptor;
  72. TORCH_CUDASPARSE_CHECK(cusparseCreateMatDescr(&raw_descriptor));
  73. descriptor_.reset(raw_descriptor);
  74. }
  75. CuSparseMatDescriptor(bool upper, bool unit) {
  76. cusparseFillMode_t fill_mode =
  77. upper ? CUSPARSE_FILL_MODE_UPPER : CUSPARSE_FILL_MODE_LOWER;
  78. cusparseDiagType_t diag_type =
  79. unit ? CUSPARSE_DIAG_TYPE_UNIT : CUSPARSE_DIAG_TYPE_NON_UNIT;
  80. cusparseMatDescr_t raw_descriptor;
  81. TORCH_CUDASPARSE_CHECK(cusparseCreateMatDescr(&raw_descriptor));
  82. TORCH_CUDASPARSE_CHECK(cusparseSetMatFillMode(raw_descriptor, fill_mode));
  83. TORCH_CUDASPARSE_CHECK(cusparseSetMatDiagType(raw_descriptor, diag_type));
  84. descriptor_.reset(raw_descriptor);
  85. }
  86. };
  87. #if AT_USE_HIPSPARSE_TRIANGULAR_SOLVE()
  88. class TORCH_CUDA_CPP_API CuSparseBsrsv2Info
  89. : public CuSparseDescriptor<bsrsv2Info, &cusparseDestroyBsrsv2Info> {
  90. public:
  91. CuSparseBsrsv2Info() {
  92. bsrsv2Info_t raw_descriptor;
  93. TORCH_CUDASPARSE_CHECK(cusparseCreateBsrsv2Info(&raw_descriptor));
  94. descriptor_.reset(raw_descriptor);
  95. }
  96. };
  97. class TORCH_CUDA_CPP_API CuSparseBsrsm2Info
  98. : public CuSparseDescriptor<bsrsm2Info, &cusparseDestroyBsrsm2Info> {
  99. public:
  100. CuSparseBsrsm2Info() {
  101. bsrsm2Info_t raw_descriptor;
  102. TORCH_CUDASPARSE_CHECK(cusparseCreateBsrsm2Info(&raw_descriptor));
  103. descriptor_.reset(raw_descriptor);
  104. }
  105. };
  106. #endif // AT_USE_HIPSPARSE_TRIANGULAR_SOLVE
  107. #if AT_USE_CUSPARSE_GENERIC_API() || AT_USE_HIPSPARSE_GENERIC_API()
  108. cusparseIndexType_t getCuSparseIndexType(const c10::ScalarType& scalar_type);
  109. #if AT_USE_HIPSPARSE_GENERIC_52_API() || \
  110. (AT_USE_CUSPARSE_GENERIC_API() && AT_USE_CUSPARSE_NON_CONST_DESCRIPTORS())
  111. class TORCH_CUDA_CPP_API CuSparseDnMatDescriptor
  112. : public CuSparseDescriptor<cusparseDnMatDescr, &cusparseDestroyDnMat> {
  113. public:
  114. explicit CuSparseDnMatDescriptor(const Tensor& input, int64_t batch_offset = -1);
  115. };
  116. class TORCH_CUDA_CPP_API CuSparseDnVecDescriptor
  117. : public CuSparseDescriptor<cusparseDnVecDescr, &cusparseDestroyDnVec> {
  118. public:
  119. explicit CuSparseDnVecDescriptor(const Tensor& input);
  120. };
  121. class TORCH_CUDA_CPP_API CuSparseSpMatDescriptor
  122. : public CuSparseDescriptor<cusparseSpMatDescr, &cusparseDestroySpMat> {};
  123. //AT_USE_HIPSPARSE_GENERIC_52_API() || (AT_USE_CUSPARSE_GENERIC_API() && AT_USE_CUSPARSE_NON_CONST_DESCRIPTORS())
  124. #elif AT_USE_CUSPARSE_CONST_DESCRIPTORS()
  125. class TORCH_CUDA_CPP_API CuSparseDnMatDescriptor
  126. : public ConstCuSparseDescriptor<
  127. cusparseDnMatDescr,
  128. &cusparseDestroyDnMat> {
  129. public:
  130. explicit CuSparseDnMatDescriptor(
  131. const Tensor& input,
  132. int64_t batch_offset = -1);
  133. };
  134. class TORCH_CUDA_CPP_API CuSparseDnVecDescriptor
  135. : public ConstCuSparseDescriptor<
  136. cusparseDnVecDescr,
  137. &cusparseDestroyDnVec> {
  138. public:
  139. explicit CuSparseDnVecDescriptor(const Tensor& input);
  140. };
  141. class TORCH_CUDA_CPP_API CuSparseSpMatDescriptor
  142. : public ConstCuSparseDescriptor<
  143. cusparseSpMatDescr,
  144. &cusparseDestroySpMat> {};
  145. #endif // AT_USE_CUSPARSE_CONST_DESCRIPTORS()
  146. class TORCH_CUDA_CPP_API CuSparseSpMatCsrDescriptor
  147. : public CuSparseSpMatDescriptor {
  148. public:
  149. explicit CuSparseSpMatCsrDescriptor(const Tensor& input, int64_t batch_offset = -1);
  150. std::tuple<int64_t, int64_t, int64_t> get_size() {
  151. int64_t rows, cols, nnz;
  152. TORCH_CUDASPARSE_CHECK(cusparseSpMatGetSize(
  153. this->descriptor(),
  154. &rows,
  155. &cols,
  156. &nnz));
  157. return std::make_tuple(rows, cols, nnz);
  158. }
  159. void set_tensor(const Tensor& input) {
  160. auto crow_indices = input.crow_indices();
  161. auto col_indices = input.col_indices();
  162. auto values = input.values();
  163. TORCH_INTERNAL_ASSERT_DEBUG_ONLY(crow_indices.is_contiguous());
  164. TORCH_INTERNAL_ASSERT_DEBUG_ONLY(col_indices.is_contiguous());
  165. TORCH_INTERNAL_ASSERT_DEBUG_ONLY(values.is_contiguous());
  166. TORCH_CUDASPARSE_CHECK(cusparseCsrSetPointers(
  167. this->descriptor(),
  168. crow_indices.data_ptr(),
  169. col_indices.data_ptr(),
  170. values.data_ptr()));
  171. }
  172. #if AT_USE_CUSPARSE_GENERIC_SPSV()
  173. void set_mat_fill_mode(bool upper) {
  174. cusparseFillMode_t fill_mode =
  175. upper ? CUSPARSE_FILL_MODE_UPPER : CUSPARSE_FILL_MODE_LOWER;
  176. TORCH_CUDASPARSE_CHECK(cusparseSpMatSetAttribute(
  177. this->descriptor(),
  178. CUSPARSE_SPMAT_FILL_MODE,
  179. &fill_mode,
  180. sizeof(fill_mode)));
  181. }
  182. void set_mat_diag_type(bool unit) {
  183. cusparseDiagType_t diag_type =
  184. unit ? CUSPARSE_DIAG_TYPE_UNIT : CUSPARSE_DIAG_TYPE_NON_UNIT;
  185. TORCH_CUDASPARSE_CHECK(cusparseSpMatSetAttribute(
  186. this->descriptor(),
  187. CUSPARSE_SPMAT_DIAG_TYPE,
  188. &diag_type,
  189. sizeof(diag_type)));
  190. }
  191. #endif
  192. };
  193. #if AT_USE_CUSPARSE_GENERIC_SPSV()
  194. class TORCH_CUDA_CPP_API CuSparseSpSVDescriptor
  195. : public CuSparseDescriptor<cusparseSpSVDescr, &cusparseSpSV_destroyDescr> {
  196. public:
  197. CuSparseSpSVDescriptor() {
  198. cusparseSpSVDescr_t raw_descriptor;
  199. TORCH_CUDASPARSE_CHECK(cusparseSpSV_createDescr(&raw_descriptor));
  200. descriptor_.reset(raw_descriptor);
  201. }
  202. };
  203. #endif
  204. #if AT_USE_CUSPARSE_GENERIC_SPSM()
  205. class TORCH_CUDA_CPP_API CuSparseSpSMDescriptor
  206. : public CuSparseDescriptor<cusparseSpSMDescr, &cusparseSpSM_destroyDescr> {
  207. public:
  208. CuSparseSpSMDescriptor() {
  209. cusparseSpSMDescr_t raw_descriptor;
  210. TORCH_CUDASPARSE_CHECK(cusparseSpSM_createDescr(&raw_descriptor));
  211. descriptor_.reset(raw_descriptor);
  212. }
  213. };
  214. #endif
  215. #if (defined(USE_ROCM) && ROCM_VERSION >= 50200) || !defined(USE_ROCM)
  216. class TORCH_CUDA_CPP_API CuSparseSpGEMMDescriptor
  217. : public CuSparseDescriptor<cusparseSpGEMMDescr, &cusparseSpGEMM_destroyDescr> {
  218. public:
  219. CuSparseSpGEMMDescriptor() {
  220. cusparseSpGEMMDescr_t raw_descriptor;
  221. TORCH_CUDASPARSE_CHECK(cusparseSpGEMM_createDescr(&raw_descriptor));
  222. descriptor_.reset(raw_descriptor);
  223. }
  224. };
  225. #endif
  226. #endif // AT_USE_CUSPARSE_GENERIC_API() || AT_USE_HIPSPARSE_GENERIC_API()
  227. } // namespace sparse
  228. } // namespace cuda
  229. } // namespace at