#pragma once #include #include #include #include #include #include #include #include #include #include namespace c10 { namespace cuda { namespace impl { struct CUDAGuardImpl final : public c10::impl::DeviceGuardImplInterface { static constexpr DeviceType static_type = DeviceType::CUDA; CUDAGuardImpl() = default; explicit CUDAGuardImpl(DeviceType t) { TORCH_INTERNAL_ASSERT(t == DeviceType::CUDA); } DeviceType type() const override { return DeviceType::CUDA; } Device exchangeDevice(Device d) const override { TORCH_INTERNAL_ASSERT(d.is_cuda()); int old_device_index = c10::cuda::ExchangeDevice(d.index()); return Device(DeviceType::CUDA, old_device_index); } Device getDevice() const override { int device; C10_CUDA_CHECK(c10::cuda::GetDevice(&device)); return Device(DeviceType::CUDA, device); } c10::optional uncheckedGetDevice() const noexcept { int device; const auto err = C10_CUDA_ERROR_HANDLED(c10::cuda::GetDevice(&device)); C10_CUDA_CHECK_WARN(err); if (err != cudaSuccess) { return c10::nullopt; } return Device(DeviceType::CUDA, device); } void setDevice(Device d) const override { TORCH_INTERNAL_ASSERT(d.is_cuda()); C10_CUDA_CHECK(c10::cuda::SetDevice(d.index())); } void uncheckedSetDevice(Device d) const noexcept override { C10_CUDA_CHECK_WARN(c10::cuda::MaybeSetDevice(d.index())); } Stream getStream(Device d) const noexcept override { return getCurrentCUDAStream(d.index()).unwrap(); } Stream getDefaultStream(Device d) const override { return getDefaultCUDAStream(d.index()); } Stream getStreamFromGlobalPool(Device d, bool isHighPriority = false) const override { return getStreamFromPool(isHighPriority, d.index()); } // NB: These do NOT set the current device Stream exchangeStream(Stream s) const noexcept override { CUDAStream cs(s); auto old_stream = getCurrentCUDAStream(s.device().index()); setCurrentCUDAStream(cs); return old_stream.unwrap(); } DeviceIndex deviceCount() const noexcept override { return device_count(); } // Event-related functions void createEvent(cudaEvent_t* cuda_event, const EventFlag flag) const { // Maps PyTorch's Event::Flag to CUDA flag auto cuda_flag = cudaEventDefault; switch (flag) { case EventFlag::PYTORCH_DEFAULT: case EventFlag::CUDA_EVENT_DISABLE_TIMING: cuda_flag = cudaEventDisableTiming; break; case EventFlag::BACKEND_DEFAULT: case EventFlag::CUDA_EVENT_DEFAULT: cuda_flag = cudaEventDefault; break; default: TORCH_CHECK(false, "CUDA event received unknown flag"); } C10_CUDA_CHECK(cudaEventCreateWithFlags(cuda_event, cuda_flag)); const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace(); if (C10_UNLIKELY(interp)) { (*interp)->trace_gpu_event_creation( reinterpret_cast(cuda_event)); } } void destroyEvent(void* event, const DeviceIndex device_index) const noexcept override { if (!event) return; auto cuda_event = static_cast(event); int orig_device; C10_CUDA_CHECK_WARN(c10::cuda::GetDevice(&orig_device)); C10_CUDA_CHECK_WARN(c10::cuda::SetDevice(device_index)); const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace(); if (C10_UNLIKELY(interp)) { (*interp)->trace_gpu_event_deletion( reinterpret_cast(cuda_event)); } C10_CUDA_CHECK_WARN(cudaEventDestroy(cuda_event)); C10_CUDA_CHECK_WARN(c10::cuda::SetDevice(orig_device)); } void record( void** event, const Stream& stream, const DeviceIndex device_index, const EventFlag flag) const override { TORCH_CHECK( device_index == -1 || device_index == stream.device_index(), "Event device index ", device_index, " does not match recording stream's device index ", stream.device_index(), "."); cudaEvent_t cuda_event = static_cast(*event); CUDAStream cuda_stream{stream}; // Moves to stream's device to record const auto orig_device = getDevice(); setDevice(stream.device()); // Creates the event (lazily) if (!cuda_event) createEvent(&cuda_event, flag); C10_CUDA_CHECK(cudaEventRecord(cuda_event, cuda_stream)); // Makes the void* point to the (possibly just allocated) CUDA event *event = cuda_event; const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace(); if (C10_UNLIKELY(interp)) { (*interp)->trace_gpu_event_record( reinterpret_cast(cuda_event), reinterpret_cast(cuda_stream.stream())); } // Resets device setDevice(orig_device); } void block(void* event, const Stream& stream) const override { if (!event) return; cudaEvent_t cuda_event = static_cast(event); CUDAStream cuda_stream{stream}; const auto orig_device = getDevice(); setDevice(stream.device()); C10_CUDA_CHECK(cudaStreamWaitEvent( cuda_stream, cuda_event, /*flags (must be zero)=*/0)); const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace(); if (C10_UNLIKELY(interp)) { (*interp)->trace_gpu_event_wait( reinterpret_cast(cuda_event), reinterpret_cast(cuda_stream.stream())); } setDevice(orig_device); } // May be called from any device bool queryEvent(void* event) const override { if (!event) return true; cudaEvent_t cuda_event = static_cast(event); const cudaError_t err = C10_CUDA_ERROR_HANDLED(cudaEventQuery(cuda_event)); if (err != cudaErrorNotReady) { C10_CUDA_CHECK(err); } else { // ignore and clear the error if not ready (void)cudaGetLastError(); } return (err == cudaSuccess); } // Stream-related functions bool queryStream(const Stream& stream) const override { CUDAStream cuda_stream{stream}; return cuda_stream.query(); } void synchronizeStream(const Stream& stream) const override { CUDAStream cuda_stream{stream}; cuda_stream.synchronize(); } void recordDataPtrOnStream(const c10::DataPtr& data_ptr, const Stream& stream) const override { CUDAStream cuda_stream{stream}; CUDACachingAllocator::recordStream(data_ptr, cuda_stream); } }; } // namespace impl } // namespace cuda } // namespace c10