tensor.h 1.6 KB

123456789101112131415161718192021222324252627282930
  1. #pragma once
  2. #include <ATen/core/Tensor.h>
  3. #include <c10/core/ScalarType.h>
  4. namespace at {
  5. // These functions are defined in ATen/Utils.cpp.
  6. #define TENSOR(T, S) \
  7. TORCH_API Tensor tensor(ArrayRef<T> values, const TensorOptions& options); \
  8. inline Tensor tensor( \
  9. std::initializer_list<T> values, const TensorOptions& options) { \
  10. return at::tensor(ArrayRef<T>(values), options); \
  11. } \
  12. inline Tensor tensor(T value, const TensorOptions& options) { \
  13. return at::tensor(ArrayRef<T>(value), options); \
  14. } \
  15. inline Tensor tensor(ArrayRef<T> values) { \
  16. return at::tensor(std::move(values), at::dtype(k##S)); \
  17. } \
  18. inline Tensor tensor(std::initializer_list<T> values) { \
  19. return at::tensor(ArrayRef<T>(values)); \
  20. } \
  21. inline Tensor tensor(T value) { \
  22. return at::tensor(ArrayRef<T>(value)); \
  23. }
  24. AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TENSOR)
  25. AT_FORALL_COMPLEX_TYPES(TENSOR)
  26. #undef TENSOR
  27. } // namespace at