123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312 |
- #pragma once
- #include <ATen/SparseCsrTensorImpl.h>
- #include <ATen/SparseTensorImpl.h>
- #include <ATen/SparseTensorUtils.h>
- #include <ATen/core/Tensor.h>
- #define AT_DISPATCH_ALL_SPARSE_COMPRESSED_LAYOUTS(LAYOUT, NAME, ...) \
- [&] { \
- const auto& the_layout = LAYOUT; \
- switch (the_layout) { \
- case kSparseCsr: \
- case kSparseCsc: \
- case kSparseBsr: \
- case kSparseBsc: \
- return __VA_ARGS__(); \
- default: \
- AT_ERROR( \
- NAME, \
- " expected sparse compressed tensor layout but got ", \
- the_layout); \
- } \
- }()
- #define AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS( \
- LAYOUT, NAME, ROW_DIM_ACTION, COLUMN_DIM_ACTION) \
- [&]() { \
- const auto& the_layout = LAYOUT; \
- switch (the_layout) { \
- case kSparseCsr: \
- case kSparseBsr: \
- return (ROW_DIM_ACTION)(); \
- case kSparseCsc: \
- case kSparseBsc: \
- return (COLUMN_DIM_ACTION)(); \
- default: \
- AT_ERROR( \
- NAME, \
- " expected sparse compressed tensor layout but got ", \
- the_layout); \
- } \
- }()
- #define AT_DISPATCH_PLAIN_SPARSE_COMPRESSED_LAYOUTS( \
- LAYOUT, NAME, NO_BLOCK_ACTION, BLOCK_ACTION) \
- [&]() { \
- const auto& the_layout = LAYOUT; \
- switch (the_layout) { \
- case kSparseCsr: \
- case kSparseCsc: \
- return (NO_BLOCK_ACTION)(); \
- case kSparseBsr: \
- case kSparseBsc: \
- return (BLOCK_ACTION)(); \
- default: \
- AT_ERROR( \
- NAME, \
- " expected sparse compressed tensor layout but got ", \
- the_layout); \
- } \
- }()
- #define AT_DISPATCH_SPARSE_ROW_COMPRESSED_LAYOUTS( \
- LAYOUT, NAME, ROW_DIM_ACTION) \
- [&]() { \
- const auto& the_layout = LAYOUT; \
- switch (the_layout) { \
- case kSparseCsr: \
- case kSparseBsr: \
- return (ROW_DIM_ACTION)(); \
- default: \
- AT_ERROR( \
- NAME, \
- " expected sparse row compressed tensor layout but got ", \
- the_layout); \
- } \
- }()
- #define AT_DISPATCH_SPARSE_COL_COMPRESSED_LAYOUTS( \
- LAYOUT, NAME, COL_DIM_ACTION) \
- [&]() { \
- const auto& the_layout = LAYOUT; \
- switch (the_layout) { \
- case kSparseCsc: \
- case kSparseBsc: \
- return (COL_DIM_ACTION)(); \
- default: \
- AT_ERROR( \
- NAME, \
- " expected sparse column compressed tensor layout but got ", \
- the_layout); \
- } \
- }()
- #define AT_DISPATCH_SPARSE_COMPRESSED_NONBLOCK_LAYOUTS(LAYOUT, NAME, ACTION) \
- [&]() { \
- const auto& the_layout = LAYOUT; \
- switch (the_layout) { \
- case kSparseCsr: \
- case kSparseCsc: \
- return (ACTION)(); \
- default: \
- AT_ERROR( \
- NAME, \
- " expected sparse compressed (non-block) tensor layout but got ", \
- the_layout); \
- } \
- }()
- #define AT_DISPATCH_SPARSE_COMPRESSED_BLOCK_LAYOUTS(LAYOUT, NAME, ACTION) \
- [&]() { \
- const auto& the_layout = LAYOUT; \
- switch (the_layout) { \
- case kSparseBsr: \
- case kSparseBsc: \
- return (ACTION)(); \
- default: \
- AT_ERROR( \
- NAME, \
- " expected sparse compressed block tensor layout but got ", \
- the_layout); \
- } \
- }()
- #define AT_DISPATCH_SPARSE_VALUE_TYPES(TYPE, NAME, ...) \
- AT_DISPATCH_SWITCH( \
- TYPE, \
- NAME, \
- AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND4( \
- kComplexHalf, kHalf, kBool, kBFloat16, __VA_ARGS__))
- namespace at {
- namespace sparse_csr {
- using SparseCsrTensor = Tensor;
- inline bool is_sparse_compressed(const Layout& layout) {
- switch (layout) {
- case kSparseCsr:
- case kSparseCsc:
- case kSparseBsr:
- case kSparseBsc:
- return true;
- default:;
- }
- return false;
- }
- inline bool is_sparse_compressed(const Tensor& self) {
- return is_sparse_compressed(self.layout());
- }
- inline SparseCsrTensorImpl* get_sparse_csr_impl(const SparseCsrTensor& self) {
- AT_DISPATCH_ALL_SPARSE_COMPRESSED_LAYOUTS(
- self.layout(), "get_sparse_csr_impl", [&] {});
- return static_cast<SparseCsrTensorImpl*>(self.unsafeGetTensorImpl());
- }
- inline std::string layoutToString(
- Layout layout,
- bool upper = false,
- bool lower = false) {
- switch (layout) {
- case kSparseCsr:
- return (upper ? "CSR" : (lower ? "csr" : "Csr"));
- case kSparseCsc:
- return (upper ? "CSC" : (lower ? "csc" : "Csc"));
- case kSparseBsr:
- return (upper ? "BSR" : (lower ? "bsr" : "Bsr"));
- case kSparseBsc:
- return (upper ? "BSC" : (lower ? "bsc" : "Bsc"));
- default:
- TORCH_CHECK(false, "Not a sparse compressed layout:", layout);
- return "";
- }
- }
- inline bool isCompressedRow(Layout layout) {
- return AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(
- layout, "isCompressedRow", [&] { return true; }, [&] { return false; });
- }
- inline bool isCompressedColumn(Layout layout) {
- return AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(
- layout,
- "isCompressedColumn",
- [&] { return false; },
- [&] { return true; });
- }
- inline std::string compressedIndicesName(Layout layout) {
- return AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(
- layout,
- "compressedIndicesName",
- [&] { return "crow_indices"; },
- [&] { return "ccol_indices"; });
- }
- inline std::string plainIndicesName(Layout layout) {
- return AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(
- layout,
- "plainIndicesName",
- [&] { return "col_indices"; },
- [&] { return "row_indices"; });
- }
- inline std::string compressedDimName(Layout layout) {
- switch (layout) {
- case kSparseCsr:
- return "row";
- case kSparseCsc:
- return "column";
- case kSparseBsr:
- return "row block";
- case kSparseBsc:
- return "column block";
- default:
- TORCH_CHECK(false, "Not a sparse compressed layout:", layout);
- return "";
- }
- }
- inline std::string plainDimName(Layout layout) {
- switch (layout) {
- case kSparseCsr:
- return "column";
- case kSparseCsc:
- return "row";
- case kSparseBsr:
- return "column block";
- case kSparseBsc:
- return "row block";
- default:
- TORCH_CHECK(false, "Not a sparse compressed layout:", layout);
- return "";
- }
- }
- inline int rowDimension(Layout layout, IntArrayRef size) {
- return size.size() - (isCompressedRow(layout) ? 2 : 1);
- }
- inline int columnDimension(Layout layout, IntArrayRef size) {
- return size.size() - (isCompressedColumn(layout) ? 2 : 1);
- }
- inline int compressedDimension(
- Layout layout,
- IntArrayRef size,
- size_t dense_ndim = 0) {
- return size.size() - dense_ndim - (isCompressedRow(layout) ? 2 : 1);
- }
- inline int plainDimension(
- Layout layout,
- IntArrayRef size,
- size_t dense_ndim = 0) {
- return size.size() - dense_ndim - (isCompressedRow(layout) ? 1 : 2);
- }
- inline int64_t numBatchDimensions(Tensor const& self) {
- return AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(
- self.layout(),
- "numBatchDimensions",
- [&self] { return self.crow_indices().dim() - 1; },
- [&self] { return self.ccol_indices().dim() - 1; });
- }
- inline std::pair<Tensor, Tensor> getCompressedPlainIndices(Tensor const& self) {
- return AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(
- self.layout(),
- "getCompressedPlainIndices",
- [&self] {
- return std::make_pair(self.crow_indices(), self.col_indices());
- },
- [&self] {
- return std::make_pair(self.ccol_indices(), self.row_indices());
- });
- }
- inline Layout flip_compressed_layout(Layout layout) {
- switch (layout) {
- case kSparseCsr:
- return kSparseCsc;
- case kSparseCsc:
- return kSparseCsr;
- case kSparseBsr:
- return kSparseBsc;
- case kSparseBsc:
- return kSparseBsr;
- default:
- TORCH_CHECK(false, "Not a sparse compressed layout:", layout);
- return kSparseCsr;
- }
- }
- inline DimVector getBlockSize(Tensor const& self) {
- int64_t n_batch = numBatchDimensions(self);
- return at::DimVector(self.values().sizes().slice(n_batch + 1, 2));
- }
- inline at::OptionalArray<at::SymInt> getSymIntBlockSize(Tensor const& self) {
- if (self.layout() == at::kSparseBsr || self.layout() == at::kSparseBsc) {
- int64_t n_batch = numBatchDimensions(self);
- return self.values().sym_sizes().slice(n_batch + 1, 2).vec();
- } else {
- return {};
- }
- }
- } // namespace sparse_csr
- } // namespace at
|