SparseTensorUtils.h 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123
  1. #pragma once
  2. #include <ATen/Parallel.h>
  3. #include <ATen/SparseTensorImpl.h>
  4. #include <ATen/core/Tensor.h>
  5. #ifndef AT_PER_OPERATOR_HEADERS
  6. #include <ATen/Functions.h>
  7. #else
  8. #include <ATen/ops/empty.h>
  9. #endif
  10. namespace at {
  11. namespace sparse {
  12. // Just for documentary purposes
  13. using SparseTensor = Tensor;
  14. using SparseType = Type;
  15. // This is an internal utility function for getting at the SparseTensorImpl,
  16. // so that we can write sparse tensor specific accessors for special fields
  17. // in SparseTensor. You should only use this for writing low level
  18. // setters/getters for SparseTensorImpl fields; otherwise, you should use
  19. // the low level setters/getters that were implemented using this.
  20. //
  21. // This may be called repeatedly, so make sure it's pretty cheap.
  22. inline SparseTensorImpl* get_sparse_impl(const SparseTensor& self) {
  23. TORCH_INTERNAL_ASSERT(
  24. self.is_sparse(), "_internal_get_SparseTensorImpl: not a sparse tensor");
  25. return static_cast<SparseTensorImpl*>(self.unsafeGetTensorImpl());
  26. }
  27. // Takes indices and values and directly puts them into the sparse tensor, no
  28. // copy. This used to be called THSTensor_(_move)
  29. inline void alias_into_sparse(
  30. const SparseTensor& self,
  31. const Tensor& indices,
  32. const Tensor& values) {
  33. get_sparse_impl(self)->set_indices_and_values_unsafe(indices, values);
  34. }
  35. // Take indices and values and makes a (data) copy of them to put into the
  36. // sparse indices/values. This used to be called THSTensor_(_set)
  37. inline void copy_into_sparse(
  38. const SparseTensor& self,
  39. const Tensor& indices,
  40. const Tensor& values,
  41. bool non_blocking) {
  42. alias_into_sparse(
  43. self,
  44. indices.to(self._indices().options(), non_blocking, /*copy=*/true),
  45. values.to(self._values().options(), non_blocking, /*copy=*/true));
  46. }
  47. // TODO: put this into the public API
  48. inline bool is_same_tensor(const Tensor& lhs, const Tensor& rhs) {
  49. return lhs.unsafeGetTensorImpl() == rhs.unsafeGetTensorImpl();
  50. }
  51. inline bool is_same_density(const SparseTensor& self, const SparseTensor& src) {
  52. return self.sparse_dim() == src.sparse_dim() &&
  53. self.dense_dim() == src.dense_dim();
  54. }
  55. // Give us a new values tensor, with the same dimensionality
  56. // as 'values' but with a new number of non-zero elements.
  57. // TODO: Expose this for real in ATen, some day?
  58. // NB: Doesn't preserve data.
  59. inline Tensor new_values_with_size_of(const Tensor& values, int64_t nnz) {
  60. std::vector<int64_t> size = values.sizes().vec();
  61. size[0] = nnz;
  62. return at::empty(size, values.options());
  63. }
  64. // NOTE [ Flatten Sparse Indices ]
  65. // This helper function flattens a sparse indices tensor (a Tensor) into a 1D
  66. // indices tensor. E.g.,
  67. // input = [[2, 4, 0],
  68. // [3, 1, 10]]
  69. // full_size = [2, 12]
  70. // output = [ 2 * 12 + 3, 4 * 12 + 1, 0 * 12 + 10 ] = [27, 49, 10]
  71. //
  72. // In other words, assuming that each `indices[i, :]` is a valid index to a
  73. // tensor `t` of shape `full_size`. This returns the corresponding indices to
  74. // the flattened tensor `t.reshape( prod(full_size[:indices.size(0)]), -1 )`.
  75. // if forceClone is true, the result will forced to be a clone of self.
  76. // if force_clone is true, the result will forced to be a clone of self.
  77. TORCH_API Tensor flatten_indices(
  78. const Tensor& indices,
  79. IntArrayRef full_size,
  80. bool force_clone = false);
  81. // Flatten sparse tensor's indices from nD to 1D, similar to NOTE [ Flatten
  82. // Sparse Indices ], except this one allows partial flatten: only flatten on
  83. // specified dims. Note that the flatten indices might be uncoalesced if
  84. // dims_to_flatten.size() < sparse_dim. Also if input indices is already
  85. // coalesced, the flattened indices will also be sorted.
  86. //
  87. // args:
  88. // indices: sparse tensor indices
  89. // sizes: sparse tensor sizes
  90. // dims_to_flatten: a list of dim index to flatten
  91. //
  92. // Ex1:
  93. // indices = [[2, 4, 0],
  94. // [3, 1, 3]]
  95. // sizes = [2, 12]
  96. // dims_to_flatten = [0, 1]
  97. // new_indices = [ 2 * 12 + 3, 4 * 12 + 1, 0 * 12 + 3 ] = [27, 49, 3]
  98. //
  99. // Ex2:
  100. // dims_to_flatten = [1]
  101. // new_indices = [ 3, 1, 3 ] # uncoalesced
  102. TORCH_API Tensor flatten_indices_by_dims(
  103. const Tensor& indices,
  104. const IntArrayRef& sizes,
  105. const IntArrayRef& dims_to_flatten);
  106. // Find the CSR representation for a row `indices` from the COO format
  107. TORCH_API Tensor coo_to_csr(const int64_t* indices, int64_t dim, int64_t nnz);
  108. } // namespace sparse
  109. } // namespace at