Resize.h 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168
  1. #pragma once
  2. #include <ATen/core/Tensor.h>
  3. #include <ATen/native/ResizeCommon.h>
  4. #include <ATen/EmptyTensor.h>
  5. #include <ATen/TensorUtils.h>
  6. #include <c10/core/CPUAllocator.h>
  7. #include <utility>
  8. namespace at { namespace native {
  9. // TODO: make all operations that resize given outputs use this function
  10. // for consistency and maintainability.
  11. // Some operations like `cat` might not be able to make the use of
  12. // resize_output directly. For more details to understand how it works in `cat`,
  13. // see https://github.com/pytorch/pytorch/pull/62560#discussion_r687363362
  14. // Resizes outputs
  15. // Functions accepting output tensors, like with the "out" kwarg, should
  16. // call this function to handle resizing their output tensor.
  17. // Issues a warning if the output tensor has one or more elements and
  18. // needs resizing
  19. // NOTE: In the future the warning will become an error
  20. // Returns a bool saying whether or not the resize actually happened or not
  21. TORCH_API bool resize_output(const Tensor& output, IntArrayRef shape);
  22. TORCH_API bool resize_output_symint(const Tensor& output, SymIntArrayRef shape);
  23. // Utility for resize_output
  24. // Returns a bool saying resize should happen or not and
  25. // raises a warning if resizing for one or more elements
  26. TORCH_API bool resize_output_check(const Tensor& output, IntArrayRef shape);
  27. TORCH_API bool resize_output_check_symint(const Tensor& output, SymIntArrayRef shape);
  28. TORCH_API void resize_bytes_cpu(StorageImpl* storage, size_t size_bytes);
  29. static inline void maybe_resize_storage_cpu(TensorImpl* self, size_t new_size_bytes) {
  30. // It does not make sense to try to resize a storage
  31. // to hold 0 elements, and this can break
  32. // if storage_offset is positive but
  33. // new_size is 0, so just bail in that case
  34. // (same comment is in cuda/Resize.h)
  35. if (self->numel() == 0) {
  36. return;
  37. }
  38. const Storage& storage = self->unsafe_storage();
  39. if (!storage) {
  40. auto new_storage = c10::make_intrusive<StorageImpl>(
  41. StorageImpl::use_byte_size_t(),
  42. new_size_bytes,
  43. c10::GetCPUAllocator(),
  44. true);
  45. self->set_storage_keep_dtype(std::move(new_storage));
  46. } else if (new_size_bytes > storage.nbytes()) {
  47. resize_bytes_cpu(storage.unsafeGetStorageImpl(), new_size_bytes);
  48. }
  49. }
  50. TORCH_API TensorImpl* resize_impl_cpu_(
  51. TensorImpl* self,
  52. IntArrayRef size,
  53. at::OptionalIntArrayRef stride,
  54. bool resize_storage = true);
  55. template <typename T>
  56. T maybe_convert_symint(c10::SymInt) = delete;
  57. template <>
  58. inline c10::SymInt maybe_convert_symint(c10::SymInt x) { return x; }
  59. template <>
  60. inline int64_t maybe_convert_symint(c10::SymInt x) { return x.expect_int(); }
  61. template <typename T>
  62. static inline void checkInBoundsForStorage(
  63. ArrayRef<T> size,
  64. ArrayRef<T> stride,
  65. T storage_offset,
  66. const caffe2::TypeMeta data_type,
  67. const Storage& new_storage) {
  68. T storage_size_bytes =
  69. at::detail::computeStorageNbytes(size, stride, data_type.itemsize());
  70. T storage_offset_bytes = storage_offset * data_type.itemsize();
  71. if (storage_size_bytes == 0) {
  72. // NB: (a tensor with arbitrary 0 dims)'s storage can have any numel.
  73. return;
  74. }
  75. T new_storage_size_bytes = maybe_convert_symint<T>(new_storage.sym_nbytes());
  76. TORCH_CHECK(
  77. storage_size_bytes + storage_offset_bytes <= new_storage_size_bytes,
  78. "setStorage: sizes ",
  79. size,
  80. ", strides ",
  81. stride,
  82. ","
  83. " storage offset ",
  84. storage_offset,
  85. ", and itemsize ",
  86. data_type.itemsize(),
  87. " requiring a storage size of ",
  88. storage_size_bytes + storage_offset_bytes,
  89. " are out of bounds for storage of size ",
  90. new_storage_size_bytes);
  91. }
  92. template <typename T>
  93. static inline void checkSetStorage(Tensor& result, Storage storage, T storage_offset,
  94. ArrayRef<T> size, ArrayRef<T> stride) {
  95. // FIXME: stride should be optional
  96. if (stride.data()) {
  97. TORCH_CHECK(size.size() == stride.size(), "unequal size length (", size.size(),
  98. ") and stride length (", stride.size(), ")");
  99. }
  100. #ifdef DEBUG
  101. TORCH_CHECK(size.size() <= INT_MAX, "size length (", size.size(), ") greater than INT_MAX");
  102. #endif
  103. // storage: note this can't be replaced with result.set_(storage) as the semantics of that
  104. // function is to set the tensor size to be equal to the size of the storage.
  105. if (!result.storage().is_alias_of(storage)) {
  106. // Caffe2 might have tensors whose storages are null, but we
  107. // don't allow it in PyTorch.
  108. TORCH_INTERNAL_ASSERT(storage);
  109. TORCH_INTERNAL_ASSERT(result.storage());
  110. // We used to allow this, but this breaks device caching.
  111. // Let's put an actual error message for this one.
  112. TORCH_CHECK(result.storage().device() == storage.device(),
  113. "Attempted to set the storage of a tensor on device \"", result.storage().device(),
  114. "\" to a storage on different device \"", storage.device(),
  115. "\". This is no longer allowed; the devices must match.");
  116. result.unsafeGetTensorImpl()->set_storage_keep_dtype(std::move(storage));
  117. }
  118. // storageOffset
  119. TORCH_CHECK(storage_offset >= 0, "Tensor: invalid storage offset ", storage_offset);
  120. }
  121. /**
  122. * Set self's sizes, strides, and storage_offset.
  123. * (size, stride, storage_offset) must be in bounds for self's storage.
  124. */
  125. template <typename T>
  126. inline void setStrided(
  127. const Tensor& self,
  128. ArrayRef<T> size,
  129. ArrayRef<T> stride,
  130. T storage_offset) {
  131. TORCH_CHECK(size.size() == stride.size(), "mismatch in length of strides and shape");
  132. for (const auto& val : stride) {
  133. TORCH_CHECK(val >= 0,
  134. "as_strided: Negative strides are not supported at the moment, "
  135. "got strides: ", stride);
  136. }
  137. auto* self_ = self.unsafeGetTensorImpl();
  138. checkInBoundsForStorage(
  139. size, stride, storage_offset, self_->dtype(), self_->storage());
  140. /* storage offset */
  141. TORCH_CHECK(storage_offset >= 0, "Tensor: invalid storage offset ", storage_offset);
  142. self_->set_sizes_and_strides(size, stride, c10::make_optional(storage_offset));
  143. }
  144. }}