123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118 |
- #pragma once
- // This header provides C++ wrappers around commonly used CUDA API functions.
- // The benefit of using C++ here is that we can raise an exception in the
- // event of an error, rather than explicitly pass around error codes. This
- // leads to more natural APIs.
- //
- // The naming convention used here matches the naming convention of torch.cuda
- #include <c10/core/Device.h>
- #include <c10/core/impl/GPUTrace.h>
- #include <c10/cuda/CUDAException.h>
- #include <c10/cuda/CUDAMacros.h>
- #include <cuda_runtime_api.h>
- namespace c10 {
- namespace cuda {
- // NB: In the past, we were inconsistent about whether or not this reported
- // an error if there were driver problems are not. Based on experience
- // interacting with users, it seems that people basically ~never want this
- // function to fail; it should just return zero if things are not working.
- // Oblige them.
- // It still might log a warning for user first time it's invoked
- C10_CUDA_API DeviceIndex device_count() noexcept;
- // Version of device_count that throws is no devices are detected
- C10_CUDA_API DeviceIndex device_count_ensure_non_zero();
- C10_CUDA_API DeviceIndex current_device();
- C10_CUDA_API void set_device(DeviceIndex device);
- C10_CUDA_API void device_synchronize();
- C10_CUDA_API void warn_or_error_on_sync();
- // Raw CUDA device management functions
- C10_CUDA_API cudaError_t GetDeviceCount(int* dev_count);
- C10_CUDA_API cudaError_t GetDevice(int* device);
- C10_CUDA_API cudaError_t SetDevice(int device);
- C10_CUDA_API cudaError_t MaybeSetDevice(int device);
- C10_CUDA_API int ExchangeDevice(int device);
- C10_CUDA_API int MaybeExchangeDevice(int device);
- C10_CUDA_API void SetTargetDevice();
- enum class SyncDebugMode { L_DISABLED = 0, L_WARN, L_ERROR };
- // this is a holder for c10 global state (similar to at GlobalContext)
- // currently it's used to store cuda synchronization warning state,
- // but can be expanded to hold other related global state, e.g. to
- // record stream usage
- class WarningState {
- public:
- void set_sync_debug_mode(SyncDebugMode l) {
- sync_debug_mode = l;
- }
- SyncDebugMode get_sync_debug_mode() {
- return sync_debug_mode;
- }
- private:
- SyncDebugMode sync_debug_mode = SyncDebugMode::L_DISABLED;
- };
- C10_CUDA_API __inline__ WarningState& warning_state() {
- static WarningState warning_state_;
- return warning_state_;
- }
- // the subsequent functions are defined in the header because for performance
- // reasons we want them to be inline
- C10_CUDA_API void __inline__ memcpy_and_sync(
- void* dst,
- void* src,
- int64_t nbytes,
- cudaMemcpyKind kind,
- cudaStream_t stream) {
- if (C10_UNLIKELY(
- warning_state().get_sync_debug_mode() != SyncDebugMode::L_DISABLED)) {
- warn_or_error_on_sync();
- }
- const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
- if (C10_UNLIKELY(interp)) {
- (*interp)->trace_gpu_stream_synchronization(
- reinterpret_cast<uintptr_t>(stream));
- }
- #if defined(TORCH_HIP_VERSION) && (TORCH_HIP_VERSION >= 301)
- C10_CUDA_CHECK(hipMemcpyWithStream(dst, src, nbytes, kind, stream));
- #else
- C10_CUDA_CHECK(cudaMemcpyAsync(dst, src, nbytes, kind, stream));
- C10_CUDA_CHECK(cudaStreamSynchronize(stream));
- #endif
- }
- C10_CUDA_API void __inline__ stream_synchronize(cudaStream_t stream) {
- if (C10_UNLIKELY(
- warning_state().get_sync_debug_mode() != SyncDebugMode::L_DISABLED)) {
- warn_or_error_on_sync();
- }
- const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
- if (C10_UNLIKELY(interp)) {
- (*interp)->trace_gpu_stream_synchronization(
- reinterpret_cast<uintptr_t>(stream));
- }
- C10_CUDA_CHECK(cudaStreamSynchronize(stream));
- }
- C10_CUDA_API bool hasPrimaryContext(int64_t device_index);
- C10_CUDA_API c10::optional<int64_t> getDeviceIndexWithPrimaryContext();
- } // namespace cuda
- } // namespace c10
|