OpaqueTensorImpl.h 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186
  1. #pragma once
  2. #include <c10/core/MemoryFormat.h>
  3. #include <c10/core/SymIntArrayRef.h>
  4. #include <c10/core/TensorImpl.h>
  5. #include <c10/util/Exception.h>
  6. namespace at {
  7. // An "Opaque" TensorImpl -- there are no strides and (for now)
  8. // even data() is not supported (thus no pointer arithmetic).
  9. // NOTE: We could allow data() in the future, but would have to ensure pointer
  10. // arithmetic code is properly guarded.
  11. //
  12. // NOTE: This does not support resize_ (and other metadata-changing ops) because
  13. // of `shallow_copy_and_detach`. We would need to define an interface to
  14. // "shallow copy" in order to add support.
  15. template <typename OpaqueHandle>
  16. struct TORCH_API OpaqueTensorImpl : public TensorImpl {
  17. // public constructor for now...
  18. OpaqueTensorImpl(
  19. at::DispatchKeySet key_set,
  20. const caffe2::TypeMeta data_type,
  21. c10::Device device,
  22. OpaqueHandle opaque_handle,
  23. c10::IntArrayRef sizes,
  24. bool is_non_overlapping_and_dense = true)
  25. : TensorImpl(key_set, data_type, device),
  26. opaque_handle_(std::move(opaque_handle)) {
  27. set_storage_access_should_throw();
  28. set_custom_sizes_strides(SizesStridesPolicy::CustomStrides);
  29. sizes_and_strides_.set_sizes(sizes);
  30. refresh_numel();
  31. is_non_overlapping_and_dense_ = is_non_overlapping_and_dense;
  32. }
  33. // Destructor doesn't call release_resources because it's
  34. // unnecessary; don't forget to change that if needed!
  35. void release_resources() override {
  36. TensorImpl::release_resources();
  37. opaque_handle_ = {};
  38. }
  39. void set_size(int64_t dim, int64_t new_size) override {
  40. AT_ERROR("opaque tensors do not have set_size");
  41. }
  42. void set_stride(int64_t dim, int64_t new_stride) override {
  43. AT_ERROR("opaque tensors do not have set_stride");
  44. }
  45. void set_storage_offset(int64_t storage_offset) override {
  46. AT_ERROR("opaque tensors do not have set_storage_offset");
  47. }
  48. #ifdef DEBUG
  49. bool has_storage() const override {
  50. TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
  51. !storage_, "OpaqueTensorImpl assumes that storage_ is never set");
  52. return false;
  53. }
  54. #endif
  55. /**
  56. * Return a TensorImpl that is a shallow-copy of this TensorImpl.
  57. *
  58. * For usage of `version_counter` and `allow_tensor_metadata_change`,
  59. * see NOTE [ TensorImpl Shallow-Copying ].
  60. */
  61. c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
  62. const c10::VariableVersion& version_counter,
  63. bool allow_tensor_metadata_change) const override {
  64. auto impl = c10::make_intrusive<OpaqueTensorImpl<OpaqueHandle>>(
  65. key_set(),
  66. dtype(),
  67. device(),
  68. opaque_handle_,
  69. sizes_and_strides_.sizes_arrayref());
  70. copy_tensor_metadata(
  71. /*src_opaque_impl=*/this,
  72. /*dest_opaque_impl=*/impl.get(),
  73. /*version_counter=*/version_counter,
  74. /*allow_tensor_metadata_change=*/allow_tensor_metadata_change);
  75. impl->refresh_numel();
  76. return impl;
  77. }
  78. /**
  79. * Return a TensorImpl that is a shallow-copy of this TensorImpl.
  80. *
  81. * For usage of `version_counter` and `allow_tensor_metadata_change`,
  82. * see NOTE [ TensorImpl Shallow-Copying ].
  83. */
  84. c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
  85. c10::VariableVersion&& version_counter,
  86. bool allow_tensor_metadata_change) const override {
  87. auto impl = c10::make_intrusive<OpaqueTensorImpl<OpaqueHandle>>(
  88. key_set(),
  89. dtype(),
  90. device(),
  91. opaque_handle_,
  92. sizes_and_strides_.sizes_arrayref());
  93. copy_tensor_metadata(
  94. /*src_opaque_impl=*/this,
  95. /*dest_opaque_impl=*/impl.get(),
  96. /*version_counter=*/std::move(version_counter),
  97. /*allow_tensor_metadata_change=*/allow_tensor_metadata_change);
  98. impl->refresh_numel();
  99. return impl;
  100. }
  101. /**
  102. * Shallow-copies data from another TensorImpl into this TensorImpl.
  103. *
  104. * For why this function doesn't check this TensorImpl's
  105. * `allow_tensor_metadata_change_`, see NOTE [ TensorImpl Shallow-Copying ].
  106. */
  107. void shallow_copy_from(const c10::intrusive_ptr<TensorImpl>& impl) override {
  108. AT_ASSERT(has_compatible_shallow_copy_type(impl->key_set()));
  109. auto opaque_impl =
  110. static_cast<const OpaqueTensorImpl<OpaqueHandle>*>(impl.get());
  111. copy_tensor_metadata(
  112. /*src_impl=*/opaque_impl,
  113. /*dest_impl=*/this,
  114. /*version_counter=*/version_counter(),
  115. /*allow_tensor_metadata_change=*/allow_tensor_metadata_change());
  116. refresh_numel();
  117. }
  118. const OpaqueHandle& opaque_handle() const {
  119. return opaque_handle_;
  120. }
  121. OpaqueHandle& unsafe_opaque_handle() {
  122. return opaque_handle_;
  123. }
  124. protected:
  125. /**
  126. * Copy the tensor metadata fields (e.g. sizes / strides / storage pointer /
  127. * storage_offset) from one TensorImpl to another TensorImpl.
  128. *
  129. * For usage of `version_counter` and `allow_tensor_metadata_change`, see NOTE
  130. * [ TensorImpl Shallow-Copying ].
  131. */
  132. static void copy_tensor_metadata(
  133. const OpaqueTensorImpl<OpaqueHandle>* src_opaque_impl,
  134. OpaqueTensorImpl<OpaqueHandle>* dest_opaque_impl,
  135. const c10::VariableVersion& version_counter,
  136. bool allow_tensor_metadata_change) {
  137. TensorImpl::copy_tensor_metadata(
  138. src_opaque_impl,
  139. dest_opaque_impl,
  140. version_counter,
  141. allow_tensor_metadata_change);
  142. // OpaqueTensorImpl-specific fields.
  143. dest_opaque_impl->opaque_handle_ = src_opaque_impl->opaque_handle_;
  144. }
  145. static void copy_tensor_metadata(
  146. const OpaqueTensorImpl<OpaqueHandle>* src_opaque_impl,
  147. OpaqueTensorImpl<OpaqueHandle>* dest_opaque_impl,
  148. c10::VariableVersion&& version_counter,
  149. bool allow_tensor_metadata_change) {
  150. TensorImpl::copy_tensor_metadata(
  151. src_opaque_impl,
  152. dest_opaque_impl,
  153. std::move(version_counter),
  154. allow_tensor_metadata_change);
  155. // OpaqueTensorImpl-specific fields.
  156. dest_opaque_impl->opaque_handle_ = src_opaque_impl->opaque_handle_;
  157. }
  158. private:
  159. const char* tensorimpl_type_name() const override {
  160. return "OpaqueTensorImpl";
  161. }
  162. OpaqueHandle opaque_handle_;
  163. };
  164. } // namespace at