123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151 |
- // Some stateful GPU libraries, such as cuDNN, cuBLAS, use handles to store states.
- // These handles are tied to device, and these libraries requires/recommends not to
- // share handles across host threads.
- //
- // These libraries recommend using one handle per host thread. We may not want to do
- // this because threads are relatively light-weight, but creating and destroying
- // handles is expensive (destroying the handle causes synchronizations). DataParallel,
- // for example, creates new threads for each forward pass.
- //
- // This file implements a handle pool mechanism. The handle pool returns handles on
- // demand as threads request them. If all existing handles in the pool are in use,
- // it creates a new one. As threads terminate, they release handles back into the pool.
- // In this way, the handle pool never creates more handles than the high-water mark of
- // active threads, so it's efficient with DataParallel.
- #pragma once
- #include <unordered_map>
- #include <vector>
- #include <utility>
- #include <mutex>
- #include <memory>
- #include <ATen/cuda/Exceptions.h>
- namespace at { namespace cuda { namespace {
- template <typename Handle_t, void Create(Handle_t *), void Destroy(Handle_t)>
- struct DeviceThreadHandlePool : public std::enable_shared_from_this<DeviceThreadHandlePool<Handle_t, Create, Destroy>> {
- struct Handle {
- Handle_t handle;
- Handle(bool create = false) : handle(nullptr)
- {
- if(create) Create(&handle);
- }
- // std::vector.emplace() and push_back() may route through temporaries and call
- // copy/move constructors along the way. If this is the case, we don't want
- // the destructors of temporaries to call cudnnDestroy on the handle.
- // We can achieve safety (for the narrow case of stashing within std::vectors)
- // by making Handle moveable but not copyable, and transferring handle ownership
- // to the latest constructed object. This is not a substitute for full-blown
- // reference counting, but reference counting may be overkill here.
- // Another alternative is to wrap the saved Handles in unique_ptrs, i.e.,
- // unordered_map<int, vector<unique_ptr<Handle>>> created_handles;
- Handle(const Handle& rhs) = delete;
- // Following https://stackoverflow.com/questions/3279543/what-is-the-copy-and-swap-idiom
- Handle(Handle&& rhs) : Handle() { std::swap(handle, rhs.handle); }
- // operator= takes argument by value
- Handle& operator=(Handle rhs) { std::swap(handle, rhs.handle); return *this; }
- ~Handle() {
- if(handle) Destroy(handle);
- }
- };
- std::mutex mutex;
- // Handles are lazily created as different threads request them,
- // but are never destroyed until the end of the process.
- // The maximum number of handles this process will create for each device is equal
- // to the high-water mark of the number of concurrently active threads that request
- // handles for that device.
- // When threads terminate, they release their handles back into the pool for reuse.
- // Otherwise, new handles would be created every time new threads were spawned,
- // resulting in poor performance for Python modules that repeatedly or frequently
- // spawned new sets of threads (like DataParallel, which creates a new set of threads
- // for each forward pass).
- //
- // To prevent potential deadlocks, we explicitly choose not to cap the number
- // of handles that are created per device.
- // Example of danger: If we cap the max handles at 4, and 5 threads are sharing a device,
- // only 4 can make forward progress at any time. The other 4 will not release their
- // handles until they exit, so the fifth cannot make progress until then. This is
- // not a problem...UNLESS all 5 threads attempt some sort of synchronization at an
- // intermediate point (ie, before any of them have exited). We have no way to anticipate
- // or enforce that user threads will not attempt such intermediate synchronization.
- // The only way to ensure safety is to avoid imposing a cap on the number of handles.
- std::unordered_map<int, std::vector<Handle>> created_handles;
- std::unordered_map<int, std::vector<Handle_t>> available_handles;
- // PoolWindow lazily creates and caches the handles that a particular thread is using,
- // so in the common case handle access doesn't incur either handle creation or a mutex lock.
- class PoolWindow
- {
- public:
- PoolWindow(std::shared_ptr<DeviceThreadHandlePool> parent): weak_parent(std::move(parent)) {}
- ~PoolWindow(){ release(); }
- Handle_t reserve(int device)
- {
- // If this thread already has a handle for this device, return it
- if(my_handles.find(device) != my_handles.end())
- return my_handles[device];
- // otherwise, either grab a handle from the pool if one is available,
- // or if not, create a new one.
- auto parent = weak_parent.lock();
- TORCH_CHECK(parent, "Cannot create handle during program termination");
- std::lock_guard<std::mutex> guard(parent->mutex);
- if(parent->available_handles[device].size() > 0)
- {
- my_handles[device] = parent->available_handles[device].back();
- parent->available_handles[device].pop_back();
- }
- else
- {
- // In local testing, I do observe that emplace_back sometimes routes through temporaries
- // that incur move-constructor and destructor calls. See comments in Handle above.
- parent->created_handles[device].emplace_back(true /*create*/);
- my_handles[device] = parent->created_handles[device].back().handle;
- }
- return my_handles[device];
- }
- private:
- // Stores the per-device handles currently owned by this thread
- std::unordered_map<int, Handle_t> my_handles;
- std::weak_ptr<DeviceThreadHandlePool> weak_parent;
- // Called by the destructor. Releases this thread's handles back into the pool.
- void release() {
- if(my_handles.size() > 0) {
- auto parent = weak_parent.lock();
- if (!parent) {
- // If this thread exits after atexit handlers have completed, the
- // cuda context itself may be invalid, so we must leak the handles.
- return;
- }
- std::lock_guard<std::mutex> guard(parent->mutex);
- for(auto d_h : my_handles)
- parent->available_handles[d_h.first].push_back(d_h.second);
- }
- }
- };
- // Warning:
- // If you want to change this function, be aware that this function will be called
- // by multiple threads and there is no mutex guarding the call of this function, so
- // make sure your implementation is thread-safe.
- PoolWindow *newPoolWindow() {
- // The returned pointer will be owned by a thread local variable
- // so that different threads does not share the same PoolWindow.
- return new PoolWindow(this->shared_from_this());
- }
- };
- }}} // namespace at::cuda::detail::<anonymous>
|