from_blob.h 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155
  1. #pragma once
  2. #include <ATen/core/Tensor.h>
  3. namespace at {
  4. namespace detail {
  5. TORCH_API inline void noopDelete(void*) {}
  6. } // namespace detail
  7. /// Provides a fluent API to construct tensors from external data.
  8. ///
  9. /// The fluent API can be used instead of `from_blob` functions in case the
  10. /// required set of parameters does not align with the existing overloads.
  11. ///
  12. /// at::Tensor tensor = at::for_blob(data, sizes)
  13. /// .strides(strides)
  14. /// .context(context, [](void *ctx) { delete static_cast<Ctx*>(ctx); })
  15. /// .options(...)
  16. /// .make_tensor();
  17. ///
  18. class TORCH_API TensorMaker {
  19. friend TensorMaker for_blob(void* data, IntArrayRef sizes) noexcept;
  20. public:
  21. using ContextDeleter = DeleterFnPtr;
  22. TensorMaker& strides(OptionalIntArrayRef value) noexcept {
  23. strides_ = value;
  24. return *this;
  25. }
  26. TensorMaker& storage_offset(optional<int64_t> value) noexcept {
  27. storage_offset_ = value;
  28. return *this;
  29. }
  30. TensorMaker& deleter(std::function<void(void*)> value) noexcept {
  31. deleter_ = std::move(value);
  32. return *this;
  33. }
  34. TensorMaker& context(void* value, ContextDeleter deleter = nullptr) noexcept {
  35. ctx_ = std::unique_ptr<void, ContextDeleter>{
  36. value, deleter != nullptr ? deleter : detail::noopDelete};
  37. return *this;
  38. }
  39. TensorMaker& target_device(optional<Device> value) noexcept {
  40. device_ = value;
  41. return *this;
  42. }
  43. TensorMaker& options(TensorOptions value) noexcept {
  44. opts_ = value;
  45. return *this;
  46. }
  47. Tensor make_tensor();
  48. private:
  49. explicit TensorMaker(void* data, IntArrayRef sizes) noexcept
  50. : data_{data}, sizes_{sizes} {}
  51. std::size_t computeStorageSize() const noexcept;
  52. DataPtr makeDataPtrFromDeleter() const;
  53. DataPtr makeDataPtrFromContext() noexcept;
  54. IntArrayRef makeTempSizes() const noexcept;
  55. void* data_;
  56. IntArrayRef sizes_;
  57. OptionalIntArrayRef strides_{};
  58. optional<int64_t> storage_offset_{};
  59. std::function<void(void*)> deleter_{};
  60. std::unique_ptr<void, ContextDeleter> ctx_{nullptr, detail::noopDelete};
  61. optional<Device> device_{};
  62. TensorOptions opts_{};
  63. };
  64. inline TensorMaker for_blob(void* data, IntArrayRef sizes) noexcept {
  65. return TensorMaker{data, sizes};
  66. }
  67. inline Tensor from_blob(
  68. void* data,
  69. IntArrayRef sizes,
  70. IntArrayRef strides,
  71. const std::function<void(void*)>& deleter,
  72. const TensorOptions& options = {},
  73. const c10::optional<Device> target_device = c10::nullopt) {
  74. return for_blob(data, sizes)
  75. .strides(strides)
  76. .deleter(deleter)
  77. .options(options)
  78. .target_device(target_device)
  79. .make_tensor();
  80. }
  81. inline Tensor from_blob(
  82. void* data,
  83. IntArrayRef sizes,
  84. IntArrayRef strides,
  85. int64_t storage_offset,
  86. const std::function<void(void*)>& deleter,
  87. const TensorOptions& options = {},
  88. const c10::optional<Device> target_device = c10::nullopt) {
  89. return for_blob(data, sizes)
  90. .strides(strides)
  91. .storage_offset(storage_offset)
  92. .deleter(deleter)
  93. .options(options)
  94. .target_device(target_device)
  95. .make_tensor();
  96. }
  97. inline Tensor from_blob(
  98. void* data,
  99. IntArrayRef sizes,
  100. const std::function<void(void*)>& deleter,
  101. const TensorOptions& options = {}) {
  102. return for_blob(data, sizes)
  103. .deleter(deleter)
  104. .options(options)
  105. .make_tensor();
  106. }
  107. inline Tensor from_blob(
  108. void* data,
  109. IntArrayRef sizes,
  110. IntArrayRef strides,
  111. const TensorOptions& options = {}) {
  112. return for_blob(data, sizes)
  113. .strides(strides)
  114. .options(options)
  115. .make_tensor();
  116. }
  117. inline Tensor from_blob(
  118. void* data,
  119. IntArrayRef sizes,
  120. const TensorOptions& options = {}) {
  121. return for_blob(data, sizes).options(options).make_tensor();
  122. }
  123. } // namespace at