CUDAAlgorithm.h 1.0 KB

123456789101112131415161718192021222324252627282930313233
  1. #ifdef THRUST_DEVICE_LOWER_BOUND_WORKS
  2. #include <thrust/binary_search.h>
  3. #include <thrust/device_vector.h>
  4. #include <thrust/execution_policy.h>
  5. #include <thrust/functional.h>
  6. #endif
  7. namespace c10 {
  8. namespace cuda {
  9. #ifdef THRUST_DEVICE_LOWER_BOUND_WORKS
  10. template <typename Iter, typename Scalar>
  11. __forceinline__ __device__ Iter
  12. lower_bound(Iter start, Iter end, Scalar value) {
  13. return thrust::lower_bound(thrust::device, start, end, value);
  14. }
  15. #else
  16. // thrust::lower_bound is broken on device, see
  17. // https://github.com/NVIDIA/thrust/issues/1734 Implementation inspired by
  18. // https://github.com/pytorch/pytorch/blob/805120ab572efef66425c9f595d9c6c464383336/aten/src/ATen/native/cuda/Bucketization.cu#L28
  19. template <typename Iter, typename Scalar>
  20. __device__ Iter lower_bound(Iter start, Iter end, Scalar value) {
  21. while (start < end) {
  22. auto mid = start + ((end - start) >> 1);
  23. if (*mid < value) {
  24. start = mid + 1;
  25. } else {
  26. end = mid;
  27. }
  28. }
  29. return end;
  30. }
  31. #endif // THRUST_DEVICE_LOWER_BOUND_WORKS
  32. } // namespace cuda
  33. } // namespace c10