jiterator_macros.h 1.5 KB

1234567891011121314151617181920212223242526272829303132333435363738
  1. #pragma once
  2. #include <c10/macros/Macros.h>
  3. #include <string>
  4. #define JITERATOR_HOST_DEVICE C10_HOST_DEVICE
  5. #if defined(_MSC_VER) && defined(__CUDACC__)
  6. // NVRTC on Windows errors if __host__ __device__ attribute is
  7. // present on kernel.
  8. // error: attribute "__host__" does not apply here
  9. // error: attribute "__device__" does not apply here
  10. #define JITERATOR_HOST_DEVICE
  11. #endif
  12. // jiterator_also_stringify_as macro is used to define code (for CPU/ROCm)
  13. // and generate code string for `jiterator` (only when compiling for CUDA).
  14. // Usage :
  15. // jiterator_also_stringify_as(
  16. // jiterator_code(template <typename T> T identity(T x) { return x; }),
  17. // identity_string);
  18. // This will define the template `identity` as present in code and
  19. // also define `std::string identity_string` with the code as the string
  20. // if this is being compiled for CUDA.
  21. // `jiterator_code` macro is to deal with `,` in the kernel code.
  22. // These `,`s confuse the preprocessor into thinking we are passing
  23. // multiple arguments to the macro.
  24. #define jiterator_code(...) __VA_ARGS__
  25. #if defined(__CUDACC__) || defined(__HIPCC__)
  26. // CPU and CUDA and ROCm case
  27. #define stringify_code(...) #__VA_ARGS__
  28. #define jiterator_also_stringify_as(code, str_name) \
  29. code /* define the function */ \
  30. const std::string str_name = std::string(stringify_code(code));
  31. #else
  32. // CPU only or CPU and ROCm case
  33. // Only needs the function
  34. #define jiterator_also_stringify_as(code, str_name) code
  35. #endif