123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101 |
- #include <array>
- #include <ATen/Functions.h>
- #include <ATen/Utils.h>
- namespace at {
- Tensor TensorMaker::make_tensor() {
- AutoDispatchBelowADInplaceOrView guard{}; // TODO: Remove.
- tracer::impl::NoTracerDispatchMode tracer_guard{};
- check_size_nonnegative(sizes_);
- TORCH_CHECK_VALUE(
- !deleter_ || !ctx_,
- "The deleter and context arguments are mutually exclusive.");
- if (device_ == nullopt) {
- device_ = globalContext().getDeviceFromPtr(data_, opts_.device().type());
- }
- if (opts_.device().has_index()) {
- // clang-format off
- TORCH_CHECK_VALUE(
- opts_.device() == *device_,
- "Specified device ", opts_.device(), " does not match device of data ", *device_);
- // clang-format on
- }
- std::size_t size_bytes = computeStorageSize();
- DataPtr data_ptr{};
- if (deleter_) {
- data_ptr = makeDataPtrFromDeleter();
- } else {
- data_ptr = makeDataPtrFromContext();
- }
- Storage storage{Storage::use_byte_size_t{}, size_bytes, std::move(data_ptr)};
- Tensor tensor = detail::make_tensor<TensorImpl>(
- std::move(storage), opts_.computeDispatchKey(), opts_.dtype());
- TensorImpl* tensor_impl = tensor.unsafeGetTensorImpl();
- if (strides_) {
- tensor_impl->set_sizes_and_strides(sizes_, *strides_);
- } else {
- tensor_impl->set_sizes_contiguous(sizes_);
- }
- if (storage_offset_) {
- tensor_impl->set_storage_offset(*storage_offset_);
- }
- return tensor;
- }
- std::size_t TensorMaker::computeStorageSize() const noexcept {
- std::size_t itemsize = opts_.dtype().itemsize();
- if (strides_) {
- auto storage_size = detail::computeStorageNbytes(sizes_, *strides_, itemsize);
- if (storage_offset_) {
- storage_size += storage_offset_.value();
- }
- return storage_size;
- }
- std::size_t size = 1;
- for (std::int64_t s : sizes_) {
- size *= static_cast<std::size_t>(s);
- }
- auto storage_size = size * itemsize;
- if (storage_offset_) {
- storage_size += storage_offset_.value();
- }
- return storage_size;
- }
- inline DataPtr TensorMaker::makeDataPtrFromDeleter() const {
- return InefficientStdFunctionContext::makeDataPtr(data_, deleter_, *device_);
- }
- inline DataPtr TensorMaker::makeDataPtrFromContext() noexcept {
- return DataPtr{data_, ctx_.release(), ctx_.get_deleter(), *device_};
- }
- IntArrayRef TensorMaker::makeTempSizes() const noexcept {
- static std::int64_t zeros[5] = {0, 0, 0, 0, 0};
- if (opts_.has_memory_format()) {
- MemoryFormat format = *opts_.memory_format_opt();
- if (format == MemoryFormat::ChannelsLast) {
- return IntArrayRef(zeros, 4);
- }
- if (format == MemoryFormat::ChannelsLast3d) {
- return IntArrayRef(zeros, 5);
- }
- }
- return IntArrayRef(zeros, 1);
- }
- } // namespace at
|