123456789101112131415161718192021222324252627282930 |
- #pragma once
- #include <ATen/core/Tensor.h>
- #include <c10/core/ScalarType.h>
- namespace at {
- // These functions are defined in ATen/Utils.cpp.
- #define TENSOR(T, S) \
- TORCH_API Tensor tensor(ArrayRef<T> values, const TensorOptions& options); \
- inline Tensor tensor( \
- std::initializer_list<T> values, const TensorOptions& options) { \
- return at::tensor(ArrayRef<T>(values), options); \
- } \
- inline Tensor tensor(T value, const TensorOptions& options) { \
- return at::tensor(ArrayRef<T>(value), options); \
- } \
- inline Tensor tensor(ArrayRef<T> values) { \
- return at::tensor(std::move(values), at::dtype(k##S)); \
- } \
- inline Tensor tensor(std::initializer_list<T> values) { \
- return at::tensor(ArrayRef<T>(values)); \
- } \
- inline Tensor tensor(T value) { \
- return at::tensor(ArrayRef<T>(value)); \
- }
- AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TENSOR)
- AT_FORALL_COMPLEX_TYPES(TENSOR)
- #undef TENSOR
- } // namespace at
|