123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211 |
- #pragma once
- #include <c10/core/DeviceGuard.h>
- #include <c10/core/impl/DeviceGuardImplInterface.h>
- #include <c10/core/impl/GPUTrace.h>
- #include <c10/macros/Macros.h>
- #include <c10/util/Exception.h>
- #include <c10/cuda/CUDACachingAllocator.h>
- #include <c10/cuda/CUDAException.h>
- #include <c10/cuda/CUDAFunctions.h>
- #include <c10/cuda/CUDAStream.h>
- #include <cuda_runtime_api.h>
- namespace c10 {
- namespace cuda {
- namespace impl {
- struct CUDAGuardImpl final : public c10::impl::DeviceGuardImplInterface {
- static constexpr DeviceType static_type = DeviceType::CUDA;
- CUDAGuardImpl() = default;
- explicit CUDAGuardImpl(DeviceType t) {
- TORCH_INTERNAL_ASSERT(t == DeviceType::CUDA);
- }
- DeviceType type() const override {
- return DeviceType::CUDA;
- }
- Device exchangeDevice(Device d) const override {
- TORCH_INTERNAL_ASSERT(d.is_cuda());
- int old_device_index = c10::cuda::ExchangeDevice(d.index());
- return Device(DeviceType::CUDA, old_device_index);
- }
- Device getDevice() const override {
- int device;
- C10_CUDA_CHECK(c10::cuda::GetDevice(&device));
- return Device(DeviceType::CUDA, device);
- }
- c10::optional<Device> uncheckedGetDevice() const noexcept {
- int device;
- const auto err = C10_CUDA_ERROR_HANDLED(c10::cuda::GetDevice(&device));
- C10_CUDA_CHECK_WARN(err);
- if (err != cudaSuccess) {
- return c10::nullopt;
- }
- return Device(DeviceType::CUDA, device);
- }
- void setDevice(Device d) const override {
- TORCH_INTERNAL_ASSERT(d.is_cuda());
- C10_CUDA_CHECK(c10::cuda::SetDevice(d.index()));
- }
- void uncheckedSetDevice(Device d) const noexcept override {
- C10_CUDA_CHECK_WARN(c10::cuda::MaybeSetDevice(d.index()));
- }
- Stream getStream(Device d) const noexcept override {
- return getCurrentCUDAStream(d.index()).unwrap();
- }
- Stream getDefaultStream(Device d) const override {
- return getDefaultCUDAStream(d.index());
- }
- Stream getStreamFromGlobalPool(Device d, bool isHighPriority = false)
- const override {
- return getStreamFromPool(isHighPriority, d.index());
- }
- // NB: These do NOT set the current device
- Stream exchangeStream(Stream s) const noexcept override {
- CUDAStream cs(s);
- auto old_stream = getCurrentCUDAStream(s.device().index());
- setCurrentCUDAStream(cs);
- return old_stream.unwrap();
- }
- DeviceIndex deviceCount() const noexcept override {
- return device_count();
- }
- // Event-related functions
- void createEvent(cudaEvent_t* cuda_event, const EventFlag flag) const {
- // Maps PyTorch's Event::Flag to CUDA flag
- auto cuda_flag = cudaEventDefault;
- switch (flag) {
- case EventFlag::PYTORCH_DEFAULT:
- case EventFlag::CUDA_EVENT_DISABLE_TIMING:
- cuda_flag = cudaEventDisableTiming;
- break;
- case EventFlag::BACKEND_DEFAULT:
- case EventFlag::CUDA_EVENT_DEFAULT:
- cuda_flag = cudaEventDefault;
- break;
- default:
- TORCH_CHECK(false, "CUDA event received unknown flag");
- }
- C10_CUDA_CHECK(cudaEventCreateWithFlags(cuda_event, cuda_flag));
- const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
- if (C10_UNLIKELY(interp)) {
- (*interp)->trace_gpu_event_creation(
- reinterpret_cast<uintptr_t>(cuda_event));
- }
- }
- void destroyEvent(void* event, const DeviceIndex device_index)
- const noexcept override {
- if (!event)
- return;
- auto cuda_event = static_cast<cudaEvent_t>(event);
- int orig_device;
- C10_CUDA_CHECK_WARN(c10::cuda::GetDevice(&orig_device));
- C10_CUDA_CHECK_WARN(c10::cuda::SetDevice(device_index));
- const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
- if (C10_UNLIKELY(interp)) {
- (*interp)->trace_gpu_event_deletion(
- reinterpret_cast<uintptr_t>(cuda_event));
- }
- C10_CUDA_CHECK_WARN(cudaEventDestroy(cuda_event));
- C10_CUDA_CHECK_WARN(c10::cuda::SetDevice(orig_device));
- }
- void record(
- void** event,
- const Stream& stream,
- const DeviceIndex device_index,
- const EventFlag flag) const override {
- TORCH_CHECK(
- device_index == -1 || device_index == stream.device_index(),
- "Event device index ",
- device_index,
- " does not match recording stream's device index ",
- stream.device_index(),
- ".");
- cudaEvent_t cuda_event = static_cast<cudaEvent_t>(*event);
- CUDAStream cuda_stream{stream};
- // Moves to stream's device to record
- const auto orig_device = getDevice();
- setDevice(stream.device());
- // Creates the event (lazily)
- if (!cuda_event)
- createEvent(&cuda_event, flag);
- C10_CUDA_CHECK(cudaEventRecord(cuda_event, cuda_stream));
- // Makes the void* point to the (possibly just allocated) CUDA event
- *event = cuda_event;
- const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
- if (C10_UNLIKELY(interp)) {
- (*interp)->trace_gpu_event_record(
- reinterpret_cast<uintptr_t>(cuda_event),
- reinterpret_cast<uintptr_t>(cuda_stream.stream()));
- }
- // Resets device
- setDevice(orig_device);
- }
- void block(void* event, const Stream& stream) const override {
- if (!event)
- return;
- cudaEvent_t cuda_event = static_cast<cudaEvent_t>(event);
- CUDAStream cuda_stream{stream};
- const auto orig_device = getDevice();
- setDevice(stream.device());
- C10_CUDA_CHECK(cudaStreamWaitEvent(
- cuda_stream,
- cuda_event,
- /*flags (must be zero)=*/0));
- const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
- if (C10_UNLIKELY(interp)) {
- (*interp)->trace_gpu_event_wait(
- reinterpret_cast<uintptr_t>(cuda_event),
- reinterpret_cast<uintptr_t>(cuda_stream.stream()));
- }
- setDevice(orig_device);
- }
- // May be called from any device
- bool queryEvent(void* event) const override {
- if (!event)
- return true;
- cudaEvent_t cuda_event = static_cast<cudaEvent_t>(event);
- const cudaError_t err = C10_CUDA_ERROR_HANDLED(cudaEventQuery(cuda_event));
- if (err != cudaErrorNotReady) {
- C10_CUDA_CHECK(err);
- } else {
- // ignore and clear the error if not ready
- (void)cudaGetLastError();
- }
- return (err == cudaSuccess);
- }
- // Stream-related functions
- bool queryStream(const Stream& stream) const override {
- CUDAStream cuda_stream{stream};
- return cuda_stream.query();
- }
- void synchronizeStream(const Stream& stream) const override {
- CUDAStream cuda_stream{stream};
- cuda_stream.synchronize();
- }
- void recordDataPtrOnStream(const c10::DataPtr& data_ptr, const Stream& stream)
- const override {
- CUDAStream cuda_stream{stream};
- CUDACachingAllocator::recordStream(data_ptr, cuda_stream);
- }
- };
- } // namespace impl
- } // namespace cuda
- } // namespace c10
|