CUDAGuardImpl.h 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211
  1. #pragma once
  2. #include <c10/core/DeviceGuard.h>
  3. #include <c10/core/impl/DeviceGuardImplInterface.h>
  4. #include <c10/core/impl/GPUTrace.h>
  5. #include <c10/macros/Macros.h>
  6. #include <c10/util/Exception.h>
  7. #include <c10/cuda/CUDACachingAllocator.h>
  8. #include <c10/cuda/CUDAException.h>
  9. #include <c10/cuda/CUDAFunctions.h>
  10. #include <c10/cuda/CUDAStream.h>
  11. #include <cuda_runtime_api.h>
  12. namespace c10 {
  13. namespace cuda {
  14. namespace impl {
  15. struct CUDAGuardImpl final : public c10::impl::DeviceGuardImplInterface {
  16. static constexpr DeviceType static_type = DeviceType::CUDA;
  17. CUDAGuardImpl() = default;
  18. explicit CUDAGuardImpl(DeviceType t) {
  19. TORCH_INTERNAL_ASSERT(t == DeviceType::CUDA);
  20. }
  21. DeviceType type() const override {
  22. return DeviceType::CUDA;
  23. }
  24. Device exchangeDevice(Device d) const override {
  25. TORCH_INTERNAL_ASSERT(d.is_cuda());
  26. int old_device_index = c10::cuda::ExchangeDevice(d.index());
  27. return Device(DeviceType::CUDA, old_device_index);
  28. }
  29. Device getDevice() const override {
  30. int device;
  31. C10_CUDA_CHECK(c10::cuda::GetDevice(&device));
  32. return Device(DeviceType::CUDA, device);
  33. }
  34. c10::optional<Device> uncheckedGetDevice() const noexcept {
  35. int device;
  36. const auto err = C10_CUDA_ERROR_HANDLED(c10::cuda::GetDevice(&device));
  37. C10_CUDA_CHECK_WARN(err);
  38. if (err != cudaSuccess) {
  39. return c10::nullopt;
  40. }
  41. return Device(DeviceType::CUDA, device);
  42. }
  43. void setDevice(Device d) const override {
  44. TORCH_INTERNAL_ASSERT(d.is_cuda());
  45. C10_CUDA_CHECK(c10::cuda::SetDevice(d.index()));
  46. }
  47. void uncheckedSetDevice(Device d) const noexcept override {
  48. C10_CUDA_CHECK_WARN(c10::cuda::MaybeSetDevice(d.index()));
  49. }
  50. Stream getStream(Device d) const noexcept override {
  51. return getCurrentCUDAStream(d.index()).unwrap();
  52. }
  53. Stream getDefaultStream(Device d) const override {
  54. return getDefaultCUDAStream(d.index());
  55. }
  56. Stream getStreamFromGlobalPool(Device d, bool isHighPriority = false)
  57. const override {
  58. return getStreamFromPool(isHighPriority, d.index());
  59. }
  60. // NB: These do NOT set the current device
  61. Stream exchangeStream(Stream s) const noexcept override {
  62. CUDAStream cs(s);
  63. auto old_stream = getCurrentCUDAStream(s.device().index());
  64. setCurrentCUDAStream(cs);
  65. return old_stream.unwrap();
  66. }
  67. DeviceIndex deviceCount() const noexcept override {
  68. return device_count();
  69. }
  70. // Event-related functions
  71. void createEvent(cudaEvent_t* cuda_event, const EventFlag flag) const {
  72. // Maps PyTorch's Event::Flag to CUDA flag
  73. auto cuda_flag = cudaEventDefault;
  74. switch (flag) {
  75. case EventFlag::PYTORCH_DEFAULT:
  76. case EventFlag::CUDA_EVENT_DISABLE_TIMING:
  77. cuda_flag = cudaEventDisableTiming;
  78. break;
  79. case EventFlag::BACKEND_DEFAULT:
  80. case EventFlag::CUDA_EVENT_DEFAULT:
  81. cuda_flag = cudaEventDefault;
  82. break;
  83. default:
  84. TORCH_CHECK(false, "CUDA event received unknown flag");
  85. }
  86. C10_CUDA_CHECK(cudaEventCreateWithFlags(cuda_event, cuda_flag));
  87. const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
  88. if (C10_UNLIKELY(interp)) {
  89. (*interp)->trace_gpu_event_creation(
  90. reinterpret_cast<uintptr_t>(cuda_event));
  91. }
  92. }
  93. void destroyEvent(void* event, const DeviceIndex device_index)
  94. const noexcept override {
  95. if (!event)
  96. return;
  97. auto cuda_event = static_cast<cudaEvent_t>(event);
  98. int orig_device;
  99. C10_CUDA_CHECK_WARN(c10::cuda::GetDevice(&orig_device));
  100. C10_CUDA_CHECK_WARN(c10::cuda::SetDevice(device_index));
  101. const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
  102. if (C10_UNLIKELY(interp)) {
  103. (*interp)->trace_gpu_event_deletion(
  104. reinterpret_cast<uintptr_t>(cuda_event));
  105. }
  106. C10_CUDA_CHECK_WARN(cudaEventDestroy(cuda_event));
  107. C10_CUDA_CHECK_WARN(c10::cuda::SetDevice(orig_device));
  108. }
  109. void record(
  110. void** event,
  111. const Stream& stream,
  112. const DeviceIndex device_index,
  113. const EventFlag flag) const override {
  114. TORCH_CHECK(
  115. device_index == -1 || device_index == stream.device_index(),
  116. "Event device index ",
  117. device_index,
  118. " does not match recording stream's device index ",
  119. stream.device_index(),
  120. ".");
  121. cudaEvent_t cuda_event = static_cast<cudaEvent_t>(*event);
  122. CUDAStream cuda_stream{stream};
  123. // Moves to stream's device to record
  124. const auto orig_device = getDevice();
  125. setDevice(stream.device());
  126. // Creates the event (lazily)
  127. if (!cuda_event)
  128. createEvent(&cuda_event, flag);
  129. C10_CUDA_CHECK(cudaEventRecord(cuda_event, cuda_stream));
  130. // Makes the void* point to the (possibly just allocated) CUDA event
  131. *event = cuda_event;
  132. const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
  133. if (C10_UNLIKELY(interp)) {
  134. (*interp)->trace_gpu_event_record(
  135. reinterpret_cast<uintptr_t>(cuda_event),
  136. reinterpret_cast<uintptr_t>(cuda_stream.stream()));
  137. }
  138. // Resets device
  139. setDevice(orig_device);
  140. }
  141. void block(void* event, const Stream& stream) const override {
  142. if (!event)
  143. return;
  144. cudaEvent_t cuda_event = static_cast<cudaEvent_t>(event);
  145. CUDAStream cuda_stream{stream};
  146. const auto orig_device = getDevice();
  147. setDevice(stream.device());
  148. C10_CUDA_CHECK(cudaStreamWaitEvent(
  149. cuda_stream,
  150. cuda_event,
  151. /*flags (must be zero)=*/0));
  152. const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
  153. if (C10_UNLIKELY(interp)) {
  154. (*interp)->trace_gpu_event_wait(
  155. reinterpret_cast<uintptr_t>(cuda_event),
  156. reinterpret_cast<uintptr_t>(cuda_stream.stream()));
  157. }
  158. setDevice(orig_device);
  159. }
  160. // May be called from any device
  161. bool queryEvent(void* event) const override {
  162. if (!event)
  163. return true;
  164. cudaEvent_t cuda_event = static_cast<cudaEvent_t>(event);
  165. const cudaError_t err = C10_CUDA_ERROR_HANDLED(cudaEventQuery(cuda_event));
  166. if (err != cudaErrorNotReady) {
  167. C10_CUDA_CHECK(err);
  168. } else {
  169. // ignore and clear the error if not ready
  170. (void)cudaGetLastError();
  171. }
  172. return (err == cudaSuccess);
  173. }
  174. // Stream-related functions
  175. bool queryStream(const Stream& stream) const override {
  176. CUDAStream cuda_stream{stream};
  177. return cuda_stream.query();
  178. }
  179. void synchronizeStream(const Stream& stream) const override {
  180. CUDAStream cuda_stream{stream};
  181. cuda_stream.synchronize();
  182. }
  183. void recordDataPtrOnStream(const c10::DataPtr& data_ptr, const Stream& stream)
  184. const override {
  185. CUDAStream cuda_stream{stream};
  186. CUDACachingAllocator::recordStream(data_ptr, cuda_stream);
  187. }
  188. };
  189. } // namespace impl
  190. } // namespace cuda
  191. } // namespace c10