CUDAContext.h 2.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182
  1. #pragma once
  2. #include <cstdint>
  3. #include <cuda_runtime_api.h>
  4. #include <cusparse.h>
  5. #include <cublas_v2.h>
  6. #ifdef CUDART_VERSION
  7. #include <cusolverDn.h>
  8. #endif
  9. #include <ATen/core/ATenGeneral.h>
  10. #include <ATen/Context.h>
  11. #include <c10/cuda/CUDAStream.h>
  12. #include <c10/cuda/CUDAFunctions.h>
  13. #include <ATen/cuda/Exceptions.h>
  14. namespace at {
  15. namespace cuda {
  16. /*
  17. A common CUDA interface for ATen.
  18. This interface is distinct from CUDAHooks, which defines an interface that links
  19. to both CPU-only and CUDA builds. That interface is intended for runtime
  20. dispatch and should be used from files that are included in both CPU-only and
  21. CUDA builds.
  22. CUDAContext, on the other hand, should be preferred by files only included in
  23. CUDA builds. It is intended to expose CUDA functionality in a consistent
  24. manner.
  25. This means there is some overlap between the CUDAContext and CUDAHooks, but
  26. the choice of which to use is simple: use CUDAContext when in a CUDA-only file,
  27. use CUDAHooks otherwise.
  28. Note that CUDAContext simply defines an interface with no associated class.
  29. It is expected that the modules whose functions compose this interface will
  30. manage their own state. There is only a single CUDA context/state.
  31. */
  32. /**
  33. * DEPRECATED: use device_count() instead
  34. */
  35. inline int64_t getNumGPUs() {
  36. return c10::cuda::device_count();
  37. }
  38. /**
  39. * CUDA is available if we compiled with CUDA, and there are one or more
  40. * devices. If we compiled with CUDA but there is a driver problem, etc.,
  41. * this function will report CUDA is not available (rather than raise an error.)
  42. */
  43. inline bool is_available() {
  44. return c10::cuda::device_count() > 0;
  45. }
  46. TORCH_CUDA_CPP_API cudaDeviceProp* getCurrentDeviceProperties();
  47. TORCH_CUDA_CPP_API int warp_size();
  48. TORCH_CUDA_CPP_API cudaDeviceProp* getDeviceProperties(int64_t device);
  49. TORCH_CUDA_CPP_API bool canDeviceAccessPeer(
  50. int64_t device,
  51. int64_t peer_device);
  52. TORCH_CUDA_CPP_API Allocator* getCUDADeviceAllocator();
  53. /* Handles */
  54. TORCH_CUDA_CPP_API cusparseHandle_t getCurrentCUDASparseHandle();
  55. TORCH_CUDA_CPP_API cublasHandle_t getCurrentCUDABlasHandle();
  56. TORCH_CUDA_CPP_API void clearCublasWorkspaces();
  57. #ifdef CUDART_VERSION
  58. TORCH_CUDA_CPP_API cusolverDnHandle_t getCurrentCUDASolverDnHandle();
  59. #endif
  60. } // namespace cuda
  61. } // namespace at