ParallelOpenMP.h 1.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657
  1. #pragma once
  2. #include <atomic>
  3. #include <cstddef>
  4. #include <exception>
  5. #ifdef _OPENMP
  6. #define INTRA_OP_PARALLEL
  7. #include <omp.h>
  8. #endif
  9. namespace at {
  10. #ifdef _OPENMP
  11. namespace internal {
  12. template <typename F>
  13. inline void invoke_parallel(
  14. int64_t begin,
  15. int64_t end,
  16. int64_t grain_size,
  17. const F& f) {
  18. std::atomic_flag err_flag = ATOMIC_FLAG_INIT;
  19. std::exception_ptr eptr;
  20. #pragma omp parallel
  21. {
  22. // choose number of tasks based on grain size and number of threads
  23. // can't use num_threads clause due to bugs in GOMP's thread pool (See
  24. // #32008)
  25. int64_t num_threads = omp_get_num_threads();
  26. if (grain_size > 0) {
  27. num_threads = std::min(num_threads, divup((end - begin), grain_size));
  28. }
  29. int64_t tid = omp_get_thread_num();
  30. int64_t chunk_size = divup((end - begin), num_threads);
  31. int64_t begin_tid = begin + tid * chunk_size;
  32. if (begin_tid < end) {
  33. try {
  34. internal::ThreadIdGuard tid_guard(tid);
  35. f(begin_tid, std::min(end, chunk_size + begin_tid));
  36. } catch (...) {
  37. if (!err_flag.test_and_set()) {
  38. eptr = std::current_exception();
  39. }
  40. }
  41. }
  42. }
  43. if (eptr) {
  44. std::rethrow_exception(eptr);
  45. }
  46. }
  47. } // namespace internal
  48. #endif // _OPENMP
  49. } // namespace at