123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168 |
- #pragma once
- #include <ATen/core/Tensor.h>
- #include <ATen/native/ResizeCommon.h>
- #include <ATen/EmptyTensor.h>
- #include <ATen/TensorUtils.h>
- #include <c10/core/CPUAllocator.h>
- #include <utility>
- namespace at { namespace native {
- // TODO: make all operations that resize given outputs use this function
- // for consistency and maintainability.
- // Some operations like `cat` might not be able to make the use of
- // resize_output directly. For more details to understand how it works in `cat`,
- // see https://github.com/pytorch/pytorch/pull/62560#discussion_r687363362
- // Resizes outputs
- // Functions accepting output tensors, like with the "out" kwarg, should
- // call this function to handle resizing their output tensor.
- // Issues a warning if the output tensor has one or more elements and
- // needs resizing
- // NOTE: In the future the warning will become an error
- // Returns a bool saying whether or not the resize actually happened or not
- TORCH_API bool resize_output(const Tensor& output, IntArrayRef shape);
- TORCH_API bool resize_output_symint(const Tensor& output, SymIntArrayRef shape);
- // Utility for resize_output
- // Returns a bool saying resize should happen or not and
- // raises a warning if resizing for one or more elements
- TORCH_API bool resize_output_check(const Tensor& output, IntArrayRef shape);
- TORCH_API bool resize_output_check_symint(const Tensor& output, SymIntArrayRef shape);
- TORCH_API void resize_bytes_cpu(StorageImpl* storage, size_t size_bytes);
- static inline void maybe_resize_storage_cpu(TensorImpl* self, size_t new_size_bytes) {
- // It does not make sense to try to resize a storage
- // to hold 0 elements, and this can break
- // if storage_offset is positive but
- // new_size is 0, so just bail in that case
- // (same comment is in cuda/Resize.h)
- if (self->numel() == 0) {
- return;
- }
- const Storage& storage = self->unsafe_storage();
- if (!storage) {
- auto new_storage = c10::make_intrusive<StorageImpl>(
- StorageImpl::use_byte_size_t(),
- new_size_bytes,
- c10::GetCPUAllocator(),
- true);
- self->set_storage_keep_dtype(std::move(new_storage));
- } else if (new_size_bytes > storage.nbytes()) {
- resize_bytes_cpu(storage.unsafeGetStorageImpl(), new_size_bytes);
- }
- }
- TORCH_API TensorImpl* resize_impl_cpu_(
- TensorImpl* self,
- IntArrayRef size,
- at::OptionalIntArrayRef stride,
- bool resize_storage = true);
- template <typename T>
- T maybe_convert_symint(c10::SymInt) = delete;
- template <>
- inline c10::SymInt maybe_convert_symint(c10::SymInt x) { return x; }
- template <>
- inline int64_t maybe_convert_symint(c10::SymInt x) { return x.expect_int(); }
- template <typename T>
- static inline void checkInBoundsForStorage(
- ArrayRef<T> size,
- ArrayRef<T> stride,
- T storage_offset,
- const caffe2::TypeMeta data_type,
- const Storage& new_storage) {
- T storage_size_bytes =
- at::detail::computeStorageNbytes(size, stride, data_type.itemsize());
- T storage_offset_bytes = storage_offset * data_type.itemsize();
- if (storage_size_bytes == 0) {
- // NB: (a tensor with arbitrary 0 dims)'s storage can have any numel.
- return;
- }
- T new_storage_size_bytes = maybe_convert_symint<T>(new_storage.sym_nbytes());
- TORCH_CHECK(
- storage_size_bytes + storage_offset_bytes <= new_storage_size_bytes,
- "setStorage: sizes ",
- size,
- ", strides ",
- stride,
- ","
- " storage offset ",
- storage_offset,
- ", and itemsize ",
- data_type.itemsize(),
- " requiring a storage size of ",
- storage_size_bytes + storage_offset_bytes,
- " are out of bounds for storage of size ",
- new_storage_size_bytes);
- }
- template <typename T>
- static inline void checkSetStorage(Tensor& result, Storage storage, T storage_offset,
- ArrayRef<T> size, ArrayRef<T> stride) {
- // FIXME: stride should be optional
- if (stride.data()) {
- TORCH_CHECK(size.size() == stride.size(), "unequal size length (", size.size(),
- ") and stride length (", stride.size(), ")");
- }
- #ifdef DEBUG
- TORCH_CHECK(size.size() <= INT_MAX, "size length (", size.size(), ") greater than INT_MAX");
- #endif
- // storage: note this can't be replaced with result.set_(storage) as the semantics of that
- // function is to set the tensor size to be equal to the size of the storage.
- if (!result.storage().is_alias_of(storage)) {
- // Caffe2 might have tensors whose storages are null, but we
- // don't allow it in PyTorch.
- TORCH_INTERNAL_ASSERT(storage);
- TORCH_INTERNAL_ASSERT(result.storage());
- // We used to allow this, but this breaks device caching.
- // Let's put an actual error message for this one.
- TORCH_CHECK(result.storage().device() == storage.device(),
- "Attempted to set the storage of a tensor on device \"", result.storage().device(),
- "\" to a storage on different device \"", storage.device(),
- "\". This is no longer allowed; the devices must match.");
- result.unsafeGetTensorImpl()->set_storage_keep_dtype(std::move(storage));
- }
- // storageOffset
- TORCH_CHECK(storage_offset >= 0, "Tensor: invalid storage offset ", storage_offset);
- }
- /**
- * Set self's sizes, strides, and storage_offset.
- * (size, stride, storage_offset) must be in bounds for self's storage.
- */
- template <typename T>
- inline void setStrided(
- const Tensor& self,
- ArrayRef<T> size,
- ArrayRef<T> stride,
- T storage_offset) {
- TORCH_CHECK(size.size() == stride.size(), "mismatch in length of strides and shape");
- for (const auto& val : stride) {
- TORCH_CHECK(val >= 0,
- "as_strided: Negative strides are not supported at the moment, "
- "got strides: ", stride);
- }
- auto* self_ = self.unsafeGetTensorImpl();
- checkInBoundsForStorage(
- size, stride, storage_offset, self_->dtype(), self_->storage());
- /* storage offset */
- TORCH_CHECK(storage_offset >= 0, "Tensor: invalid storage offset ", storage_offset);
- self_->set_sizes_and_strides(size, stride, c10::make_optional(storage_offset));
- }
- }}
|