DeviceThreadHandles.h 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151
  1. // Some stateful GPU libraries, such as cuDNN, cuBLAS, use handles to store states.
  2. // These handles are tied to device, and these libraries requires/recommends not to
  3. // share handles across host threads.
  4. //
  5. // These libraries recommend using one handle per host thread. We may not want to do
  6. // this because threads are relatively light-weight, but creating and destroying
  7. // handles is expensive (destroying the handle causes synchronizations). DataParallel,
  8. // for example, creates new threads for each forward pass.
  9. //
  10. // This file implements a handle pool mechanism. The handle pool returns handles on
  11. // demand as threads request them. If all existing handles in the pool are in use,
  12. // it creates a new one. As threads terminate, they release handles back into the pool.
  13. // In this way, the handle pool never creates more handles than the high-water mark of
  14. // active threads, so it's efficient with DataParallel.
  15. #pragma once
  16. #include <unordered_map>
  17. #include <vector>
  18. #include <utility>
  19. #include <mutex>
  20. #include <memory>
  21. #include <ATen/cuda/Exceptions.h>
  22. namespace at { namespace cuda { namespace {
  23. template <typename Handle_t, void Create(Handle_t *), void Destroy(Handle_t)>
  24. struct DeviceThreadHandlePool : public std::enable_shared_from_this<DeviceThreadHandlePool<Handle_t, Create, Destroy>> {
  25. struct Handle {
  26. Handle_t handle;
  27. Handle(bool create = false) : handle(nullptr)
  28. {
  29. if(create) Create(&handle);
  30. }
  31. // std::vector.emplace() and push_back() may route through temporaries and call
  32. // copy/move constructors along the way. If this is the case, we don't want
  33. // the destructors of temporaries to call cudnnDestroy on the handle.
  34. // We can achieve safety (for the narrow case of stashing within std::vectors)
  35. // by making Handle moveable but not copyable, and transferring handle ownership
  36. // to the latest constructed object. This is not a substitute for full-blown
  37. // reference counting, but reference counting may be overkill here.
  38. // Another alternative is to wrap the saved Handles in unique_ptrs, i.e.,
  39. // unordered_map<int, vector<unique_ptr<Handle>>> created_handles;
  40. Handle(const Handle& rhs) = delete;
  41. // Following https://stackoverflow.com/questions/3279543/what-is-the-copy-and-swap-idiom
  42. Handle(Handle&& rhs) : Handle() { std::swap(handle, rhs.handle); }
  43. // operator= takes argument by value
  44. Handle& operator=(Handle rhs) { std::swap(handle, rhs.handle); return *this; }
  45. ~Handle() {
  46. if(handle) Destroy(handle);
  47. }
  48. };
  49. std::mutex mutex;
  50. // Handles are lazily created as different threads request them,
  51. // but are never destroyed until the end of the process.
  52. // The maximum number of handles this process will create for each device is equal
  53. // to the high-water mark of the number of concurrently active threads that request
  54. // handles for that device.
  55. // When threads terminate, they release their handles back into the pool for reuse.
  56. // Otherwise, new handles would be created every time new threads were spawned,
  57. // resulting in poor performance for Python modules that repeatedly or frequently
  58. // spawned new sets of threads (like DataParallel, which creates a new set of threads
  59. // for each forward pass).
  60. //
  61. // To prevent potential deadlocks, we explicitly choose not to cap the number
  62. // of handles that are created per device.
  63. // Example of danger: If we cap the max handles at 4, and 5 threads are sharing a device,
  64. // only 4 can make forward progress at any time. The other 4 will not release their
  65. // handles until they exit, so the fifth cannot make progress until then. This is
  66. // not a problem...UNLESS all 5 threads attempt some sort of synchronization at an
  67. // intermediate point (ie, before any of them have exited). We have no way to anticipate
  68. // or enforce that user threads will not attempt such intermediate synchronization.
  69. // The only way to ensure safety is to avoid imposing a cap on the number of handles.
  70. std::unordered_map<int, std::vector<Handle>> created_handles;
  71. std::unordered_map<int, std::vector<Handle_t>> available_handles;
  72. // PoolWindow lazily creates and caches the handles that a particular thread is using,
  73. // so in the common case handle access doesn't incur either handle creation or a mutex lock.
  74. class PoolWindow
  75. {
  76. public:
  77. PoolWindow(std::shared_ptr<DeviceThreadHandlePool> parent): weak_parent(std::move(parent)) {}
  78. ~PoolWindow(){ release(); }
  79. Handle_t reserve(int device)
  80. {
  81. // If this thread already has a handle for this device, return it
  82. if(my_handles.find(device) != my_handles.end())
  83. return my_handles[device];
  84. // otherwise, either grab a handle from the pool if one is available,
  85. // or if not, create a new one.
  86. auto parent = weak_parent.lock();
  87. TORCH_CHECK(parent, "Cannot create handle during program termination");
  88. std::lock_guard<std::mutex> guard(parent->mutex);
  89. if(parent->available_handles[device].size() > 0)
  90. {
  91. my_handles[device] = parent->available_handles[device].back();
  92. parent->available_handles[device].pop_back();
  93. }
  94. else
  95. {
  96. // In local testing, I do observe that emplace_back sometimes routes through temporaries
  97. // that incur move-constructor and destructor calls. See comments in Handle above.
  98. parent->created_handles[device].emplace_back(true /*create*/);
  99. my_handles[device] = parent->created_handles[device].back().handle;
  100. }
  101. return my_handles[device];
  102. }
  103. private:
  104. // Stores the per-device handles currently owned by this thread
  105. std::unordered_map<int, Handle_t> my_handles;
  106. std::weak_ptr<DeviceThreadHandlePool> weak_parent;
  107. // Called by the destructor. Releases this thread's handles back into the pool.
  108. void release() {
  109. if(my_handles.size() > 0) {
  110. auto parent = weak_parent.lock();
  111. if (!parent) {
  112. // If this thread exits after atexit handlers have completed, the
  113. // cuda context itself may be invalid, so we must leak the handles.
  114. return;
  115. }
  116. std::lock_guard<std::mutex> guard(parent->mutex);
  117. for(auto d_h : my_handles)
  118. parent->available_handles[d_h.first].push_back(d_h.second);
  119. }
  120. }
  121. };
  122. // Warning:
  123. // If you want to change this function, be aware that this function will be called
  124. // by multiple threads and there is no mutex guarding the call of this function, so
  125. // make sure your implementation is thread-safe.
  126. PoolWindow *newPoolWindow() {
  127. // The returned pointer will be owned by a thread local variable
  128. // so that different threads does not share the same PoolWindow.
  129. return new PoolWindow(this->shared_from_this());
  130. }
  131. };
  132. }}} // namespace at::cuda::detail::<anonymous>