CUDAFunctions.h 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118
  1. #pragma once
  2. // This header provides C++ wrappers around commonly used CUDA API functions.
  3. // The benefit of using C++ here is that we can raise an exception in the
  4. // event of an error, rather than explicitly pass around error codes. This
  5. // leads to more natural APIs.
  6. //
  7. // The naming convention used here matches the naming convention of torch.cuda
  8. #include <c10/core/Device.h>
  9. #include <c10/core/impl/GPUTrace.h>
  10. #include <c10/cuda/CUDAException.h>
  11. #include <c10/cuda/CUDAMacros.h>
  12. #include <cuda_runtime_api.h>
  13. namespace c10 {
  14. namespace cuda {
  15. // NB: In the past, we were inconsistent about whether or not this reported
  16. // an error if there were driver problems are not. Based on experience
  17. // interacting with users, it seems that people basically ~never want this
  18. // function to fail; it should just return zero if things are not working.
  19. // Oblige them.
  20. // It still might log a warning for user first time it's invoked
  21. C10_CUDA_API DeviceIndex device_count() noexcept;
  22. // Version of device_count that throws is no devices are detected
  23. C10_CUDA_API DeviceIndex device_count_ensure_non_zero();
  24. C10_CUDA_API DeviceIndex current_device();
  25. C10_CUDA_API void set_device(DeviceIndex device);
  26. C10_CUDA_API void device_synchronize();
  27. C10_CUDA_API void warn_or_error_on_sync();
  28. // Raw CUDA device management functions
  29. C10_CUDA_API cudaError_t GetDeviceCount(int* dev_count);
  30. C10_CUDA_API cudaError_t GetDevice(int* device);
  31. C10_CUDA_API cudaError_t SetDevice(int device);
  32. C10_CUDA_API cudaError_t MaybeSetDevice(int device);
  33. C10_CUDA_API int ExchangeDevice(int device);
  34. C10_CUDA_API int MaybeExchangeDevice(int device);
  35. C10_CUDA_API void SetTargetDevice();
  36. enum class SyncDebugMode { L_DISABLED = 0, L_WARN, L_ERROR };
  37. // this is a holder for c10 global state (similar to at GlobalContext)
  38. // currently it's used to store cuda synchronization warning state,
  39. // but can be expanded to hold other related global state, e.g. to
  40. // record stream usage
  41. class WarningState {
  42. public:
  43. void set_sync_debug_mode(SyncDebugMode l) {
  44. sync_debug_mode = l;
  45. }
  46. SyncDebugMode get_sync_debug_mode() {
  47. return sync_debug_mode;
  48. }
  49. private:
  50. SyncDebugMode sync_debug_mode = SyncDebugMode::L_DISABLED;
  51. };
  52. C10_CUDA_API __inline__ WarningState& warning_state() {
  53. static WarningState warning_state_;
  54. return warning_state_;
  55. }
  56. // the subsequent functions are defined in the header because for performance
  57. // reasons we want them to be inline
  58. C10_CUDA_API void __inline__ memcpy_and_sync(
  59. void* dst,
  60. void* src,
  61. int64_t nbytes,
  62. cudaMemcpyKind kind,
  63. cudaStream_t stream) {
  64. if (C10_UNLIKELY(
  65. warning_state().get_sync_debug_mode() != SyncDebugMode::L_DISABLED)) {
  66. warn_or_error_on_sync();
  67. }
  68. const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
  69. if (C10_UNLIKELY(interp)) {
  70. (*interp)->trace_gpu_stream_synchronization(
  71. reinterpret_cast<uintptr_t>(stream));
  72. }
  73. #if defined(TORCH_HIP_VERSION) && (TORCH_HIP_VERSION >= 301)
  74. C10_CUDA_CHECK(hipMemcpyWithStream(dst, src, nbytes, kind, stream));
  75. #else
  76. C10_CUDA_CHECK(cudaMemcpyAsync(dst, src, nbytes, kind, stream));
  77. C10_CUDA_CHECK(cudaStreamSynchronize(stream));
  78. #endif
  79. }
  80. C10_CUDA_API void __inline__ stream_synchronize(cudaStream_t stream) {
  81. if (C10_UNLIKELY(
  82. warning_state().get_sync_debug_mode() != SyncDebugMode::L_DISABLED)) {
  83. warn_or_error_on_sync();
  84. }
  85. const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
  86. if (C10_UNLIKELY(interp)) {
  87. (*interp)->trace_gpu_stream_synchronization(
  88. reinterpret_cast<uintptr_t>(stream));
  89. }
  90. C10_CUDA_CHECK(cudaStreamSynchronize(stream));
  91. }
  92. C10_CUDA_API bool hasPrimaryContext(int64_t device_index);
  93. C10_CUDA_API c10::optional<int64_t> getDeviceIndexWithPrimaryContext();
  94. } // namespace cuda
  95. } // namespace c10