TensorMethods.cpp 1.3 KB

123456789101112131415161718192021222324252627282930313233
  1. #include <c10/core/Scalar.h>
  2. #include <ATen/core/TensorBody.h>
  3. namespace at {
  4. #define DEFINE_CAST(T, name) \
  5. template <> \
  6. TORCH_API T* TensorBase::data_ptr() const { \
  7. TORCH_CHECK( \
  8. scalar_type() == ScalarType::name \
  9. || (isQIntType(scalar_type()) \
  10. && toUnderlying(scalar_type()) == ScalarType::name), \
  11. "expected scalar type " \
  12. #name \
  13. " but found ", \
  14. scalar_type()); \
  15. return this->unsafeGetTensorImpl()->data_ptr_impl<T>(); \
  16. }
  17. AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(DEFINE_CAST)
  18. AT_FORALL_QINT_TYPES(DEFINE_CAST)
  19. #undef DEFINE_CAST
  20. #define DEFINE_ITEM(T, name) \
  21. template <> \
  22. TORCH_API T Tensor::item() const { \
  23. return item().to##name(); \
  24. }
  25. AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(DEFINE_ITEM)
  26. #undef DEFINE_ITEM
  27. } //namespace at