ParallelNativeTBB.h 1.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354
  1. #pragma once
  2. #include <atomic>
  3. #include <cstddef>
  4. #include <exception>
  5. #include <c10/util/Exception.h>
  6. #ifdef _WIN32
  7. #ifndef WIN32_LEAN_AND_MEAN
  8. #define WIN32_LEAN_AND_MEAN
  9. #endif
  10. #endif
  11. #include <tbb/tbb.h>
  12. #define INTRA_OP_PARALLEL
  13. namespace at {
  14. namespace internal {
  15. template <typename F>
  16. inline void invoke_parallel(
  17. const int64_t begin,
  18. const int64_t end,
  19. const int64_t grain_size,
  20. const F& f) {
  21. // Choose number of tasks based on grain size and number of threads.
  22. int64_t chunk_size = divup((end - begin), get_num_threads());
  23. // Make sure each task is at least grain_size size.
  24. chunk_size = std::max(grain_size, chunk_size);
  25. std::atomic_flag err_flag = ATOMIC_FLAG_INIT;
  26. std::exception_ptr eptr;
  27. tbb::parallel_for(
  28. tbb::blocked_range<int64_t>(begin, end, chunk_size),
  29. [&eptr, &err_flag, f](const tbb::blocked_range<int64_t>& r) {
  30. try {
  31. internal::ThreadIdGuard tid_guard(
  32. tbb::this_task_arena::current_thread_index());
  33. f(r.begin(), r.end());
  34. } catch (...) {
  35. if (!err_flag.test_and_set()) {
  36. eptr = std::current_exception();
  37. }
  38. }
  39. },
  40. tbb::static_partitioner{});
  41. if (eptr) {
  42. std::rethrow_exception(eptr);
  43. }
  44. }
  45. } // namespace internal
  46. } // namespace at