CUDASparseBlas.h 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322
  1. #pragma once
  2. /*
  3. Provides a subset of cuSPARSE functions as templates:
  4. csrgeam2<scalar_t>(...)
  5. where scalar_t is double, float, c10::complex<double> or c10::complex<float>.
  6. The functions are available in at::cuda::sparse namespace.
  7. */
  8. #include <ATen/cuda/CUDAContext.h>
  9. #include <ATen/cuda/CUDASparse.h>
  10. namespace at {
  11. namespace cuda {
  12. namespace sparse {
  13. #define CUSPARSE_CSRGEAM2_BUFFERSIZE_ARGTYPES(scalar_t) \
  14. cusparseHandle_t handle, int m, int n, const scalar_t *alpha, \
  15. const cusparseMatDescr_t descrA, int nnzA, \
  16. const scalar_t *csrSortedValA, const int *csrSortedRowPtrA, \
  17. const int *csrSortedColIndA, const scalar_t *beta, \
  18. const cusparseMatDescr_t descrB, int nnzB, \
  19. const scalar_t *csrSortedValB, const int *csrSortedRowPtrB, \
  20. const int *csrSortedColIndB, const cusparseMatDescr_t descrC, \
  21. const scalar_t *csrSortedValC, const int *csrSortedRowPtrC, \
  22. const int *csrSortedColIndC, size_t *pBufferSizeInBytes
  23. template <typename scalar_t>
  24. inline void csrgeam2_bufferSizeExt(
  25. CUSPARSE_CSRGEAM2_BUFFERSIZE_ARGTYPES(scalar_t)) {
  26. TORCH_INTERNAL_ASSERT(
  27. false,
  28. "at::cuda::sparse::csrgeam2_bufferSizeExt: not implemented for ",
  29. typeid(scalar_t).name());
  30. }
  31. template <>
  32. void csrgeam2_bufferSizeExt<float>(
  33. CUSPARSE_CSRGEAM2_BUFFERSIZE_ARGTYPES(float));
  34. template <>
  35. void csrgeam2_bufferSizeExt<double>(
  36. CUSPARSE_CSRGEAM2_BUFFERSIZE_ARGTYPES(double));
  37. template <>
  38. void csrgeam2_bufferSizeExt<c10::complex<float>>(
  39. CUSPARSE_CSRGEAM2_BUFFERSIZE_ARGTYPES(c10::complex<float>));
  40. template <>
  41. void csrgeam2_bufferSizeExt<c10::complex<double>>(
  42. CUSPARSE_CSRGEAM2_BUFFERSIZE_ARGTYPES(c10::complex<double>));
  43. #define CUSPARSE_CSRGEAM2_NNZ_ARGTYPES() \
  44. cusparseHandle_t handle, int m, int n, const cusparseMatDescr_t descrA, \
  45. int nnzA, const int *csrSortedRowPtrA, const int *csrSortedColIndA, \
  46. const cusparseMatDescr_t descrB, int nnzB, const int *csrSortedRowPtrB, \
  47. const int *csrSortedColIndB, const cusparseMatDescr_t descrC, \
  48. int *csrSortedRowPtrC, int *nnzTotalDevHostPtr, void *workspace
  49. template <typename scalar_t>
  50. inline void csrgeam2Nnz(CUSPARSE_CSRGEAM2_NNZ_ARGTYPES()) {
  51. TORCH_CUDASPARSE_CHECK(cusparseXcsrgeam2Nnz(
  52. handle,
  53. m,
  54. n,
  55. descrA,
  56. nnzA,
  57. csrSortedRowPtrA,
  58. csrSortedColIndA,
  59. descrB,
  60. nnzB,
  61. csrSortedRowPtrB,
  62. csrSortedColIndB,
  63. descrC,
  64. csrSortedRowPtrC,
  65. nnzTotalDevHostPtr,
  66. workspace));
  67. }
  68. #define CUSPARSE_CSRGEAM2_ARGTYPES(scalar_t) \
  69. cusparseHandle_t handle, int m, int n, const scalar_t *alpha, \
  70. const cusparseMatDescr_t descrA, int nnzA, \
  71. const scalar_t *csrSortedValA, const int *csrSortedRowPtrA, \
  72. const int *csrSortedColIndA, const scalar_t *beta, \
  73. const cusparseMatDescr_t descrB, int nnzB, \
  74. const scalar_t *csrSortedValB, const int *csrSortedRowPtrB, \
  75. const int *csrSortedColIndB, const cusparseMatDescr_t descrC, \
  76. scalar_t *csrSortedValC, int *csrSortedRowPtrC, int *csrSortedColIndC, \
  77. void *pBuffer
  78. template <typename scalar_t>
  79. inline void csrgeam2(CUSPARSE_CSRGEAM2_ARGTYPES(scalar_t)) {
  80. TORCH_INTERNAL_ASSERT(
  81. false,
  82. "at::cuda::sparse::csrgeam2: not implemented for ",
  83. typeid(scalar_t).name());
  84. }
  85. template <>
  86. void csrgeam2<float>(CUSPARSE_CSRGEAM2_ARGTYPES(float));
  87. template <>
  88. void csrgeam2<double>(CUSPARSE_CSRGEAM2_ARGTYPES(double));
  89. template <>
  90. void csrgeam2<c10::complex<float>>(
  91. CUSPARSE_CSRGEAM2_ARGTYPES(c10::complex<float>));
  92. template <>
  93. void csrgeam2<c10::complex<double>>(
  94. CUSPARSE_CSRGEAM2_ARGTYPES(c10::complex<double>));
  95. #define CUSPARSE_BSRMM_ARGTYPES(scalar_t) \
  96. cusparseHandle_t handle, cusparseDirection_t dirA, \
  97. cusparseOperation_t transA, cusparseOperation_t transB, int mb, int n, \
  98. int kb, int nnzb, const scalar_t *alpha, \
  99. const cusparseMatDescr_t descrA, const scalar_t *bsrValA, \
  100. const int *bsrRowPtrA, const int *bsrColIndA, int blockDim, \
  101. const scalar_t *B, int ldb, const scalar_t *beta, scalar_t *C, int ldc
  102. template <typename scalar_t>
  103. inline void bsrmm(CUSPARSE_BSRMM_ARGTYPES(scalar_t)) {
  104. TORCH_INTERNAL_ASSERT(
  105. false,
  106. "at::cuda::sparse::bsrmm: not implemented for ",
  107. typeid(scalar_t).name());
  108. }
  109. template <>
  110. void bsrmm<float>(CUSPARSE_BSRMM_ARGTYPES(float));
  111. template <>
  112. void bsrmm<double>(CUSPARSE_BSRMM_ARGTYPES(double));
  113. template <>
  114. void bsrmm<c10::complex<float>>(CUSPARSE_BSRMM_ARGTYPES(c10::complex<float>));
  115. template <>
  116. void bsrmm<c10::complex<double>>(CUSPARSE_BSRMM_ARGTYPES(c10::complex<double>));
  117. #define CUSPARSE_BSRMV_ARGTYPES(scalar_t) \
  118. cusparseHandle_t handle, cusparseDirection_t dirA, \
  119. cusparseOperation_t transA, int mb, int nb, int nnzb, \
  120. const scalar_t *alpha, const cusparseMatDescr_t descrA, \
  121. const scalar_t *bsrValA, const int *bsrRowPtrA, const int *bsrColIndA, \
  122. int blockDim, const scalar_t *x, const scalar_t *beta, scalar_t *y
  123. template <typename scalar_t>
  124. inline void bsrmv(CUSPARSE_BSRMV_ARGTYPES(scalar_t)) {
  125. TORCH_INTERNAL_ASSERT(
  126. false,
  127. "at::cuda::sparse::bsrmv: not implemented for ",
  128. typeid(scalar_t).name());
  129. }
  130. template <>
  131. void bsrmv<float>(CUSPARSE_BSRMV_ARGTYPES(float));
  132. template <>
  133. void bsrmv<double>(CUSPARSE_BSRMV_ARGTYPES(double));
  134. template <>
  135. void bsrmv<c10::complex<float>>(CUSPARSE_BSRMV_ARGTYPES(c10::complex<float>));
  136. template <>
  137. void bsrmv<c10::complex<double>>(CUSPARSE_BSRMV_ARGTYPES(c10::complex<double>));
  138. #if AT_USE_HIPSPARSE_TRIANGULAR_SOLVE()
  139. #define CUSPARSE_BSRSV2_BUFFER_ARGTYPES(scalar_t) \
  140. cusparseHandle_t handle, cusparseDirection_t dirA, \
  141. cusparseOperation_t transA, int mb, int nnzb, \
  142. const cusparseMatDescr_t descrA, scalar_t *bsrValA, \
  143. const int *bsrRowPtrA, const int *bsrColIndA, int blockDim, \
  144. bsrsv2Info_t info, int *pBufferSizeInBytes
  145. template <typename scalar_t>
  146. inline void bsrsv2_bufferSize(CUSPARSE_BSRSV2_BUFFER_ARGTYPES(scalar_t)) {
  147. TORCH_INTERNAL_ASSERT(
  148. false,
  149. "at::cuda::sparse::bsrsv2_bufferSize: not implemented for ",
  150. typeid(scalar_t).name());
  151. }
  152. template <>
  153. void bsrsv2_bufferSize<float>(CUSPARSE_BSRSV2_BUFFER_ARGTYPES(float));
  154. template <>
  155. void bsrsv2_bufferSize<double>(CUSPARSE_BSRSV2_BUFFER_ARGTYPES(double));
  156. template <>
  157. void bsrsv2_bufferSize<c10::complex<float>>(
  158. CUSPARSE_BSRSV2_BUFFER_ARGTYPES(c10::complex<float>));
  159. template <>
  160. void bsrsv2_bufferSize<c10::complex<double>>(
  161. CUSPARSE_BSRSV2_BUFFER_ARGTYPES(c10::complex<double>));
  162. #define CUSPARSE_BSRSV2_ANALYSIS_ARGTYPES(scalar_t) \
  163. cusparseHandle_t handle, cusparseDirection_t dirA, \
  164. cusparseOperation_t transA, int mb, int nnzb, \
  165. const cusparseMatDescr_t descrA, const scalar_t *bsrValA, \
  166. const int *bsrRowPtrA, const int *bsrColIndA, int blockDim, \
  167. bsrsv2Info_t info, cusparseSolvePolicy_t policy, void *pBuffer
  168. template <typename scalar_t>
  169. inline void bsrsv2_analysis(CUSPARSE_BSRSV2_ANALYSIS_ARGTYPES(scalar_t)) {
  170. TORCH_INTERNAL_ASSERT(
  171. false,
  172. "at::cuda::sparse::bsrsv2_analysis: not implemented for ",
  173. typeid(scalar_t).name());
  174. }
  175. template <>
  176. void bsrsv2_analysis<float>(CUSPARSE_BSRSV2_ANALYSIS_ARGTYPES(float));
  177. template <>
  178. void bsrsv2_analysis<double>(CUSPARSE_BSRSV2_ANALYSIS_ARGTYPES(double));
  179. template <>
  180. void bsrsv2_analysis<c10::complex<float>>(
  181. CUSPARSE_BSRSV2_ANALYSIS_ARGTYPES(c10::complex<float>));
  182. template <>
  183. void bsrsv2_analysis<c10::complex<double>>(
  184. CUSPARSE_BSRSV2_ANALYSIS_ARGTYPES(c10::complex<double>));
  185. #define CUSPARSE_BSRSV2_SOLVE_ARGTYPES(scalar_t) \
  186. cusparseHandle_t handle, cusparseDirection_t dirA, \
  187. cusparseOperation_t transA, int mb, int nnzb, const scalar_t *alpha, \
  188. const cusparseMatDescr_t descrA, const scalar_t *bsrValA, \
  189. const int *bsrRowPtrA, const int *bsrColIndA, int blockDim, \
  190. bsrsv2Info_t info, const scalar_t *x, scalar_t *y, \
  191. cusparseSolvePolicy_t policy, void *pBuffer
  192. template <typename scalar_t>
  193. inline void bsrsv2_solve(CUSPARSE_BSRSV2_SOLVE_ARGTYPES(scalar_t)) {
  194. TORCH_INTERNAL_ASSERT(
  195. false,
  196. "at::cuda::sparse::bsrsv2_solve: not implemented for ",
  197. typeid(scalar_t).name());
  198. }
  199. template <>
  200. void bsrsv2_solve<float>(CUSPARSE_BSRSV2_SOLVE_ARGTYPES(float));
  201. template <>
  202. void bsrsv2_solve<double>(CUSPARSE_BSRSV2_SOLVE_ARGTYPES(double));
  203. template <>
  204. void bsrsv2_solve<c10::complex<float>>(
  205. CUSPARSE_BSRSV2_SOLVE_ARGTYPES(c10::complex<float>));
  206. template <>
  207. void bsrsv2_solve<c10::complex<double>>(
  208. CUSPARSE_BSRSV2_SOLVE_ARGTYPES(c10::complex<double>));
  209. #define CUSPARSE_BSRSM2_BUFFER_ARGTYPES(scalar_t) \
  210. cusparseHandle_t handle, cusparseDirection_t dirA, \
  211. cusparseOperation_t transA, cusparseOperation_t transX, int mb, int n, \
  212. int nnzb, const cusparseMatDescr_t descrA, scalar_t *bsrValA, \
  213. const int *bsrRowPtrA, const int *bsrColIndA, int blockDim, \
  214. bsrsm2Info_t info, int *pBufferSizeInBytes
  215. template <typename scalar_t>
  216. inline void bsrsm2_bufferSize(CUSPARSE_BSRSM2_BUFFER_ARGTYPES(scalar_t)) {
  217. TORCH_INTERNAL_ASSERT(
  218. false,
  219. "at::cuda::sparse::bsrsm2_bufferSize: not implemented for ",
  220. typeid(scalar_t).name());
  221. }
  222. template <>
  223. void bsrsm2_bufferSize<float>(CUSPARSE_BSRSM2_BUFFER_ARGTYPES(float));
  224. template <>
  225. void bsrsm2_bufferSize<double>(CUSPARSE_BSRSM2_BUFFER_ARGTYPES(double));
  226. template <>
  227. void bsrsm2_bufferSize<c10::complex<float>>(
  228. CUSPARSE_BSRSM2_BUFFER_ARGTYPES(c10::complex<float>));
  229. template <>
  230. void bsrsm2_bufferSize<c10::complex<double>>(
  231. CUSPARSE_BSRSM2_BUFFER_ARGTYPES(c10::complex<double>));
  232. #define CUSPARSE_BSRSM2_ANALYSIS_ARGTYPES(scalar_t) \
  233. cusparseHandle_t handle, cusparseDirection_t dirA, \
  234. cusparseOperation_t transA, cusparseOperation_t transX, int mb, int n, \
  235. int nnzb, const cusparseMatDescr_t descrA, const scalar_t *bsrValA, \
  236. const int *bsrRowPtrA, const int *bsrColIndA, int blockDim, \
  237. bsrsm2Info_t info, cusparseSolvePolicy_t policy, void *pBuffer
  238. template <typename scalar_t>
  239. inline void bsrsm2_analysis(CUSPARSE_BSRSM2_ANALYSIS_ARGTYPES(scalar_t)) {
  240. TORCH_INTERNAL_ASSERT(
  241. false,
  242. "at::cuda::sparse::bsrsm2_analysis: not implemented for ",
  243. typeid(scalar_t).name());
  244. }
  245. template <>
  246. void bsrsm2_analysis<float>(CUSPARSE_BSRSM2_ANALYSIS_ARGTYPES(float));
  247. template <>
  248. void bsrsm2_analysis<double>(CUSPARSE_BSRSM2_ANALYSIS_ARGTYPES(double));
  249. template <>
  250. void bsrsm2_analysis<c10::complex<float>>(
  251. CUSPARSE_BSRSM2_ANALYSIS_ARGTYPES(c10::complex<float>));
  252. template <>
  253. void bsrsm2_analysis<c10::complex<double>>(
  254. CUSPARSE_BSRSM2_ANALYSIS_ARGTYPES(c10::complex<double>));
  255. #define CUSPARSE_BSRSM2_SOLVE_ARGTYPES(scalar_t) \
  256. cusparseHandle_t handle, cusparseDirection_t dirA, \
  257. cusparseOperation_t transA, cusparseOperation_t transX, int mb, int n, \
  258. int nnzb, const scalar_t *alpha, const cusparseMatDescr_t descrA, \
  259. const scalar_t *bsrValA, const int *bsrRowPtrA, const int *bsrColIndA, \
  260. int blockDim, bsrsm2Info_t info, const scalar_t *B, int ldb, \
  261. scalar_t *X, int ldx, cusparseSolvePolicy_t policy, void *pBuffer
  262. template <typename scalar_t>
  263. inline void bsrsm2_solve(CUSPARSE_BSRSM2_SOLVE_ARGTYPES(scalar_t)) {
  264. TORCH_INTERNAL_ASSERT(
  265. false,
  266. "at::cuda::sparse::bsrsm2_solve: not implemented for ",
  267. typeid(scalar_t).name());
  268. }
  269. template <>
  270. void bsrsm2_solve<float>(CUSPARSE_BSRSM2_SOLVE_ARGTYPES(float));
  271. template <>
  272. void bsrsm2_solve<double>(CUSPARSE_BSRSM2_SOLVE_ARGTYPES(double));
  273. template <>
  274. void bsrsm2_solve<c10::complex<float>>(
  275. CUSPARSE_BSRSM2_SOLVE_ARGTYPES(c10::complex<float>));
  276. template <>
  277. void bsrsm2_solve<c10::complex<double>>(
  278. CUSPARSE_BSRSM2_SOLVE_ARGTYPES(c10::complex<double>));
  279. #endif // AT_USE_HIPSPARSE_TRIANGULAR_SOLVE
  280. } // namespace sparse
  281. } // namespace cuda
  282. } // namespace at