SparseTensorImpl.h 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400
  1. #pragma once
  2. #include <ATen/Tensor.h>
  3. #include <c10/core/TensorImpl.h>
  4. #include <c10/util/Exception.h>
  5. #include <c10/util/irange.h>
  6. #ifndef AT_PER_OPERATOR_HEADERS
  7. #include <ATen/Functions.h>
  8. #else
  9. #include <ATen/ops/empty.h>
  10. #include <ATen/ops/resize.h>
  11. #endif
  12. namespace at {
  13. struct TORCH_API SparseTensorImpl : public TensorImpl {
  14. // Stored in COO format, indices + values.
  15. // INVARIANTS:
  16. // sparse_dim: range [0, len(shape)]; sparse_dim + dense_dim = len(shape)
  17. // dense_dim : range [0, len(shape)]; sparse_dim + dense_dim = len(shape)
  18. // _indices.shape: dimensionality: 2, shape: (sparse_dim, nnz)
  19. // _values.shape: dimensionality: 1 + dense_dim. shape: (nnz,
  20. // shape[sparse_dim:])
  21. int64_t sparse_dim_ = 0; // number of sparse dimensions
  22. int64_t dense_dim_ = 0; // number of dense dimensions
  23. Tensor indices_; // always a LongTensor
  24. Tensor values_;
  25. // A sparse tensor is 'coalesced' if every index occurs at most once in
  26. // the indices tensor, and the indices are in sorted order. (This means
  27. // that it is very easy to convert a coalesced tensor to CSR format: you
  28. // need only compute CSR format indices.)
  29. //
  30. // Most math operations can only be performed on coalesced sparse tensors,
  31. // because many algorithms proceed by merging two sorted lists (of indices).
  32. bool coalesced_ = false;
  33. // compute_numel with integer multiplication overflow check, see gh-57542
  34. void refresh_numel() {
  35. TensorImpl::safe_refresh_numel();
  36. }
  37. public:
  38. // Public for now...
  39. explicit SparseTensorImpl(at::DispatchKeySet, const caffe2::TypeMeta);
  40. void release_resources() override;
  41. int64_t nnz() const {
  42. return values_.size(0);
  43. }
  44. c10::SymInt sym_nnz() const {
  45. return values_.sym_size(0);
  46. }
  47. int64_t sparse_dim() const {
  48. return sparse_dim_;
  49. }
  50. int64_t dense_dim() const {
  51. return dense_dim_;
  52. }
  53. bool coalesced() const {
  54. return coalesced_;
  55. }
  56. Tensor indices() const {
  57. return indices_;
  58. }
  59. Tensor values() const {
  60. return values_;
  61. }
  62. void set_size(int64_t dim, int64_t new_size) override;
  63. void set_stride(int64_t dim, int64_t new_stride) override;
  64. void set_storage_offset(int64_t storage_offset) override;
  65. #ifdef DEBUG
  66. bool has_storage() const override;
  67. #endif
  68. // WARNING: This function does NOT preserve invariants of sparse_dim/dense_dim
  69. // with respect to indices and values
  70. void raw_resize_(int64_t sparse_dim, int64_t dense_dim, IntArrayRef size) {
  71. TORCH_CHECK(
  72. allow_tensor_metadata_change(),
  73. "raw_resize_ ",
  74. err_msg_tensor_metadata_change_not_allowed);
  75. TORCH_CHECK(
  76. !has_symbolic_sizes_strides_,
  77. "raw_resize_ called on tensor with symbolic shape")
  78. set_sizes_and_strides(size, std::vector<int64_t>(size.size()));
  79. sparse_dim_ = sparse_dim;
  80. dense_dim_ = dense_dim;
  81. refresh_numel();
  82. }
  83. // NOTE: This function preserves invariants of sparse_dim/dense_dim with
  84. // respect to indices and values.
  85. //
  86. // NOTE: This function supports the following cases:
  87. // 1. When we keep the number of dense dimensions unchanged, and NOT shrinking
  88. // the size of any of the dense dimensions.
  89. // 2. When we keep the number of sparse dimensions unchanged, and NOT
  90. // shrinking the size of any of the sparse dimensions.
  91. // 3. When the sparse tensor has zero nnz, in which case we are free to change
  92. // the shapes of both its sparse and dense dimensions.
  93. //
  94. // This function DOESN'T support (and will throw an error) the following
  95. // cases:
  96. // 1. When we attempt to change the number of sparse dimensions on a non-empty
  97. // sparse tensor (such an operation will invalidate the indices stored).
  98. // 2. When we attempt to change the number of dense dimensions on a non-empty
  99. // sparse tensor (such an operation will behave differently from an equivalent
  100. // dense tensor's resize method, and for API consistency we don't support it).
  101. // 3. When we attempt to shrink the size of any of the dense dimensions on a
  102. // non-empty sparse tensor (such an operation will behave differently from an
  103. // equivalent dense tensor's resize method, and for API consistency we don't
  104. // support it).
  105. // 4. When we attempt to shrink the size of any of the sparse dimensions on a
  106. // non-empty sparse tensor (this could make some of the stored indices
  107. // out-of-bound and thus unsafe).
  108. template <typename T>
  109. void _resize_(int64_t sparse_dim, int64_t dense_dim, ArrayRef<T> size) {
  110. TORCH_CHECK(
  111. allow_tensor_metadata_change(),
  112. "resize_ ",
  113. err_msg_tensor_metadata_change_not_allowed);
  114. TORCH_CHECK(
  115. !has_symbolic_sizes_strides_,
  116. "resize_ called on tensor with symbolic shape")
  117. TORCH_CHECK(
  118. sparse_dim + dense_dim == static_cast<int64_t>(size.size()),
  119. "number of dimensions must be sparse_dim (",
  120. sparse_dim,
  121. ") + dense_dim (",
  122. dense_dim,
  123. "), but got ",
  124. size.size());
  125. if (nnz() > 0) {
  126. auto alt_options_msg =
  127. "You could try the following options:\n\
  128. 1. If you need an empty sparse tensor of this size, call `x = torch.sparse_coo_tensor(size)`.\n\
  129. 2. If you need to resize this tensor, you have the following options:\n\
  130. 1. For both sparse and dense dimensions, keep the number of them constant and the size of them non-shrinking, and then try the same call again.\n\
  131. 2. Or, create a new sparse tensor with the correct indices and values from this sparse tensor.";
  132. TORCH_CHECK(
  133. sparse_dim == sparse_dim_,
  134. "changing the number of sparse dimensions (from ",
  135. sparse_dim_,
  136. " to ",
  137. sparse_dim,
  138. ") on a non-empty sparse tensor is not supported.\n",
  139. alt_options_msg);
  140. TORCH_CHECK(
  141. dense_dim == dense_dim_,
  142. "changing the number of dense dimensions (from ",
  143. dense_dim_,
  144. " to ",
  145. dense_dim,
  146. ") on a non-empty sparse tensor is not supported.\n",
  147. alt_options_msg);
  148. bool shrinking_sparse_dims = false;
  149. bool shrinking_dense_dim = false;
  150. auto sparse_size_original = generic_sizes<T>().slice(0, sparse_dim);
  151. auto sparse_size_new = size.slice(0, sparse_dim);
  152. for (const auto i : c10::irange(sparse_dim)) {
  153. if (sparse_size_new[i] < sparse_size_original[i]) {
  154. shrinking_sparse_dims = true;
  155. break;
  156. }
  157. }
  158. auto dense_size_original = generic_sizes<T>().slice(sparse_dim);
  159. auto dense_size_new = size.slice(sparse_dim);
  160. for (const auto i : c10::irange(dense_dim)) {
  161. if (dense_size_new[i] < dense_size_original[i]) {
  162. shrinking_dense_dim = true;
  163. break;
  164. }
  165. }
  166. TORCH_CHECK(
  167. !shrinking_sparse_dims,
  168. "shrinking the size of sparse dimensions (from ",
  169. sparse_size_original,
  170. " to ",
  171. sparse_size_new,
  172. ") on a non-empty sparse tensor is not supported.\n",
  173. alt_options_msg);
  174. TORCH_CHECK(
  175. !shrinking_dense_dim,
  176. "shrinking the size of dense dimensions (from ",
  177. dense_size_original,
  178. " to ",
  179. dense_size_new,
  180. ") on a non-empty sparse tensor is not supported.\n",
  181. alt_options_msg);
  182. }
  183. auto sizes_and_strides = generic_sizes<T>();
  184. const bool size_equals_sizes = std::equal(
  185. size.begin(),
  186. size.end(),
  187. sizes_and_strides.begin(),
  188. sizes_and_strides.end());
  189. if ((!size_equals_sizes) || (sparse_dim != sparse_dim_) ||
  190. (dense_dim != dense_dim_)) {
  191. auto nnz = at::symint::sizes<T>(values())[0];
  192. std::vector<T> values_size = {nnz};
  193. auto dense_size = size.slice(sparse_dim);
  194. values_size.insert(
  195. values_size.end(), dense_size.begin(), dense_size.end());
  196. at::symint::resize_<T>(values_, values_size);
  197. at::symint::resize_<T>(indices_, {T(sparse_dim), nnz});
  198. }
  199. if (!size_equals_sizes) {
  200. set_sizes_and_strides(size, std::vector<T>(size.size()));
  201. }
  202. sparse_dim_ = sparse_dim;
  203. dense_dim_ = dense_dim;
  204. refresh_numel();
  205. }
  206. void resize_(int64_t sparse_dim, int64_t dense_dim, ArrayRef<int64_t> size) {
  207. return _resize_(sparse_dim, dense_dim, size);
  208. }
  209. void resize_(
  210. int64_t sparse_dim,
  211. int64_t dense_dim,
  212. ArrayRef<c10::SymInt> size) {
  213. return _resize_(sparse_dim, dense_dim, size);
  214. }
  215. // NOTE: this function will resize the sparse tensor and also set `indices`
  216. // and `values` to empty.
  217. void resize_and_clear_(
  218. int64_t sparse_dim,
  219. int64_t dense_dim,
  220. IntArrayRef size) {
  221. TORCH_CHECK(
  222. allow_tensor_metadata_change(),
  223. "resize_and_clear_ ",
  224. err_msg_tensor_metadata_change_not_allowed);
  225. TORCH_CHECK(
  226. !has_symbolic_sizes_strides_,
  227. "resize_and_clear_ called on tensor with symbolic shape")
  228. TORCH_CHECK(
  229. sparse_dim + dense_dim == static_cast<int64_t>(size.size()),
  230. "number of dimensions must be sparse_dim (",
  231. sparse_dim,
  232. ") + dense_dim (",
  233. dense_dim,
  234. "), but got ",
  235. size.size());
  236. set_sizes_and_strides(size, std::vector<int64_t>(size.size()));
  237. sparse_dim_ = sparse_dim;
  238. dense_dim_ = dense_dim;
  239. auto empty_indices = at::empty({sparse_dim, 0}, indices().options());
  240. std::vector<int64_t> values_size = {0};
  241. auto dense_size = sizes().slice(sparse_dim);
  242. values_size.insert(values_size.end(), dense_size.begin(), dense_size.end());
  243. auto empty_values = at::empty(values_size, values().options());
  244. set_indices_and_values_unsafe(empty_indices, empty_values);
  245. refresh_numel();
  246. }
  247. void set_coalesced(bool coalesced) {
  248. TORCH_CHECK(
  249. allow_tensor_metadata_change(),
  250. "set_coalesced ",
  251. err_msg_tensor_metadata_change_not_allowed);
  252. coalesced_ = coalesced;
  253. }
  254. // NOTE: this function is only used internally and not exposed to Python
  255. // frontend
  256. void set_nnz_and_narrow(int64_t new_nnz) {
  257. TORCH_CHECK(
  258. allow_tensor_metadata_change(),
  259. "set_nnz_and_narrow ",
  260. err_msg_tensor_metadata_change_not_allowed);
  261. AT_ASSERT(new_nnz <= nnz());
  262. indices_ = indices_.narrow(1, 0, new_nnz);
  263. values_ = values_.narrow(0, 0, new_nnz);
  264. if (new_nnz < 2) {
  265. coalesced_ = true;
  266. }
  267. }
  268. // Takes indices and values and directly puts them into the sparse tensor, no
  269. // copy. NOTE: this function is unsafe because it doesn't check whether any
  270. // indices are out of boundaries of `sizes`, so it should ONLY be used where
  271. // we know that the indices are guaranteed to be within bounds. This used to
  272. // be called THSTensor_(_move) NB: This used to be able to avoid a refcount
  273. // bump, but I was too lazy to make it happen
  274. void set_indices_and_values_unsafe(
  275. const Tensor& indices,
  276. const Tensor& values);
  277. /**
  278. * Return a TensorImpl that is a shallow-copy of this TensorImpl.
  279. *
  280. * For usage of `version_counter` and `allow_tensor_metadata_change`,
  281. * see NOTE [ TensorImpl Shallow-Copying ].
  282. */
  283. c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
  284. const c10::VariableVersion& version_counter,
  285. bool allow_tensor_metadata_change) const override {
  286. auto impl = c10::make_intrusive<SparseTensorImpl>(key_set(), dtype());
  287. copy_tensor_metadata(
  288. /*src_impl=*/this,
  289. /*dest_impl=*/impl.get(),
  290. /*version_counter=*/version_counter,
  291. /*allow_tensor_metadata_change=*/allow_tensor_metadata_change);
  292. impl->refresh_numel();
  293. return impl;
  294. }
  295. /**
  296. * Return a TensorImpl that is a shallow-copy of this TensorImpl.
  297. *
  298. * For usage of `version_counter` and `allow_tensor_metadata_change`,
  299. * see NOTE [ TensorImpl Shallow-Copying ].
  300. */
  301. c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
  302. c10::VariableVersion&& version_counter,
  303. bool allow_tensor_metadata_change) const override {
  304. auto impl = c10::make_intrusive<SparseTensorImpl>(key_set(), dtype());
  305. copy_tensor_metadata(
  306. /*src_impl=*/this,
  307. /*dest_impl=*/impl.get(),
  308. /*version_counter=*/std::move(version_counter),
  309. /*allow_tensor_metadata_change=*/allow_tensor_metadata_change);
  310. impl->refresh_numel();
  311. return impl;
  312. }
  313. /**
  314. * Shallow-copies data from another TensorImpl into this TensorImpl.
  315. *
  316. * For why this function doesn't check this TensorImpl's
  317. * `allow_tensor_metadata_change_`, see NOTE [ TensorImpl Shallow-Copying ].
  318. */
  319. void shallow_copy_from(const c10::intrusive_ptr<TensorImpl>& impl) override {
  320. AT_ASSERT(has_compatible_shallow_copy_type(impl->key_set()));
  321. auto sparse_impl = static_cast<const SparseTensorImpl*>(impl.get());
  322. copy_tensor_metadata(
  323. /*src_impl=*/sparse_impl,
  324. /*dest_impl=*/this,
  325. /*version_counter=*/version_counter(),
  326. /*allow_tensor_metadata_change=*/allow_tensor_metadata_change());
  327. refresh_numel();
  328. }
  329. private:
  330. explicit SparseTensorImpl(
  331. at::DispatchKeySet,
  332. const caffe2::TypeMeta,
  333. at::Tensor indices,
  334. at::Tensor values);
  335. /**
  336. * Copy the tensor metadata fields (e.g. sizes / strides / storage pointer /
  337. * storage_offset) from one TensorImpl to another TensorImpl.
  338. *
  339. * For usage of `version_counter` and `allow_tensor_metadata_change`, see NOTE
  340. * [ TensorImpl Shallow-Copying ].
  341. */
  342. static void copy_tensor_metadata(
  343. const SparseTensorImpl* src_sparse_impl,
  344. SparseTensorImpl* dest_sparse_impl,
  345. const c10::VariableVersion& version_counter,
  346. bool allow_tensor_metadata_change) {
  347. TensorImpl::copy_tensor_metadata(
  348. src_sparse_impl,
  349. dest_sparse_impl,
  350. version_counter,
  351. allow_tensor_metadata_change);
  352. // Sparse-specific fields
  353. dest_sparse_impl->sparse_dim_ = src_sparse_impl->sparse_dim();
  354. dest_sparse_impl->dense_dim_ = src_sparse_impl->dense_dim();
  355. dest_sparse_impl->indices_ = src_sparse_impl->indices();
  356. dest_sparse_impl->values_ = src_sparse_impl->values();
  357. dest_sparse_impl->coalesced_ = src_sparse_impl->coalesced();
  358. }
  359. const char* tensorimpl_type_name() const override;
  360. };
  361. } // namespace at