CUDAGraph.h 2.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. #pragma once
  2. #include <ATen/Tensor.h>
  3. #include <c10/core/Device.h>
  4. #include <c10/cuda/CUDAGraphsC10Utils.h>
  5. #include <c10/cuda/CUDAStream.h>
  6. namespace at {
  7. struct CUDAGeneratorImpl;
  8. namespace cuda {
  9. // Standalone way to get a unique mempool id usable as a pool=... argument
  10. // to CUDAGraph::capture_begin
  11. TORCH_CUDA_CPP_API MempoolId_t graph_pool_handle();
  12. struct TORCH_CUDA_CPP_API CUDAGraph {
  13. CUDAGraph();
  14. ~CUDAGraph();
  15. void capture_begin(MempoolId_t pool={0, 0});
  16. void capture_end();
  17. void replay();
  18. void reset();
  19. MempoolId_t pool();
  20. void enable_debug_mode();
  21. void debug_dump(const std::string& debug_path);
  22. protected:
  23. #if !defined(USE_ROCM) || ROCM_VERSION >= 50300
  24. cudaGraph_t graph_ = NULL;
  25. cudaGraphExec_t graph_exec_ = NULL;
  26. #endif
  27. // internal states so reset() can do its best cleaning up
  28. // Set to true in capture_end if cudaStreamEndCapture succeeded
  29. // Set back to false soon after, when graph_ is consumed by cudaGraphInstantiate
  30. // to create graph_exec_, then graph_ is deleted
  31. bool has_graph_ = false;
  32. // Set to true in capture_end if cudaGraphInstantiate succeeded
  33. bool has_graph_exec_ = false;
  34. // uuid of this instance's current capture, retrieved from Cuda
  35. CaptureId_t id_;
  36. // uuid used to request a particular private mempool from CUDACachingAllocator.
  37. // By default, this will be set to {id_, 0}.
  38. //
  39. // If capture_begin is called with "pool=other_graph.pool()", this graph's mempool_id_
  40. // will be set to the other graph's mempool_id_, and therefore share a mempool with the
  41. // other graph.
  42. //
  43. // If capture_begin is called with "pool=handle" where "handle" came from graph_pool_handle(),
  44. // it will share a mempool with any other captures that used "pool=handle".
  45. //
  46. // Sharing a mempool across graphs saves memory, and it's safe if you
  47. // know you'll replay those graphs in the same order you captured them.
  48. MempoolId_t mempool_id_;
  49. // Stream on which capture began
  50. at::cuda::CUDAStream capture_stream_;
  51. // Default generator on device where capture began
  52. at::CUDAGeneratorImpl* capture_gen_;
  53. // Device where capture occurred. Right now, for simplicity, we require all ops
  54. // in a capture to run on the same device, but this is a limitation of CUDAGraph,
  55. // not CUDA itself. We can straightforwardly modify CUDAGraph to support multi-device
  56. // captures if needed.
  57. int capture_dev_;
  58. // RNG state trackers
  59. at::Tensor seed_extragraph_;
  60. at::Tensor offset_extragraph_;
  61. uint64_t wholegraph_increment_;
  62. };
  63. } // namespace cuda
  64. } // namespace at