CUDAGraphsUtils.cuh 2.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859
  1. #pragma once
  2. #include <ATen/cuda/CUDAGeneratorImpl.h>
  3. #include <ATen/cuda/CUDAEvent.h>
  4. #include <ATen/cuda/detail/UnpackRaw.cuh>
  5. #include <ATen/cuda/detail/CUDAHooks.h>
  6. #include <ATen/detail/CUDAHooksInterface.h>
  7. #include <c10/core/StreamGuard.h>
  8. #include <c10/cuda/CUDAGraphsC10Utils.h>
  9. #include <c10/cuda/CUDAGuard.h>
  10. // c10/cuda/CUDAGraphsC10Utils.h has utils used by both c10 and aten.
  11. // This file adds utils used by aten only.
  12. namespace at {
  13. namespace cuda {
  14. using CaptureId_t = c10::cuda::CaptureId_t;
  15. using CaptureStatus = c10::cuda::CaptureStatus;
  16. // Use this version where you don't want to create a CUDA context if none exists.
  17. inline CaptureStatus currentStreamCaptureStatus() {
  18. #if !defined(USE_ROCM) || ROCM_VERSION >= 50300
  19. // don't create a context if we don't have to
  20. if (c10::cuda::hasPrimaryContext(c10::cuda::current_device())) {
  21. return c10::cuda::currentStreamCaptureStatusMayInitCtx();
  22. } else {
  23. return CaptureStatus::None;
  24. }
  25. #else
  26. return CaptureStatus::None;
  27. #endif
  28. }
  29. inline void assertNotCapturing(std::string attempt) {
  30. auto status = currentStreamCaptureStatus();
  31. TORCH_CHECK(status == CaptureStatus::None,
  32. attempt,
  33. " during CUDA graph capture. If you need this call to be captured, "
  34. "please file an issue. "
  35. "Current cudaStreamCaptureStatus: ",
  36. status);
  37. }
  38. inline void errorIfCapturingCudnnBenchmark(std::string version_specific) {
  39. auto status = currentStreamCaptureStatus();
  40. TORCH_CHECK(status == CaptureStatus::None,
  41. "Current cudaStreamCaptureStatus: ",
  42. status,
  43. "\nCapturing ",
  44. version_specific,
  45. "is prohibited. Possible causes of this error:\n"
  46. "1. No warmup iterations occurred before capture.\n"
  47. "2. The convolutions you're trying to capture use dynamic shapes, "
  48. "in which case capturing them is generally prohibited.");
  49. }
  50. } // namespace cuda
  51. } // namespace at