cub.cuh 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420
  1. #pragma once
  2. #include <ATen/cuda/cub.h>
  3. #include <cstddef>
  4. #include <type_traits>
  5. #include <iterator>
  6. #include <limits>
  7. #include <c10/util/C++17.h>
  8. #include <ATen/cuda/cub_definitions.cuh>
  9. #if USE_GLOBAL_CUB_WRAPPED_NAMESPACE()
  10. #include <cub/cub.cuh>
  11. #else
  12. // include cub in a safe manner, see:
  13. // https://github.com/pytorch/pytorch/pull/55292
  14. #undef CUB_NS_POSTFIX //undef to avoid redefinition warnings
  15. #undef CUB_NS_PREFIX
  16. #undef CUB_NS_QUALIFIER
  17. #define CUB_NS_PREFIX namespace at_cuda_detail {
  18. #define CUB_NS_POSTFIX }
  19. #define CUB_NS_QUALIFIER ::at_cuda_detail::cub
  20. #include <cub/cub.cuh>
  21. #undef CUB_NS_POSTFIX
  22. #undef CUB_NS_PREFIX
  23. #undef CUB_NS_QUALIFIER
  24. #endif
  25. #include <ATen/cuda/Exceptions.h>
  26. #include <c10/cuda/CUDACachingAllocator.h>
  27. #include <c10/cuda/CUDAStream.h>
  28. // handle the temporary storage and 'twice' calls for cub API
  29. #define CUB_WRAPPER(func, ...) do { \
  30. size_t temp_storage_bytes = 0; \
  31. func(nullptr, temp_storage_bytes, __VA_ARGS__); \
  32. auto& caching_allocator = *::c10::cuda::CUDACachingAllocator::get(); \
  33. auto temp_storage = caching_allocator.allocate(temp_storage_bytes); \
  34. func(temp_storage.get(), temp_storage_bytes, __VA_ARGS__); \
  35. AT_CUDA_CHECK(cudaGetLastError()); \
  36. } while (false)
  37. #ifdef USE_ROCM
  38. #define NO_ROCM(x)
  39. #define ROCM_HIPCUB(x) ::hipcub
  40. #else
  41. #define NO_ROCM(x) x
  42. #define ROCM_HIPCUB(x) x
  43. #endif
  44. #if (!defined(USE_ROCM) && !CUB_SUPPORTS_NV_BFLOAT16()) || \
  45. (defined(USE_ROCM) && ROCM_VERSION >= 40500)
  46. #if !defined(USE_ROCM)
  47. namespace at_cuda_detail {
  48. #endif
  49. // backport https://github.com/NVIDIA/cub/pull/306 for c10::BFloat16
  50. template <>
  51. struct ROCM_HIPCUB(cub)::FpLimits<c10::BFloat16>
  52. {
  53. static __host__ __device__ __forceinline__ c10::BFloat16 Max() {
  54. unsigned short max_word = 0x7F7F;
  55. return reinterpret_cast<c10::BFloat16&>(max_word);
  56. }
  57. static __host__ __device__ __forceinline__ c10::BFloat16 Lowest() {
  58. unsigned short lowest_word = 0xFF7F;
  59. return reinterpret_cast<c10::BFloat16&>(lowest_word);
  60. }
  61. };
  62. template <>
  63. struct ROCM_HIPCUB(cub)::NumericTraits<c10::BFloat16>:
  64. ROCM_HIPCUB(cub)::BaseTraits<ROCM_HIPCUB(cub)::FLOATING_POINT, true, false, unsigned short, c10::BFloat16> {};
  65. #if !defined(USE_ROCM)
  66. } // namespace at_cuda_detail
  67. #endif
  68. #endif
  69. #if !defined(USE_ROCM)
  70. namespace at { namespace native {
  71. namespace cub = ::at_cuda_detail::cub;
  72. }}
  73. #endif
  74. namespace at {
  75. namespace cuda {
  76. namespace cub {
  77. namespace detail {
  78. template<typename T>
  79. struct cuda_type {
  80. using type = T;
  81. };
  82. template<>
  83. struct cuda_type<c10::Half> {
  84. using type = __half;
  85. };
  86. #if !defined(USE_ROCM) && CUB_SUPPORTS_NV_BFLOAT16()
  87. template<>
  88. struct cuda_type<c10::BFloat16> {
  89. using type = __nv_bfloat16;
  90. };
  91. #elif (defined(USE_ROCM) && ROCM_VERSION >= 40500)
  92. template<>
  93. struct cuda_type<c10::BFloat16> {
  94. using type = hip_bfloat16;
  95. };
  96. #endif
  97. } // namespace detail
  98. template<typename key_t, typename value_t, typename OffsetIteratorT>
  99. inline void segmented_sort_pairs(
  100. const key_t *keys_in, key_t *keys_out,
  101. const value_t *values_in, value_t *values_out,
  102. int64_t num_elements, int64_t num_segments,
  103. OffsetIteratorT begin_offsets, OffsetIteratorT end_offsets,
  104. bool descending=false, int64_t begin_bit=0, int64_t end_bit=sizeof(key_t)*8
  105. ) {
  106. TORCH_CHECK(num_elements <= std::numeric_limits<int>::max(),
  107. "cub sort does not support sorting more than INT_MAX elements");
  108. TORCH_CHECK(num_segments <= std::numeric_limits<int>::max(),
  109. "cub sort does not support sorting more than INT_MAX elements");
  110. using key_t_ = typename detail::cuda_type<key_t>::type;
  111. auto allocator = c10::cuda::CUDACachingAllocator::get();
  112. c10::DataPtr keys_out_owner;
  113. if (keys_out == nullptr) {
  114. keys_out_owner = allocator->allocate(num_elements * sizeof(key_t));
  115. keys_out = reinterpret_cast<key_t *>(keys_out_owner.get());
  116. }
  117. const key_t_ *keys_in_ = reinterpret_cast<const key_t_*>(keys_in);
  118. key_t_ *keys_out_ = reinterpret_cast<key_t_*>(keys_out);
  119. if (descending) {
  120. CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceSegmentedRadixSort::SortPairsDescending,
  121. keys_in_, keys_out_, values_in, values_out,
  122. num_elements, num_segments, begin_offsets, end_offsets,
  123. begin_bit, end_bit, c10::cuda::getCurrentCUDAStream());
  124. } else {
  125. CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceSegmentedRadixSort::SortPairs,
  126. keys_in_, keys_out_, values_in, values_out,
  127. num_elements, num_segments, begin_offsets, end_offsets,
  128. begin_bit, end_bit, c10::cuda::getCurrentCUDAStream());
  129. }
  130. }
  131. #if CUB_SUPPORTS_UNIQUE_BY_KEY()
  132. template <typename KeysInputIteratorT, typename ValuesInputIteratorT, typename KeysOutputIteratorT, typename ValuesOutputIteratorT, typename NumSelectedIteratorT>
  133. inline void unique_by_key(
  134. KeysInputIteratorT keys_in, ValuesInputIteratorT values_in,
  135. KeysOutputIteratorT keys_out, ValuesOutputIteratorT values_out,
  136. NumSelectedIteratorT num_selected, int64_t num_input_items)
  137. {
  138. // TODO: use thrust::discard_iterator to handle null keys_out when https://github.com/NVIDIA/cub/issues/406 is fixed.
  139. constexpr bool null_keys_out = std::is_same<KeysOutputIteratorT, std::nullptr_t>::value;
  140. using KeyT = typename std::iterator_traits<KeysInputIteratorT>::value_type;
  141. using RealKeysOutputIteratorT = typename std::conditional<null_keys_out, KeyT *, KeysOutputIteratorT>::type;
  142. RealKeysOutputIteratorT keys_out_;
  143. auto allocator = c10::cuda::CUDACachingAllocator::get();
  144. c10::DataPtr keys_out_owner;
  145. c10::guts::if_constexpr<null_keys_out>(
  146. [&](auto _) {
  147. keys_out_owner = allocator->allocate(num_input_items * sizeof(KeyT));
  148. keys_out_ = static_cast<KeyT *>(keys_out_owner.get());
  149. },
  150. [&](auto _) {
  151. keys_out_ = keys_out;
  152. }
  153. );
  154. CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceSelect::UniqueByKey,
  155. keys_in, values_in, keys_out_, values_out, num_selected, num_input_items, c10::cuda::getCurrentCUDAStream());
  156. }
  157. #endif
  158. namespace impl {
  159. template<typename InputIteratorT1, typename InputIteratorT2, typename OutputIteratorT, class ScanOpT>
  160. C10_LAUNCH_BOUNDS_1(1)
  161. __global__ void transform_vals(InputIteratorT1 a, InputIteratorT2 b, OutputIteratorT out, ScanOpT scan_op){
  162. // NOTE: out here not the final scan output, but an intermediate of the accumulation type.
  163. using acc_t = typename std::iterator_traits<OutputIteratorT>::value_type;
  164. *out = scan_op(static_cast<acc_t>(*a), static_cast<acc_t>(*b));
  165. }
  166. #if !CUB_SUPPORTS_FUTURE_VALUE()
  167. template<typename ValueT, typename InputIteratorT>
  168. struct chained_iterator {
  169. using iterator_category = std::random_access_iterator_tag;
  170. using difference_type = std::ptrdiff_t;
  171. using value_type = ValueT;
  172. using pointer = ValueT*;
  173. using reference = ValueT&;
  174. InputIteratorT iter;
  175. ValueT *first;
  176. difference_type offset = 0;
  177. __device__ ValueT operator[](difference_type i) {
  178. i += offset;
  179. if (i == 0) {
  180. return *first;
  181. } else {
  182. return ValueT(iter[i - 1]);
  183. }
  184. }
  185. __device__ chained_iterator operator+(difference_type i) {
  186. return chained_iterator{iter, first, i};
  187. }
  188. __device__ ValueT operator*() {
  189. return (*this)[0];
  190. }
  191. };
  192. #endif
  193. // even though cub is supposed to support tensors with int_max elements, in reality it doesn't,
  194. // so split at int_max/2
  195. constexpr int max_cub_size = std::numeric_limits<int>::max() / 2 + 1; // 2**30
  196. }
  197. // non synchronizing cub call
  198. // even though cub is supposed to support tensors with int_max elements, in reality it doesn't,
  199. // so split at int_max/2
  200. template<typename InputIteratorT, typename OutputIteratorT, typename ScanOpT, int max_cub_size=impl::max_cub_size>
  201. inline void inclusive_scan(InputIteratorT input, OutputIteratorT output, ScanOpT scan_op, int64_t num_items) {
  202. #if defined(USE_ROCM) && (ROCM_VERSION >= 50000)
  203. //For ROCm, use hipCUB chained iterators
  204. CUB_WRAPPER(NO_ROCM(detail)::hipcub::DeviceScan::InclusiveScan,
  205. input,
  206. output,
  207. scan_op,
  208. num_items,
  209. at::cuda::getCurrentCUDAStream());
  210. C10_HIP_KERNEL_LAUNCH_CHECK();
  211. #else
  212. // non synchronizing cub call
  213. // even though cub is supposed to support tensors with int_max elements, in reality it doesn't,
  214. // so split at int_max/2
  215. int size_cub = std::min<int64_t>(num_items, max_cub_size);
  216. CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceScan::InclusiveScan,
  217. input,
  218. output,
  219. scan_op,
  220. size_cub,
  221. at::cuda::getCurrentCUDAStream());
  222. C10_CUDA_KERNEL_LAUNCH_CHECK();
  223. using input_t = typename std::iterator_traits<InputIteratorT>::value_type;
  224. for (int64_t i = max_cub_size; i < num_items; i += max_cub_size) {
  225. auto allocator = c10::cuda::CUDACachingAllocator::get();
  226. c10::DataPtr first_elem = allocator->allocate(sizeof(input_t));
  227. auto first_elem_ptr = reinterpret_cast<input_t *>(first_elem.get());
  228. size_cub = std::min<int64_t>(num_items - i, max_cub_size);
  229. impl::transform_vals<<<1, 1, 0, at::cuda::getCurrentCUDAStream()>>>(
  230. output + i - 1,
  231. input + i,
  232. first_elem_ptr,
  233. scan_op);
  234. C10_CUDA_KERNEL_LAUNCH_CHECK();
  235. #if !CUB_SUPPORTS_FUTURE_VALUE()
  236. using ArgIndexInputIterator = NO_ROCM(at_cuda_detail)::cub::ArgIndexInputIterator<InputIteratorT>;
  237. using tuple = typename ArgIndexInputIterator::value_type;
  238. auto input_iter_transform = [=] __device__ (const tuple &x)->input_t {
  239. if (x.key == 0) {
  240. return *first_elem_ptr;
  241. } else {
  242. return x.value;
  243. }
  244. };
  245. auto input_ = NO_ROCM(at_cuda_detail)::cub::TransformInputIterator<input_t, decltype(input_iter_transform), ArgIndexInputIterator>(
  246. ArgIndexInputIterator(input + i), input_iter_transform);
  247. CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceScan::InclusiveScan,
  248. input_,
  249. output + i,
  250. scan_op,
  251. size_cub,
  252. at::cuda::getCurrentCUDAStream());
  253. #else
  254. CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceScan::ExclusiveScan,
  255. input + i + 1,
  256. output + i,
  257. scan_op,
  258. ::at_cuda_detail::cub::FutureValue<input_t>(first_elem_ptr),
  259. size_cub,
  260. at::cuda::getCurrentCUDAStream());
  261. #endif
  262. }
  263. #endif
  264. }
  265. template<typename InputIteratorT, typename OutputIteratorT, typename ScanOpT, typename InitValueT, int max_cub_size=impl::max_cub_size>
  266. inline void exclusive_scan(InputIteratorT input, OutputIteratorT output, ScanOpT scan_op, InitValueT init_value, int64_t num_items) {
  267. #if defined(USE_ROCM) && (ROCM_VERSION >= 50000)
  268. //For ROCm, use hipCUB chained iterators
  269. CUB_WRAPPER(NO_ROCM(detail)::hipcub::DeviceScan::ExclusiveScan,
  270. input,
  271. output,
  272. scan_op,
  273. init_value,
  274. num_items,
  275. at::cuda::getCurrentCUDAStream());
  276. C10_HIP_KERNEL_LAUNCH_CHECK();
  277. #else
  278. // non synchronizing cub call
  279. // even though cub is supposed to support tensors with int_max elements, in reality it doesn't,
  280. // so split at int_max/2
  281. int size_cub = std::min<int64_t>(num_items, max_cub_size);
  282. CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceScan::ExclusiveScan,
  283. input,
  284. output,
  285. scan_op,
  286. init_value,
  287. size_cub,
  288. at::cuda::getCurrentCUDAStream());
  289. C10_CUDA_KERNEL_LAUNCH_CHECK();
  290. for (int64_t i = max_cub_size; i < num_items; i += max_cub_size) {
  291. auto allocator = c10::cuda::CUDACachingAllocator::get();
  292. c10::DataPtr first_elem = allocator->allocate(sizeof(InitValueT));
  293. auto first_elem_ptr = reinterpret_cast<InitValueT *>(first_elem.get());
  294. size_cub = std::min<int64_t>(num_items - i, max_cub_size);
  295. impl::transform_vals<<<1, 1, 0, at::cuda::getCurrentCUDAStream()>>>(
  296. output + i - 1,
  297. input + i - 1,
  298. first_elem_ptr,
  299. scan_op);
  300. C10_CUDA_KERNEL_LAUNCH_CHECK();
  301. #if !CUB_SUPPORTS_FUTURE_VALUE()
  302. auto input_ = impl::chained_iterator<InitValueT, InputIteratorT>{
  303. input + i, first_elem_ptr};
  304. CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceScan::InclusiveScan,
  305. input_,
  306. output + i,
  307. scan_op,
  308. size_cub,
  309. at::cuda::getCurrentCUDAStream());
  310. #else
  311. CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceScan::ExclusiveScan,
  312. input + i,
  313. output + i,
  314. scan_op,
  315. ::at_cuda_detail::cub::FutureValue<InitValueT>(first_elem_ptr),
  316. size_cub,
  317. at::cuda::getCurrentCUDAStream());
  318. #endif
  319. }
  320. #endif
  321. }
  322. #if CUB_SUPPORTS_SCAN_BY_KEY()
  323. template <typename KeysInputIteratorT, typename ValuesInputIteratorT, typename ValuesOutputIteratorT>
  324. inline void inclusive_sum_by_key(KeysInputIteratorT keys, ValuesInputIteratorT input, ValuesOutputIteratorT output, int64_t num_items) {
  325. TORCH_CHECK(num_items <= std::numeric_limits<int>::max(),
  326. "cub InclusiveSumByKey does not support more than INT_MAX elements");
  327. CUB_WRAPPER(at_cuda_detail::cub::DeviceScan::InclusiveSumByKey,
  328. keys, input, output, num_items, at_cuda_detail::cub::Equality(), at::cuda::getCurrentCUDAStream());
  329. }
  330. template <typename KeysInputIteratorT, typename ValuesInputIteratorT, typename ValuesOutputIteratorT, typename ScanOpT>
  331. inline void inclusive_scan_by_key(KeysInputIteratorT keys, ValuesInputIteratorT input, ValuesOutputIteratorT output, ScanOpT scan_op, int64_t num_items) {
  332. TORCH_CHECK(num_items <= std::numeric_limits<int>::max(),
  333. "cub InclusiveSumByKey does not support more than INT_MAX elements");
  334. CUB_WRAPPER(at_cuda_detail::cub::DeviceScan::InclusiveScanByKey,
  335. keys, input, output, scan_op, num_items, at_cuda_detail::cub::Equality(), at::cuda::getCurrentCUDAStream());
  336. }
  337. #endif
  338. template <typename InputIteratorT, typename OutputIteratorT, typename NumSelectedIteratorT>
  339. void unique(InputIteratorT input, OutputIteratorT output,
  340. NumSelectedIteratorT num_selected_out, int64_t num_items) {
  341. TORCH_CHECK(num_items <= std::numeric_limits<int>::max(),
  342. "cub unique does not support more than INT_MAX elements");
  343. CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceSelect::Unique,
  344. input, output, num_selected_out, num_items, at::cuda::getCurrentCUDAStream());
  345. }
  346. template <typename InputIteratorT, typename OutputIteratorT, typename CountsOutputIteratorT,
  347. typename LengthOutputIteratorT>
  348. void run_length_encode(InputIteratorT input, OutputIteratorT output, CountsOutputIteratorT counts_out,
  349. LengthOutputIteratorT length_out, int64_t num_items) {
  350. TORCH_CHECK(num_items <= std::numeric_limits<int>::max(),
  351. "cub run_length_encode does not support more than INT_MAX elements");
  352. CUB_WRAPPER(
  353. NO_ROCM(at_cuda_detail)::cub::DeviceRunLengthEncode::Encode,
  354. input, output, counts_out, length_out, num_items,
  355. at::cuda::getCurrentCUDAStream());
  356. }
  357. template <typename InputIteratorT, typename OutputIteratorT, typename ReductionOpT, typename T>
  358. void reduce(InputIteratorT input, OutputIteratorT output, int64_t num_items, ReductionOpT op, T init) {
  359. TORCH_CHECK(num_items <= std::numeric_limits<int>::max(),
  360. "cub reduce does not support more than INT_MAX elements");
  361. CUB_WRAPPER(
  362. NO_ROCM(at_cuda_detail)::cub::DeviceReduce::Reduce,
  363. input, output, num_items, op, init,
  364. at::cuda::getCurrentCUDAStream());
  365. }
  366. }}} // namespace at::cuda::cub