12345678910111213141516171819202122232425262728293031323334353637383940414243 |
- // No "#pragma once" because this is a raw definition that can be copied by jit codegen.
- // Eager mode clients should not include this file directly, instead,
- // they should #include <ATen/cuda/CUDAGeneratorImpl.h>, which has a #pragma once.
- // Stores RNG state values. Passed as a kernel argument.
- // See Note [CUDA Graph-safe RNG states].
- //
- // The raw definition lives in its own file so jit codegen can easily copy it.
- namespace at {
- struct PhiloxCudaState {
- PhiloxCudaState() = default;
- // Called if graph capture is not underway
- PhiloxCudaState(uint64_t seed,
- uint64_t offset) {
- seed_.val = seed;
- offset_.val = offset;
- }
- // Called if graph capture is underway
- PhiloxCudaState(int64_t* seed,
- int64_t* offset_extragraph,
- uint32_t offset_intragraph) {
- seed_.ptr = seed;
- offset_.ptr = offset_extragraph;
- offset_intragraph_ = offset_intragraph;
- captured_ = true;
- }
- // Public members, directly accessible by at::cuda::philox::unpack.
- // If we made them private with getters/setters, the getters/setters
- // would have to be __device__, and we can't declare __device__ in ATen.
- union Payload {
- uint64_t val;
- int64_t* ptr;
- };
- Payload seed_;
- Payload offset_;
- uint32_t offset_intragraph_ = 0;
- bool captured_ = false;
- };
- } // namespace at
|