UnpackRaw.cuh 1.5 KB

1234567891011121314151617181920212223242526272829303132
  1. // No "#pragma once" because this is a raw definition that can be copied by jit codegen.
  2. // Eager mode clients should not include this file directly, instead,
  3. // they should #include <ATen/cuda/CUDAGraphsUtils.cuh>, which has a #pragma once.
  4. namespace at {
  5. namespace cuda {
  6. namespace philox {
  7. // In-kernel call to retrieve philox seed and offset from a PhiloxCudaState instance whether
  8. // that instance was created with graph capture underway or not.
  9. // See Note [CUDA Graph-safe RNG states].
  10. //
  11. // We can't write a __device__ function in CUDAGeneratorImpl.h, because it's in ATen.
  12. // Also, whatever call unpacks PhiloxCudaState in consumer kernels must be inlineable.
  13. // Easiest thing that comes to mind is, define a __device__ unpack helper here, in ATen/cuda.
  14. //
  15. // The raw definition lives in its own file so jit codegen can easily copy it.
  16. __device__ __forceinline__ std::tuple<uint64_t, uint64_t>
  17. unpack(at::PhiloxCudaState arg) {
  18. if (arg.captured_) {
  19. // static_cast avoids "warning: invalid narrowing conversion from "long" to "unsigned long".
  20. // *(arg.offset_.ptr) is a broadcast load of a single int64_t to the entire kernel.
  21. // For most threads' reads it will hit in cache, so it shouldn't hurt performance.
  22. return std::make_tuple(static_cast<uint64_t>(*arg.seed_.ptr), static_cast<uint64_t>(*(arg.offset_.ptr) + arg.offset_intragraph_));
  23. } else {
  24. return std::make_tuple(arg.seed_.val, arg.offset_.val);
  25. }
  26. }
  27. } // namespace philox
  28. } // namespace cuda
  29. } // namespace at