SparseCsrTensorUtils.h 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312
  1. #pragma once
  2. #include <ATen/SparseCsrTensorImpl.h>
  3. #include <ATen/SparseTensorImpl.h>
  4. #include <ATen/SparseTensorUtils.h>
  5. #include <ATen/core/Tensor.h>
  6. #define AT_DISPATCH_ALL_SPARSE_COMPRESSED_LAYOUTS(LAYOUT, NAME, ...) \
  7. [&] { \
  8. const auto& the_layout = LAYOUT; \
  9. switch (the_layout) { \
  10. case kSparseCsr: \
  11. case kSparseCsc: \
  12. case kSparseBsr: \
  13. case kSparseBsc: \
  14. return __VA_ARGS__(); \
  15. default: \
  16. AT_ERROR( \
  17. NAME, \
  18. " expected sparse compressed tensor layout but got ", \
  19. the_layout); \
  20. } \
  21. }()
  22. #define AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS( \
  23. LAYOUT, NAME, ROW_DIM_ACTION, COLUMN_DIM_ACTION) \
  24. [&]() { \
  25. const auto& the_layout = LAYOUT; \
  26. switch (the_layout) { \
  27. case kSparseCsr: \
  28. case kSparseBsr: \
  29. return (ROW_DIM_ACTION)(); \
  30. case kSparseCsc: \
  31. case kSparseBsc: \
  32. return (COLUMN_DIM_ACTION)(); \
  33. default: \
  34. AT_ERROR( \
  35. NAME, \
  36. " expected sparse compressed tensor layout but got ", \
  37. the_layout); \
  38. } \
  39. }()
  40. #define AT_DISPATCH_PLAIN_SPARSE_COMPRESSED_LAYOUTS( \
  41. LAYOUT, NAME, NO_BLOCK_ACTION, BLOCK_ACTION) \
  42. [&]() { \
  43. const auto& the_layout = LAYOUT; \
  44. switch (the_layout) { \
  45. case kSparseCsr: \
  46. case kSparseCsc: \
  47. return (NO_BLOCK_ACTION)(); \
  48. case kSparseBsr: \
  49. case kSparseBsc: \
  50. return (BLOCK_ACTION)(); \
  51. default: \
  52. AT_ERROR( \
  53. NAME, \
  54. " expected sparse compressed tensor layout but got ", \
  55. the_layout); \
  56. } \
  57. }()
  58. #define AT_DISPATCH_SPARSE_ROW_COMPRESSED_LAYOUTS( \
  59. LAYOUT, NAME, ROW_DIM_ACTION) \
  60. [&]() { \
  61. const auto& the_layout = LAYOUT; \
  62. switch (the_layout) { \
  63. case kSparseCsr: \
  64. case kSparseBsr: \
  65. return (ROW_DIM_ACTION)(); \
  66. default: \
  67. AT_ERROR( \
  68. NAME, \
  69. " expected sparse row compressed tensor layout but got ", \
  70. the_layout); \
  71. } \
  72. }()
  73. #define AT_DISPATCH_SPARSE_COL_COMPRESSED_LAYOUTS( \
  74. LAYOUT, NAME, COL_DIM_ACTION) \
  75. [&]() { \
  76. const auto& the_layout = LAYOUT; \
  77. switch (the_layout) { \
  78. case kSparseCsc: \
  79. case kSparseBsc: \
  80. return (COL_DIM_ACTION)(); \
  81. default: \
  82. AT_ERROR( \
  83. NAME, \
  84. " expected sparse column compressed tensor layout but got ", \
  85. the_layout); \
  86. } \
  87. }()
  88. #define AT_DISPATCH_SPARSE_COMPRESSED_NONBLOCK_LAYOUTS(LAYOUT, NAME, ACTION) \
  89. [&]() { \
  90. const auto& the_layout = LAYOUT; \
  91. switch (the_layout) { \
  92. case kSparseCsr: \
  93. case kSparseCsc: \
  94. return (ACTION)(); \
  95. default: \
  96. AT_ERROR( \
  97. NAME, \
  98. " expected sparse compressed (non-block) tensor layout but got ", \
  99. the_layout); \
  100. } \
  101. }()
  102. #define AT_DISPATCH_SPARSE_COMPRESSED_BLOCK_LAYOUTS(LAYOUT, NAME, ACTION) \
  103. [&]() { \
  104. const auto& the_layout = LAYOUT; \
  105. switch (the_layout) { \
  106. case kSparseBsr: \
  107. case kSparseBsc: \
  108. return (ACTION)(); \
  109. default: \
  110. AT_ERROR( \
  111. NAME, \
  112. " expected sparse compressed block tensor layout but got ", \
  113. the_layout); \
  114. } \
  115. }()
  116. #define AT_DISPATCH_SPARSE_VALUE_TYPES(TYPE, NAME, ...) \
  117. AT_DISPATCH_SWITCH( \
  118. TYPE, \
  119. NAME, \
  120. AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND4( \
  121. kComplexHalf, kHalf, kBool, kBFloat16, __VA_ARGS__))
  122. namespace at {
  123. namespace sparse_csr {
  124. using SparseCsrTensor = Tensor;
  125. inline bool is_sparse_compressed(const Layout& layout) {
  126. switch (layout) {
  127. case kSparseCsr:
  128. case kSparseCsc:
  129. case kSparseBsr:
  130. case kSparseBsc:
  131. return true;
  132. default:;
  133. }
  134. return false;
  135. }
  136. inline bool is_sparse_compressed(const Tensor& self) {
  137. return is_sparse_compressed(self.layout());
  138. }
  139. inline SparseCsrTensorImpl* get_sparse_csr_impl(const SparseCsrTensor& self) {
  140. AT_DISPATCH_ALL_SPARSE_COMPRESSED_LAYOUTS(
  141. self.layout(), "get_sparse_csr_impl", [&] {});
  142. return static_cast<SparseCsrTensorImpl*>(self.unsafeGetTensorImpl());
  143. }
  144. inline std::string layoutToString(
  145. Layout layout,
  146. bool upper = false,
  147. bool lower = false) {
  148. switch (layout) {
  149. case kSparseCsr:
  150. return (upper ? "CSR" : (lower ? "csr" : "Csr"));
  151. case kSparseCsc:
  152. return (upper ? "CSC" : (lower ? "csc" : "Csc"));
  153. case kSparseBsr:
  154. return (upper ? "BSR" : (lower ? "bsr" : "Bsr"));
  155. case kSparseBsc:
  156. return (upper ? "BSC" : (lower ? "bsc" : "Bsc"));
  157. default:
  158. TORCH_CHECK(false, "Not a sparse compressed layout:", layout);
  159. return "";
  160. }
  161. }
  162. inline bool isCompressedRow(Layout layout) {
  163. return AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(
  164. layout, "isCompressedRow", [&] { return true; }, [&] { return false; });
  165. }
  166. inline bool isCompressedColumn(Layout layout) {
  167. return AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(
  168. layout,
  169. "isCompressedColumn",
  170. [&] { return false; },
  171. [&] { return true; });
  172. }
  173. inline std::string compressedIndicesName(Layout layout) {
  174. return AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(
  175. layout,
  176. "compressedIndicesName",
  177. [&] { return "crow_indices"; },
  178. [&] { return "ccol_indices"; });
  179. }
  180. inline std::string plainIndicesName(Layout layout) {
  181. return AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(
  182. layout,
  183. "plainIndicesName",
  184. [&] { return "col_indices"; },
  185. [&] { return "row_indices"; });
  186. }
  187. inline std::string compressedDimName(Layout layout) {
  188. switch (layout) {
  189. case kSparseCsr:
  190. return "row";
  191. case kSparseCsc:
  192. return "column";
  193. case kSparseBsr:
  194. return "row block";
  195. case kSparseBsc:
  196. return "column block";
  197. default:
  198. TORCH_CHECK(false, "Not a sparse compressed layout:", layout);
  199. return "";
  200. }
  201. }
  202. inline std::string plainDimName(Layout layout) {
  203. switch (layout) {
  204. case kSparseCsr:
  205. return "column";
  206. case kSparseCsc:
  207. return "row";
  208. case kSparseBsr:
  209. return "column block";
  210. case kSparseBsc:
  211. return "row block";
  212. default:
  213. TORCH_CHECK(false, "Not a sparse compressed layout:", layout);
  214. return "";
  215. }
  216. }
  217. inline int rowDimension(Layout layout, IntArrayRef size) {
  218. return size.size() - (isCompressedRow(layout) ? 2 : 1);
  219. }
  220. inline int columnDimension(Layout layout, IntArrayRef size) {
  221. return size.size() - (isCompressedColumn(layout) ? 2 : 1);
  222. }
  223. inline int compressedDimension(
  224. Layout layout,
  225. IntArrayRef size,
  226. size_t dense_ndim = 0) {
  227. return size.size() - dense_ndim - (isCompressedRow(layout) ? 2 : 1);
  228. }
  229. inline int plainDimension(
  230. Layout layout,
  231. IntArrayRef size,
  232. size_t dense_ndim = 0) {
  233. return size.size() - dense_ndim - (isCompressedRow(layout) ? 1 : 2);
  234. }
  235. inline int64_t numBatchDimensions(Tensor const& self) {
  236. return AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(
  237. self.layout(),
  238. "numBatchDimensions",
  239. [&self] { return self.crow_indices().dim() - 1; },
  240. [&self] { return self.ccol_indices().dim() - 1; });
  241. }
  242. inline std::pair<Tensor, Tensor> getCompressedPlainIndices(Tensor const& self) {
  243. return AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(
  244. self.layout(),
  245. "getCompressedPlainIndices",
  246. [&self] {
  247. return std::make_pair(self.crow_indices(), self.col_indices());
  248. },
  249. [&self] {
  250. return std::make_pair(self.ccol_indices(), self.row_indices());
  251. });
  252. }
  253. inline Layout flip_compressed_layout(Layout layout) {
  254. switch (layout) {
  255. case kSparseCsr:
  256. return kSparseCsc;
  257. case kSparseCsc:
  258. return kSparseCsr;
  259. case kSparseBsr:
  260. return kSparseBsc;
  261. case kSparseBsc:
  262. return kSparseBsr;
  263. default:
  264. TORCH_CHECK(false, "Not a sparse compressed layout:", layout);
  265. return kSparseCsr;
  266. }
  267. }
  268. inline DimVector getBlockSize(Tensor const& self) {
  269. int64_t n_batch = numBatchDimensions(self);
  270. return at::DimVector(self.values().sizes().slice(n_batch + 1, 2));
  271. }
  272. inline at::OptionalArray<at::SymInt> getSymIntBlockSize(Tensor const& self) {
  273. if (self.layout() == at::kSparseBsr || self.layout() == at::kSparseBsc) {
  274. int64_t n_batch = numBatchDimensions(self);
  275. return self.values().sym_sizes().slice(n_batch + 1, 2).vec();
  276. } else {
  277. return {};
  278. }
  279. }
  280. } // namespace sparse_csr
  281. } // namespace at