ResizeCommon.h 1.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647
  1. #pragma once
  2. #include <ATen/core/Tensor.h>
  3. #include <ATen/NamedTensorUtils.h>
  4. #include <c10/util/irange.h>
  5. namespace at { namespace native {
  6. template <typename T>
  7. inline T storage_size_for(ArrayRef<T> size, ArrayRef<T> stride) {
  8. TORCH_INTERNAL_ASSERT_DEBUG_ONLY(size.size() == stride.size(),
  9. "storage_size_for(size, stride) requires that size and stride ",
  10. "have the same size as a precondition.");
  11. T storage_size = 1;
  12. for (const auto dim : c10::irange(size.size())) {
  13. if (size[dim] == 0) {
  14. storage_size = 0;
  15. break;
  16. }
  17. storage_size += (size[dim] - 1) * stride[dim];
  18. }
  19. return storage_size;
  20. }
  21. inline const Tensor& resize_named_tensor_(
  22. const Tensor& self,
  23. IntArrayRef size,
  24. c10::optional<MemoryFormat> optional_memory_format) {
  25. TORCH_INTERNAL_ASSERT(self.has_names());
  26. TORCH_CHECK(
  27. self.sizes() == size,
  28. "Cannot resize named tensor with resize_ or resize_as_ (tried to resize "
  29. "Tensor",
  30. self.names(),
  31. " with size ",
  32. self.sizes(),
  33. " to ",
  34. size,
  35. "). This may be caused by passing a named tensor ",
  36. "as an `out=` argument; please ensure that the sizes are the same. ");
  37. TORCH_CHECK(
  38. !optional_memory_format.has_value(),
  39. "Unsupported memory format for named tensor resize ",
  40. optional_memory_format.value());
  41. return self;
  42. }
  43. }}