SparseCsrTensorImpl.h 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179
  1. #pragma once
  2. #include <ATen/Tensor.h>
  3. #include <c10/core/TensorImpl.h>
  4. #include <c10/util/Exception.h>
  5. namespace at {
  6. // Struct implementing a sparse CSR tensor. It uses three 1-D tensors for
  7. // denoting the data: `crow_indices_`, `col_indices_` and `values_`.
  8. // The `crow_indices_` tensor is a integer tensor of shape `(size(0) + 1)`
  9. // that represents the compressed row indices of the CSR tensor. The
  10. // `col_indices_` tensor is an integer tensor of shape `(nnz())`
  11. // that explicitly stores the column indices of each value of the sparse
  12. // tensor. The `values_` tensor can be of any pytorch-supported data type
  13. // and has shape `(nnz())`.
  14. //
  15. // Since the main advantage of the CSR format over the COO format is speed of
  16. // computation, care must be taken to facilitate smooth interfacing of
  17. // these data structures with optimized libraries such as MKL and MAGMA.
  18. // Since the MKL interface for pytorch currently uses indexing with int32
  19. // type, it is important to make sure that the `crow_indices` and `col_indices`
  20. // are of type int32 when calling MKL routines such as SPMM or SPMV.
  21. //
  22. // If not calling MKL, it should be alright to use 64 bit integer tensors
  23. // for indexing.
  24. struct TORCH_API SparseCsrTensorImpl : public TensorImpl {
  25. Tensor crow_indices_;
  26. Tensor col_indices_;
  27. Tensor values_;
  28. Layout layout_;
  29. public:
  30. explicit SparseCsrTensorImpl(
  31. at::DispatchKeySet,
  32. at::Device device,
  33. Layout layout,
  34. const caffe2::TypeMeta);
  35. void resize_(int64_t nnz, IntArrayRef size);
  36. void resize_and_clear_(int64_t sparse_dim, IntArrayRef size);
  37. void resize_as_sparse_compressed_tensor_(const Tensor& src);
  38. void set_member_tensors(
  39. const Tensor& crow_indices,
  40. const Tensor& col_indices,
  41. const Tensor& values,
  42. IntArrayRef size);
  43. const Tensor& compressed_indices() const {
  44. return crow_indices_;
  45. }
  46. const Tensor& plain_indices() const {
  47. return col_indices_;
  48. }
  49. const Tensor& values() const {
  50. return values_;
  51. }
  52. int nnz() {
  53. return col_indices_.size(-1);
  54. }
  55. inline int64_t batch_dim() const noexcept {
  56. return crow_indices_.dim() - 1;
  57. }
  58. inline int64_t sparse_dim() const noexcept {
  59. return 2;
  60. }
  61. inline int64_t dense_dim() const noexcept {
  62. return values_.dim() - batch_dim() - block_dim() - 1;
  63. }
  64. private:
  65. inline int64_t block_dim() const noexcept {
  66. return (layout_ == kSparseBsr || layout_ == kSparseBsc ? 2 : 0);
  67. }
  68. protected:
  69. IntArrayRef strides_custom() const override;
  70. SymIntArrayRef sym_strides_custom() const override;
  71. bool is_contiguous_custom(MemoryFormat) const override;
  72. public:
  73. void set_size(int64_t dim, int64_t new_size) override;
  74. void set_stride(int64_t dim, int64_t new_stride) override;
  75. void set_storage_offset(int64_t storage_offset) override;
  76. Layout layout_impl() const override {
  77. return layout_;
  78. }
  79. void set_layout(Layout layout) {
  80. switch (layout) {
  81. case kSparseCsr:
  82. case kSparseCsc:
  83. case kSparseBsr:
  84. case kSparseBsc:
  85. layout_ = layout;
  86. break;
  87. default:
  88. TORCH_CHECK(false, "unsupported layout ", layout);
  89. }
  90. }
  91. /**
  92. * Return a TensorImpl that is a shallow-copy of this TensorImpl.
  93. *
  94. * For usage of `version_counter` and `allow_tensor_metadata_change`,
  95. * see NOTE [ TensorImpl Shallow-Copying ].
  96. */
  97. c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
  98. const c10::VariableVersion& version_counter,
  99. bool allow_tensor_metadata_change) const override {
  100. auto impl = c10::make_intrusive<SparseCsrTensorImpl>(
  101. key_set(), device(), layout_impl(), dtype());
  102. copy_tensor_metadata(
  103. /*src_impl=*/this,
  104. /*dest_impl=*/impl.get(),
  105. /*version_counter=*/version_counter,
  106. /*allow_tensor_metadata_change=*/allow_tensor_metadata_change);
  107. impl->refresh_numel();
  108. return impl;
  109. }
  110. /**
  111. * Return a TensorImpl that is a shallow-copy of this TensorImpl.
  112. *
  113. * For usage of `version_counter` and `allow_tensor_metadata_change`,
  114. * see NOTE [ TensorImpl Shallow-Copying ].
  115. */
  116. c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
  117. c10::VariableVersion&& version_counter,
  118. bool allow_tensor_metadata_change) const override {
  119. auto impl = c10::make_intrusive<SparseCsrTensorImpl>(
  120. key_set(), device(), layout_impl(), dtype());
  121. copy_tensor_metadata(
  122. /*src_impl=*/this,
  123. /*dest_impl=*/impl.get(),
  124. /*version_counter=*/std::move(version_counter),
  125. /*allow_tensor_metadata_change=*/allow_tensor_metadata_change);
  126. impl->refresh_numel();
  127. return impl;
  128. }
  129. private:
  130. explicit SparseCsrTensorImpl(
  131. at::DispatchKeySet key_set,
  132. const caffe2::TypeMeta data_type,
  133. at::Tensor crow_indices,
  134. at::Tensor col_indices,
  135. at::Tensor values,
  136. at::Layout layout);
  137. const char* tensorimpl_type_name() const override;
  138. /**
  139. * Copy the tensor metadata fields (e.g. sizes / strides / storage pointer /
  140. * storage_offset) from one TensorImpl to another TensorImpl.
  141. *
  142. * For usage of `version_counter` and `allow_tensor_metadata_change`, see NOTE
  143. * [ TensorImpl Shallow-Copying ].
  144. */
  145. static void copy_tensor_metadata(
  146. const SparseCsrTensorImpl* src_sparse_impl,
  147. SparseCsrTensorImpl* dest_sparse_impl,
  148. const c10::VariableVersion& version_counter,
  149. bool allow_tensor_metadata_change) {
  150. TensorImpl::copy_tensor_metadata(
  151. src_sparse_impl,
  152. dest_sparse_impl,
  153. version_counter,
  154. allow_tensor_metadata_change);
  155. // Sparse-specific fields
  156. dest_sparse_impl->crow_indices_ = src_sparse_impl->compressed_indices();
  157. dest_sparse_impl->col_indices_ = src_sparse_impl->plain_indices();
  158. dest_sparse_impl->values_ = src_sparse_impl->values();
  159. dest_sparse_impl->layout_ = src_sparse_impl->layout_impl();
  160. }
  161. };
  162. } // namespace at