1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192 |
- #pragma once
- #include <c10/cuda/CUDAStream.h>
- #include <utility>
- // CUDA Graphs utils used by c10 and aten.
- // aten/cuda/CUDAGraphsUtils.cuh adds utils used by aten only.
- namespace c10 {
- namespace cuda {
- using CaptureId_t = unsigned long long;
- // first is set if the instance is created by CUDAGraph::capture_begin.
- // second is set if the instance is created by at::cuda::graph_pool_handle.
- using MempoolId_t = std::pair<CaptureId_t, CaptureId_t>;
- // RAII guard for "cudaStreamCaptureMode", a thread-local value
- // that controls the error-checking strictness of a capture.
- #if !defined(USE_ROCM) || ROCM_VERSION >= 50300
- struct C10_CUDA_API CUDAStreamCaptureModeGuard {
- CUDAStreamCaptureModeGuard(cudaStreamCaptureMode desired) {
- strictness_ = desired;
- C10_CUDA_CHECK(cudaThreadExchangeStreamCaptureMode(&strictness_));
- }
- ~CUDAStreamCaptureModeGuard() {
- C10_CUDA_CHECK_WARN(cudaThreadExchangeStreamCaptureMode(&strictness_));
- }
- private:
- cudaStreamCaptureMode strictness_;
- };
- #endif
- #if !defined(USE_ROCM) || ROCM_VERSION >= 50300
- // Protects against enum cudaStreamCaptureStatus implementation changes.
- // Some compilers seem not to like static_assert without the messages.
- static_assert(
- int(cudaStreamCaptureStatus::cudaStreamCaptureStatusNone) == 0,
- "unexpected int(cudaStreamCaptureStatusNone) value");
- static_assert(
- int(cudaStreamCaptureStatus::cudaStreamCaptureStatusActive) == 1,
- "unexpected int(cudaStreamCaptureStatusActive) value");
- static_assert(
- int(cudaStreamCaptureStatus::cudaStreamCaptureStatusInvalidated) == 2,
- "unexpected int(cudaStreamCaptureStatusInvalidated) value");
- #endif
- enum class CaptureStatus : int {
- #if !defined(USE_ROCM) || ROCM_VERSION >= 50300
- None = int(cudaStreamCaptureStatus::cudaStreamCaptureStatusNone),
- Active = int(cudaStreamCaptureStatus::cudaStreamCaptureStatusActive),
- Invalidated = int(cudaStreamCaptureStatus::cudaStreamCaptureStatusInvalidated)
- #else
- None = 0
- #endif
- };
- inline std::ostream& operator<<(std::ostream& os, CaptureStatus status) {
- switch (status) {
- case CaptureStatus::None:
- os << "cudaStreamCaptureStatusNone";
- break;
- #if !defined(USE_ROCM) || ROCM_VERSION >= 50300
- case CaptureStatus::Active:
- os << "cudaStreamCaptureStatusActive";
- break;
- case CaptureStatus::Invalidated:
- os << "cudaStreamCaptureStatusInvalidated";
- break;
- #endif
- default:
- TORCH_INTERNAL_ASSERT(
- false, "Unknown CUDA graph CaptureStatus", int(status));
- }
- return os;
- }
- // Use this version where you're sure a CUDA context exists already.
- inline CaptureStatus currentStreamCaptureStatusMayInitCtx() {
- #if !defined(USE_ROCM) || ROCM_VERSION >= 50300
- cudaStreamCaptureStatus is_capturing;
- C10_CUDA_CHECK(
- cudaStreamIsCapturing(c10::cuda::getCurrentCUDAStream(), &is_capturing));
- return CaptureStatus(is_capturing);
- #else
- return CaptureStatus::None;
- #endif
- }
- } // namespace cuda
- } // namespace c10
|