Atomic.cuh 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511
  1. #pragma once
  2. #include <cuda.h>
  3. #include <c10/util/Half.h>
  4. #include <c10/util/BFloat16.h>
  5. #include <ATen/NumericUtils.h>
  6. #if !(defined(USE_ROCM) || ((defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))))
  7. #include <cuda_bf16.h>
  8. #endif
  9. template <typename T>
  10. struct AtomicFPOp;
  11. template <>
  12. struct AtomicFPOp<at::Half> {
  13. template <typename func_t>
  14. inline __device__ at::Half operator() (at::Half *address, at::Half val, const func_t& func) {
  15. unsigned int * address_as_ui =
  16. (unsigned int *) ((char *)address - ((size_t)address & 2));
  17. unsigned int old = *address_as_ui;
  18. unsigned int assumed;
  19. at::Half hsum;
  20. do {
  21. assumed = old;
  22. hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff);
  23. hsum = func(hsum, val);
  24. old = (size_t)address & 2 ? (old & 0xffff) | (hsum.x << 16) : (old & 0xffff0000) | hsum.x;
  25. old = atomicCAS(address_as_ui, assumed, old);
  26. } while (assumed != old);
  27. hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff);
  28. return hsum;
  29. }
  30. };
  31. template <>
  32. struct AtomicFPOp<at::BFloat16> {
  33. template <typename func_t>
  34. inline __device__ at::BFloat16 operator() (at::BFloat16 *address, at::BFloat16 val, const func_t& func) {
  35. unsigned int * address_as_ui =
  36. (unsigned int *) ((char *)address - ((size_t)address & 2));
  37. unsigned int old = *address_as_ui;
  38. unsigned int assumed;
  39. at::BFloat16 bsum;
  40. do {
  41. assumed = old;
  42. bsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff);
  43. bsum = func(bsum, val);
  44. old = (size_t)address & 2 ? (old & 0xffff) | (bsum.x << 16) : (old & 0xffff0000) | bsum.x;
  45. old = atomicCAS(address_as_ui, assumed, old);
  46. } while (assumed != old);
  47. bsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff);
  48. return bsum.x;
  49. }
  50. };
  51. template <>
  52. struct AtomicFPOp<double> {
  53. template <typename func_t>
  54. inline __device__ double operator() (double * address, double val, const func_t& func) {
  55. unsigned long long int* address_as_ull = (unsigned long long int*)address;
  56. unsigned long long int old = *address_as_ull;
  57. unsigned long long int assumed;
  58. do {
  59. assumed = old;
  60. old = atomicCAS(address_as_ull, assumed, func(val, assumed));
  61. // Note: uses integer comparison to avoid hang in case of NaN (since NaN != NaN)
  62. } while (assumed != old);
  63. return __longlong_as_double(old);
  64. }
  65. };
  66. #define ATOMIC_INTEGER_IMPL(NAME) \
  67. template <typename T, size_t n> \
  68. struct Atomic##NAME##IntegerImpl; \
  69. \
  70. template<typename T> \
  71. struct Atomic##NAME##IntegerImpl<T, 1> { \
  72. template <typename func_t> \
  73. inline __device__ void operator()(T *address, T val, const func_t& func) { \
  74. size_t offset = (size_t)address & 3; \
  75. uint32_t * address_as_ui = (uint32_t *)((char *)address - offset); \
  76. uint32_t old = *address_as_ui; \
  77. uint32_t shift = offset * 8; \
  78. uint32_t old_byte; \
  79. uint32_t newval; \
  80. uint32_t assumed; \
  81. \
  82. do { \
  83. assumed = old; \
  84. old_byte = (old >> shift) & 0xff; \
  85. newval = static_cast<uint8_t>(func(val, static_cast<T>(old_byte))); \
  86. newval = (old & ~(0x000000ff << shift)) | (newval << shift); \
  87. old = atomicCAS(address_as_ui, assumed, newval); \
  88. } while (assumed != old); \
  89. } \
  90. }; \
  91. \
  92. template<typename T> \
  93. struct Atomic##NAME##IntegerImpl<T, 2> { \
  94. template <typename func_t> \
  95. inline __device__ void operator()(T *address, T val, const func_t& func) { \
  96. size_t offset = (size_t)address & 2; \
  97. uint32_t * address_as_ui = (uint32_t *)((char *)address - offset); \
  98. bool is_32_align = offset; \
  99. uint32_t old = *address_as_ui; \
  100. uint32_t old_bytes; \
  101. uint32_t newval; \
  102. uint32_t assumed; \
  103. \
  104. do { \
  105. assumed = old; \
  106. old_bytes = is_32_align ? old >> 16 : old & 0xffff; \
  107. newval = static_cast<uint16_t>(func(val, static_cast<T>(old_bytes))); \
  108. newval = is_32_align ? (old & 0xffff) | (newval << 16) : (old & 0xffff0000) | newval; \
  109. old = atomicCAS(address_as_ui, assumed, newval); \
  110. } while (assumed != old); \
  111. } \
  112. }; \
  113. \
  114. template<typename T> \
  115. struct Atomic##NAME##IntegerImpl<T, 4> { \
  116. template <typename func_t> \
  117. inline __device__ void operator()(T *address, T val, const func_t& func) { \
  118. uint32_t * address_as_ui = (uint32_t *) (address); \
  119. uint32_t old = *address_as_ui; \
  120. uint32_t newval; \
  121. uint32_t assumed; \
  122. \
  123. do { \
  124. assumed = old; \
  125. newval = static_cast<uint32_t>(func(val, static_cast<T>(old))); \
  126. old = atomicCAS(address_as_ui, assumed, newval); \
  127. } while (assumed != old); \
  128. } \
  129. }; \
  130. \
  131. template<typename T> \
  132. struct Atomic##NAME##IntegerImpl<T, 8> { \
  133. template <typename func_t> \
  134. inline __device__ void operator()(T *address, T val, const func_t& func) { \
  135. unsigned long long * address_as_ui = (unsigned long long *) (address); \
  136. unsigned long long old = *address_as_ui; \
  137. unsigned long long newval; \
  138. unsigned long long assumed; \
  139. \
  140. do { \
  141. assumed = old; \
  142. newval = static_cast<uint64_t>(func(val, static_cast<T>(old))); \
  143. old = atomicCAS(address_as_ui, assumed, newval); \
  144. } while (assumed != old); \
  145. } \
  146. };
  147. # define GPU_ATOMIC_INTEGER(NAME, OP, DTYPE) \
  148. static inline __device__ void gpuAtomic##NAME(DTYPE *address, DTYPE val) { \
  149. Atomic##NAME##IntegerImpl<DTYPE, sizeof(DTYPE)>()(address, \
  150. val, \
  151. [](DTYPE a, DTYPE b) { \
  152. return OP; \
  153. }); \
  154. } \
  155. ATOMIC_INTEGER_IMPL(Add)
  156. GPU_ATOMIC_INTEGER(Add, a || b, bool)
  157. // Don't instantiate gpuAtomicAdd with the macro as it seems non-standard (see int32, int64)
  158. static inline __device__ void gpuAtomicAdd(uint8_t *address, uint8_t val) {
  159. AtomicAddIntegerImpl<uint8_t, sizeof(uint8_t)>()(address,
  160. val,
  161. [](uint8_t a, uint8_t b) {
  162. return a + b;
  163. });
  164. }
  165. static inline __device__ void gpuAtomicAdd(int8_t *address, int8_t val) {
  166. AtomicAddIntegerImpl<int8_t, sizeof(int8_t)>()(address,
  167. val,
  168. [](int8_t a, int8_t b) {
  169. return a + b;
  170. });
  171. }
  172. static inline __device__ void gpuAtomicAdd(int16_t *address, int16_t val) {
  173. AtomicAddIntegerImpl<int16_t, sizeof(int16_t)>()(address,
  174. val,
  175. [](int16_t a, int16_t b) {
  176. return a + b;
  177. });
  178. }
  179. static inline __device__ int32_t gpuAtomicAdd(int32_t *address, int32_t val) {
  180. return atomicAdd(address, val);
  181. }
  182. static inline __device__ void gpuAtomicAdd(int64_t *address, int64_t val) {
  183. #if defined(USE_ROCM)
  184. __atomic_fetch_add(address, val, __ATOMIC_RELAXED);
  185. #else
  186. AtomicAddIntegerImpl<int64_t, sizeof(int64_t)>()(address,
  187. val,
  188. [](int64_t a, int64_t b) {
  189. return a + b;
  190. });
  191. #endif
  192. }
  193. static inline __device__ at::Half gpuAtomicAdd(at::Half *address, at::Half val) {
  194. #if defined(USE_ROCM) || ((defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 700)))
  195. return AtomicFPOp<at::Half>()(address, val,
  196. [](at::Half hsum, at::Half val) {
  197. return hsum + val;
  198. });
  199. #else
  200. return atomicAdd(reinterpret_cast<__half*>(address), val);
  201. #endif
  202. }
  203. static inline __device__ at::BFloat16 gpuAtomicAdd(at::BFloat16 *address, at::BFloat16 val) {
  204. #if defined(USE_ROCM) || ((defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)))
  205. return AtomicFPOp<at::BFloat16>()(address, val,
  206. [](at::BFloat16 bsum, at::BFloat16 val) {
  207. return bsum + val;
  208. });
  209. #else
  210. __nv_bfloat16 r = atomicAdd(reinterpret_cast<__nv_bfloat16*>(address), *reinterpret_cast<__nv_bfloat16*>(&val));
  211. return *reinterpret_cast<c10::BFloat16*>(&r);
  212. #endif
  213. }
  214. #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 600)
  215. // from CUDA C Programmic Guide
  216. static inline __device__ double atomicAdd(double* address, double val)
  217. #if defined(__clang__) && defined(__CUDA__)
  218. #pragma GCC diagnostic push
  219. #pragma GCC diagnostic ignored "-Wgcc-compat"
  220. __attribute__((enable_if(true, "")))
  221. #pragma GCC diagnostic pop
  222. #endif
  223. {
  224. return AtomicFPOp<double>()(address, val,
  225. [](double val, unsigned long long int assumed) {
  226. return __double_as_longlong(val + __longlong_as_double(assumed));
  227. });
  228. }
  229. #elif defined(USE_ROCM) || !(defined(__CUDA_ARCH__))
  230. /* Note [hip-clang differences to hcc]
  231. * ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  232. * The upcoming hip-clang compiler for ROCm differs from hcc in a few details.
  233. * It exports the __HIP__ macro, we can hence differentiate between hcc and
  234. * hip-clang. In the below, hcc only received support for atomicAdd with double
  235. * typing after work week 18312. hip-clang had support from the first version.
  236. * In general, the code-visible differences between hip-clang and hcc will be
  237. * minimal.
  238. */
  239. #if defined(USE_ROCM) && __hcc_workweek__ < 18312 && !__HIP__
  240. // This needs to be defined for the host side pass
  241. static inline __device__ double atomicAdd(double *address, double val) { }
  242. #endif
  243. #endif
  244. static inline __device__ double gpuAtomicAdd(double *address, double val) {
  245. return atomicAdd(address, val);
  246. }
  247. static inline __device__ float gpuAtomicAdd(float *address, float val) {
  248. return atomicAdd(address, val);
  249. }
  250. template<typename T>
  251. static inline __device__ void gpuAtomicAdd(c10::complex<T> *address, c10::complex<T> val) {
  252. gpuAtomicAdd(&address->real_, val.real_);
  253. gpuAtomicAdd(&address->imag_, val.imag_);
  254. }
  255. /* Note [gpuAtomicAdd vs atomicAdd]
  256. * ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  257. * Some extensions such as torchvision call atomicAdd()
  258. * directly and require non-library provided data type support. Only for these, we
  259. * continue to provide atomicAdd overloads.
  260. */
  261. static inline __device__ at::Half atomicAdd(at::Half *address, at::Half val) {
  262. return gpuAtomicAdd(address, val);
  263. }
  264. static inline __device__ at::BFloat16 atomicAdd(at::BFloat16 *address, at::BFloat16 val) {
  265. return gpuAtomicAdd(address, val);
  266. }
  267. static inline __device__ void atomicAdd(uint8_t *address, uint8_t val) {
  268. gpuAtomicAdd(address, val);
  269. }
  270. static inline __device__ void atomicAdd(int8_t *address, int8_t val) {
  271. gpuAtomicAdd(address, val);
  272. }
  273. static inline __device__ void atomicAdd(int16_t *address, int16_t val) {
  274. gpuAtomicAdd(address, val);
  275. }
  276. static inline __device__ void atomicAdd(int64_t *address, int64_t val) {
  277. gpuAtomicAdd(address, val);
  278. }
  279. static inline __device__ void atomicAdd(bool *address, bool val) {
  280. gpuAtomicAdd(address, val);
  281. }
  282. /* Note [explicitly non-returning atomics]
  283. * ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  284. * AMD's MI100 (gfx908) provides an optimized fp32 atomicAdd, exposed via atomicAddNoRet().
  285. * Due to compiler limitations, callers must opt-in to guarantee the optimized instruction.
  286. * This non-returning atomicAddNoRet cannot be used to implement the returning atomicAdd,
  287. * therefore we need a new API 'gpuAtomicAddNoReturn'.
  288. */
  289. template<typename T>
  290. static inline __device__ void gpuAtomicAddNoReturn(c10::complex<T> *address, c10::complex<T> val) { gpuAtomicAdd(address, val); }
  291. static inline __device__ void gpuAtomicAddNoReturn(uint8_t *address, uint8_t val) { gpuAtomicAdd(address, val); }
  292. static inline __device__ void gpuAtomicAddNoReturn(int8_t *address, int8_t val) { gpuAtomicAdd(address, val); }
  293. static inline __device__ void gpuAtomicAddNoReturn(int16_t *address, int16_t val) { gpuAtomicAdd(address, val); }
  294. static inline __device__ void gpuAtomicAddNoReturn(int32_t *address, int32_t val) { gpuAtomicAdd(address, val); }
  295. static inline __device__ void gpuAtomicAddNoReturn(int64_t *address, int64_t val) { gpuAtomicAdd(address, val); }
  296. static inline __device__ void gpuAtomicAddNoReturn(bool *address, bool val) { gpuAtomicAdd(address, val); }
  297. static inline __device__ void gpuAtomicAddNoReturn(at::Half *address, at::Half val) { gpuAtomicAdd(address, val); }
  298. static inline __device__ void gpuAtomicAddNoReturn(at::BFloat16 *address, at::BFloat16 val) { gpuAtomicAdd(address, val); }
  299. static inline __device__ void gpuAtomicAddNoReturn(double *address, double val) { gpuAtomicAdd(address, val); }
  300. /* Special case fp32 atomic. */
  301. #if defined(USE_ROCM)
  302. static inline __device__ void gpuAtomicAddNoReturn(float *address, float val) { atomicAddNoRet(address, val); }
  303. #else
  304. static inline __device__ void gpuAtomicAddNoReturn(float *address, float val) { gpuAtomicAdd(address, val); }
  305. #endif
  306. // Atomic multiplication implementation.
  307. ATOMIC_INTEGER_IMPL(Mul)
  308. GPU_ATOMIC_INTEGER(Mul, a * b, uint8_t)
  309. GPU_ATOMIC_INTEGER(Mul, a * b, int8_t)
  310. GPU_ATOMIC_INTEGER(Mul, a * b, int16_t)
  311. GPU_ATOMIC_INTEGER(Mul, a * b, int32_t)
  312. GPU_ATOMIC_INTEGER(Mul, a * b, int64_t)
  313. inline __device__ at::Half gpuAtomicMul(at::Half * address, at::Half val) {
  314. return AtomicFPOp<at::Half>()(address, val,
  315. [](at::Half bsum, at::Half val) {
  316. return bsum * val;
  317. });
  318. }
  319. inline __device__ at::BFloat16 gpuAtomicMul(at::BFloat16 * address, at::BFloat16 val) {
  320. return AtomicFPOp<at::BFloat16>()(address, val,
  321. [](at::BFloat16 bsum, at::BFloat16 val) {
  322. return bsum * val;
  323. });
  324. }
  325. inline __device__ double gpuAtomicMul(double * address, double val) {
  326. return AtomicFPOp<double>()(address, val,
  327. [](double val, unsigned long long int assumed) {
  328. return __double_as_longlong(val * __longlong_as_double(assumed));
  329. });
  330. }
  331. // Dont use a templated function for this since the addition function defaults to the CUDA built-in.
  332. inline __device__ float gpuAtomicMul (float * address, float val) {
  333. unsigned int* address_as_ull = (unsigned int*)address;
  334. unsigned int old = *address_as_ull;
  335. unsigned int assumed;
  336. do {
  337. assumed = old;
  338. old = atomicCAS(address_as_ull, assumed,
  339. __float_as_int(val *
  340. __int_as_float(assumed)));
  341. // Note: uses integer comparison to avoid hang in case of NaN (since NaN != NaN)
  342. } while (assumed != old);
  343. return __int_as_float(old);
  344. }
  345. // Atomic maximum implementation.
  346. template <typename T>
  347. __host__ __device__ T safe_max(T a, T b) {
  348. #if defined(__HIPCC__)
  349. // TODO: remove this special case for HIP when issue is fixed:
  350. // https://github.com/ROCm-Developer-Tools/HIP/issues/2209
  351. T max = at::_isnan(a) ? a : (at::_isnan(b) ? b : std::max<T>(a, b));
  352. #else
  353. T max = at::_isnan(b) ? b : std::max<T>(a, b);
  354. #endif
  355. return max;
  356. }
  357. ATOMIC_INTEGER_IMPL(Max)
  358. GPU_ATOMIC_INTEGER(Max, safe_max(a, b), uint8_t)
  359. GPU_ATOMIC_INTEGER(Max, safe_max(a, b), int8_t)
  360. GPU_ATOMIC_INTEGER(Max, safe_max(a, b), int16_t)
  361. GPU_ATOMIC_INTEGER(Max, safe_max(a, b), int32_t)
  362. GPU_ATOMIC_INTEGER(Max, safe_max(a, b), int64_t)
  363. inline __device__ at::Half gpuAtomicMax(at::Half * address, at::Half val) {
  364. return AtomicFPOp<at::Half>()(address, val,
  365. [](at::Half bsum, at::Half val) {
  366. return safe_max(bsum, val);
  367. });
  368. }
  369. inline __device__ at::BFloat16 gpuAtomicMax(at::BFloat16 * address, at::BFloat16 val) {
  370. return AtomicFPOp<at::BFloat16>()(address, val,
  371. [](at::BFloat16 bsum, at::BFloat16 val) {
  372. return safe_max(bsum, val);
  373. });
  374. }
  375. inline __device__ double gpuAtomicMax(double * address, double val) {
  376. return AtomicFPOp<double>()(address, val,
  377. [](double val, unsigned long long int assumed) {
  378. return __double_as_longlong(safe_max(val, __longlong_as_double(assumed)));
  379. });
  380. }
  381. // Dont use a templated function for this since the addition function defaults to the CUDA built-in.
  382. inline __device__ float gpuAtomicMax(float * address, float val) {
  383. unsigned int* address_as_ull = (unsigned int*)address;
  384. unsigned int old = *address_as_ull;
  385. unsigned int assumed;
  386. do {
  387. assumed = old;
  388. old = atomicCAS(address_as_ull, assumed,
  389. __float_as_int(safe_max(val, __int_as_float(assumed))));
  390. // Note: uses integer comparison to avoid hang in case of NaN (since NaN != NaN)
  391. } while (assumed != old);
  392. return __int_as_float(old);
  393. }
  394. // Atomic minimum implementation.
  395. template <typename T>
  396. __host__ __device__ T safe_min(T a, T b) {
  397. #if defined(__HIPCC__)
  398. // TODO: remove this special case for HIP when issue is fixed:
  399. // https://github.com/ROCm-Developer-Tools/HIP/issues/2209
  400. T min = at::_isnan(a) ? a : (at::_isnan(b) ? b : std::min<T>(a, b));
  401. #else
  402. T min = at::_isnan(b) ? b : std::min<T>(a, b);
  403. #endif
  404. return min;
  405. }
  406. ATOMIC_INTEGER_IMPL(Min)
  407. GPU_ATOMIC_INTEGER(Min, safe_min(a, b), uint8_t)
  408. GPU_ATOMIC_INTEGER(Min, safe_min(a, b), int8_t)
  409. GPU_ATOMIC_INTEGER(Min, safe_min(a, b), int16_t)
  410. GPU_ATOMIC_INTEGER(Min, safe_min(a, b), int32_t)
  411. GPU_ATOMIC_INTEGER(Min, safe_min(a, b), int64_t)
  412. inline __device__ at::Half gpuAtomicMin(at::Half * address, at::Half val) {
  413. return AtomicFPOp<at::Half>()(address, val,
  414. [](at::Half bsum, at::Half val) {
  415. return safe_min(bsum, val);
  416. });
  417. }
  418. inline __device__ at::BFloat16 gpuAtomicMin(at::BFloat16 * address, at::BFloat16 val) {
  419. return AtomicFPOp<at::BFloat16>()(address, val,
  420. [](at::BFloat16 bsum, at::BFloat16 val) {
  421. return safe_min(bsum, val);
  422. });
  423. }
  424. inline __device__ double gpuAtomicMin(double * address, double val) {
  425. return AtomicFPOp<double>()(address, val,
  426. [](double val, unsigned long long int assumed) {
  427. return __double_as_longlong(safe_min(val, __longlong_as_double(assumed)));
  428. });
  429. }
  430. // Dont use a templated function for this since the addition function defaults to the CUDA built-in.
  431. inline __device__ float gpuAtomicMin(float * address, float val) {
  432. unsigned int* address_as_ull = (unsigned int*)address;
  433. unsigned int old = *address_as_ull;
  434. unsigned int assumed;
  435. do {
  436. assumed = old;
  437. old = atomicCAS(address_as_ull, assumed,
  438. __float_as_int(safe_min(val, __int_as_float(assumed))));
  439. // Note: uses integer comparison to avoid hang in case of NaN (since NaN != NaN)
  440. } while (assumed != old);
  441. return __int_as_float(old);
  442. }