CUDAGraphsC10Utils.h 2.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192
  1. #pragma once
  2. #include <c10/cuda/CUDAStream.h>
  3. #include <utility>
  4. // CUDA Graphs utils used by c10 and aten.
  5. // aten/cuda/CUDAGraphsUtils.cuh adds utils used by aten only.
  6. namespace c10 {
  7. namespace cuda {
  8. using CaptureId_t = unsigned long long;
  9. // first is set if the instance is created by CUDAGraph::capture_begin.
  10. // second is set if the instance is created by at::cuda::graph_pool_handle.
  11. using MempoolId_t = std::pair<CaptureId_t, CaptureId_t>;
  12. // RAII guard for "cudaStreamCaptureMode", a thread-local value
  13. // that controls the error-checking strictness of a capture.
  14. #if !defined(USE_ROCM) || ROCM_VERSION >= 50300
  15. struct C10_CUDA_API CUDAStreamCaptureModeGuard {
  16. CUDAStreamCaptureModeGuard(cudaStreamCaptureMode desired) {
  17. strictness_ = desired;
  18. C10_CUDA_CHECK(cudaThreadExchangeStreamCaptureMode(&strictness_));
  19. }
  20. ~CUDAStreamCaptureModeGuard() {
  21. C10_CUDA_CHECK_WARN(cudaThreadExchangeStreamCaptureMode(&strictness_));
  22. }
  23. private:
  24. cudaStreamCaptureMode strictness_;
  25. };
  26. #endif
  27. #if !defined(USE_ROCM) || ROCM_VERSION >= 50300
  28. // Protects against enum cudaStreamCaptureStatus implementation changes.
  29. // Some compilers seem not to like static_assert without the messages.
  30. static_assert(
  31. int(cudaStreamCaptureStatus::cudaStreamCaptureStatusNone) == 0,
  32. "unexpected int(cudaStreamCaptureStatusNone) value");
  33. static_assert(
  34. int(cudaStreamCaptureStatus::cudaStreamCaptureStatusActive) == 1,
  35. "unexpected int(cudaStreamCaptureStatusActive) value");
  36. static_assert(
  37. int(cudaStreamCaptureStatus::cudaStreamCaptureStatusInvalidated) == 2,
  38. "unexpected int(cudaStreamCaptureStatusInvalidated) value");
  39. #endif
  40. enum class CaptureStatus : int {
  41. #if !defined(USE_ROCM) || ROCM_VERSION >= 50300
  42. None = int(cudaStreamCaptureStatus::cudaStreamCaptureStatusNone),
  43. Active = int(cudaStreamCaptureStatus::cudaStreamCaptureStatusActive),
  44. Invalidated = int(cudaStreamCaptureStatus::cudaStreamCaptureStatusInvalidated)
  45. #else
  46. None = 0
  47. #endif
  48. };
  49. inline std::ostream& operator<<(std::ostream& os, CaptureStatus status) {
  50. switch (status) {
  51. case CaptureStatus::None:
  52. os << "cudaStreamCaptureStatusNone";
  53. break;
  54. #if !defined(USE_ROCM) || ROCM_VERSION >= 50300
  55. case CaptureStatus::Active:
  56. os << "cudaStreamCaptureStatusActive";
  57. break;
  58. case CaptureStatus::Invalidated:
  59. os << "cudaStreamCaptureStatusInvalidated";
  60. break;
  61. #endif
  62. default:
  63. TORCH_INTERNAL_ASSERT(
  64. false, "Unknown CUDA graph CaptureStatus", int(status));
  65. }
  66. return os;
  67. }
  68. // Use this version where you're sure a CUDA context exists already.
  69. inline CaptureStatus currentStreamCaptureStatusMayInitCtx() {
  70. #if !defined(USE_ROCM) || ROCM_VERSION >= 50300
  71. cudaStreamCaptureStatus is_capturing;
  72. C10_CUDA_CHECK(
  73. cudaStreamIsCapturing(c10::cuda::getCurrentCUDAStream(), &is_capturing));
  74. return CaptureStatus(is_capturing);
  75. #else
  76. return CaptureStatus::None;
  77. #endif
  78. }
  79. } // namespace cuda
  80. } // namespace c10