CUDABlas.h 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305
  1. #pragma once
  2. /*
  3. Provides a subset of CUDA BLAS functions as templates:
  4. gemm<Dtype>(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c,
  5. ldc)
  6. gemv<Dtype>(transa, m, n, alpha, a, lda, x, incx, beta, y, incy)
  7. dot<Dtype>(n, x, incx, y, incy, result)
  8. where Dtype is double, float, at::Half or at::BFloat16 (ROCm, NOT for dot).
  9. The functions are available in at::cuda::blas namespace.
  10. */
  11. #include <ATen/cuda/CUDAContext.h>
  12. #include <ATen/OpMathType.h>
  13. namespace at {
  14. namespace cuda {
  15. namespace blas {
  16. // RAII guard that sets the CuBLAS pointer mode and restores it to
  17. // its previous value when the guard is destroyed
  18. class PointerModeGuard {
  19. public:
  20. PointerModeGuard(cublasHandle_t handle, cublasPointerMode_t mode) :
  21. handle(handle) {
  22. TORCH_CUDABLAS_CHECK(cublasGetPointerMode(handle, &previous_mode));
  23. TORCH_CUDABLAS_CHECK(cublasSetPointerMode(handle, mode));
  24. }
  25. ~PointerModeGuard() {
  26. cublasSetPointerMode(handle, previous_mode);
  27. }
  28. private:
  29. cublasHandle_t handle;
  30. cublasPointerMode_t previous_mode;
  31. };
  32. /* LEVEL 3 BLAS FUNCTIONS */
  33. #define CUDABLAS_GEMM_ARGTYPES(Dtype) \
  34. char transa, char transb, int64_t m, int64_t n, int64_t k, at::opmath_type<Dtype> alpha, \
  35. const Dtype *a, int64_t lda, const Dtype *b, int64_t ldb, at::opmath_type<Dtype> beta,\
  36. Dtype *c, int64_t ldc
  37. template <typename Dtype>
  38. inline void gemm(CUDABLAS_GEMM_ARGTYPES(Dtype)) {
  39. AT_ERROR("at::cuda::blas::gemm: not implemented for ", typeid(Dtype).name());
  40. }
  41. template <>
  42. void gemm<double>(CUDABLAS_GEMM_ARGTYPES(double));
  43. template <>
  44. void gemm<float>(CUDABLAS_GEMM_ARGTYPES(float));
  45. #if !defined(USE_ROCM) || (defined(USE_ROCM) && ROCM_VERSION >= 21000)
  46. template <>
  47. void gemm<c10::complex<double>>(CUDABLAS_GEMM_ARGTYPES(c10::complex<double>));
  48. #endif
  49. #if !defined(USE_ROCM) || (defined(USE_ROCM) && ROCM_VERSION >= 21000)
  50. template <>
  51. void gemm<c10::complex<float>>(CUDABLAS_GEMM_ARGTYPES(c10::complex<float>));
  52. #endif
  53. template <>
  54. void gemm<at::Half>(CUDABLAS_GEMM_ARGTYPES(at::Half));
  55. template <>
  56. void gemm<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16));
  57. #if !defined(USE_ROCM) && !defined(_MSC_VER)
  58. enum GEMMAndBiasActivationEpilogue {
  59. None,
  60. RELU,
  61. GELU,
  62. };
  63. // NOTE: GELU activation is not supported prior to CUDA 11.4 and will
  64. // do nothing if passed in that case.
  65. template <typename Dtype>
  66. void gemm_and_bias(
  67. bool transpose_mat1,
  68. bool transpose_mat2,
  69. int64_t m,
  70. int64_t n,
  71. int64_t k,
  72. at::opmath_type<Dtype> alpha_val,
  73. const Dtype* mat1_ptr,
  74. int64_t mat1_ld,
  75. const Dtype* mat2_ptr,
  76. int64_t mat2_ld,
  77. const Dtype* bias,
  78. Dtype* result_ptr,
  79. int64_t result_ld,
  80. GEMMAndBiasActivationEpilogue activation = GEMMAndBiasActivationEpilogue::None);
  81. #endif
  82. #define CUDABLAS_BGEMM_ARGTYPES(Dtype) \
  83. char transa, char transb, int64_t m, int64_t n, int64_t k, at::opmath_type<Dtype> alpha, \
  84. const Dtype *a, int64_t lda, int64_t stridea, \
  85. const Dtype *b, int64_t ldb, int64_t strideb, \
  86. at::opmath_type<Dtype> beta, Dtype *c, int64_t ldc, int64_t stridec, int64_t num_batches
  87. template <typename Dtype>
  88. inline void bgemm(CUDABLAS_BGEMM_ARGTYPES(Dtype)) {
  89. AT_ERROR("at::cuda::blas::bgemm: not implemented for ", typeid(Dtype).name());
  90. }
  91. template <>
  92. void bgemm<double>(CUDABLAS_BGEMM_ARGTYPES(double));
  93. template <>
  94. void bgemm<float>(CUDABLAS_BGEMM_ARGTYPES(float));
  95. template <>
  96. void bgemm<c10::complex<double>>(CUDABLAS_BGEMM_ARGTYPES(c10::complex<double>));
  97. template <>
  98. void bgemm<c10::complex<float>>(CUDABLAS_BGEMM_ARGTYPES(c10::complex<float>));
  99. template <>
  100. void bgemm<at::Half>(CUDABLAS_BGEMM_ARGTYPES(at::Half));
  101. template <>
  102. void bgemm<at::BFloat16>(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16));
  103. #define CUDABLAS_TRSM_ARGTYPES(Dtype) \
  104. cublasHandle_t handle, cublasSideMode_t side, cublasFillMode_t uplo, \
  105. cublasOperation_t trans, cublasDiagType_t diag, int m, int n, \
  106. const Dtype *alpha, const Dtype *A, int lda, Dtype *B, int ldb
  107. template <typename Dtype>
  108. inline void trsm(CUDABLAS_TRSM_ARGTYPES(Dtype)) {
  109. TORCH_INTERNAL_ASSERT(false, "at::cuda::blas::trsm: not implemented for ", typeid(Dtype).name());
  110. }
  111. template <>
  112. TORCH_CUDA_CU_API void trsm<float>(CUDABLAS_TRSM_ARGTYPES(float));
  113. template <>
  114. TORCH_CUDA_CU_API void trsm<double>(CUDABLAS_TRSM_ARGTYPES(double));
  115. template <>
  116. TORCH_CUDA_CU_API void trsm<c10::complex<float>>(CUDABLAS_TRSM_ARGTYPES(c10::complex<float>));
  117. template <>
  118. TORCH_CUDA_CU_API void trsm<c10::complex<double>>(CUDABLAS_TRSM_ARGTYPES(c10::complex<double>));
  119. #define CUDABLAS_TRSM_BATCHED_ARGTYPES(Dtype) \
  120. cublasHandle_t handle, cublasSideMode_t side, cublasFillMode_t uplo, \
  121. cublasOperation_t trans, cublasDiagType_t diag, int m, int n, \
  122. const Dtype *alpha, Dtype *A[], int lda, Dtype *B[], int ldb, \
  123. int batchCount
  124. template <typename Dtype>
  125. inline void trsmBatched(CUDABLAS_TRSM_BATCHED_ARGTYPES(Dtype)) {
  126. TORCH_INTERNAL_ASSERT(
  127. false,
  128. "at::cuda::blas::trsmBatched: not implemented for ",
  129. typeid(Dtype).name());
  130. }
  131. template <>
  132. TORCH_CUDA_CU_API void trsmBatched<float>(CUDABLAS_TRSM_BATCHED_ARGTYPES(float));
  133. template <>
  134. TORCH_CUDA_CU_API void trsmBatched<double>(CUDABLAS_TRSM_BATCHED_ARGTYPES(double));
  135. template <>
  136. TORCH_CUDA_CU_API void trsmBatched<c10::complex<float>>(CUDABLAS_TRSM_BATCHED_ARGTYPES(c10::complex<float>));
  137. template <>
  138. TORCH_CUDA_CU_API void trsmBatched<c10::complex<double>>(CUDABLAS_TRSM_BATCHED_ARGTYPES(c10::complex<double>));
  139. /* LEVEL 2 BLAS FUNCTIONS */
  140. #define CUDABLAS_GEMV_ARGTYPES(Dtype) \
  141. char trans, int64_t m, int64_t n, Dtype alpha, const Dtype *a, int64_t lda, \
  142. const Dtype *x, int64_t incx, Dtype beta, Dtype *y, int64_t incy
  143. template <typename Dtype>
  144. inline void gemv(CUDABLAS_GEMV_ARGTYPES(Dtype)) {
  145. AT_ERROR("at::cuda::blas::gemv: not implemented for ", typeid(Dtype).name());
  146. }
  147. template <>
  148. void gemv<double>(CUDABLAS_GEMV_ARGTYPES(double));
  149. template <>
  150. void gemv<float>(CUDABLAS_GEMV_ARGTYPES(float));
  151. #if !defined(USE_ROCM) || (defined(USE_ROCM) && ROCM_VERSION >= 21000)
  152. template <>
  153. void gemv<c10::complex<double>>(CUDABLAS_GEMV_ARGTYPES(c10::complex<double>));
  154. template <>
  155. void gemv<c10::complex<float>>(CUDABLAS_GEMV_ARGTYPES(c10::complex<float>));
  156. #endif
  157. template <>
  158. void gemv<at::Half>(CUDABLAS_GEMV_ARGTYPES(at::Half));
  159. template <>
  160. void gemv<at::BFloat16>(CUDABLAS_GEMV_ARGTYPES(at::BFloat16));
  161. /* LEVEL 1 BLAS FUNCTIONS */
  162. #define CUDABLAS_DOT_ARGTYPES(Dtype) \
  163. cublasHandle_t handle, int n, const Dtype *x, int incx, const Dtype *y, \
  164. int incy, Dtype *result
  165. template <typename Dtype>
  166. inline void dot(CUDABLAS_DOT_ARGTYPES(Dtype)) {
  167. AT_ERROR("at::cuda::blas::dot: not implemented for ", typeid(Dtype).name());
  168. }
  169. template <>
  170. void dot<double>(CUDABLAS_DOT_ARGTYPES(double));
  171. template <>
  172. void dot<float>(CUDABLAS_DOT_ARGTYPES(float));
  173. template <>
  174. void dot<at::Half>(CUDABLAS_DOT_ARGTYPES(at::Half));
  175. template <>
  176. void dot<at::BFloat16>(CUDABLAS_DOT_ARGTYPES(at::BFloat16));
  177. template <>
  178. void dot<c10::complex<double>>(CUDABLAS_DOT_ARGTYPES(c10::complex<double>));
  179. template <>
  180. void dot<c10::complex<float>>(CUDABLAS_DOT_ARGTYPES(c10::complex<float>));
  181. template <typename Dtype>
  182. inline void vdot(CUDABLAS_DOT_ARGTYPES(Dtype)) {
  183. AT_ERROR("at::cuda::blas::vdot: not implemented for ", typeid(Dtype).name());
  184. }
  185. template <>
  186. void vdot<c10::complex<float>>(CUDABLAS_DOT_ARGTYPES(c10::complex<float>));
  187. template <>
  188. void vdot<c10::complex<double>>(CUDABLAS_DOT_ARGTYPES(c10::complex<double>));
  189. // This guards blocks use of getrsBatched, geqrfBatched, getrfBatched on platforms other than cuda
  190. #ifdef CUDART_VERSION
  191. #define CUDABLAS_GETRS_ARGTYPES(Dtype) \
  192. cublasHandle_t handle, cublasOperation_t trans, \
  193. int n, int nrhs, Dtype** dA_array, int lda, int* ipiv_array, \
  194. Dtype** dB_array, int ldb, int* info_array, int batchsize
  195. template<class Dtype>
  196. void getrsBatched(CUDABLAS_GETRS_ARGTYPES(Dtype)) {
  197. TORCH_INTERNAL_ASSERT(false, "at::cuda::blas::getrsBatched: not implemented for ",
  198. typeid(Dtype).name());
  199. }
  200. template<>
  201. TORCH_CUDA_CU_API void getrsBatched<float>(CUDABLAS_GETRS_ARGTYPES(float));
  202. template<>
  203. TORCH_CUDA_CU_API void getrsBatched<double>(CUDABLAS_GETRS_ARGTYPES(double));
  204. template<>
  205. TORCH_CUDA_CU_API void getrsBatched<c10::complex<float>>(CUDABLAS_GETRS_ARGTYPES(c10::complex<float>));
  206. template<>
  207. TORCH_CUDA_CU_API void getrsBatched<c10::complex<double>>(CUDABLAS_GETRS_ARGTYPES(c10::complex<double>));
  208. #define CUDABLAS_GEQRF_BATCHED_ARGTYPES(Dtype) \
  209. cublasHandle_t handle, int m, int n, Dtype **A_array, int lda, \
  210. Dtype **tau_array, int *info, int batchsize
  211. template <class Dtype>
  212. void geqrfBatched(CUDABLAS_GEQRF_BATCHED_ARGTYPES(Dtype)) {
  213. TORCH_INTERNAL_ASSERT(
  214. false,
  215. "at::cuda::blas::geqrfBatched: not implemented for ",
  216. typeid(Dtype).name());
  217. }
  218. template <>
  219. TORCH_CUDA_CU_API void geqrfBatched<float>(CUDABLAS_GEQRF_BATCHED_ARGTYPES(float));
  220. template <>
  221. TORCH_CUDA_CU_API void geqrfBatched<double>(CUDABLAS_GEQRF_BATCHED_ARGTYPES(double));
  222. template <>
  223. TORCH_CUDA_CU_API void geqrfBatched<c10::complex<double>>(
  224. CUDABLAS_GEQRF_BATCHED_ARGTYPES(c10::complex<double>));
  225. template <>
  226. TORCH_CUDA_CU_API void geqrfBatched<c10::complex<float>>(
  227. CUDABLAS_GEQRF_BATCHED_ARGTYPES(c10::complex<float>));
  228. #define CUDABLAS_GETRF_ARGTYPES(Dtype) \
  229. int n, Dtype** dA_array, int ldda, int* ipiv_array, int* info_array, int batchsize
  230. template<class Dtype>
  231. void getrfBatched(CUDABLAS_GETRF_ARGTYPES(Dtype)) {
  232. TORCH_CHECK(false, "at::cuda::blas::getrfBatched: not implemented for ", typeid(Dtype).name());
  233. }
  234. template<>
  235. TORCH_CUDA_CU_API void getrfBatched<float>(CUDABLAS_GETRF_ARGTYPES(float));
  236. template<>
  237. TORCH_CUDA_CU_API void getrfBatched<double>(CUDABLAS_GETRF_ARGTYPES(double));
  238. template<>
  239. TORCH_CUDA_CU_API void getrfBatched<c10::complex<double>>(CUDABLAS_GETRF_ARGTYPES(c10::complex<double>));
  240. template<>
  241. TORCH_CUDA_CU_API void getrfBatched<c10::complex<float>>(CUDABLAS_GETRF_ARGTYPES(c10::complex<float>));
  242. #define CUDABLAS_GELS_BATCHED_ARGTYPES(Dtype) \
  243. cublasHandle_t handle, cublasOperation_t trans, int m, int n, int nrhs, Dtype** dA_array, int ldda, Dtype** dC_array, int lddc, int* info, int *devInfoArray, int batchSize
  244. template <class Dtype>
  245. void gelsBatched(CUDABLAS_GELS_BATCHED_ARGTYPES(Dtype)) {
  246. TORCH_INTERNAL_ASSERT(false, "at::cuda::blas::gelsBatched: not implemented for ", typeid(Dtype).name());
  247. }
  248. template<>
  249. TORCH_CUDA_CU_API void gelsBatched<double>(CUDABLAS_GELS_BATCHED_ARGTYPES(double));
  250. template<>
  251. TORCH_CUDA_CU_API void gelsBatched<float>(CUDABLAS_GELS_BATCHED_ARGTYPES(float));
  252. template<>
  253. TORCH_CUDA_CU_API void gelsBatched<c10::complex<double>>(CUDABLAS_GELS_BATCHED_ARGTYPES(c10::complex<double>));
  254. template<>
  255. TORCH_CUDA_CU_API void gelsBatched<c10::complex<float>>(CUDABLAS_GELS_BATCHED_ARGTYPES(c10::complex<float>));
  256. #endif // CUDART_VERSION
  257. } // namespace blas
  258. } // namespace cuda
  259. } // namespace at