NestedTensorImpl.h 9.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283
  1. #pragma once
  2. #include <ATen/MemoryOverlap.h>
  3. #include <ATen/Tensor.h>
  4. #include <c10/core/DispatchKey.h>
  5. #include <c10/core/DispatchKeySet.h>
  6. #include <c10/core/MemoryFormat.h>
  7. #include <c10/core/TensorImpl.h>
  8. #include <c10/util/ArrayRef.h>
  9. #include <c10/util/Exception.h>
  10. #include <c10/util/Metaprogramming.h>
  11. #include <c10/util/irange.h>
  12. namespace at {
  13. namespace native {
  14. struct NestedTensorImpl;
  15. inline bool nested_tensor_impl_is_contiguous(const NestedTensorImpl* nt);
  16. struct TORCH_API NestedTensorImpl : public c10::TensorImpl {
  17. explicit NestedTensorImpl(
  18. Storage storage,
  19. c10::DispatchKeySet key_set,
  20. const caffe2::TypeMeta data_type,
  21. at::Tensor nested_size_tensor,
  22. at::Tensor nested_stride_tensor,
  23. std::vector<int64_t>&& offsets);
  24. explicit NestedTensorImpl(
  25. at::Tensor buffer,
  26. at::Tensor nested_size_tensor,
  27. at::Tensor nested_stride_tensor,
  28. std::vector<int64_t>&& offsets);
  29. // assume contiguous, `nested_stride_tensor` and `offsets`
  30. // can be infered from `nested_size_tensor`
  31. explicit NestedTensorImpl(at::Tensor buffer, at::Tensor nested_size_tensor);
  32. // This constructor is used creating view tensors from nested tensors
  33. explicit NestedTensorImpl(
  34. c10::TensorImpl::ImplType impl_type,
  35. const at::Tensor& base_tensor,
  36. at::Tensor nested_size_tensor,
  37. at::Tensor nested_stride_tensor,
  38. std::vector<int64_t>&& offsets);
  39. // TODO: don't expose private implementation details like this; in
  40. // particular, resizing this tensor will mess up our dim() and
  41. // callers cannot fix it.
  42. const Tensor& get_nested_size_tensor() const {
  43. return nested_size_tensor_;
  44. }
  45. // TODO: don't expose private implementation details like this
  46. const Tensor& get_nested_stride_tensor() const {
  47. return nested_stride_tensor_;
  48. }
  49. const std::vector<int64_t>& get_storage_offsets() const {
  50. return storage_offsets_;
  51. }
  52. // Returns nullopt if the ith dimension is irregular. The ith dimension
  53. // of a NestedTensor is regular if the unbound tensors match in
  54. // size at the (i-1)th dimension.
  55. c10::optional<int64_t> opt_size(int64_t d) const {
  56. d = at::maybe_wrap_dim(d, dim(), false);
  57. if (opt_sizes_[d] == -1) {
  58. return c10::nullopt;
  59. }
  60. return opt_sizes_[d];
  61. }
  62. int64_t size(int64_t d) const {
  63. c10::optional<int64_t> optional_size = this->opt_size(d);
  64. TORCH_CHECK(
  65. optional_size.has_value(),
  66. "Given dimension ",
  67. d,
  68. " is irregular and does not have a size.");
  69. return *optional_size;
  70. }
  71. /**
  72. * Return a view of the nested tensor as a 1 dimensional contiguous tensor.
  73. *
  74. * The buffer tensor created by this function shares the same storage_impl as
  75. * the original nested tensor, and therefore can be seen as a view.
  76. *
  77. * @return A newly constructed view tensor
  78. */
  79. at::Tensor get_buffer() const {
  80. TORCH_CHECK(
  81. nested_tensor_impl_is_contiguous(this),
  82. "NestedTensor must be contiguous to get buffer.");
  83. return get_unsafe_storage_as_tensor();
  84. }
  85. /**
  86. * If possible use get_buffer() instead. This function returns the storage
  87. * as a tensor directly, which is not safe to use in general. If using this
  88. * function, The caller must ensure to account for nested_sizes,
  89. * nested_strides and storage_offsets.
  90. *
  91. * @return A newly constructed view tensor
  92. */
  93. at::Tensor get_unsafe_storage_as_tensor() const {
  94. auto buffer_key_set_ = generate_buffer_key_set();
  95. const auto buffer_size = get_buffer_size();
  96. auto buffer_tensor_impl = c10::make_intrusive<TensorImpl>(
  97. c10::TensorImpl::VIEW, Storage(storage_), buffer_key_set_, data_type_);
  98. buffer_tensor_impl->set_sizes_contiguous(c10::makeArrayRef(buffer_size));
  99. return Tensor(buffer_tensor_impl);
  100. }
  101. int64_t get_buffer_size() const {
  102. return storage_.nbytes() / data_type_.itemsize();
  103. }
  104. protected:
  105. const char* tensorimpl_type_name() const override;
  106. // TODO: numel_custom and is_contiguous_custom can be profitably overridden
  107. // with real implementations
  108. int64_t numel_custom() const override;
  109. c10::SymInt sym_numel_custom() const override;
  110. bool is_contiguous_custom(MemoryFormat) const override;
  111. int64_t size_custom(int64_t d) const override {
  112. return this->size(d);
  113. }
  114. c10::SymInt sym_size_custom(int64_t d) const override {
  115. return c10::SymInt{this->size(d)};
  116. }
  117. IntArrayRef sizes_custom() const override;
  118. c10::SymIntArrayRef sym_sizes_custom() const override;
  119. IntArrayRef strides_custom() const override;
  120. c10::SymIntArrayRef sym_strides_custom() const override;
  121. // this one is real
  122. int64_t dim_custom() const override;
  123. c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
  124. const c10::VariableVersion& version_counter,
  125. bool allow_tensor_metadata_change) const override;
  126. c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
  127. c10::VariableVersion&& version_counter,
  128. bool allow_tensor_metadata_change) const override;
  129. void shallow_copy_from(const c10::intrusive_ptr<TensorImpl>& impl) override {
  130. copy_tensor_metadata(
  131. /*src_impl=*/impl.get(),
  132. /*dest_impl=*/this,
  133. /*version_counter=*/version_counter(),
  134. /*allow_tensor_metadata_change=*/allow_tensor_metadata_change());
  135. }
  136. private:
  137. // Must be called after any changes to our dim() to sync the state
  138. // to TensorImpl.
  139. void refresh_dim();
  140. const at::Tensor nested_size_tensor_, nested_stride_tensor_;
  141. // The starting positions of the underlying tensors in contiguous buffer
  142. // i.e. the buffer memory offsets to get the underlying tensors
  143. // The reason to keep this metadata is that, without strong enough constraint
  144. // it cannot be derived from `nested_size_tensor_`
  145. // and `nested_stride_tensor_`:
  146. // 1. when buffer has blanks, e.g. [tensor1, blank, tensor2]
  147. // this can happen e.g. after slicing a nested tensor
  148. // 2. when multiple tensors share a same memory
  149. // 3. when the nesting ordering is changed, e.g. [tensor1, tensor3, tensor2]
  150. // Some strong enough constraints are:
  151. // 1. every underlying tensor is contiguous in memory
  152. // && nesting in ascending order
  153. std::vector<int64_t> storage_offsets_;
  154. // NOTE: -1 here means the size is missing
  155. // TODO: maybe we can remove this metadata since
  156. // we can compute it from `nested_size_tensor_`
  157. std::vector<int64_t> opt_sizes_;
  158. template <typename VariableVersion>
  159. c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach_core(
  160. VariableVersion&& version_counter,
  161. bool allow_tensor_metadata_change) const;
  162. /**
  163. * Generates a non-nested key_set from a nested tensor.
  164. *
  165. * For many nested tensor kernel implementations a buffer tensor
  166. * is generated and redispatched to a non-nested kernel this function
  167. * generates the key set used by that buffer tensor
  168. *
  169. * @return Appropriate key set for non-nested tensor
  170. */
  171. inline c10::DispatchKeySet generate_buffer_key_set() const {
  172. auto buffer_key_set = this->key_set();
  173. const bool Autograd = buffer_key_set.has_any(c10::autograd_dispatch_keyset);
  174. // Remove nested tensor specific keys
  175. buffer_key_set = buffer_key_set -
  176. c10::DispatchKeySet{
  177. c10::DispatchKey::NestedTensor,
  178. c10::DispatchKey::AutogradNestedTensor};
  179. // Add dense tensor specific keys
  180. buffer_key_set =
  181. buffer_key_set | c10::DispatchKeySet{c10::DispatchKey::Dense};
  182. buffer_key_set = Autograd
  183. ? c10::DispatchKeySet{c10::DispatchKey::Autograd} | buffer_key_set
  184. : buffer_key_set;
  185. return buffer_key_set;
  186. }
  187. };
  188. inline NestedTensorImpl* get_nested_tensor_impl_or_null(
  189. const at::Tensor& tensor) {
  190. if (tensor.is_nested()) {
  191. return static_cast<NestedTensorImpl*>(tensor.unsafeGetTensorImpl());
  192. }
  193. return nullptr;
  194. }
  195. inline NestedTensorImpl* get_nested_tensor_impl(const at::Tensor& tensor) {
  196. TORCH_CHECK(
  197. tensor.is_nested(), "get_nested_tensor_impl requires a NestedTensor.");
  198. return static_cast<NestedTensorImpl*>(tensor.unsafeGetTensorImpl());
  199. }
  200. inline bool nested_tensor_impl_is_contiguous(const NestedTensorImpl* nt) {
  201. int64_t ntensors = nt->size(0);
  202. if (ntensors == 0) {
  203. return true;
  204. }
  205. const Tensor &sizemat = nt->get_nested_size_tensor(),
  206. &stridemat = nt->get_nested_stride_tensor();
  207. const auto& offsets = nt->get_storage_offsets();
  208. int64_t orig_dim = sizemat.size(1);
  209. // nesting scalars
  210. if (orig_dim == 0) {
  211. // each scalar must be contiguous
  212. // if there is blanck memory between underlying scalars
  213. for (int64_t i = 0; i < ntensors; i++) {
  214. if (offsets[i] != i) {
  215. return false;
  216. }
  217. }
  218. }
  219. // nesting tensors
  220. else {
  221. // if any underlying tensor is noncontiguous
  222. const int64_t *sizemat_ptr = sizemat.data_ptr<int64_t>(),
  223. *stridemat_ptr = stridemat.data_ptr<int64_t>();
  224. for (int64_t i = 0; i < ntensors; i++) {
  225. if (stridemat_ptr[orig_dim - 1] != 1) {
  226. return false;
  227. }
  228. int64_t product = sizemat_ptr[orig_dim - 1];
  229. for (int64_t j = orig_dim - 2; j >= 0; j--) {
  230. if (stridemat_ptr[j] != product) {
  231. return false;
  232. }
  233. product *= sizemat_ptr[j];
  234. }
  235. sizemat_ptr += orig_dim;
  236. stridemat_ptr += orig_dim;
  237. }
  238. // if there is blanck memory between underlying tensors
  239. if (offsets[0] != 0) {
  240. return false;
  241. }
  242. sizemat_ptr = sizemat.data_ptr<int64_t>();
  243. stridemat_ptr = stridemat.data_ptr<int64_t>();
  244. for (int64_t i = 1; i < ntensors; i++) {
  245. if (offsets[i] != offsets[i - 1] + *sizemat_ptr * *stridemat_ptr) {
  246. return false;
  247. }
  248. sizemat_ptr += orig_dim;
  249. stridemat_ptr += orig_dim;
  250. }
  251. }
  252. // everything is fine
  253. return true;
  254. }
  255. inline const at::Tensor& get_nested_size_tensor(const at::Tensor& tensor) {
  256. return get_nested_tensor_impl(tensor)->get_nested_size_tensor();
  257. }
  258. } // namespace native
  259. } // namespace at