CUDADevice.h 539 B

123456789101112131415161718192021222324
  1. #pragma once
  2. #include <ATen/cuda/Exceptions.h>
  3. #include <cuda.h>
  4. #include <cuda_runtime.h>
  5. namespace at {
  6. namespace cuda {
  7. inline Device getDeviceFromPtr(void* ptr) {
  8. cudaPointerAttributes attr{};
  9. AT_CUDA_CHECK(cudaPointerGetAttributes(&attr, ptr));
  10. #if !defined(USE_ROCM)
  11. TORCH_CHECK(attr.type != cudaMemoryTypeUnregistered,
  12. "The specified pointer resides on host memory and is not registered with any CUDA device.");
  13. #endif
  14. return {DeviceType::CUDA, static_cast<DeviceIndex>(attr.device)};
  15. }
  16. }} // namespace at::cuda