123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322 |
- #pragma once
- /*
- Provides a subset of cuSPARSE functions as templates:
- csrgeam2<scalar_t>(...)
- where scalar_t is double, float, c10::complex<double> or c10::complex<float>.
- The functions are available in at::cuda::sparse namespace.
- */
- #include <ATen/cuda/CUDAContext.h>
- #include <ATen/cuda/CUDASparse.h>
- namespace at {
- namespace cuda {
- namespace sparse {
- #define CUSPARSE_CSRGEAM2_BUFFERSIZE_ARGTYPES(scalar_t) \
- cusparseHandle_t handle, int m, int n, const scalar_t *alpha, \
- const cusparseMatDescr_t descrA, int nnzA, \
- const scalar_t *csrSortedValA, const int *csrSortedRowPtrA, \
- const int *csrSortedColIndA, const scalar_t *beta, \
- const cusparseMatDescr_t descrB, int nnzB, \
- const scalar_t *csrSortedValB, const int *csrSortedRowPtrB, \
- const int *csrSortedColIndB, const cusparseMatDescr_t descrC, \
- const scalar_t *csrSortedValC, const int *csrSortedRowPtrC, \
- const int *csrSortedColIndC, size_t *pBufferSizeInBytes
- template <typename scalar_t>
- inline void csrgeam2_bufferSizeExt(
- CUSPARSE_CSRGEAM2_BUFFERSIZE_ARGTYPES(scalar_t)) {
- TORCH_INTERNAL_ASSERT(
- false,
- "at::cuda::sparse::csrgeam2_bufferSizeExt: not implemented for ",
- typeid(scalar_t).name());
- }
- template <>
- void csrgeam2_bufferSizeExt<float>(
- CUSPARSE_CSRGEAM2_BUFFERSIZE_ARGTYPES(float));
- template <>
- void csrgeam2_bufferSizeExt<double>(
- CUSPARSE_CSRGEAM2_BUFFERSIZE_ARGTYPES(double));
- template <>
- void csrgeam2_bufferSizeExt<c10::complex<float>>(
- CUSPARSE_CSRGEAM2_BUFFERSIZE_ARGTYPES(c10::complex<float>));
- template <>
- void csrgeam2_bufferSizeExt<c10::complex<double>>(
- CUSPARSE_CSRGEAM2_BUFFERSIZE_ARGTYPES(c10::complex<double>));
- #define CUSPARSE_CSRGEAM2_NNZ_ARGTYPES() \
- cusparseHandle_t handle, int m, int n, const cusparseMatDescr_t descrA, \
- int nnzA, const int *csrSortedRowPtrA, const int *csrSortedColIndA, \
- const cusparseMatDescr_t descrB, int nnzB, const int *csrSortedRowPtrB, \
- const int *csrSortedColIndB, const cusparseMatDescr_t descrC, \
- int *csrSortedRowPtrC, int *nnzTotalDevHostPtr, void *workspace
- template <typename scalar_t>
- inline void csrgeam2Nnz(CUSPARSE_CSRGEAM2_NNZ_ARGTYPES()) {
- TORCH_CUDASPARSE_CHECK(cusparseXcsrgeam2Nnz(
- handle,
- m,
- n,
- descrA,
- nnzA,
- csrSortedRowPtrA,
- csrSortedColIndA,
- descrB,
- nnzB,
- csrSortedRowPtrB,
- csrSortedColIndB,
- descrC,
- csrSortedRowPtrC,
- nnzTotalDevHostPtr,
- workspace));
- }
- #define CUSPARSE_CSRGEAM2_ARGTYPES(scalar_t) \
- cusparseHandle_t handle, int m, int n, const scalar_t *alpha, \
- const cusparseMatDescr_t descrA, int nnzA, \
- const scalar_t *csrSortedValA, const int *csrSortedRowPtrA, \
- const int *csrSortedColIndA, const scalar_t *beta, \
- const cusparseMatDescr_t descrB, int nnzB, \
- const scalar_t *csrSortedValB, const int *csrSortedRowPtrB, \
- const int *csrSortedColIndB, const cusparseMatDescr_t descrC, \
- scalar_t *csrSortedValC, int *csrSortedRowPtrC, int *csrSortedColIndC, \
- void *pBuffer
- template <typename scalar_t>
- inline void csrgeam2(CUSPARSE_CSRGEAM2_ARGTYPES(scalar_t)) {
- TORCH_INTERNAL_ASSERT(
- false,
- "at::cuda::sparse::csrgeam2: not implemented for ",
- typeid(scalar_t).name());
- }
- template <>
- void csrgeam2<float>(CUSPARSE_CSRGEAM2_ARGTYPES(float));
- template <>
- void csrgeam2<double>(CUSPARSE_CSRGEAM2_ARGTYPES(double));
- template <>
- void csrgeam2<c10::complex<float>>(
- CUSPARSE_CSRGEAM2_ARGTYPES(c10::complex<float>));
- template <>
- void csrgeam2<c10::complex<double>>(
- CUSPARSE_CSRGEAM2_ARGTYPES(c10::complex<double>));
- #define CUSPARSE_BSRMM_ARGTYPES(scalar_t) \
- cusparseHandle_t handle, cusparseDirection_t dirA, \
- cusparseOperation_t transA, cusparseOperation_t transB, int mb, int n, \
- int kb, int nnzb, const scalar_t *alpha, \
- const cusparseMatDescr_t descrA, const scalar_t *bsrValA, \
- const int *bsrRowPtrA, const int *bsrColIndA, int blockDim, \
- const scalar_t *B, int ldb, const scalar_t *beta, scalar_t *C, int ldc
- template <typename scalar_t>
- inline void bsrmm(CUSPARSE_BSRMM_ARGTYPES(scalar_t)) {
- TORCH_INTERNAL_ASSERT(
- false,
- "at::cuda::sparse::bsrmm: not implemented for ",
- typeid(scalar_t).name());
- }
- template <>
- void bsrmm<float>(CUSPARSE_BSRMM_ARGTYPES(float));
- template <>
- void bsrmm<double>(CUSPARSE_BSRMM_ARGTYPES(double));
- template <>
- void bsrmm<c10::complex<float>>(CUSPARSE_BSRMM_ARGTYPES(c10::complex<float>));
- template <>
- void bsrmm<c10::complex<double>>(CUSPARSE_BSRMM_ARGTYPES(c10::complex<double>));
- #define CUSPARSE_BSRMV_ARGTYPES(scalar_t) \
- cusparseHandle_t handle, cusparseDirection_t dirA, \
- cusparseOperation_t transA, int mb, int nb, int nnzb, \
- const scalar_t *alpha, const cusparseMatDescr_t descrA, \
- const scalar_t *bsrValA, const int *bsrRowPtrA, const int *bsrColIndA, \
- int blockDim, const scalar_t *x, const scalar_t *beta, scalar_t *y
- template <typename scalar_t>
- inline void bsrmv(CUSPARSE_BSRMV_ARGTYPES(scalar_t)) {
- TORCH_INTERNAL_ASSERT(
- false,
- "at::cuda::sparse::bsrmv: not implemented for ",
- typeid(scalar_t).name());
- }
- template <>
- void bsrmv<float>(CUSPARSE_BSRMV_ARGTYPES(float));
- template <>
- void bsrmv<double>(CUSPARSE_BSRMV_ARGTYPES(double));
- template <>
- void bsrmv<c10::complex<float>>(CUSPARSE_BSRMV_ARGTYPES(c10::complex<float>));
- template <>
- void bsrmv<c10::complex<double>>(CUSPARSE_BSRMV_ARGTYPES(c10::complex<double>));
- #if AT_USE_HIPSPARSE_TRIANGULAR_SOLVE()
- #define CUSPARSE_BSRSV2_BUFFER_ARGTYPES(scalar_t) \
- cusparseHandle_t handle, cusparseDirection_t dirA, \
- cusparseOperation_t transA, int mb, int nnzb, \
- const cusparseMatDescr_t descrA, scalar_t *bsrValA, \
- const int *bsrRowPtrA, const int *bsrColIndA, int blockDim, \
- bsrsv2Info_t info, int *pBufferSizeInBytes
- template <typename scalar_t>
- inline void bsrsv2_bufferSize(CUSPARSE_BSRSV2_BUFFER_ARGTYPES(scalar_t)) {
- TORCH_INTERNAL_ASSERT(
- false,
- "at::cuda::sparse::bsrsv2_bufferSize: not implemented for ",
- typeid(scalar_t).name());
- }
- template <>
- void bsrsv2_bufferSize<float>(CUSPARSE_BSRSV2_BUFFER_ARGTYPES(float));
- template <>
- void bsrsv2_bufferSize<double>(CUSPARSE_BSRSV2_BUFFER_ARGTYPES(double));
- template <>
- void bsrsv2_bufferSize<c10::complex<float>>(
- CUSPARSE_BSRSV2_BUFFER_ARGTYPES(c10::complex<float>));
- template <>
- void bsrsv2_bufferSize<c10::complex<double>>(
- CUSPARSE_BSRSV2_BUFFER_ARGTYPES(c10::complex<double>));
- #define CUSPARSE_BSRSV2_ANALYSIS_ARGTYPES(scalar_t) \
- cusparseHandle_t handle, cusparseDirection_t dirA, \
- cusparseOperation_t transA, int mb, int nnzb, \
- const cusparseMatDescr_t descrA, const scalar_t *bsrValA, \
- const int *bsrRowPtrA, const int *bsrColIndA, int blockDim, \
- bsrsv2Info_t info, cusparseSolvePolicy_t policy, void *pBuffer
- template <typename scalar_t>
- inline void bsrsv2_analysis(CUSPARSE_BSRSV2_ANALYSIS_ARGTYPES(scalar_t)) {
- TORCH_INTERNAL_ASSERT(
- false,
- "at::cuda::sparse::bsrsv2_analysis: not implemented for ",
- typeid(scalar_t).name());
- }
- template <>
- void bsrsv2_analysis<float>(CUSPARSE_BSRSV2_ANALYSIS_ARGTYPES(float));
- template <>
- void bsrsv2_analysis<double>(CUSPARSE_BSRSV2_ANALYSIS_ARGTYPES(double));
- template <>
- void bsrsv2_analysis<c10::complex<float>>(
- CUSPARSE_BSRSV2_ANALYSIS_ARGTYPES(c10::complex<float>));
- template <>
- void bsrsv2_analysis<c10::complex<double>>(
- CUSPARSE_BSRSV2_ANALYSIS_ARGTYPES(c10::complex<double>));
- #define CUSPARSE_BSRSV2_SOLVE_ARGTYPES(scalar_t) \
- cusparseHandle_t handle, cusparseDirection_t dirA, \
- cusparseOperation_t transA, int mb, int nnzb, const scalar_t *alpha, \
- const cusparseMatDescr_t descrA, const scalar_t *bsrValA, \
- const int *bsrRowPtrA, const int *bsrColIndA, int blockDim, \
- bsrsv2Info_t info, const scalar_t *x, scalar_t *y, \
- cusparseSolvePolicy_t policy, void *pBuffer
- template <typename scalar_t>
- inline void bsrsv2_solve(CUSPARSE_BSRSV2_SOLVE_ARGTYPES(scalar_t)) {
- TORCH_INTERNAL_ASSERT(
- false,
- "at::cuda::sparse::bsrsv2_solve: not implemented for ",
- typeid(scalar_t).name());
- }
- template <>
- void bsrsv2_solve<float>(CUSPARSE_BSRSV2_SOLVE_ARGTYPES(float));
- template <>
- void bsrsv2_solve<double>(CUSPARSE_BSRSV2_SOLVE_ARGTYPES(double));
- template <>
- void bsrsv2_solve<c10::complex<float>>(
- CUSPARSE_BSRSV2_SOLVE_ARGTYPES(c10::complex<float>));
- template <>
- void bsrsv2_solve<c10::complex<double>>(
- CUSPARSE_BSRSV2_SOLVE_ARGTYPES(c10::complex<double>));
- #define CUSPARSE_BSRSM2_BUFFER_ARGTYPES(scalar_t) \
- cusparseHandle_t handle, cusparseDirection_t dirA, \
- cusparseOperation_t transA, cusparseOperation_t transX, int mb, int n, \
- int nnzb, const cusparseMatDescr_t descrA, scalar_t *bsrValA, \
- const int *bsrRowPtrA, const int *bsrColIndA, int blockDim, \
- bsrsm2Info_t info, int *pBufferSizeInBytes
- template <typename scalar_t>
- inline void bsrsm2_bufferSize(CUSPARSE_BSRSM2_BUFFER_ARGTYPES(scalar_t)) {
- TORCH_INTERNAL_ASSERT(
- false,
- "at::cuda::sparse::bsrsm2_bufferSize: not implemented for ",
- typeid(scalar_t).name());
- }
- template <>
- void bsrsm2_bufferSize<float>(CUSPARSE_BSRSM2_BUFFER_ARGTYPES(float));
- template <>
- void bsrsm2_bufferSize<double>(CUSPARSE_BSRSM2_BUFFER_ARGTYPES(double));
- template <>
- void bsrsm2_bufferSize<c10::complex<float>>(
- CUSPARSE_BSRSM2_BUFFER_ARGTYPES(c10::complex<float>));
- template <>
- void bsrsm2_bufferSize<c10::complex<double>>(
- CUSPARSE_BSRSM2_BUFFER_ARGTYPES(c10::complex<double>));
- #define CUSPARSE_BSRSM2_ANALYSIS_ARGTYPES(scalar_t) \
- cusparseHandle_t handle, cusparseDirection_t dirA, \
- cusparseOperation_t transA, cusparseOperation_t transX, int mb, int n, \
- int nnzb, const cusparseMatDescr_t descrA, const scalar_t *bsrValA, \
- const int *bsrRowPtrA, const int *bsrColIndA, int blockDim, \
- bsrsm2Info_t info, cusparseSolvePolicy_t policy, void *pBuffer
- template <typename scalar_t>
- inline void bsrsm2_analysis(CUSPARSE_BSRSM2_ANALYSIS_ARGTYPES(scalar_t)) {
- TORCH_INTERNAL_ASSERT(
- false,
- "at::cuda::sparse::bsrsm2_analysis: not implemented for ",
- typeid(scalar_t).name());
- }
- template <>
- void bsrsm2_analysis<float>(CUSPARSE_BSRSM2_ANALYSIS_ARGTYPES(float));
- template <>
- void bsrsm2_analysis<double>(CUSPARSE_BSRSM2_ANALYSIS_ARGTYPES(double));
- template <>
- void bsrsm2_analysis<c10::complex<float>>(
- CUSPARSE_BSRSM2_ANALYSIS_ARGTYPES(c10::complex<float>));
- template <>
- void bsrsm2_analysis<c10::complex<double>>(
- CUSPARSE_BSRSM2_ANALYSIS_ARGTYPES(c10::complex<double>));
- #define CUSPARSE_BSRSM2_SOLVE_ARGTYPES(scalar_t) \
- cusparseHandle_t handle, cusparseDirection_t dirA, \
- cusparseOperation_t transA, cusparseOperation_t transX, int mb, int n, \
- int nnzb, const scalar_t *alpha, const cusparseMatDescr_t descrA, \
- const scalar_t *bsrValA, const int *bsrRowPtrA, const int *bsrColIndA, \
- int blockDim, bsrsm2Info_t info, const scalar_t *B, int ldb, \
- scalar_t *X, int ldx, cusparseSolvePolicy_t policy, void *pBuffer
- template <typename scalar_t>
- inline void bsrsm2_solve(CUSPARSE_BSRSM2_SOLVE_ARGTYPES(scalar_t)) {
- TORCH_INTERNAL_ASSERT(
- false,
- "at::cuda::sparse::bsrsm2_solve: not implemented for ",
- typeid(scalar_t).name());
- }
- template <>
- void bsrsm2_solve<float>(CUSPARSE_BSRSM2_SOLVE_ARGTYPES(float));
- template <>
- void bsrsm2_solve<double>(CUSPARSE_BSRSM2_SOLVE_ARGTYPES(double));
- template <>
- void bsrsm2_solve<c10::complex<float>>(
- CUSPARSE_BSRSM2_SOLVE_ARGTYPES(c10::complex<float>));
- template <>
- void bsrsm2_solve<c10::complex<double>>(
- CUSPARSE_BSRSM2_SOLVE_ARGTYPES(c10::complex<double>));
- #endif // AT_USE_HIPSPARSE_TRIANGULAR_SOLVE
- } // namespace sparse
- } // namespace cuda
- } // namespace at
|