Functions.cpp 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101
  1. #include <array>
  2. #include <ATen/Functions.h>
  3. #include <ATen/Utils.h>
  4. namespace at {
  5. Tensor TensorMaker::make_tensor() {
  6. AutoDispatchBelowADInplaceOrView guard{}; // TODO: Remove.
  7. tracer::impl::NoTracerDispatchMode tracer_guard{};
  8. check_size_nonnegative(sizes_);
  9. TORCH_CHECK_VALUE(
  10. !deleter_ || !ctx_,
  11. "The deleter and context arguments are mutually exclusive.");
  12. if (device_ == nullopt) {
  13. device_ = globalContext().getDeviceFromPtr(data_, opts_.device().type());
  14. }
  15. if (opts_.device().has_index()) {
  16. // clang-format off
  17. TORCH_CHECK_VALUE(
  18. opts_.device() == *device_,
  19. "Specified device ", opts_.device(), " does not match device of data ", *device_);
  20. // clang-format on
  21. }
  22. std::size_t size_bytes = computeStorageSize();
  23. DataPtr data_ptr{};
  24. if (deleter_) {
  25. data_ptr = makeDataPtrFromDeleter();
  26. } else {
  27. data_ptr = makeDataPtrFromContext();
  28. }
  29. Storage storage{Storage::use_byte_size_t{}, size_bytes, std::move(data_ptr)};
  30. Tensor tensor = detail::make_tensor<TensorImpl>(
  31. std::move(storage), opts_.computeDispatchKey(), opts_.dtype());
  32. TensorImpl* tensor_impl = tensor.unsafeGetTensorImpl();
  33. if (strides_) {
  34. tensor_impl->set_sizes_and_strides(sizes_, *strides_);
  35. } else {
  36. tensor_impl->set_sizes_contiguous(sizes_);
  37. }
  38. if (storage_offset_) {
  39. tensor_impl->set_storage_offset(*storage_offset_);
  40. }
  41. return tensor;
  42. }
  43. std::size_t TensorMaker::computeStorageSize() const noexcept {
  44. std::size_t itemsize = opts_.dtype().itemsize();
  45. if (strides_) {
  46. auto storage_size = detail::computeStorageNbytes(sizes_, *strides_, itemsize);
  47. if (storage_offset_) {
  48. storage_size += storage_offset_.value();
  49. }
  50. return storage_size;
  51. }
  52. std::size_t size = 1;
  53. for (std::int64_t s : sizes_) {
  54. size *= static_cast<std::size_t>(s);
  55. }
  56. auto storage_size = size * itemsize;
  57. if (storage_offset_) {
  58. storage_size += storage_offset_.value();
  59. }
  60. return storage_size;
  61. }
  62. inline DataPtr TensorMaker::makeDataPtrFromDeleter() const {
  63. return InefficientStdFunctionContext::makeDataPtr(data_, deleter_, *device_);
  64. }
  65. inline DataPtr TensorMaker::makeDataPtrFromContext() noexcept {
  66. return DataPtr{data_, ctx_.release(), ctx_.get_deleter(), *device_};
  67. }
  68. IntArrayRef TensorMaker::makeTempSizes() const noexcept {
  69. static std::int64_t zeros[5] = {0, 0, 0, 0, 0};
  70. if (opts_.has_memory_format()) {
  71. MemoryFormat format = *opts_.memory_format_opt();
  72. if (format == MemoryFormat::ChannelsLast) {
  73. return IntArrayRef(zeros, 4);
  74. }
  75. if (format == MemoryFormat::ChannelsLast3d) {
  76. return IntArrayRef(zeros, 5);
  77. }
  78. }
  79. return IntArrayRef(zeros, 1);
  80. }
  81. } // namespace at