123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179 |
- #pragma once
- #include <ATen/Tensor.h>
- #include <c10/core/TensorImpl.h>
- #include <c10/util/Exception.h>
- namespace at {
- // Struct implementing a sparse CSR tensor. It uses three 1-D tensors for
- // denoting the data: `crow_indices_`, `col_indices_` and `values_`.
- // The `crow_indices_` tensor is a integer tensor of shape `(size(0) + 1)`
- // that represents the compressed row indices of the CSR tensor. The
- // `col_indices_` tensor is an integer tensor of shape `(nnz())`
- // that explicitly stores the column indices of each value of the sparse
- // tensor. The `values_` tensor can be of any pytorch-supported data type
- // and has shape `(nnz())`.
- //
- // Since the main advantage of the CSR format over the COO format is speed of
- // computation, care must be taken to facilitate smooth interfacing of
- // these data structures with optimized libraries such as MKL and MAGMA.
- // Since the MKL interface for pytorch currently uses indexing with int32
- // type, it is important to make sure that the `crow_indices` and `col_indices`
- // are of type int32 when calling MKL routines such as SPMM or SPMV.
- //
- // If not calling MKL, it should be alright to use 64 bit integer tensors
- // for indexing.
- struct TORCH_API SparseCsrTensorImpl : public TensorImpl {
- Tensor crow_indices_;
- Tensor col_indices_;
- Tensor values_;
- Layout layout_;
- public:
- explicit SparseCsrTensorImpl(
- at::DispatchKeySet,
- at::Device device,
- Layout layout,
- const caffe2::TypeMeta);
- void resize_(int64_t nnz, IntArrayRef size);
- void resize_and_clear_(int64_t sparse_dim, IntArrayRef size);
- void resize_as_sparse_compressed_tensor_(const Tensor& src);
- void set_member_tensors(
- const Tensor& crow_indices,
- const Tensor& col_indices,
- const Tensor& values,
- IntArrayRef size);
- const Tensor& compressed_indices() const {
- return crow_indices_;
- }
- const Tensor& plain_indices() const {
- return col_indices_;
- }
- const Tensor& values() const {
- return values_;
- }
- int nnz() {
- return col_indices_.size(-1);
- }
- inline int64_t batch_dim() const noexcept {
- return crow_indices_.dim() - 1;
- }
- inline int64_t sparse_dim() const noexcept {
- return 2;
- }
- inline int64_t dense_dim() const noexcept {
- return values_.dim() - batch_dim() - block_dim() - 1;
- }
- private:
- inline int64_t block_dim() const noexcept {
- return (layout_ == kSparseBsr || layout_ == kSparseBsc ? 2 : 0);
- }
- protected:
- IntArrayRef strides_custom() const override;
- SymIntArrayRef sym_strides_custom() const override;
- bool is_contiguous_custom(MemoryFormat) const override;
- public:
- void set_size(int64_t dim, int64_t new_size) override;
- void set_stride(int64_t dim, int64_t new_stride) override;
- void set_storage_offset(int64_t storage_offset) override;
- Layout layout_impl() const override {
- return layout_;
- }
- void set_layout(Layout layout) {
- switch (layout) {
- case kSparseCsr:
- case kSparseCsc:
- case kSparseBsr:
- case kSparseBsc:
- layout_ = layout;
- break;
- default:
- TORCH_CHECK(false, "unsupported layout ", layout);
- }
- }
- /**
- * Return a TensorImpl that is a shallow-copy of this TensorImpl.
- *
- * For usage of `version_counter` and `allow_tensor_metadata_change`,
- * see NOTE [ TensorImpl Shallow-Copying ].
- */
- c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
- const c10::VariableVersion& version_counter,
- bool allow_tensor_metadata_change) const override {
- auto impl = c10::make_intrusive<SparseCsrTensorImpl>(
- key_set(), device(), layout_impl(), dtype());
- copy_tensor_metadata(
- /*src_impl=*/this,
- /*dest_impl=*/impl.get(),
- /*version_counter=*/version_counter,
- /*allow_tensor_metadata_change=*/allow_tensor_metadata_change);
- impl->refresh_numel();
- return impl;
- }
- /**
- * Return a TensorImpl that is a shallow-copy of this TensorImpl.
- *
- * For usage of `version_counter` and `allow_tensor_metadata_change`,
- * see NOTE [ TensorImpl Shallow-Copying ].
- */
- c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
- c10::VariableVersion&& version_counter,
- bool allow_tensor_metadata_change) const override {
- auto impl = c10::make_intrusive<SparseCsrTensorImpl>(
- key_set(), device(), layout_impl(), dtype());
- copy_tensor_metadata(
- /*src_impl=*/this,
- /*dest_impl=*/impl.get(),
- /*version_counter=*/std::move(version_counter),
- /*allow_tensor_metadata_change=*/allow_tensor_metadata_change);
- impl->refresh_numel();
- return impl;
- }
- private:
- explicit SparseCsrTensorImpl(
- at::DispatchKeySet key_set,
- const caffe2::TypeMeta data_type,
- at::Tensor crow_indices,
- at::Tensor col_indices,
- at::Tensor values,
- at::Layout layout);
- const char* tensorimpl_type_name() const override;
- /**
- * Copy the tensor metadata fields (e.g. sizes / strides / storage pointer /
- * storage_offset) from one TensorImpl to another TensorImpl.
- *
- * For usage of `version_counter` and `allow_tensor_metadata_change`, see NOTE
- * [ TensorImpl Shallow-Copying ].
- */
- static void copy_tensor_metadata(
- const SparseCsrTensorImpl* src_sparse_impl,
- SparseCsrTensorImpl* dest_sparse_impl,
- const c10::VariableVersion& version_counter,
- bool allow_tensor_metadata_change) {
- TensorImpl::copy_tensor_metadata(
- src_sparse_impl,
- dest_sparse_impl,
- version_counter,
- allow_tensor_metadata_change);
- // Sparse-specific fields
- dest_sparse_impl->crow_indices_ = src_sparse_impl->compressed_indices();
- dest_sparse_impl->col_indices_ = src_sparse_impl->plain_indices();
- dest_sparse_impl->values_ = src_sparse_impl->values();
- dest_sparse_impl->layout_ = src_sparse_impl->layout_impl();
- }
- };
- } // namespace at
|