123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657 |
- #pragma once
- #include <atomic>
- #include <cstddef>
- #include <exception>
- #ifdef _OPENMP
- #define INTRA_OP_PARALLEL
- #include <omp.h>
- #endif
- namespace at {
- #ifdef _OPENMP
- namespace internal {
- template <typename F>
- inline void invoke_parallel(
- int64_t begin,
- int64_t end,
- int64_t grain_size,
- const F& f) {
- std::atomic_flag err_flag = ATOMIC_FLAG_INIT;
- std::exception_ptr eptr;
- #pragma omp parallel
- {
- // choose number of tasks based on grain size and number of threads
- // can't use num_threads clause due to bugs in GOMP's thread pool (See
- // #32008)
- int64_t num_threads = omp_get_num_threads();
- if (grain_size > 0) {
- num_threads = std::min(num_threads, divup((end - begin), grain_size));
- }
- int64_t tid = omp_get_thread_num();
- int64_t chunk_size = divup((end - begin), num_threads);
- int64_t begin_tid = begin + tid * chunk_size;
- if (begin_tid < end) {
- try {
- internal::ThreadIdGuard tid_guard(tid);
- f(begin_tid, std::min(end, chunk_size + begin_tid));
- } catch (...) {
- if (!err_flag.test_and_set()) {
- eptr = std::current_exception();
- }
- }
- }
- }
- if (eptr) {
- std::rethrow_exception(eptr);
- }
- }
- } // namespace internal
- #endif // _OPENMP
- } // namespace at
|