PhiloxCudaStateRaw.cuh 1.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243
  1. // No "#pragma once" because this is a raw definition that can be copied by jit codegen.
  2. // Eager mode clients should not include this file directly, instead,
  3. // they should #include <ATen/cuda/CUDAGeneratorImpl.h>, which has a #pragma once.
  4. // Stores RNG state values. Passed as a kernel argument.
  5. // See Note [CUDA Graph-safe RNG states].
  6. //
  7. // The raw definition lives in its own file so jit codegen can easily copy it.
  8. namespace at {
  9. struct PhiloxCudaState {
  10. PhiloxCudaState() = default;
  11. // Called if graph capture is not underway
  12. PhiloxCudaState(uint64_t seed,
  13. uint64_t offset) {
  14. seed_.val = seed;
  15. offset_.val = offset;
  16. }
  17. // Called if graph capture is underway
  18. PhiloxCudaState(int64_t* seed,
  19. int64_t* offset_extragraph,
  20. uint32_t offset_intragraph) {
  21. seed_.ptr = seed;
  22. offset_.ptr = offset_extragraph;
  23. offset_intragraph_ = offset_intragraph;
  24. captured_ = true;
  25. }
  26. // Public members, directly accessible by at::cuda::philox::unpack.
  27. // If we made them private with getters/setters, the getters/setters
  28. // would have to be __device__, and we can't declare __device__ in ATen.
  29. union Payload {
  30. uint64_t val;
  31. int64_t* ptr;
  32. };
  33. Payload seed_;
  34. Payload offset_;
  35. uint32_t offset_intragraph_ = 0;
  36. bool captured_ = false;
  37. };
  38. } // namespace at