123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266 |
- #pragma once
- #include <ATen/Tensor.h>
- #include <ATen/cuda/CUDAContext.h>
- #include <ATen/cuda/CUDASparse.h>
- #include <c10/core/ScalarType.h>
- #if defined(USE_ROCM)
- #include <type_traits>
- #endif
- namespace at {
- namespace cuda {
- namespace sparse {
- template <typename T, cusparseStatus_t (*destructor)(T*)>
- struct CuSparseDescriptorDeleter {
- void operator()(T* x) {
- if (x != nullptr) {
- TORCH_CUDASPARSE_CHECK(destructor(x));
- }
- }
- };
- template <typename T, cusparseStatus_t (*destructor)(T*)>
- class CuSparseDescriptor {
- public:
- T* descriptor() const {
- return descriptor_.get();
- }
- T* descriptor() {
- return descriptor_.get();
- }
- protected:
- std::unique_ptr<T, CuSparseDescriptorDeleter<T, destructor>> descriptor_;
- };
- #if AT_USE_CUSPARSE_CONST_DESCRIPTORS()
- template <typename T, cusparseStatus_t (*destructor)(const T*)>
- struct ConstCuSparseDescriptorDeleter {
- void operator()(T* x) {
- if (x != nullptr) {
- TORCH_CUDASPARSE_CHECK(destructor(x));
- }
- }
- };
- template <typename T, cusparseStatus_t (*destructor)(const T*)>
- class ConstCuSparseDescriptor {
- public:
- T* descriptor() const {
- return descriptor_.get();
- }
- T* descriptor() {
- return descriptor_.get();
- }
- protected:
- std::unique_ptr<T, ConstCuSparseDescriptorDeleter<T, destructor>> descriptor_;
- };
- #endif // AT_USE_CUSPARSE_CONST_DESCRIPTORS
- #if defined(USE_ROCM)
- // hipSPARSE doesn't define this
- using cusparseMatDescr = std::remove_pointer<cusparseMatDescr_t>::type;
- using cusparseDnMatDescr = std::remove_pointer<cusparseDnMatDescr_t>::type;
- using cusparseDnVecDescr = std::remove_pointer<cusparseDnVecDescr_t>::type;
- using cusparseSpMatDescr = std::remove_pointer<cusparseSpMatDescr_t>::type;
- using cusparseSpMatDescr = std::remove_pointer<cusparseSpMatDescr_t>::type;
- using cusparseSpGEMMDescr = std::remove_pointer<cusparseSpGEMMDescr_t>::type;
- #if AT_USE_HIPSPARSE_TRIANGULAR_SOLVE()
- using bsrsv2Info = std::remove_pointer<bsrsv2Info_t>::type;
- using bsrsm2Info = std::remove_pointer<bsrsm2Info_t>::type;
- #endif
- #endif
- class TORCH_CUDA_CPP_API CuSparseMatDescriptor
- : public CuSparseDescriptor<cusparseMatDescr, &cusparseDestroyMatDescr> {
- public:
- CuSparseMatDescriptor() {
- cusparseMatDescr_t raw_descriptor;
- TORCH_CUDASPARSE_CHECK(cusparseCreateMatDescr(&raw_descriptor));
- descriptor_.reset(raw_descriptor);
- }
- CuSparseMatDescriptor(bool upper, bool unit) {
- cusparseFillMode_t fill_mode =
- upper ? CUSPARSE_FILL_MODE_UPPER : CUSPARSE_FILL_MODE_LOWER;
- cusparseDiagType_t diag_type =
- unit ? CUSPARSE_DIAG_TYPE_UNIT : CUSPARSE_DIAG_TYPE_NON_UNIT;
- cusparseMatDescr_t raw_descriptor;
- TORCH_CUDASPARSE_CHECK(cusparseCreateMatDescr(&raw_descriptor));
- TORCH_CUDASPARSE_CHECK(cusparseSetMatFillMode(raw_descriptor, fill_mode));
- TORCH_CUDASPARSE_CHECK(cusparseSetMatDiagType(raw_descriptor, diag_type));
- descriptor_.reset(raw_descriptor);
- }
- };
- #if AT_USE_HIPSPARSE_TRIANGULAR_SOLVE()
- class TORCH_CUDA_CPP_API CuSparseBsrsv2Info
- : public CuSparseDescriptor<bsrsv2Info, &cusparseDestroyBsrsv2Info> {
- public:
- CuSparseBsrsv2Info() {
- bsrsv2Info_t raw_descriptor;
- TORCH_CUDASPARSE_CHECK(cusparseCreateBsrsv2Info(&raw_descriptor));
- descriptor_.reset(raw_descriptor);
- }
- };
- class TORCH_CUDA_CPP_API CuSparseBsrsm2Info
- : public CuSparseDescriptor<bsrsm2Info, &cusparseDestroyBsrsm2Info> {
- public:
- CuSparseBsrsm2Info() {
- bsrsm2Info_t raw_descriptor;
- TORCH_CUDASPARSE_CHECK(cusparseCreateBsrsm2Info(&raw_descriptor));
- descriptor_.reset(raw_descriptor);
- }
- };
- #endif // AT_USE_HIPSPARSE_TRIANGULAR_SOLVE
- #if AT_USE_CUSPARSE_GENERIC_API() || AT_USE_HIPSPARSE_GENERIC_API()
- cusparseIndexType_t getCuSparseIndexType(const c10::ScalarType& scalar_type);
- #if AT_USE_HIPSPARSE_GENERIC_52_API() || \
- (AT_USE_CUSPARSE_GENERIC_API() && AT_USE_CUSPARSE_NON_CONST_DESCRIPTORS())
- class TORCH_CUDA_CPP_API CuSparseDnMatDescriptor
- : public CuSparseDescriptor<cusparseDnMatDescr, &cusparseDestroyDnMat> {
- public:
- explicit CuSparseDnMatDescriptor(const Tensor& input, int64_t batch_offset = -1);
- };
- class TORCH_CUDA_CPP_API CuSparseDnVecDescriptor
- : public CuSparseDescriptor<cusparseDnVecDescr, &cusparseDestroyDnVec> {
- public:
- explicit CuSparseDnVecDescriptor(const Tensor& input);
- };
- class TORCH_CUDA_CPP_API CuSparseSpMatDescriptor
- : public CuSparseDescriptor<cusparseSpMatDescr, &cusparseDestroySpMat> {};
- //AT_USE_HIPSPARSE_GENERIC_52_API() || (AT_USE_CUSPARSE_GENERIC_API() && AT_USE_CUSPARSE_NON_CONST_DESCRIPTORS())
- #elif AT_USE_CUSPARSE_CONST_DESCRIPTORS()
- class TORCH_CUDA_CPP_API CuSparseDnMatDescriptor
- : public ConstCuSparseDescriptor<
- cusparseDnMatDescr,
- &cusparseDestroyDnMat> {
- public:
- explicit CuSparseDnMatDescriptor(
- const Tensor& input,
- int64_t batch_offset = -1);
- };
- class TORCH_CUDA_CPP_API CuSparseDnVecDescriptor
- : public ConstCuSparseDescriptor<
- cusparseDnVecDescr,
- &cusparseDestroyDnVec> {
- public:
- explicit CuSparseDnVecDescriptor(const Tensor& input);
- };
- class TORCH_CUDA_CPP_API CuSparseSpMatDescriptor
- : public ConstCuSparseDescriptor<
- cusparseSpMatDescr,
- &cusparseDestroySpMat> {};
- #endif // AT_USE_CUSPARSE_CONST_DESCRIPTORS()
- class TORCH_CUDA_CPP_API CuSparseSpMatCsrDescriptor
- : public CuSparseSpMatDescriptor {
- public:
- explicit CuSparseSpMatCsrDescriptor(const Tensor& input, int64_t batch_offset = -1);
- std::tuple<int64_t, int64_t, int64_t> get_size() {
- int64_t rows, cols, nnz;
- TORCH_CUDASPARSE_CHECK(cusparseSpMatGetSize(
- this->descriptor(),
- &rows,
- &cols,
- &nnz));
- return std::make_tuple(rows, cols, nnz);
- }
- void set_tensor(const Tensor& input) {
- auto crow_indices = input.crow_indices();
- auto col_indices = input.col_indices();
- auto values = input.values();
- TORCH_INTERNAL_ASSERT_DEBUG_ONLY(crow_indices.is_contiguous());
- TORCH_INTERNAL_ASSERT_DEBUG_ONLY(col_indices.is_contiguous());
- TORCH_INTERNAL_ASSERT_DEBUG_ONLY(values.is_contiguous());
- TORCH_CUDASPARSE_CHECK(cusparseCsrSetPointers(
- this->descriptor(),
- crow_indices.data_ptr(),
- col_indices.data_ptr(),
- values.data_ptr()));
- }
- #if AT_USE_CUSPARSE_GENERIC_SPSV()
- void set_mat_fill_mode(bool upper) {
- cusparseFillMode_t fill_mode =
- upper ? CUSPARSE_FILL_MODE_UPPER : CUSPARSE_FILL_MODE_LOWER;
- TORCH_CUDASPARSE_CHECK(cusparseSpMatSetAttribute(
- this->descriptor(),
- CUSPARSE_SPMAT_FILL_MODE,
- &fill_mode,
- sizeof(fill_mode)));
- }
- void set_mat_diag_type(bool unit) {
- cusparseDiagType_t diag_type =
- unit ? CUSPARSE_DIAG_TYPE_UNIT : CUSPARSE_DIAG_TYPE_NON_UNIT;
- TORCH_CUDASPARSE_CHECK(cusparseSpMatSetAttribute(
- this->descriptor(),
- CUSPARSE_SPMAT_DIAG_TYPE,
- &diag_type,
- sizeof(diag_type)));
- }
- #endif
- };
- #if AT_USE_CUSPARSE_GENERIC_SPSV()
- class TORCH_CUDA_CPP_API CuSparseSpSVDescriptor
- : public CuSparseDescriptor<cusparseSpSVDescr, &cusparseSpSV_destroyDescr> {
- public:
- CuSparseSpSVDescriptor() {
- cusparseSpSVDescr_t raw_descriptor;
- TORCH_CUDASPARSE_CHECK(cusparseSpSV_createDescr(&raw_descriptor));
- descriptor_.reset(raw_descriptor);
- }
- };
- #endif
- #if AT_USE_CUSPARSE_GENERIC_SPSM()
- class TORCH_CUDA_CPP_API CuSparseSpSMDescriptor
- : public CuSparseDescriptor<cusparseSpSMDescr, &cusparseSpSM_destroyDescr> {
- public:
- CuSparseSpSMDescriptor() {
- cusparseSpSMDescr_t raw_descriptor;
- TORCH_CUDASPARSE_CHECK(cusparseSpSM_createDescr(&raw_descriptor));
- descriptor_.reset(raw_descriptor);
- }
- };
- #endif
- #if (defined(USE_ROCM) && ROCM_VERSION >= 50200) || !defined(USE_ROCM)
- class TORCH_CUDA_CPP_API CuSparseSpGEMMDescriptor
- : public CuSparseDescriptor<cusparseSpGEMMDescr, &cusparseSpGEMM_destroyDescr> {
- public:
- CuSparseSpGEMMDescriptor() {
- cusparseSpGEMMDescr_t raw_descriptor;
- TORCH_CUDASPARSE_CHECK(cusparseSpGEMM_createDescr(&raw_descriptor));
- descriptor_.reset(raw_descriptor);
- }
- };
- #endif
- #endif // AT_USE_CUSPARSE_GENERIC_API() || AT_USE_HIPSPARSE_GENERIC_API()
- } // namespace sparse
- } // namespace cuda
- } // namespace at
|