CUDATensorMethods.cuh 270 B

123456789101112131415
  1. #pragma once
  2. #include <ATen/Tensor.h>
  3. #include <c10/util/Half.h>
  4. #include <cuda.h>
  5. #include <cuda_runtime.h>
  6. #include <cuda_fp16.h>
  7. namespace at {
  8. template <>
  9. inline __half* Tensor::data() const {
  10. return reinterpret_cast<__half*>(data<Half>());
  11. }
  12. } // namespace at