123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186 |
- #pragma once
- #include <c10/core/MemoryFormat.h>
- #include <c10/core/SymIntArrayRef.h>
- #include <c10/core/TensorImpl.h>
- #include <c10/util/Exception.h>
- namespace at {
- // An "Opaque" TensorImpl -- there are no strides and (for now)
- // even data() is not supported (thus no pointer arithmetic).
- // NOTE: We could allow data() in the future, but would have to ensure pointer
- // arithmetic code is properly guarded.
- //
- // NOTE: This does not support resize_ (and other metadata-changing ops) because
- // of `shallow_copy_and_detach`. We would need to define an interface to
- // "shallow copy" in order to add support.
- template <typename OpaqueHandle>
- struct TORCH_API OpaqueTensorImpl : public TensorImpl {
- // public constructor for now...
- OpaqueTensorImpl(
- at::DispatchKeySet key_set,
- const caffe2::TypeMeta data_type,
- c10::Device device,
- OpaqueHandle opaque_handle,
- c10::IntArrayRef sizes,
- bool is_non_overlapping_and_dense = true)
- : TensorImpl(key_set, data_type, device),
- opaque_handle_(std::move(opaque_handle)) {
- set_storage_access_should_throw();
- set_custom_sizes_strides(SizesStridesPolicy::CustomStrides);
- sizes_and_strides_.set_sizes(sizes);
- refresh_numel();
- is_non_overlapping_and_dense_ = is_non_overlapping_and_dense;
- }
- // Destructor doesn't call release_resources because it's
- // unnecessary; don't forget to change that if needed!
- void release_resources() override {
- TensorImpl::release_resources();
- opaque_handle_ = {};
- }
- void set_size(int64_t dim, int64_t new_size) override {
- AT_ERROR("opaque tensors do not have set_size");
- }
- void set_stride(int64_t dim, int64_t new_stride) override {
- AT_ERROR("opaque tensors do not have set_stride");
- }
- void set_storage_offset(int64_t storage_offset) override {
- AT_ERROR("opaque tensors do not have set_storage_offset");
- }
- #ifdef DEBUG
- bool has_storage() const override {
- TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
- !storage_, "OpaqueTensorImpl assumes that storage_ is never set");
- return false;
- }
- #endif
- /**
- * Return a TensorImpl that is a shallow-copy of this TensorImpl.
- *
- * For usage of `version_counter` and `allow_tensor_metadata_change`,
- * see NOTE [ TensorImpl Shallow-Copying ].
- */
- c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
- const c10::VariableVersion& version_counter,
- bool allow_tensor_metadata_change) const override {
- auto impl = c10::make_intrusive<OpaqueTensorImpl<OpaqueHandle>>(
- key_set(),
- dtype(),
- device(),
- opaque_handle_,
- sizes_and_strides_.sizes_arrayref());
- copy_tensor_metadata(
- /*src_opaque_impl=*/this,
- /*dest_opaque_impl=*/impl.get(),
- /*version_counter=*/version_counter,
- /*allow_tensor_metadata_change=*/allow_tensor_metadata_change);
- impl->refresh_numel();
- return impl;
- }
- /**
- * Return a TensorImpl that is a shallow-copy of this TensorImpl.
- *
- * For usage of `version_counter` and `allow_tensor_metadata_change`,
- * see NOTE [ TensorImpl Shallow-Copying ].
- */
- c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
- c10::VariableVersion&& version_counter,
- bool allow_tensor_metadata_change) const override {
- auto impl = c10::make_intrusive<OpaqueTensorImpl<OpaqueHandle>>(
- key_set(),
- dtype(),
- device(),
- opaque_handle_,
- sizes_and_strides_.sizes_arrayref());
- copy_tensor_metadata(
- /*src_opaque_impl=*/this,
- /*dest_opaque_impl=*/impl.get(),
- /*version_counter=*/std::move(version_counter),
- /*allow_tensor_metadata_change=*/allow_tensor_metadata_change);
- impl->refresh_numel();
- return impl;
- }
- /**
- * Shallow-copies data from another TensorImpl into this TensorImpl.
- *
- * For why this function doesn't check this TensorImpl's
- * `allow_tensor_metadata_change_`, see NOTE [ TensorImpl Shallow-Copying ].
- */
- void shallow_copy_from(const c10::intrusive_ptr<TensorImpl>& impl) override {
- AT_ASSERT(has_compatible_shallow_copy_type(impl->key_set()));
- auto opaque_impl =
- static_cast<const OpaqueTensorImpl<OpaqueHandle>*>(impl.get());
- copy_tensor_metadata(
- /*src_impl=*/opaque_impl,
- /*dest_impl=*/this,
- /*version_counter=*/version_counter(),
- /*allow_tensor_metadata_change=*/allow_tensor_metadata_change());
- refresh_numel();
- }
- const OpaqueHandle& opaque_handle() const {
- return opaque_handle_;
- }
- OpaqueHandle& unsafe_opaque_handle() {
- return opaque_handle_;
- }
- protected:
- /**
- * Copy the tensor metadata fields (e.g. sizes / strides / storage pointer /
- * storage_offset) from one TensorImpl to another TensorImpl.
- *
- * For usage of `version_counter` and `allow_tensor_metadata_change`, see NOTE
- * [ TensorImpl Shallow-Copying ].
- */
- static void copy_tensor_metadata(
- const OpaqueTensorImpl<OpaqueHandle>* src_opaque_impl,
- OpaqueTensorImpl<OpaqueHandle>* dest_opaque_impl,
- const c10::VariableVersion& version_counter,
- bool allow_tensor_metadata_change) {
- TensorImpl::copy_tensor_metadata(
- src_opaque_impl,
- dest_opaque_impl,
- version_counter,
- allow_tensor_metadata_change);
- // OpaqueTensorImpl-specific fields.
- dest_opaque_impl->opaque_handle_ = src_opaque_impl->opaque_handle_;
- }
- static void copy_tensor_metadata(
- const OpaqueTensorImpl<OpaqueHandle>* src_opaque_impl,
- OpaqueTensorImpl<OpaqueHandle>* dest_opaque_impl,
- c10::VariableVersion&& version_counter,
- bool allow_tensor_metadata_change) {
- TensorImpl::copy_tensor_metadata(
- src_opaque_impl,
- dest_opaque_impl,
- std::move(version_counter),
- allow_tensor_metadata_change);
- // OpaqueTensorImpl-specific fields.
- dest_opaque_impl->opaque_handle_ = src_opaque_impl->opaque_handle_;
- }
- private:
- const char* tensorimpl_type_name() const override {
- return "OpaqueTensorImpl";
- }
- OpaqueHandle opaque_handle_;
- };
- } // namespace at
|