CUDAGeneratorImpl.h 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138
  1. #pragma once
  2. #include <ATen/core/Generator.h>
  3. #include <ATen/cuda/detail/PhiloxCudaStateRaw.cuh>
  4. #include <ATen/Context.h>
  5. #include <limits>
  6. #include <atomic>
  7. namespace at {
  8. /**
  9. * Note [CUDA Graph-safe RNG states]
  10. * ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  11. *
  12. * Strategy:
  13. * ~~~~~~~~~
  14. * (It helps to look at
  15. * cuda/detail/PhiloxCudaStateRaw.cuh and
  16. * cuda/detail/UnpackRaw.cuh
  17. * while you read this.)
  18. *
  19. * A CUDA graph containing multiple RNG ops behaves like a
  20. * single giant kernel from the perspective of ops external
  21. * to the graph. During graph capture, logic in CUDAGeneratorImpl
  22. * records the total of all offset increments that occur in the
  23. * graphed region, and records the final total as the offset for
  24. * the entire graph.
  25. *
  26. * When the graph reruns, the logic that reruns it
  27. * increments this device's CUDA generator's offset
  28. * by that total.
  29. *
  30. * Meanwhile, within the graph, at capture time, instead of
  31. * populating PhiloxCudaStates with the uint64_t offset pulled
  32. * directly from the global state, PhiloxCudaState uses a pointer
  33. * to a one-element stream-local int64_t device tensor
  34. * holding an initial offset value, and a uint64_t holding an
  35. * intra-graph offset. (The intra-graph offset starts from zero
  36. * when capture begins.) In each consumer kernel,
  37. * at::cuda::philox::unpack computes the offset to use for this kernel
  38. * as intra-graph offset + *initial offset.
  39. *
  40. * When the graph reruns, the logic that reruns it first
  41. * fill_s the initial offset tensor with this device's
  42. * CUDA generator's current offset.
  43. *
  44. * The control flow above ensures graphed execution is bitwise
  45. * identical to eager execution as long as RNG ops are enqueued
  46. * from a single thread, even if RNG ops and graphs containing
  47. * RNG ops are enqueued and run simultaneously on multiple streams.
  48. *
  49. * Usage:
  50. * ~~~~~~
  51. * PhiloxCudaState in this file, and unpack() in
  52. * cuda/CUDAGraphsUtils.cuh allow non-divergent use of
  53. * CUDAGeneratorImpl whether graph capture is underway or not.
  54. *
  55. * Each PhiloxCudaState instance should be used for one and only one
  56. * consumer kernel.
  57. *
  58. * Example (see e.g. native/cuda/Dropout.cu):
  59. *
  60. * #include <ATen/cuda/CUDAGeneratorImpl.h>
  61. * #include <ATen/cuda/CUDAGraphsUtils.cuh>
  62. *
  63. * __global__ void kernel(..., PhiloxCudaState philox_args) {
  64. * auto seeds = at::cuda::philox::unpack(philox_args);
  65. * IndexType idx = blockIdx.x * blockDim.x + threadIdx.x;
  66. * curandStatePhilox4_32_10_t state;
  67. * curand_init(std::get<0>(seeds), // seed
  68. * idx, // per-thread subsequence
  69. * std::get<1>(seeds), // offset in subsequence
  70. * &state);
  71. * ...
  72. * }
  73. *
  74. * host_caller(...) {
  75. * PhiloxCudaState rng_engine_inputs;
  76. * {
  77. * // See Note [Acquire lock when using random generators]
  78. * std::lock_guard<std::mutex> lock(gen->mutex_);
  79. *
  80. * // gen could be HostState or DevState here! No divergent code needed!
  81. * rng_engine_inputs = gen->philox_cuda_state(offset_increment);
  82. * }
  83. * kernel<<<...>>>(..., rng_engine_inputs);
  84. * }
  85. *
  86. */
  87. struct TORCH_CUDA_CPP_API CUDAGeneratorImpl : public c10::GeneratorImpl {
  88. // Constructors
  89. CUDAGeneratorImpl(DeviceIndex device_index = -1);
  90. ~CUDAGeneratorImpl() override = default;
  91. // CUDAGeneratorImpl methods
  92. std::shared_ptr<CUDAGeneratorImpl> clone() const;
  93. void set_current_seed(uint64_t seed) override;
  94. uint64_t current_seed() const override;
  95. uint64_t seed() override;
  96. void set_state(const c10::TensorImpl& new_state) override;
  97. c10::intrusive_ptr<c10::TensorImpl> get_state() const override;
  98. void set_philox_offset_per_thread(uint64_t offset);
  99. uint64_t philox_offset_per_thread() const;
  100. void capture_prologue(int64_t* seed_extragraph, int64_t* offset_extragraph);
  101. uint64_t capture_epilogue();
  102. PhiloxCudaState philox_cuda_state(uint64_t increment);
  103. bool reset_rnn_state() {
  104. return !no_reset_rnn_state_.test_and_set();
  105. }
  106. // Temporarily accommodates call sites that use philox_engine_inputs.
  107. // Allows incremental refactor of call sites to use philox_cuda_state.
  108. std::pair<uint64_t, uint64_t> philox_engine_inputs(uint64_t increment);
  109. static DeviceType device_type();
  110. private:
  111. CUDAGeneratorImpl* clone_impl() const override;
  112. uint64_t seed_ = default_rng_seed_val;
  113. uint64_t philox_offset_per_thread_ = 0;
  114. int64_t* seed_extragraph_{};
  115. int64_t* offset_extragraph_{};
  116. uint32_t offset_intragraph_ = 0;
  117. bool graph_expects_this_gen_ = false;
  118. std::atomic_flag no_reset_rnn_state_;
  119. };
  120. namespace cuda {
  121. namespace detail {
  122. TORCH_CUDA_CPP_API const Generator& getDefaultCUDAGenerator(
  123. DeviceIndex device_index = -1);
  124. TORCH_CUDA_CPP_API Generator createCUDAGenerator(DeviceIndex device_index = -1);
  125. } // namespace detail
  126. } // namespace cuda
  127. } // namespace at